feat: support disabling reverse/local port forwarding in agent SSH server (#24026)
The agent SSH server unconditionally allows all four SSH forwarding paths (TCP local, TCP reverse, Unix local, Unix reverse). This is a sandbox escape vector when workspaces are used for AI agent containment — a reverse tunnel lets anything inside the workspace reach the user's local machine, bypassing network isolation. This adds two new agent CLI flags / environment variables: - `--block-reverse-port-forwarding` / `CODER_AGENT_BLOCK_REVERSE_PORT_FORWARDING` — blocks both TCP (`ssh -R`) and Unix socket reverse forwarding - `--block-local-port-forwarding` / `CODER_AGENT_BLOCK_LOCAL_PORT_FORWARDING` — blocks both TCP (`ssh -L`) and Unix socket local forwarding Template admins can set these via the `env` block on the container/VM resource that runs the agent (e.g. `docker_container`, `kubernetes_pod`), or via `coder_env` resources tied to the agent. Fixes https://github.com/coder/coder/issues/22275 <details> <summary>Implementation notes</summary> Follows the existing `BlockFileTransfer` pattern: 1. `agent/agentssh/agentssh.go` — New `BlockReversePortForwarding` and `BlockLocalPortForwarding` fields on `Config`. TCP callbacks check these before allowing forwarding. The `direct-streamlocal@openssh.com` channel handler is wrapped to reject Unix local forwards. 2. `agent/agentssh/forward.go` — `forwardedUnixHandler` gains a `blockReversePortForwarding` field to reject `streamlocal-forward@openssh.com` requests. 3. `agent/agent.go` — New fields on `Options` and `agent` struct, plumbed to SSH config. 4. `cli/agent.go` — New serpent flags with env vars. 5. Tests cover all four blocked paths: TCP local, TCP reverse, Unix local, Unix reverse. </details> > 🤖 Generated by Coder Agents
This commit is contained in:
+14
-6
@@ -102,6 +102,8 @@ type Options struct {
|
||||
ReportMetadataInterval time.Duration
|
||||
ServiceBannerRefreshInterval time.Duration
|
||||
BlockFileTransfer bool
|
||||
BlockReversePortForwarding bool
|
||||
BlockLocalPortForwarding bool
|
||||
Execer agentexec.Execer
|
||||
Devcontainers bool
|
||||
DevcontainerAPIOptions []agentcontainers.Option // Enable Devcontainers for these to be effective.
|
||||
@@ -214,6 +216,8 @@ func New(options Options) Agent {
|
||||
subsystems: options.Subsystems,
|
||||
logSender: agentsdk.NewLogSender(options.Logger),
|
||||
blockFileTransfer: options.BlockFileTransfer,
|
||||
blockReversePortForwarding: options.BlockReversePortForwarding,
|
||||
blockLocalPortForwarding: options.BlockLocalPortForwarding,
|
||||
|
||||
prometheusRegistry: prometheusRegistry,
|
||||
metrics: newAgentMetrics(prometheusRegistry),
|
||||
@@ -280,6 +284,8 @@ type agent struct {
|
||||
sshServer *agentssh.Server
|
||||
sshMaxTimeout time.Duration
|
||||
blockFileTransfer bool
|
||||
blockReversePortForwarding bool
|
||||
blockLocalPortForwarding bool
|
||||
|
||||
lifecycleUpdate chan struct{}
|
||||
lifecycleReported chan codersdk.WorkspaceAgentLifecycle
|
||||
@@ -331,12 +337,14 @@ func (a *agent) TailnetConn() *tailnet.Conn {
|
||||
func (a *agent) init() {
|
||||
// pass the "hard" context because we explicitly close the SSH server as part of graceful shutdown.
|
||||
sshSrv, err := agentssh.NewServer(a.hardCtx, a.logger.Named("ssh-server"), a.prometheusRegistry, a.filesystem, a.execer, &agentssh.Config{
|
||||
MaxTimeout: a.sshMaxTimeout,
|
||||
MOTDFile: func() string { return a.manifest.Load().MOTDFile },
|
||||
AnnouncementBanners: func() *[]codersdk.BannerConfig { return a.announcementBanners.Load() },
|
||||
UpdateEnv: a.updateCommandEnv,
|
||||
WorkingDirectory: func() string { return a.manifest.Load().Directory },
|
||||
BlockFileTransfer: a.blockFileTransfer,
|
||||
MaxTimeout: a.sshMaxTimeout,
|
||||
MOTDFile: func() string { return a.manifest.Load().MOTDFile },
|
||||
AnnouncementBanners: func() *[]codersdk.BannerConfig { return a.announcementBanners.Load() },
|
||||
UpdateEnv: a.updateCommandEnv,
|
||||
WorkingDirectory: func() string { return a.manifest.Load().Directory },
|
||||
BlockFileTransfer: a.blockFileTransfer,
|
||||
BlockReversePortForwarding: a.blockReversePortForwarding,
|
||||
BlockLocalPortForwarding: a.blockLocalPortForwarding,
|
||||
ReportConnection: func(id uuid.UUID, magicType agentssh.MagicSessionType, ip string) func(code int, reason string) {
|
||||
var connectionType proto.Connection_Type
|
||||
switch magicType {
|
||||
|
||||
@@ -986,6 +986,161 @@ func TestAgent_TCPRemoteForwarding(t *testing.T) {
|
||||
requireEcho(t, conn)
|
||||
}
|
||||
|
||||
func TestAgent_TCPLocalForwardingBlocked(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
rl, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer rl.Close()
|
||||
tcpAddr, valid := rl.Addr().(*net.TCPAddr)
|
||||
require.True(t, valid)
|
||||
remotePort := tcpAddr.Port
|
||||
|
||||
//nolint:dogsled
|
||||
agentConn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
|
||||
o.BlockLocalPortForwarding = true
|
||||
})
|
||||
sshClient, err := agentConn.SSHClient(ctx)
|
||||
require.NoError(t, err)
|
||||
defer sshClient.Close()
|
||||
|
||||
_, err = sshClient.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", remotePort))
|
||||
require.ErrorContains(t, err, "administratively prohibited")
|
||||
}
|
||||
|
||||
func TestAgent_TCPRemoteForwardingBlocked(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
//nolint:dogsled
|
||||
agentConn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
|
||||
o.BlockReversePortForwarding = true
|
||||
})
|
||||
sshClient, err := agentConn.SSHClient(ctx)
|
||||
require.NoError(t, err)
|
||||
defer sshClient.Close()
|
||||
|
||||
localhost := netip.MustParseAddr("127.0.0.1")
|
||||
randomPort := testutil.RandomPortNoListen(t)
|
||||
addr := net.TCPAddrFromAddrPort(netip.AddrPortFrom(localhost, randomPort))
|
||||
_, err = sshClient.ListenTCP(addr)
|
||||
require.ErrorContains(t, err, "tcpip-forward request denied by peer")
|
||||
}
|
||||
|
||||
func TestAgent_UnixLocalForwardingBlocked(t *testing.T) {
|
||||
t.Parallel()
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("unix domain sockets are not fully supported on Windows")
|
||||
}
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
tmpdir := testutil.TempDirUnixSocket(t)
|
||||
remoteSocketPath := filepath.Join(tmpdir, "remote-socket")
|
||||
|
||||
l, err := net.Listen("unix", remoteSocketPath)
|
||||
require.NoError(t, err)
|
||||
defer l.Close()
|
||||
|
||||
//nolint:dogsled
|
||||
agentConn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
|
||||
o.BlockLocalPortForwarding = true
|
||||
})
|
||||
sshClient, err := agentConn.SSHClient(ctx)
|
||||
require.NoError(t, err)
|
||||
defer sshClient.Close()
|
||||
|
||||
_, err = sshClient.Dial("unix", remoteSocketPath)
|
||||
require.ErrorContains(t, err, "administratively prohibited")
|
||||
}
|
||||
|
||||
func TestAgent_UnixRemoteForwardingBlocked(t *testing.T) {
|
||||
t.Parallel()
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("unix domain sockets are not fully supported on Windows")
|
||||
}
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
tmpdir := testutil.TempDirUnixSocket(t)
|
||||
remoteSocketPath := filepath.Join(tmpdir, "remote-socket")
|
||||
|
||||
//nolint:dogsled
|
||||
agentConn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
|
||||
o.BlockReversePortForwarding = true
|
||||
})
|
||||
sshClient, err := agentConn.SSHClient(ctx)
|
||||
require.NoError(t, err)
|
||||
defer sshClient.Close()
|
||||
|
||||
_, err = sshClient.ListenUnix(remoteSocketPath)
|
||||
require.ErrorContains(t, err, "streamlocal-forward@openssh.com request denied by peer")
|
||||
}
|
||||
|
||||
// TestAgent_LocalBlockedDoesNotAffectReverse verifies that blocking
|
||||
// local port forwarding does not prevent reverse port forwarding from
|
||||
// working. A field-name transposition at any plumbing hop would cause
|
||||
// both directions to be blocked when only one flag is set.
|
||||
func TestAgent_LocalBlockedDoesNotAffectReverse(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
//nolint:dogsled
|
||||
agentConn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
|
||||
o.BlockLocalPortForwarding = true
|
||||
})
|
||||
sshClient, err := agentConn.SSHClient(ctx)
|
||||
require.NoError(t, err)
|
||||
defer sshClient.Close()
|
||||
|
||||
// Reverse forwarding must still work.
|
||||
localhost := netip.MustParseAddr("127.0.0.1")
|
||||
var ll net.Listener
|
||||
for {
|
||||
randomPort := testutil.RandomPortNoListen(t)
|
||||
addr := net.TCPAddrFromAddrPort(netip.AddrPortFrom(localhost, randomPort))
|
||||
ll, err = sshClient.ListenTCP(addr)
|
||||
if err != nil {
|
||||
t.Logf("error remote forwarding: %s", err.Error())
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timed out getting random listener")
|
||||
default:
|
||||
continue
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
_ = ll.Close()
|
||||
}
|
||||
|
||||
// TestAgent_ReverseBlockedDoesNotAffectLocal verifies that blocking
|
||||
// reverse port forwarding does not prevent local port forwarding from
|
||||
// working.
|
||||
func TestAgent_ReverseBlockedDoesNotAffectLocal(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
rl, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer rl.Close()
|
||||
tcpAddr, valid := rl.Addr().(*net.TCPAddr)
|
||||
require.True(t, valid)
|
||||
remotePort := tcpAddr.Port
|
||||
go echoOnce(t, rl)
|
||||
|
||||
//nolint:dogsled
|
||||
agentConn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
|
||||
o.BlockReversePortForwarding = true
|
||||
})
|
||||
sshClient, err := agentConn.SSHClient(ctx)
|
||||
require.NoError(t, err)
|
||||
defer sshClient.Close()
|
||||
|
||||
// Local forwarding must still work.
|
||||
conn, err := sshClient.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", remotePort))
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
requireEcho(t, conn)
|
||||
}
|
||||
|
||||
func TestAgent_UnixLocalForwarding(t *testing.T) {
|
||||
t.Parallel()
|
||||
if runtime.GOOS == "windows" {
|
||||
|
||||
@@ -117,6 +117,10 @@ type Config struct {
|
||||
X11MaxPort *int
|
||||
// BlockFileTransfer restricts use of file transfer applications.
|
||||
BlockFileTransfer bool
|
||||
// BlockReversePortForwarding disables reverse port forwarding (ssh -R).
|
||||
BlockReversePortForwarding bool
|
||||
// BlockLocalPortForwarding disables local port forwarding (ssh -L).
|
||||
BlockLocalPortForwarding bool
|
||||
// ReportConnection.
|
||||
ReportConnection reportConnectionFunc
|
||||
// Experimental: allow connecting to running containers via Docker exec.
|
||||
@@ -190,7 +194,7 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom
|
||||
}
|
||||
|
||||
forwardHandler := &ssh.ForwardedTCPHandler{}
|
||||
unixForwardHandler := newForwardedUnixHandler(logger)
|
||||
unixForwardHandler := newForwardedUnixHandler(logger, config.BlockReversePortForwarding)
|
||||
|
||||
metrics := newSSHServerMetrics(prometheusRegistry)
|
||||
s := &Server{
|
||||
@@ -229,8 +233,15 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom
|
||||
wrapped := NewJetbrainsChannelWatcher(ctx, s.logger, s.config.ReportConnection, newChan, &s.connCountJetBrains)
|
||||
ssh.DirectTCPIPHandler(srv, conn, wrapped, ctx)
|
||||
},
|
||||
"direct-streamlocal@openssh.com": directStreamLocalHandler,
|
||||
"session": ssh.DefaultSessionHandler,
|
||||
"direct-streamlocal@openssh.com": func(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx ssh.Context) {
|
||||
if s.config.BlockLocalPortForwarding {
|
||||
s.logger.Warn(ctx, "unix local port forward blocked")
|
||||
_ = newChan.Reject(gossh.Prohibited, "local port forwarding is disabled")
|
||||
return
|
||||
}
|
||||
directStreamLocalHandler(srv, conn, newChan, ctx)
|
||||
},
|
||||
"session": ssh.DefaultSessionHandler,
|
||||
},
|
||||
ConnectionFailedCallback: func(conn net.Conn, err error) {
|
||||
s.logger.Warn(ctx, "ssh connection failed",
|
||||
@@ -250,6 +261,12 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom
|
||||
// be set before we start listening.
|
||||
HostSigners: []ssh.Signer{},
|
||||
LocalPortForwardingCallback: func(ctx ssh.Context, destinationHost string, destinationPort uint32) bool {
|
||||
if s.config.BlockLocalPortForwarding {
|
||||
s.logger.Warn(ctx, "local port forward blocked",
|
||||
slog.F("destination_host", destinationHost),
|
||||
slog.F("destination_port", destinationPort))
|
||||
return false
|
||||
}
|
||||
// Allow local port forwarding all!
|
||||
s.logger.Debug(ctx, "local port forward",
|
||||
slog.F("destination_host", destinationHost),
|
||||
@@ -260,6 +277,12 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom
|
||||
return true
|
||||
},
|
||||
ReversePortForwardingCallback: func(ctx ssh.Context, bindHost string, bindPort uint32) bool {
|
||||
if s.config.BlockReversePortForwarding {
|
||||
s.logger.Warn(ctx, "reverse port forward blocked",
|
||||
slog.F("bind_host", bindHost),
|
||||
slog.F("bind_port", bindPort))
|
||||
return false
|
||||
}
|
||||
// Allow reverse port forwarding all!
|
||||
s.logger.Debug(ctx, "reverse port forward",
|
||||
slog.F("bind_host", bindHost),
|
||||
|
||||
@@ -35,8 +35,9 @@ type forwardedStreamLocalPayload struct {
|
||||
// streamlocal forwarding (aka. unix forwarding) instead of TCP forwarding.
|
||||
type forwardedUnixHandler struct {
|
||||
sync.Mutex
|
||||
log slog.Logger
|
||||
forwards map[forwardKey]net.Listener
|
||||
log slog.Logger
|
||||
forwards map[forwardKey]net.Listener
|
||||
blockReversePortForwarding bool
|
||||
}
|
||||
|
||||
type forwardKey struct {
|
||||
@@ -44,10 +45,11 @@ type forwardKey struct {
|
||||
addr string
|
||||
}
|
||||
|
||||
func newForwardedUnixHandler(log slog.Logger) *forwardedUnixHandler {
|
||||
func newForwardedUnixHandler(log slog.Logger, blockReversePortForwarding bool) *forwardedUnixHandler {
|
||||
return &forwardedUnixHandler{
|
||||
log: log,
|
||||
forwards: make(map[forwardKey]net.Listener),
|
||||
log: log,
|
||||
forwards: make(map[forwardKey]net.Listener),
|
||||
blockReversePortForwarding: blockReversePortForwarding,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -62,6 +64,10 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
|
||||
|
||||
switch req.Type {
|
||||
case "streamlocal-forward@openssh.com":
|
||||
if h.blockReversePortForwarding {
|
||||
log.Warn(ctx, "unix reverse port forward blocked")
|
||||
return false, nil
|
||||
}
|
||||
var reqPayload streamLocalForwardPayload
|
||||
err := gossh.Unmarshal(req.Payload, &reqPayload)
|
||||
if err != nil {
|
||||
|
||||
+22
-4
@@ -53,6 +53,8 @@ func workspaceAgent() *serpent.Command {
|
||||
slogJSONPath string
|
||||
slogStackdriverPath string
|
||||
blockFileTransfer bool
|
||||
blockReversePortForwarding bool
|
||||
blockLocalPortForwarding bool
|
||||
agentHeaderCommand string
|
||||
agentHeader []string
|
||||
devcontainers bool
|
||||
@@ -319,10 +321,12 @@ func workspaceAgent() *serpent.Command {
|
||||
SSHMaxTimeout: sshMaxTimeout,
|
||||
Subsystems: subsystems,
|
||||
|
||||
PrometheusRegistry: prometheusRegistry,
|
||||
BlockFileTransfer: blockFileTransfer,
|
||||
Execer: execer,
|
||||
Devcontainers: devcontainers,
|
||||
PrometheusRegistry: prometheusRegistry,
|
||||
BlockFileTransfer: blockFileTransfer,
|
||||
BlockReversePortForwarding: blockReversePortForwarding,
|
||||
BlockLocalPortForwarding: blockLocalPortForwarding,
|
||||
Execer: execer,
|
||||
Devcontainers: devcontainers,
|
||||
DevcontainerAPIOptions: []agentcontainers.Option{
|
||||
agentcontainers.WithSubAgentURL(agentAuth.agentURL.String()),
|
||||
agentcontainers.WithProjectDiscovery(devcontainerProjectDiscovery),
|
||||
@@ -493,6 +497,20 @@ func workspaceAgent() *serpent.Command {
|
||||
Description: fmt.Sprintf("Block file transfer using known applications: %s.", strings.Join(agentssh.BlockedFileTransferCommands, ",")),
|
||||
Value: serpent.BoolOf(&blockFileTransfer),
|
||||
},
|
||||
{
|
||||
Flag: "block-reverse-port-forwarding",
|
||||
Default: "false",
|
||||
Env: "CODER_AGENT_BLOCK_REVERSE_PORT_FORWARDING",
|
||||
Description: "Block reverse port forwarding through the SSH server (ssh -R).",
|
||||
Value: serpent.BoolOf(&blockReversePortForwarding),
|
||||
},
|
||||
{
|
||||
Flag: "block-local-port-forwarding",
|
||||
Default: "false",
|
||||
Env: "CODER_AGENT_BLOCK_LOCAL_PORT_FORWARDING",
|
||||
Description: "Block local port forwarding through the SSH server (ssh -L).",
|
||||
Value: serpent.BoolOf(&blockLocalPortForwarding),
|
||||
},
|
||||
{
|
||||
Flag: "devcontainers-enable",
|
||||
Default: "true",
|
||||
|
||||
+6
@@ -39,6 +39,12 @@ OPTIONS:
|
||||
--block-file-transfer bool, $CODER_AGENT_BLOCK_FILE_TRANSFER (default: false)
|
||||
Block file transfer using known applications: nc,rsync,scp,sftp.
|
||||
|
||||
--block-local-port-forwarding bool, $CODER_AGENT_BLOCK_LOCAL_PORT_FORWARDING (default: false)
|
||||
Block local port forwarding through the SSH server (ssh -L).
|
||||
|
||||
--block-reverse-port-forwarding bool, $CODER_AGENT_BLOCK_REVERSE_PORT_FORWARDING (default: false)
|
||||
Block reverse port forwarding through the SSH server (ssh -R).
|
||||
|
||||
--boundary-log-proxy-socket-path string, $CODER_AGENT_BOUNDARY_LOG_PROXY_SOCKET_PATH (default: /tmp/boundary-audit.sock)
|
||||
The path for the boundary log proxy server Unix socket. Boundary
|
||||
should write audit logs to this socket.
|
||||
|
||||
Reference in New Issue
Block a user