fix: mock Agent querying OS for listening ports in tests (#20842)

fixes https://github.com/coder/internal/issues/1123

We want to tests that ports are not included after they are no longer used, but this isn't safe on the real OS networking stack because there is no way to guarantee a port _won't_ be used. Instead, we introduce an interface and fake implementation for testing.

On order to leave the filtering logic in the test path, this PR also does some refactoring.

Caching logic is left in the real OS querying implementation and a new test case is added for it in this PR.
This commit is contained in:
Spike Curtis
2025-11-25 14:25:24 +04:00
committed by GitHub
parent 823009d9ea
commit afd40436f0
6 changed files with 219 additions and 291 deletions
+56 -48
View File
@@ -8,6 +8,7 @@ import (
"fmt"
"hash/fnv"
"io"
"maps"
"net"
"net/http"
"net/netip"
@@ -70,16 +71,21 @@ const (
)
type Options struct {
Filesystem afero.Fs
LogDir string
TempDir string
ScriptDataDir string
Client Client
ReconnectingPTYTimeout time.Duration
EnvironmentVariables map[string]string
Logger slog.Logger
IgnorePorts map[int]string
PortCacheDuration time.Duration
Filesystem afero.Fs
LogDir string
TempDir string
ScriptDataDir string
Client Client
ReconnectingPTYTimeout time.Duration
EnvironmentVariables map[string]string
Logger slog.Logger
// IgnorePorts tells the api handler which ports to ignore when
// listing all listening ports. This is helpful to hide ports that
// are used by the agent, that the user does not care about.
IgnorePorts map[int]string
// ListeningPortsGetter is used to get the list of listening ports. Only
// tests should set this. If unset, a default that queries the OS will be used.
ListeningPortsGetter ListeningPortsGetter
SSHMaxTimeout time.Duration
TailnetListenPort uint16
Subsystems []codersdk.AgentSubsystem
@@ -137,9 +143,7 @@ func New(options Options) Agent {
if options.ServiceBannerRefreshInterval == 0 {
options.ServiceBannerRefreshInterval = 2 * time.Minute
}
if options.PortCacheDuration == 0 {
options.PortCacheDuration = 1 * time.Second
}
if options.Clock == nil {
options.Clock = quartz.NewReal()
}
@@ -153,30 +157,38 @@ func New(options Options) Agent {
options.Execer = agentexec.DefaultExecer
}
if options.ListeningPortsGetter == nil {
options.ListeningPortsGetter = &osListeningPortsGetter{
cacheDuration: 1 * time.Second,
}
}
hardCtx, hardCancel := context.WithCancel(context.Background())
gracefulCtx, gracefulCancel := context.WithCancel(hardCtx)
a := &agent{
clock: options.Clock,
tailnetListenPort: options.TailnetListenPort,
reconnectingPTYTimeout: options.ReconnectingPTYTimeout,
logger: options.Logger,
gracefulCtx: gracefulCtx,
gracefulCancel: gracefulCancel,
hardCtx: hardCtx,
hardCancel: hardCancel,
coordDisconnected: make(chan struct{}),
environmentVariables: options.EnvironmentVariables,
client: options.Client,
filesystem: options.Filesystem,
logDir: options.LogDir,
tempDir: options.TempDir,
scriptDataDir: options.ScriptDataDir,
lifecycleUpdate: make(chan struct{}, 1),
lifecycleReported: make(chan codersdk.WorkspaceAgentLifecycle, 1),
lifecycleStates: []agentsdk.PostLifecycleRequest{{State: codersdk.WorkspaceAgentLifecycleCreated}},
reportConnectionsUpdate: make(chan struct{}, 1),
ignorePorts: options.IgnorePorts,
portCacheDuration: options.PortCacheDuration,
clock: options.Clock,
tailnetListenPort: options.TailnetListenPort,
reconnectingPTYTimeout: options.ReconnectingPTYTimeout,
logger: options.Logger,
gracefulCtx: gracefulCtx,
gracefulCancel: gracefulCancel,
hardCtx: hardCtx,
hardCancel: hardCancel,
coordDisconnected: make(chan struct{}),
environmentVariables: options.EnvironmentVariables,
client: options.Client,
filesystem: options.Filesystem,
logDir: options.LogDir,
tempDir: options.TempDir,
scriptDataDir: options.ScriptDataDir,
lifecycleUpdate: make(chan struct{}, 1),
lifecycleReported: make(chan codersdk.WorkspaceAgentLifecycle, 1),
lifecycleStates: []agentsdk.PostLifecycleRequest{{State: codersdk.WorkspaceAgentLifecycleCreated}},
reportConnectionsUpdate: make(chan struct{}, 1),
listeningPortsHandler: listeningPortsHandler{
getter: options.ListeningPortsGetter,
ignorePorts: maps.Clone(options.IgnorePorts),
},
reportMetadataInterval: options.ReportMetadataInterval,
announcementBannersRefreshInterval: options.ServiceBannerRefreshInterval,
sshMaxTimeout: options.SSHMaxTimeout,
@@ -202,20 +214,16 @@ func New(options Options) Agent {
}
type agent struct {
clock quartz.Clock
logger slog.Logger
client Client
tailnetListenPort uint16
filesystem afero.Fs
logDir string
tempDir string
scriptDataDir string
// ignorePorts tells the api handler which ports to ignore when
// listing all listening ports. This is helpful to hide ports that
// are used by the agent, that the user does not care about.
ignorePorts map[int]string
portCacheDuration time.Duration
subsystems []codersdk.AgentSubsystem
clock quartz.Clock
logger slog.Logger
client Client
tailnetListenPort uint16
filesystem afero.Fs
logDir string
tempDir string
scriptDataDir string
listeningPortsHandler listeningPortsHandler
subsystems []codersdk.AgentSubsystem
reconnectingPTYTimeout time.Duration
reconnectingPTYServer *reconnectingpty.Server
+25 -31
View File
@@ -2,14 +2,13 @@ package agent
import (
"net/http"
"sync"
"time"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/workspacesdk"
)
func (a *agent) apiHandler() http.Handler {
@@ -20,23 +19,6 @@ func (a *agent) apiHandler() http.Handler {
})
})
// Make a copy to ensure the map is not modified after the handler is
// created.
cpy := make(map[int]string)
for k, b := range a.ignorePorts {
cpy[k] = b
}
cacheDuration := 1 * time.Second
if a.portCacheDuration > 0 {
cacheDuration = a.portCacheDuration
}
lp := &listeningPortsHandler{
ignorePorts: cpy,
cacheDuration: cacheDuration,
}
if a.devcontainers {
r.Mount("/api/v0/containers", a.containerAPI.Routes())
} else if manifest := a.manifest.Load(); manifest != nil && manifest.ParentID != uuid.Nil {
@@ -57,7 +39,7 @@ func (a *agent) apiHandler() http.Handler {
promHandler := PrometheusMetricsHandler(a.prometheusRegistry, a.logger)
r.Get("/api/v0/listening-ports", lp.handler)
r.Get("/api/v0/listening-ports", a.listeningPortsHandler.handler)
r.Get("/api/v0/netcheck", a.HandleNetcheck)
r.Post("/api/v0/list-directory", a.HandleLS)
r.Get("/api/v0/read-file", a.HandleReadFile)
@@ -72,22 +54,21 @@ func (a *agent) apiHandler() http.Handler {
return r
}
type listeningPortsHandler struct {
ignorePorts map[int]string
cacheDuration time.Duration
type ListeningPortsGetter interface {
GetListeningPorts() ([]codersdk.WorkspaceAgentListeningPort, error)
}
//nolint: unused // used on some but not all platforms
mut sync.Mutex
//nolint: unused // used on some but not all platforms
ports []codersdk.WorkspaceAgentListeningPort
//nolint: unused // used on some but not all platforms
mtime time.Time
type listeningPortsHandler struct {
// In production code, this is set to an osListeningPortsGetter, but it can be overridden for
// testing.
getter ListeningPortsGetter
ignorePorts map[int]string
}
// handler returns a list of listening ports. This is tested by coderd's
// TestWorkspaceAgentListeningPorts test.
func (lp *listeningPortsHandler) handler(rw http.ResponseWriter, r *http.Request) {
ports, err := lp.getListeningPorts()
ports, err := lp.getter.GetListeningPorts()
if err != nil {
httpapi.Write(r.Context(), rw, http.StatusInternalServerError, codersdk.Response{
Message: "Could not scan for listening ports.",
@@ -96,7 +77,20 @@ func (lp *listeningPortsHandler) handler(rw http.ResponseWriter, r *http.Request
return
}
filteredPorts := make([]codersdk.WorkspaceAgentListeningPort, 0, len(ports))
for _, port := range ports {
if port.Port < workspacesdk.AgentMinimumListeningPort {
continue
}
// Ignore ports that we've been told to ignore.
if _, ok := lp.ignorePorts[int(port.Port)]; ok {
continue
}
filteredPorts = append(filteredPorts, port)
}
httpapi.Write(r.Context(), rw, http.StatusOK, codersdk.WorkspaceAgentListeningPortsResponse{
Ports: ports,
Ports: filteredPorts,
})
}
+10 -8
View File
@@ -3,16 +3,23 @@
package agent
import (
"sync"
"time"
"github.com/cakturk/go-netstat/netstat"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/workspacesdk"
)
func (lp *listeningPortsHandler) getListeningPorts() ([]codersdk.WorkspaceAgentListeningPort, error) {
type osListeningPortsGetter struct {
cacheDuration time.Duration
mut sync.Mutex
ports []codersdk.WorkspaceAgentListeningPort
mtime time.Time
}
func (lp *osListeningPortsGetter) GetListeningPorts() ([]codersdk.WorkspaceAgentListeningPort, error) {
lp.mut.Lock()
defer lp.mut.Unlock()
@@ -33,12 +40,7 @@ func (lp *listeningPortsHandler) getListeningPorts() ([]codersdk.WorkspaceAgentL
seen := make(map[uint16]struct{}, len(tabs))
ports := []codersdk.WorkspaceAgentListeningPort{}
for _, tab := range tabs {
if tab.LocalAddr == nil || tab.LocalAddr.Port < workspacesdk.AgentMinimumListeningPort {
continue
}
// Ignore ports that we've been told to ignore.
if _, ok := lp.ignorePorts[int(tab.LocalAddr.Port)]; ok {
if tab.LocalAddr == nil {
continue
}
+45
View File
@@ -0,0 +1,45 @@
//go:build linux || (windows && amd64)
package agent
import (
"net"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestOSListeningPortsGetter(t *testing.T) {
t.Parallel()
uut := &osListeningPortsGetter{
cacheDuration: 1 * time.Hour,
}
l, err := net.Listen("tcp", "localhost:0")
require.NoError(t, err)
defer l.Close()
ports, err := uut.GetListeningPorts()
require.NoError(t, err)
found := false
for _, port := range ports {
// #nosec G115 - Safe conversion as TCP port numbers are within uint16 range (0-65535)
if port.Port == uint16(l.Addr().(*net.TCPAddr).Port) {
found = true
break
}
}
require.True(t, found)
// check that we cache the ports
err = l.Close()
require.NoError(t, err)
portsNew, err := uut.GetListeningPorts()
require.NoError(t, err)
require.Equal(t, ports, portsNew)
// note that it's unsafe to try to assert that a port does not exist in the response
// because the OS may reallocate the port very quickly.
}
+10 -2
View File
@@ -2,9 +2,17 @@
package agent
import "github.com/coder/coder/v2/codersdk"
import (
"time"
func (*listeningPortsHandler) getListeningPorts() ([]codersdk.WorkspaceAgentListeningPort, error) {
"github.com/coder/coder/v2/codersdk"
)
type osListeningPortsGetter struct {
cacheDuration time.Duration
}
func (*osListeningPortsGetter) GetListeningPorts() ([]codersdk.WorkspaceAgentListeningPort, error) {
// Can't scan for ports on non-linux or non-windows_amd64 systems at the
// moment. The UI will not show any "no ports found" message to the user, so
// the user won't suspect a thing.
+73 -202
View File
@@ -5,12 +5,10 @@ import (
"encoding/json"
"fmt"
"maps"
"net"
"net/http"
"os"
"path/filepath"
"runtime"
"strconv"
"slices"
"strings"
"sync"
"sync/atomic"
@@ -934,17 +932,45 @@ func TestWorkspaceAgentTailnetDirectDisabled(t *testing.T) {
require.False(t, p2p)
}
type fakeListeningPortsGetter struct {
sync.Mutex
ports []codersdk.WorkspaceAgentListeningPort
}
func (g *fakeListeningPortsGetter) GetListeningPorts() ([]codersdk.WorkspaceAgentListeningPort, error) {
g.Lock()
defer g.Unlock()
return slices.Clone(g.ports), nil
}
func (g *fakeListeningPortsGetter) setPorts(ports ...codersdk.WorkspaceAgentListeningPort) {
g.Lock()
defer g.Unlock()
g.ports = slices.Clone(ports)
}
func TestWorkspaceAgentListeningPorts(t *testing.T) {
t.Parallel()
setup := func(t *testing.T, apps []*proto.App, dv *codersdk.DeploymentValues) (*codersdk.Client, uint16, uuid.UUID) {
testPort := codersdk.WorkspaceAgentListeningPort{
Network: "tcp",
ProcessName: "test-app",
Port: 44762,
}
filteredPort := codersdk.WorkspaceAgentListeningPort{
Network: "tcp",
ProcessName: "postgres",
Port: 5432,
}
setup := func(t *testing.T, apps []*proto.App, dv *codersdk.DeploymentValues) (*codersdk.Client, uuid.UUID, *fakeListeningPortsGetter) {
t.Helper()
client, db := coderdtest.NewWithDatabase(t, &coderdtest.Options{
DeploymentValues: dv,
})
coderdPort, err := strconv.Atoi(client.URL.Port())
require.NoError(t, err)
fLPG := &fakeListeningPortsGetter{}
user := coderdtest.CreateFirstUser(t, client)
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
@@ -955,228 +981,73 @@ func TestWorkspaceAgentListeningPorts(t *testing.T) {
return agents
}).Do()
_ = agenttest.New(t, client.URL, r.AgentToken, func(o *agent.Options) {
o.PortCacheDuration = time.Millisecond
o.ListeningPortsGetter = fLPG
})
resources := coderdtest.AwaitWorkspaceAgents(t, client, r.Workspace.ID)
resources := coderdtest.NewWorkspaceAgentWaiter(t, client, r.Workspace.ID).Wait()
// #nosec G115 - Safe conversion as TCP port numbers are within uint16 range (0-65535)
return client, uint16(coderdPort), resources[0].Agents[0].ID
return client, resources[0].Agents[0].ID, fLPG
}
willFilterPort := func(port int) bool {
if port < workspacesdk.AgentMinimumListeningPort || port > 65535 {
return true
}
if _, ok := workspacesdk.AgentIgnoredListeningPorts[uint16(port)]; ok {
return true
}
return false
}
generateUnfilteredPort := func(t *testing.T) (net.Listener, uint16) {
var (
l net.Listener
port uint16
)
require.Eventually(t, func() bool {
var err error
l, err = net.Listen("tcp", "localhost:0")
if err != nil {
return false
}
tcpAddr, _ := l.Addr().(*net.TCPAddr)
if willFilterPort(tcpAddr.Port) {
_ = l.Close()
return false
}
t.Cleanup(func() {
_ = l.Close()
})
// #nosec G115 - Safe conversion as TCP port numbers are within uint16 range (0-65535)
port = uint16(tcpAddr.Port)
return true
}, testutil.WaitShort, testutil.IntervalFast)
return l, port
}
generateFilteredPort := func(t *testing.T) (net.Listener, uint16) {
var (
l net.Listener
port uint16
)
require.Eventually(t, func() bool {
for ignoredPort := range workspacesdk.AgentIgnoredListeningPorts {
if ignoredPort < 1024 || ignoredPort == 5432 {
continue
}
var err error
l, err = net.Listen("tcp", fmt.Sprintf("localhost:%d", ignoredPort))
if err != nil {
continue
}
t.Cleanup(func() {
_ = l.Close()
})
port = ignoredPort
return true
}
return false
}, testutil.WaitShort, testutil.IntervalFast)
return l, port
}
t.Run("LinuxAndWindows", func(t *testing.T) {
t.Parallel()
if runtime.GOOS != "linux" && runtime.GOOS != "windows" {
t.Skip("only runs on linux and windows")
return
}
for _, tc := range []struct {
name string
setDV func(t *testing.T, dv *codersdk.DeploymentValues)
}{
{
name: "Mainline",
setDV: func(*testing.T, *codersdk.DeploymentValues) {},
},
{
name: "BlockDirect",
setDV: func(t *testing.T, dv *codersdk.DeploymentValues) {
err := dv.DERP.Config.BlockDirect.Set("true")
require.NoError(t, err)
require.True(t, dv.DERP.Config.BlockDirect.Value())
},
},
} {
t.Run("OK_"+tc.name, func(t *testing.T) {
t.Parallel()
dv := coderdtest.DeploymentValues(t)
tc.setDV(t, dv)
client, coderdPort, agentID := setup(t, nil, dv)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
// Generate a random unfiltered port.
l, lPort := generateUnfilteredPort(t)
// List ports and ensure that the port we expect to see is there.
res, err := client.WorkspaceAgentListeningPorts(ctx, agentID)
for _, tc := range []struct {
name string
setDV func(t *testing.T, dv *codersdk.DeploymentValues)
}{
{
name: "Mainline",
setDV: func(*testing.T, *codersdk.DeploymentValues) {},
},
{
name: "BlockDirect",
setDV: func(t *testing.T, dv *codersdk.DeploymentValues) {
err := dv.DERP.Config.BlockDirect.Set("true")
require.NoError(t, err)
expected := map[uint16]bool{
// expect the listener we made
lPort: false,
// expect the coderdtest server
coderdPort: false,
}
for _, port := range res.Ports {
if port.Network == "tcp" {
if val, ok := expected[port.Port]; ok {
if val {
t.Fatalf("expected to find TCP port %d only once in response", port.Port)
}
}
expected[port.Port] = true
}
}
for port, found := range expected {
if !found {
t.Fatalf("expected to find TCP port %d in response", port)
}
}
// Close the listener and check that the port is no longer in the response.
require.NoError(t, l.Close())
t.Log("checking for ports after listener close:")
require.Eventually(t, func() bool {
res, err = client.WorkspaceAgentListeningPorts(ctx, agentID)
if !assert.NoError(t, err) {
return false
}
for _, port := range res.Ports {
if port.Network == "tcp" && port.Port == lPort {
t.Logf("expected to not find TCP port %d in response", lPort)
return false
}
}
return true
}, testutil.WaitLong, testutil.IntervalMedium)
})
}
t.Run("Filter", func(t *testing.T) {
require.True(t, dv.DERP.Config.BlockDirect.Value())
},
},
} {
t.Run("OK_"+tc.name, func(t *testing.T) {
t.Parallel()
// Generate an unfiltered port that we will create an app for and
// should not exist in the response.
_, appLPort := generateUnfilteredPort(t)
app := &proto.App{
Slug: "test-app",
Url: fmt.Sprintf("http://localhost:%d", appLPort),
}
// Generate a filtered port that should not exist in the response.
_, filteredLPort := generateFilteredPort(t)
client, coderdPort, agentID := setup(t, []*proto.App{app}, nil)
dv := coderdtest.DeploymentValues(t)
tc.setDV(t, dv)
client, agentID, fLPG := setup(t, nil, dv)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
fLPG.setPorts(testPort)
// List ports and ensure that the port we expect to see is there.
res, err := client.WorkspaceAgentListeningPorts(ctx, agentID)
require.NoError(t, err)
require.Equal(t, []codersdk.WorkspaceAgentListeningPort{testPort}, res.Ports)
sawCoderdPort := false
for _, port := range res.Ports {
if port.Network == "tcp" {
if port.Port == appLPort {
t.Fatalf("expected to not find TCP port (app port) %d in response", appLPort)
}
if port.Port == filteredLPort {
t.Fatalf("expected to not find TCP port (filtered port) %d in response", filteredLPort)
}
if port.Port == coderdPort {
sawCoderdPort = true
}
}
}
if !sawCoderdPort {
t.Fatalf("expected to find TCP port (coderd port) %d in response", coderdPort)
}
// Remove the port and check that the port is no longer in the response.
fLPG.setPorts()
res, err = client.WorkspaceAgentListeningPorts(ctx, agentID)
require.NoError(t, err)
require.Empty(t, res.Ports)
})
})
}
t.Run("Darwin", func(t *testing.T) {
t.Run("Filter", func(t *testing.T) {
t.Parallel()
if runtime.GOOS != "darwin" {
t.Skip("only runs on darwin")
return
app := &proto.App{
Slug: testPort.ProcessName,
Url: fmt.Sprintf("http://localhost:%d", testPort.Port),
}
client, _, agentID := setup(t, nil, nil)
client, agentID, fLPG := setup(t, []*proto.App{app}, nil)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
// Create a TCP listener on a random port.
l, err := net.Listen("tcp", "localhost:0")
require.NoError(t, err)
defer l.Close()
fLPG.setPorts(testPort, filteredPort)
// List ports and ensure that the list is empty because we're on darwin.
res, err := client.WorkspaceAgentListeningPorts(ctx, agentID)
require.NoError(t, err)
require.Len(t, res.Ports, 0)
require.Empty(t, res.Ports)
})
}