Compare commits

...

1 Commits

Author SHA1 Message Date
Ehab Younes
d7e08678eb fix(cli/ssh): add stdio retry with proper tunnel/SSH layering
Add retry logic for transient connection failures in stdio mode
(ProxyCommand). The retry loop covers only tunnel establishment
(sshEstablishStdioConn); data transfer runs once, outside the
loop, in sshHandleStdio.

Extracts sshResolveAgent (shared workspace/agent resolution for
both stdio and interactive paths) and verifyContainer (shared
container validation using the existing Match method).
2026-04-06 16:44:11 +03:00
2 changed files with 634 additions and 183 deletions

View File

@@ -60,8 +60,30 @@ var (
// gracefulShutdownTimeout is the timeout, per item in the stack of things to close
gracefulShutdownTimeout = 2 * time.Second
workspaceNameRe = regexp.MustCompile(`[/.]+|--`)
// stdioRetryDelays controls retry timing for transient
// ProxyCommand failures during tunnel establishment.
stdioRetryDelays = []time.Duration{0, 2 * time.Second, 5 * time.Second, 10 * time.Second, 10 * time.Second}
)
// isRetryableError reports whether err is a transient connection
// error worth retrying (DNS failures, connection refused, server 5xx).
func isRetryableError(err error) bool {
if err == nil {
return false
}
if xerrors.Is(err, context.Canceled) || xerrors.Is(err, context.DeadlineExceeded) {
return false
}
if codersdk.IsConnectionError(err) {
return true
}
var sdkErr *codersdk.Error
if xerrors.As(err, &sdkErr) {
return sdkErr.StatusCode() >= 500
}
return false
}
func (r *RootCmd) ssh() *serpent.Command {
var (
stdio bool
@@ -272,121 +294,39 @@ func (r *RootCmd) ssh() *serpent.Command {
parsedEnv = append(parsedEnv, [2]string{k, v})
}
cliConfig := codersdk.SSHConfigResponse{
HostnamePrefix: hostPrefix,
HostnameSuffix: hostnameSuffix,
if stdio {
return r.sshHandleStdio(ctx, cancel, inv, client, wsClient, stdioReader, stdioWriter, logger, &wg, sshStdioArgs{
appearanceConfig: appearanceConfig,
forceNewTunnel: forceNewTunnel,
usageApp: usageApp,
disableAutostart: disableAutostart,
networkInfoDir: networkInfoDir,
networkInfoInterval: networkInfoInterval,
containerName: containerName,
containerUser: containerUser,
cliConfig: codersdk.SSHConfigResponse{
HostnamePrefix: hostPrefix,
HostnameSuffix: hostnameSuffix,
},
waitEnum: waitEnum,
noWait: noWait,
})
}
workspace, workspaceAgent, err := findWorkspaceAndAgentByHostname(
ctx, inv, client,
inv.Args[0], cliConfig, disableAutostart)
if err != nil {
return err
}
// Select the startup script behavior based on template configuration or flags.
var wait bool
switch waitEnum {
case "yes":
wait = true
case "no":
wait = false
case "auto":
for _, script := range workspaceAgent.Scripts {
if script.StartBlocksLogin {
wait = true
break
}
}
default:
return xerrors.Errorf("unknown wait value %q", waitEnum)
}
// The `--no-wait` flag is deprecated, but for now, check it.
// Non-stdio (interactive) path, single attempt, no retry.
resolvedWait := waitEnum
if noWait {
wait = false
resolvedWait = "no"
}
templateVersion, err := client.TemplateVersion(ctx, workspace.LatestBuild.TemplateVersionID)
resolved, err := sshResolveAgent(ctx, inv, client,
codersdk.SSHConfigResponse{HostnamePrefix: hostPrefix, HostnameSuffix: hostnameSuffix},
appearanceConfig, disableAutostart, resolvedWait,
)
if err != nil {
return err
}
var unsupportedWorkspace bool
for _, warning := range templateVersion.Warnings {
if warning == codersdk.TemplateVersionWarningUnsupportedWorkspaces {
unsupportedWorkspace = true
break
}
}
if unsupportedWorkspace && isTTYErr(inv) {
_, _ = fmt.Fprintln(inv.Stderr, "👋 Your workspace uses legacy parameters which are not supported anymore. Contact your administrator for assistance.")
}
updateWorkspaceBanner, outdated := verifyWorkspaceOutdated(client, workspace)
if outdated && isTTYErr(inv) {
_, _ = fmt.Fprintln(inv.Stderr, updateWorkspaceBanner)
}
// OpenSSH passes stderr directly to the calling TTY.
// This is required in "stdio" mode so a connecting indicator can be displayed.
err = cliui.Agent(ctx, inv.Stderr, workspaceAgent.ID, cliui.AgentOptions{
FetchInterval: 0,
Fetch: client.WorkspaceAgent,
FetchLogs: client.WorkspaceAgentLogsAfter,
Wait: wait,
DocsURL: appearanceConfig.DocsURL,
})
if err != nil {
if xerrors.Is(err, context.Canceled) {
return cliui.ErrCanceled
}
return err
}
// If we're in stdio mode, check to see if we can use Coder Connect.
// We don't support Coder Connect over non-stdio coder ssh yet.
if stdio && !forceNewTunnel {
connInfo, err := wsClient.AgentConnectionInfoGeneric(ctx)
if err != nil {
return xerrors.Errorf("get agent connection info: %w", err)
}
coderConnectHost := fmt.Sprintf("%s.%s.%s.%s",
workspaceAgent.Name, workspace.Name, workspace.OwnerName, connInfo.HostnameSuffix)
// Use trailing dot to indicate FQDN and prevent DNS
// search domain expansion, which can add 20-30s of
// delay on corporate networks with search domains
// configured.
exists, ccErr := workspacesdk.ExistsViaCoderConnect(ctx, coderConnectHost+".")
if ccErr != nil {
logger.Debug(ctx, "failed to check coder connect",
slog.F("hostname", coderConnectHost),
slog.Error(ccErr),
)
}
if exists {
defer cancel()
if networkInfoDir != "" {
if err := writeCoderConnectNetInfo(ctx, networkInfoDir); err != nil {
logger.Error(ctx, "failed to write coder connect net info file", slog.Error(err))
}
}
stopPolling := tryPollWorkspaceAutostop(ctx, client, workspace)
defer stopPolling()
usageAppName := getUsageAppName(usageApp)
if usageAppName != "" {
closeUsage := client.UpdateWorkspaceUsageWithBodyContext(ctx, workspace.ID, codersdk.PostWorkspaceUsageRequest{
AgentID: workspaceAgent.ID,
AppName: usageAppName,
})
defer closeUsage()
}
return runCoderConnectStdio(ctx, fmt.Sprintf("%s:22", coderConnectHost), stdioReader, stdioWriter, stack)
}
}
workspace := resolved.workspace
workspaceAgent := resolved.workspaceAgent
if r.disableDirect {
_, _ = fmt.Fprintln(inv.Stderr, "Direct connections disabled.")
@@ -405,30 +345,12 @@ func (r *RootCmd) ssh() *serpent.Command {
}
conn.AwaitReachable(ctx)
if containerName != "" {
cts, err := client.WorkspaceAgentListContainers(ctx, workspaceAgent.ID, nil)
if err != nil {
return xerrors.Errorf("list containers: %w", err)
}
if len(cts.Containers) == 0 {
cliui.Info(inv.Stderr, "No containers found!")
return nil
}
var found bool
for _, c := range cts.Containers {
if c.FriendlyName == containerName || c.ID == containerName {
found = true
break
}
}
if !found {
availableContainers := make([]string, len(cts.Containers))
for i, c := range cts.Containers {
availableContainers[i] = c.FriendlyName
}
cliui.Errorf(inv.Stderr, "Container not found: %q\nAvailable containers: %v", containerName, availableContainers)
return nil
}
found, err := verifyContainer(ctx, client, workspaceAgent.ID, containerName, inv.Stderr)
if err != nil {
return err
}
if !found {
return nil
}
stopPolling := tryPollWorkspaceAutostop(ctx, client, workspace)
@@ -443,36 +365,6 @@ func (r *RootCmd) ssh() *serpent.Command {
defer closeUsage()
}
if stdio {
rawSSH, err := conn.SSH(ctx)
if err != nil {
return xerrors.Errorf("connect SSH: %w", err)
}
copier := newRawSSHCopier(logger, rawSSH, stdioReader, stdioWriter)
if err = stack.push("rawSSHCopier", copier); err != nil {
return err
}
var errCh <-chan error
if networkInfoDir != "" {
errCh, err = setStatsCallback(ctx, conn, logger, networkInfoDir, networkInfoInterval)
if err != nil {
return err
}
}
wg.Add(1)
go func() {
defer wg.Done()
watchAndClose(ctx, func() error {
stack.close(xerrors.New("watchAndClose"))
return nil
}, logger, client, workspace, errCh)
}()
copier.copy(&wg)
return nil
}
sshClient, err := conn.SSHClient(ctx)
if err != nil {
return xerrors.Errorf("ssh client: %w", err)
@@ -790,6 +682,357 @@ func (r *RootCmd) ssh() *serpent.Command {
return cmd
}
// sshResolvedAgent is the result of resolving and waiting for an
// agent. Shared by both the stdio and interactive paths.
type sshResolvedAgent struct {
workspace codersdk.Workspace
workspaceAgent codersdk.WorkspaceAgent
}
// sshResolveAgent resolves the workspace/agent, displays banners,
// and waits for the agent to be ready.
func sshResolveAgent(
ctx context.Context,
inv *serpent.Invocation,
client *codersdk.Client,
cliConfig codersdk.SSHConfigResponse,
appearanceConfig codersdk.AppearanceConfig,
disableAutostart bool,
waitEnum string,
) (*sshResolvedAgent, error) {
workspace, workspaceAgent, err := findWorkspaceAndAgentByHostname(
ctx, inv, client,
inv.Args[0], cliConfig, disableAutostart)
if err != nil {
return nil, err
}
var wait bool
switch waitEnum {
case "yes":
wait = true
case "no":
wait = false
case "auto":
for _, script := range workspaceAgent.Scripts {
if script.StartBlocksLogin {
wait = true
break
}
}
default:
return nil, xerrors.Errorf("unknown wait value %q", waitEnum)
}
templateVersion, err := client.TemplateVersion(ctx, workspace.LatestBuild.TemplateVersionID)
if err != nil {
return nil, err
}
var unsupportedWorkspace bool
for _, warning := range templateVersion.Warnings {
if warning == codersdk.TemplateVersionWarningUnsupportedWorkspaces {
unsupportedWorkspace = true
break
}
}
if unsupportedWorkspace && isTTYErr(inv) {
_, _ = fmt.Fprintln(inv.Stderr, "👋 Your workspace uses legacy parameters which are not supported anymore. Contact your administrator for assistance.")
}
updateWorkspaceBanner, outdated := verifyWorkspaceOutdated(client, workspace)
if outdated && isTTYErr(inv) {
_, _ = fmt.Fprintln(inv.Stderr, updateWorkspaceBanner)
}
err = cliui.Agent(ctx, inv.Stderr, workspaceAgent.ID, cliui.AgentOptions{
FetchInterval: 0,
Fetch: client.WorkspaceAgent,
FetchLogs: client.WorkspaceAgentLogsAfter,
Wait: wait,
DocsURL: appearanceConfig.DocsURL,
})
if err != nil {
if xerrors.Is(err, context.Canceled) {
return nil, cliui.ErrCanceled
}
return nil, err
}
return &sshResolvedAgent{
workspace: workspace,
workspaceAgent: workspaceAgent,
}, nil
}
// sshStdioArgs holds CLI flag values for the stdio path.
type sshStdioArgs struct {
appearanceConfig codersdk.AppearanceConfig
forceNewTunnel bool
usageApp string
disableAutostart bool
networkInfoDir string
networkInfoInterval time.Duration
containerName string
containerUser string
cliConfig codersdk.SSHConfigResponse
waitEnum string
noWait bool
}
// stdioConnResult holds the established connection for stdio proxying.
// Either coderConnectConn or agentConn/rawSSH is set, not both.
type stdioConnResult struct {
workspace codersdk.Workspace
workspaceAgent codersdk.WorkspaceAgent
coderConnectConn net.Conn // Coder Connect path
agentConn workspacesdk.AgentConn // tailnet path
rawSSH *gonet.TCPConn // tailnet path
}
// sshEstablishStdioConn resolves the workspace, waits for the agent,
// and dials the tunnel. This is the only phase retried on transient
// errors.
func (r *RootCmd) sshEstablishStdioConn(
ctx context.Context,
inv *serpent.Invocation,
client *codersdk.Client,
wsClient *workspacesdk.Client,
logger slog.Logger,
stack *closerStack,
args sshStdioArgs,
) (*stdioConnResult, error) {
resolvedWait := args.waitEnum
if args.noWait {
resolvedWait = "no"
}
resolved, err := sshResolveAgent(ctx, inv, client,
args.cliConfig, args.appearanceConfig,
args.disableAutostart, resolvedWait,
)
if err != nil {
return nil, err
}
workspace := resolved.workspace
workspaceAgent := resolved.workspaceAgent
// Check if we can use Coder Connect.
if !args.forceNewTunnel {
connInfo, err := wsClient.AgentConnectionInfoGeneric(ctx)
if err != nil {
return nil, xerrors.Errorf("get agent connection info: %w", err)
}
coderConnectHost := fmt.Sprintf("%s.%s.%s.%s",
workspaceAgent.Name, workspace.Name, workspace.OwnerName, connInfo.HostnameSuffix)
// Use trailing dot to indicate FQDN and prevent DNS
// search domain expansion, which can add 20-30s of
// delay on corporate networks with search domains
// configured.
exists, ccErr := workspacesdk.ExistsViaCoderConnect(ctx, coderConnectHost+".")
if ccErr != nil {
logger.Debug(ctx, "failed to check coder connect",
slog.F("hostname", coderConnectHost),
slog.Error(ccErr),
)
}
if exists {
dialer := testOrDefaultDialer(ctx)
tcpConn, err := dialer.DialContext(ctx, "tcp", fmt.Sprintf("%s:22", coderConnectHost))
if err != nil {
return nil, xerrors.Errorf("dial coder connect host: %w", err)
}
if err := stack.push("tcp conn", tcpConn); err != nil {
return nil, err
}
return &stdioConnResult{
workspace: workspace,
workspaceAgent: workspaceAgent,
coderConnectConn: tcpConn,
}, nil
}
}
// Regular tailnet path.
if r.disableDirect {
_, _ = fmt.Fprintln(inv.Stderr, "Direct connections disabled.")
}
conn, err := wsClient.
DialAgent(ctx, workspaceAgent.ID, &workspacesdk.DialAgentOptions{
Logger: logger,
BlockEndpoints: r.disableDirect,
EnableTelemetry: !r.disableNetworkTelemetry,
})
if err != nil {
return nil, xerrors.Errorf("dial agent: %w", err)
}
if err = stack.push("agent conn", conn); err != nil {
return nil, err
}
conn.AwaitReachable(ctx)
rawSSH, err := conn.SSH(ctx)
if err != nil {
return nil, xerrors.Errorf("connect SSH: %w", err)
}
return &stdioConnResult{
workspace: workspace,
workspaceAgent: workspaceAgent,
agentConn: conn,
rawSSH: rawSSH,
}, nil
}
// sshHandleStdio retries tunnel establishment, then proxies data
// once. After the tunnel is up, errors are final because the remote
// SSH server has session state tied to the connection.
func (r *RootCmd) sshHandleStdio(
ctx context.Context,
cancel context.CancelFunc,
inv *serpent.Invocation,
client *codersdk.Client,
wsClient *workspacesdk.Client,
stdioReader io.Reader,
stdioWriter io.Writer,
logger slog.Logger,
wg *sync.WaitGroup,
args sshStdioArgs,
) error {
// Fresh stack per attempt so failed connections get cleaned up.
var stack *closerStack
defer func() {
if stack != nil {
stack.close(nil)
}
}()
var (
result *stdioConnResult
lastErr error
)
for attempt, delay := range stdioRetryDelays {
if attempt > 0 {
stack.close(nil)
logger.Warn(ctx, "ssh tunnel establishment failed, retrying",
slog.Error(lastErr),
slog.F("attempt", attempt+1),
slog.F("delay", delay),
)
select {
case <-ctx.Done():
return lastErr
case <-time.After(delay):
}
}
stack = newCloserStack(ctx, logger, quartz.NewReal())
result, lastErr = r.sshEstablishStdioConn(ctx, inv, client, wsClient, logger, stack, args)
if lastErr == nil || !isRetryableError(lastErr) {
break
}
}
if lastErr != nil {
return lastErr
}
// Tunnel established. Do not retry past this point.
found, err := verifyContainer(ctx, client, result.workspaceAgent.ID, args.containerName, inv.Stderr)
if err != nil {
return err
}
if !found {
return nil
}
stopPolling := tryPollWorkspaceAutostop(ctx, client, result.workspace)
defer stopPolling()
usageAppName := getUsageAppName(args.usageApp)
if usageAppName != "" {
closeUsage := client.UpdateWorkspaceUsageWithBodyContext(ctx, result.workspace.ID, codersdk.PostWorkspaceUsageRequest{
AgentID: result.workspaceAgent.ID,
AppName: usageAppName,
})
defer closeUsage()
}
if result.coderConnectConn != nil {
defer cancel()
if args.networkInfoDir != "" {
if err := writeCoderConnectNetInfo(ctx, args.networkInfoDir); err != nil {
logger.Error(ctx, "failed to write coder connect net info file", slog.Error(err))
}
}
agentssh.Bicopy(ctx, result.coderConnectConn, &StdioRwc{
Reader: stdioReader,
Writer: stdioWriter,
})
return nil
}
copier := newRawSSHCopier(logger, result.rawSSH, stdioReader, stdioWriter)
if err := stack.push("rawSSHCopier", copier); err != nil {
return err
}
var errCh <-chan error
if args.networkInfoDir != "" {
var err error
errCh, err = setStatsCallback(ctx, result.agentConn, logger, args.networkInfoDir, args.networkInfoInterval)
if err != nil {
return err
}
}
wg.Add(1)
go func() {
defer wg.Done()
watchAndClose(ctx, func() error {
stack.close(xerrors.New("watchAndClose"))
return nil
}, logger, client, result.workspace, errCh)
}()
copier.copy(wg)
return nil
}
// verifyContainer checks that the named container exists on the
// agent. Returns (true, nil) if found or name is empty, (false,
// nil) if not found (prints diagnostic), or (false, err) on API
// failure.
func verifyContainer(
ctx context.Context,
client *codersdk.Client,
agentID uuid.UUID,
containerName string,
stderr io.Writer,
) (bool, error) {
if containerName == "" {
return true, nil
}
cts, err := client.WorkspaceAgentListContainers(ctx, agentID, nil)
if err != nil {
return false, xerrors.Errorf("list containers: %w", err)
}
if len(cts.Containers) == 0 {
cliui.Info(stderr, "No containers found!")
return false, nil
}
for _, c := range cts.Containers {
if c.Match(containerName) {
return true, nil
}
}
availableContainers := make([]string, len(cts.Containers))
for i, c := range cts.Containers {
availableContainers[i] = c.FriendlyName
}
cliui.Errorf(stderr, "Container not found: %q\nAvailable containers: %v", containerName, availableContainers)
return false, nil
}
// findWorkspaceAndAgentByHostname parses the hostname from the commandline and finds the workspace and agent it
// corresponds to, taking into account any name prefixes or suffixes configured (e.g. myworkspace.coder, or
// vscode-coder--myusername--myworkspace).
@@ -1583,24 +1826,6 @@ func testOrDefaultDialer(ctx context.Context) coderConnectDialer {
return dialer
}
func runCoderConnectStdio(ctx context.Context, addr string, stdin io.Reader, stdout io.Writer, stack *closerStack) error {
dialer := testOrDefaultDialer(ctx)
conn, err := dialer.DialContext(ctx, "tcp", addr)
if err != nil {
return xerrors.Errorf("dial coder connect host: %w", err)
}
if err := stack.push("tcp conn", conn); err != nil {
return err
}
agentssh.Bicopy(ctx, conn, &StdioRwc{
Reader: stdin,
Writer: stdout,
})
return nil
}
type StdioRwc struct {
io.Reader
io.Writer

View File

@@ -1,11 +1,16 @@
package cli
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/http/httptest"
"net/url"
"os"
"sync"
"testing"
"time"
@@ -19,6 +24,7 @@ import (
"cdr.dev/slog/v3"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/agent/agentssh"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/testutil"
"github.com/coder/quartz"
@@ -254,9 +260,19 @@ func TestCoderConnectStdio(t *testing.T) {
stdioDone := make(chan struct{})
go func() {
err = runCoderConnectStdio(ctx, ln.Addr().String(), clientOutput, serverInput, stack)
assert.NoError(t, err)
close(stdioDone)
defer close(stdioDone)
conn, dialErr := net.Dial("tcp", ln.Addr().String())
if !assert.NoError(t, dialErr) {
return
}
pushErr := stack.push("tcp conn", conn)
if !assert.NoError(t, pushErr) {
return
}
agentssh.Bicopy(ctx, conn, &StdioRwc{
Reader: clientOutput,
Writer: serverInput,
})
}()
conn, channels, requests, err := ssh.NewClientConn(&testutil.ReaderWriterConn{
@@ -448,3 +464,213 @@ func Test_getWorkspaceAgent(t *testing.T) {
assert.Contains(t, err.Error(), "available agents: [clark krypton zod]")
})
}
func TestIsRetryableError(t *testing.T) {
t.Parallel()
tests := []struct {
name string
err error
expected bool
}{
{
name: "Nil",
err: nil,
expected: false,
},
{
name: "ContextCanceled",
err: context.Canceled,
expected: false,
},
{
name: "ContextDeadlineExceeded",
err: context.DeadlineExceeded,
expected: false,
},
{
name: "WrappedContextCanceled",
err: xerrors.Errorf("wrapped: %w", context.Canceled),
expected: false,
},
{
name: "DNSError",
err: &net.DNSError{
Err: "no such host",
Name: "example.com",
IsNotFound: true,
},
expected: true,
},
{
name: "OpError",
err: &net.OpError{
Op: "dial",
Net: "tcp",
Err: &os.SyscallError{},
},
expected: true,
},
{
name: "WrappedDNSError",
err: xerrors.Errorf("connect failed: %w", &net.DNSError{
Err: "no such host",
Name: "example.com",
}),
expected: true,
},
{
name: "SDKError500",
err: codersdk.NewTestError(http.StatusInternalServerError, "GET", "/api"),
expected: true,
},
{
name: "SDKError502",
err: codersdk.NewTestError(http.StatusBadGateway, "GET", "/api"),
expected: true,
},
{
name: "SDKError503",
err: codersdk.NewTestError(http.StatusServiceUnavailable, "GET", "/api"),
expected: true,
},
{
name: "SDKError401",
err: codersdk.NewTestError(http.StatusUnauthorized, "GET", "/api"),
expected: false,
},
{
name: "SDKError403",
err: codersdk.NewTestError(http.StatusForbidden, "GET", "/api"),
expected: false,
},
{
name: "SDKError404",
err: codersdk.NewTestError(http.StatusNotFound, "GET", "/api"),
expected: false,
},
{
name: "GenericError",
err: xerrors.New("something went wrong"),
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
assert.Equal(t, tt.expected, isRetryableError(tt.err))
})
}
}
func TestVerifyContainer(t *testing.T) {
t.Parallel()
agentID := uuid.New()
newContainerClient := func(t *testing.T, statusCode int, resp codersdk.WorkspaceAgentListContainersResponse) *codersdk.Client {
t.Helper()
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(statusCode)
_ = json.NewEncoder(w).Encode(resp)
}))
t.Cleanup(srv.Close)
serverURL, err := url.Parse(srv.URL)
require.NoError(t, err)
return codersdk.New(serverURL)
}
t.Run("EmptyName", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
// Empty name short-circuits before making any HTTP calls.
serverURL, err := url.Parse(fakeServerURL)
require.NoError(t, err)
client := codersdk.New(serverURL)
var stderr bytes.Buffer
found, err := verifyContainer(ctx, client, agentID, "", &stderr)
require.NoError(t, err)
assert.True(t, found)
})
t.Run("FoundByFriendlyName", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
client := newContainerClient(t, http.StatusOK, codersdk.WorkspaceAgentListContainersResponse{
Containers: []codersdk.WorkspaceAgentContainer{
{ID: "abc123", FriendlyName: "my-container"},
},
})
var stderr bytes.Buffer
found, err := verifyContainer(ctx, client, agentID, "my-container", &stderr)
require.NoError(t, err)
assert.True(t, found)
})
t.Run("FoundByID", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
client := newContainerClient(t, http.StatusOK, codersdk.WorkspaceAgentListContainersResponse{
Containers: []codersdk.WorkspaceAgentContainer{
{ID: "abc123", FriendlyName: "my-container"},
},
})
var stderr bytes.Buffer
found, err := verifyContainer(ctx, client, agentID, "abc123", &stderr)
require.NoError(t, err)
assert.True(t, found)
})
t.Run("NotFound", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
client := newContainerClient(t, http.StatusOK, codersdk.WorkspaceAgentListContainersResponse{
Containers: []codersdk.WorkspaceAgentContainer{
{ID: "abc123", FriendlyName: "other-container"},
},
})
var stderr bytes.Buffer
found, err := verifyContainer(ctx, client, agentID, "missing", &stderr)
require.NoError(t, err)
assert.False(t, found)
assert.Contains(t, stderr.String(), "Container not found")
assert.Contains(t, stderr.String(), "other-container")
})
t.Run("NoContainers", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
client := newContainerClient(t, http.StatusOK, codersdk.WorkspaceAgentListContainersResponse{
Containers: []codersdk.WorkspaceAgentContainer{},
})
var stderr bytes.Buffer
found, err := verifyContainer(ctx, client, agentID, "anything", &stderr)
require.NoError(t, err)
assert.False(t, found)
assert.Contains(t, stderr.String(), "No containers found")
})
t.Run("APIError", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
client := newContainerClient(t, http.StatusInternalServerError, codersdk.WorkspaceAgentListContainersResponse{})
var stderr bytes.Buffer
_, err := verifyContainer(ctx, client, agentID, "my-container", &stderr)
require.Error(t, err)
assert.Contains(t, err.Error(), "list containers")
})
}