Compare commits
1 Commits
pubsub-buf
...
fix/ssh-st
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d7e08678eb |
585
cli/ssh.go
585
cli/ssh.go
@@ -60,8 +60,30 @@ var (
|
||||
// gracefulShutdownTimeout is the timeout, per item in the stack of things to close
|
||||
gracefulShutdownTimeout = 2 * time.Second
|
||||
workspaceNameRe = regexp.MustCompile(`[/.]+|--`)
|
||||
// stdioRetryDelays controls retry timing for transient
|
||||
// ProxyCommand failures during tunnel establishment.
|
||||
stdioRetryDelays = []time.Duration{0, 2 * time.Second, 5 * time.Second, 10 * time.Second, 10 * time.Second}
|
||||
)
|
||||
|
||||
// isRetryableError reports whether err is a transient connection
|
||||
// error worth retrying (DNS failures, connection refused, server 5xx).
|
||||
func isRetryableError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
if xerrors.Is(err, context.Canceled) || xerrors.Is(err, context.DeadlineExceeded) {
|
||||
return false
|
||||
}
|
||||
if codersdk.IsConnectionError(err) {
|
||||
return true
|
||||
}
|
||||
var sdkErr *codersdk.Error
|
||||
if xerrors.As(err, &sdkErr) {
|
||||
return sdkErr.StatusCode() >= 500
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (r *RootCmd) ssh() *serpent.Command {
|
||||
var (
|
||||
stdio bool
|
||||
@@ -272,121 +294,39 @@ func (r *RootCmd) ssh() *serpent.Command {
|
||||
parsedEnv = append(parsedEnv, [2]string{k, v})
|
||||
}
|
||||
|
||||
cliConfig := codersdk.SSHConfigResponse{
|
||||
HostnamePrefix: hostPrefix,
|
||||
HostnameSuffix: hostnameSuffix,
|
||||
if stdio {
|
||||
return r.sshHandleStdio(ctx, cancel, inv, client, wsClient, stdioReader, stdioWriter, logger, &wg, sshStdioArgs{
|
||||
appearanceConfig: appearanceConfig,
|
||||
forceNewTunnel: forceNewTunnel,
|
||||
usageApp: usageApp,
|
||||
disableAutostart: disableAutostart,
|
||||
networkInfoDir: networkInfoDir,
|
||||
networkInfoInterval: networkInfoInterval,
|
||||
containerName: containerName,
|
||||
containerUser: containerUser,
|
||||
cliConfig: codersdk.SSHConfigResponse{
|
||||
HostnamePrefix: hostPrefix,
|
||||
HostnameSuffix: hostnameSuffix,
|
||||
},
|
||||
waitEnum: waitEnum,
|
||||
noWait: noWait,
|
||||
})
|
||||
}
|
||||
|
||||
workspace, workspaceAgent, err := findWorkspaceAndAgentByHostname(
|
||||
ctx, inv, client,
|
||||
inv.Args[0], cliConfig, disableAutostart)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Select the startup script behavior based on template configuration or flags.
|
||||
var wait bool
|
||||
switch waitEnum {
|
||||
case "yes":
|
||||
wait = true
|
||||
case "no":
|
||||
wait = false
|
||||
case "auto":
|
||||
for _, script := range workspaceAgent.Scripts {
|
||||
if script.StartBlocksLogin {
|
||||
wait = true
|
||||
break
|
||||
}
|
||||
}
|
||||
default:
|
||||
return xerrors.Errorf("unknown wait value %q", waitEnum)
|
||||
}
|
||||
// The `--no-wait` flag is deprecated, but for now, check it.
|
||||
// Non-stdio (interactive) path, single attempt, no retry.
|
||||
resolvedWait := waitEnum
|
||||
if noWait {
|
||||
wait = false
|
||||
resolvedWait = "no"
|
||||
}
|
||||
|
||||
templateVersion, err := client.TemplateVersion(ctx, workspace.LatestBuild.TemplateVersionID)
|
||||
resolved, err := sshResolveAgent(ctx, inv, client,
|
||||
codersdk.SSHConfigResponse{HostnamePrefix: hostPrefix, HostnameSuffix: hostnameSuffix},
|
||||
appearanceConfig, disableAutostart, resolvedWait,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var unsupportedWorkspace bool
|
||||
for _, warning := range templateVersion.Warnings {
|
||||
if warning == codersdk.TemplateVersionWarningUnsupportedWorkspaces {
|
||||
unsupportedWorkspace = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if unsupportedWorkspace && isTTYErr(inv) {
|
||||
_, _ = fmt.Fprintln(inv.Stderr, "👋 Your workspace uses legacy parameters which are not supported anymore. Contact your administrator for assistance.")
|
||||
}
|
||||
|
||||
updateWorkspaceBanner, outdated := verifyWorkspaceOutdated(client, workspace)
|
||||
if outdated && isTTYErr(inv) {
|
||||
_, _ = fmt.Fprintln(inv.Stderr, updateWorkspaceBanner)
|
||||
}
|
||||
|
||||
// OpenSSH passes stderr directly to the calling TTY.
|
||||
// This is required in "stdio" mode so a connecting indicator can be displayed.
|
||||
err = cliui.Agent(ctx, inv.Stderr, workspaceAgent.ID, cliui.AgentOptions{
|
||||
FetchInterval: 0,
|
||||
Fetch: client.WorkspaceAgent,
|
||||
FetchLogs: client.WorkspaceAgentLogsAfter,
|
||||
Wait: wait,
|
||||
DocsURL: appearanceConfig.DocsURL,
|
||||
})
|
||||
if err != nil {
|
||||
if xerrors.Is(err, context.Canceled) {
|
||||
return cliui.ErrCanceled
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// If we're in stdio mode, check to see if we can use Coder Connect.
|
||||
// We don't support Coder Connect over non-stdio coder ssh yet.
|
||||
if stdio && !forceNewTunnel {
|
||||
connInfo, err := wsClient.AgentConnectionInfoGeneric(ctx)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get agent connection info: %w", err)
|
||||
}
|
||||
coderConnectHost := fmt.Sprintf("%s.%s.%s.%s",
|
||||
workspaceAgent.Name, workspace.Name, workspace.OwnerName, connInfo.HostnameSuffix)
|
||||
// Use trailing dot to indicate FQDN and prevent DNS
|
||||
// search domain expansion, which can add 20-30s of
|
||||
// delay on corporate networks with search domains
|
||||
// configured.
|
||||
exists, ccErr := workspacesdk.ExistsViaCoderConnect(ctx, coderConnectHost+".")
|
||||
if ccErr != nil {
|
||||
logger.Debug(ctx, "failed to check coder connect",
|
||||
slog.F("hostname", coderConnectHost),
|
||||
slog.Error(ccErr),
|
||||
)
|
||||
}
|
||||
if exists {
|
||||
defer cancel()
|
||||
|
||||
if networkInfoDir != "" {
|
||||
if err := writeCoderConnectNetInfo(ctx, networkInfoDir); err != nil {
|
||||
logger.Error(ctx, "failed to write coder connect net info file", slog.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
stopPolling := tryPollWorkspaceAutostop(ctx, client, workspace)
|
||||
defer stopPolling()
|
||||
|
||||
usageAppName := getUsageAppName(usageApp)
|
||||
if usageAppName != "" {
|
||||
closeUsage := client.UpdateWorkspaceUsageWithBodyContext(ctx, workspace.ID, codersdk.PostWorkspaceUsageRequest{
|
||||
AgentID: workspaceAgent.ID,
|
||||
AppName: usageAppName,
|
||||
})
|
||||
defer closeUsage()
|
||||
}
|
||||
return runCoderConnectStdio(ctx, fmt.Sprintf("%s:22", coderConnectHost), stdioReader, stdioWriter, stack)
|
||||
}
|
||||
}
|
||||
workspace := resolved.workspace
|
||||
workspaceAgent := resolved.workspaceAgent
|
||||
|
||||
if r.disableDirect {
|
||||
_, _ = fmt.Fprintln(inv.Stderr, "Direct connections disabled.")
|
||||
@@ -405,30 +345,12 @@ func (r *RootCmd) ssh() *serpent.Command {
|
||||
}
|
||||
conn.AwaitReachable(ctx)
|
||||
|
||||
if containerName != "" {
|
||||
cts, err := client.WorkspaceAgentListContainers(ctx, workspaceAgent.ID, nil)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("list containers: %w", err)
|
||||
}
|
||||
if len(cts.Containers) == 0 {
|
||||
cliui.Info(inv.Stderr, "No containers found!")
|
||||
return nil
|
||||
}
|
||||
var found bool
|
||||
for _, c := range cts.Containers {
|
||||
if c.FriendlyName == containerName || c.ID == containerName {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
availableContainers := make([]string, len(cts.Containers))
|
||||
for i, c := range cts.Containers {
|
||||
availableContainers[i] = c.FriendlyName
|
||||
}
|
||||
cliui.Errorf(inv.Stderr, "Container not found: %q\nAvailable containers: %v", containerName, availableContainers)
|
||||
return nil
|
||||
}
|
||||
found, err := verifyContainer(ctx, client, workspaceAgent.ID, containerName, inv.Stderr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !found {
|
||||
return nil
|
||||
}
|
||||
|
||||
stopPolling := tryPollWorkspaceAutostop(ctx, client, workspace)
|
||||
@@ -443,36 +365,6 @@ func (r *RootCmd) ssh() *serpent.Command {
|
||||
defer closeUsage()
|
||||
}
|
||||
|
||||
if stdio {
|
||||
rawSSH, err := conn.SSH(ctx)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("connect SSH: %w", err)
|
||||
}
|
||||
copier := newRawSSHCopier(logger, rawSSH, stdioReader, stdioWriter)
|
||||
if err = stack.push("rawSSHCopier", copier); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var errCh <-chan error
|
||||
if networkInfoDir != "" {
|
||||
errCh, err = setStatsCallback(ctx, conn, logger, networkInfoDir, networkInfoInterval)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
watchAndClose(ctx, func() error {
|
||||
stack.close(xerrors.New("watchAndClose"))
|
||||
return nil
|
||||
}, logger, client, workspace, errCh)
|
||||
}()
|
||||
copier.copy(&wg)
|
||||
return nil
|
||||
}
|
||||
|
||||
sshClient, err := conn.SSHClient(ctx)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("ssh client: %w", err)
|
||||
@@ -790,6 +682,357 @@ func (r *RootCmd) ssh() *serpent.Command {
|
||||
return cmd
|
||||
}
|
||||
|
||||
// sshResolvedAgent is the result of resolving and waiting for an
|
||||
// agent. Shared by both the stdio and interactive paths.
|
||||
type sshResolvedAgent struct {
|
||||
workspace codersdk.Workspace
|
||||
workspaceAgent codersdk.WorkspaceAgent
|
||||
}
|
||||
|
||||
// sshResolveAgent resolves the workspace/agent, displays banners,
|
||||
// and waits for the agent to be ready.
|
||||
func sshResolveAgent(
|
||||
ctx context.Context,
|
||||
inv *serpent.Invocation,
|
||||
client *codersdk.Client,
|
||||
cliConfig codersdk.SSHConfigResponse,
|
||||
appearanceConfig codersdk.AppearanceConfig,
|
||||
disableAutostart bool,
|
||||
waitEnum string,
|
||||
) (*sshResolvedAgent, error) {
|
||||
workspace, workspaceAgent, err := findWorkspaceAndAgentByHostname(
|
||||
ctx, inv, client,
|
||||
inv.Args[0], cliConfig, disableAutostart)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var wait bool
|
||||
switch waitEnum {
|
||||
case "yes":
|
||||
wait = true
|
||||
case "no":
|
||||
wait = false
|
||||
case "auto":
|
||||
for _, script := range workspaceAgent.Scripts {
|
||||
if script.StartBlocksLogin {
|
||||
wait = true
|
||||
break
|
||||
}
|
||||
}
|
||||
default:
|
||||
return nil, xerrors.Errorf("unknown wait value %q", waitEnum)
|
||||
}
|
||||
|
||||
templateVersion, err := client.TemplateVersion(ctx, workspace.LatestBuild.TemplateVersionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var unsupportedWorkspace bool
|
||||
for _, warning := range templateVersion.Warnings {
|
||||
if warning == codersdk.TemplateVersionWarningUnsupportedWorkspaces {
|
||||
unsupportedWorkspace = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if unsupportedWorkspace && isTTYErr(inv) {
|
||||
_, _ = fmt.Fprintln(inv.Stderr, "👋 Your workspace uses legacy parameters which are not supported anymore. Contact your administrator for assistance.")
|
||||
}
|
||||
|
||||
updateWorkspaceBanner, outdated := verifyWorkspaceOutdated(client, workspace)
|
||||
if outdated && isTTYErr(inv) {
|
||||
_, _ = fmt.Fprintln(inv.Stderr, updateWorkspaceBanner)
|
||||
}
|
||||
|
||||
err = cliui.Agent(ctx, inv.Stderr, workspaceAgent.ID, cliui.AgentOptions{
|
||||
FetchInterval: 0,
|
||||
Fetch: client.WorkspaceAgent,
|
||||
FetchLogs: client.WorkspaceAgentLogsAfter,
|
||||
Wait: wait,
|
||||
DocsURL: appearanceConfig.DocsURL,
|
||||
})
|
||||
if err != nil {
|
||||
if xerrors.Is(err, context.Canceled) {
|
||||
return nil, cliui.ErrCanceled
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &sshResolvedAgent{
|
||||
workspace: workspace,
|
||||
workspaceAgent: workspaceAgent,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// sshStdioArgs holds CLI flag values for the stdio path.
|
||||
type sshStdioArgs struct {
|
||||
appearanceConfig codersdk.AppearanceConfig
|
||||
forceNewTunnel bool
|
||||
usageApp string
|
||||
disableAutostart bool
|
||||
networkInfoDir string
|
||||
networkInfoInterval time.Duration
|
||||
containerName string
|
||||
containerUser string
|
||||
cliConfig codersdk.SSHConfigResponse
|
||||
waitEnum string
|
||||
noWait bool
|
||||
}
|
||||
|
||||
// stdioConnResult holds the established connection for stdio proxying.
|
||||
// Either coderConnectConn or agentConn/rawSSH is set, not both.
|
||||
type stdioConnResult struct {
|
||||
workspace codersdk.Workspace
|
||||
workspaceAgent codersdk.WorkspaceAgent
|
||||
coderConnectConn net.Conn // Coder Connect path
|
||||
agentConn workspacesdk.AgentConn // tailnet path
|
||||
rawSSH *gonet.TCPConn // tailnet path
|
||||
}
|
||||
|
||||
// sshEstablishStdioConn resolves the workspace, waits for the agent,
|
||||
// and dials the tunnel. This is the only phase retried on transient
|
||||
// errors.
|
||||
func (r *RootCmd) sshEstablishStdioConn(
|
||||
ctx context.Context,
|
||||
inv *serpent.Invocation,
|
||||
client *codersdk.Client,
|
||||
wsClient *workspacesdk.Client,
|
||||
logger slog.Logger,
|
||||
stack *closerStack,
|
||||
args sshStdioArgs,
|
||||
) (*stdioConnResult, error) {
|
||||
resolvedWait := args.waitEnum
|
||||
if args.noWait {
|
||||
resolvedWait = "no"
|
||||
}
|
||||
resolved, err := sshResolveAgent(ctx, inv, client,
|
||||
args.cliConfig, args.appearanceConfig,
|
||||
args.disableAutostart, resolvedWait,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
workspace := resolved.workspace
|
||||
workspaceAgent := resolved.workspaceAgent
|
||||
|
||||
// Check if we can use Coder Connect.
|
||||
if !args.forceNewTunnel {
|
||||
connInfo, err := wsClient.AgentConnectionInfoGeneric(ctx)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("get agent connection info: %w", err)
|
||||
}
|
||||
coderConnectHost := fmt.Sprintf("%s.%s.%s.%s",
|
||||
workspaceAgent.Name, workspace.Name, workspace.OwnerName, connInfo.HostnameSuffix)
|
||||
// Use trailing dot to indicate FQDN and prevent DNS
|
||||
// search domain expansion, which can add 20-30s of
|
||||
// delay on corporate networks with search domains
|
||||
// configured.
|
||||
exists, ccErr := workspacesdk.ExistsViaCoderConnect(ctx, coderConnectHost+".")
|
||||
if ccErr != nil {
|
||||
logger.Debug(ctx, "failed to check coder connect",
|
||||
slog.F("hostname", coderConnectHost),
|
||||
slog.Error(ccErr),
|
||||
)
|
||||
}
|
||||
if exists {
|
||||
dialer := testOrDefaultDialer(ctx)
|
||||
tcpConn, err := dialer.DialContext(ctx, "tcp", fmt.Sprintf("%s:22", coderConnectHost))
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("dial coder connect host: %w", err)
|
||||
}
|
||||
if err := stack.push("tcp conn", tcpConn); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &stdioConnResult{
|
||||
workspace: workspace,
|
||||
workspaceAgent: workspaceAgent,
|
||||
coderConnectConn: tcpConn,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Regular tailnet path.
|
||||
if r.disableDirect {
|
||||
_, _ = fmt.Fprintln(inv.Stderr, "Direct connections disabled.")
|
||||
}
|
||||
conn, err := wsClient.
|
||||
DialAgent(ctx, workspaceAgent.ID, &workspacesdk.DialAgentOptions{
|
||||
Logger: logger,
|
||||
BlockEndpoints: r.disableDirect,
|
||||
EnableTelemetry: !r.disableNetworkTelemetry,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("dial agent: %w", err)
|
||||
}
|
||||
if err = stack.push("agent conn", conn); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conn.AwaitReachable(ctx)
|
||||
|
||||
rawSSH, err := conn.SSH(ctx)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("connect SSH: %w", err)
|
||||
}
|
||||
|
||||
return &stdioConnResult{
|
||||
workspace: workspace,
|
||||
workspaceAgent: workspaceAgent,
|
||||
agentConn: conn,
|
||||
rawSSH: rawSSH,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// sshHandleStdio retries tunnel establishment, then proxies data
|
||||
// once. After the tunnel is up, errors are final because the remote
|
||||
// SSH server has session state tied to the connection.
|
||||
func (r *RootCmd) sshHandleStdio(
|
||||
ctx context.Context,
|
||||
cancel context.CancelFunc,
|
||||
inv *serpent.Invocation,
|
||||
client *codersdk.Client,
|
||||
wsClient *workspacesdk.Client,
|
||||
stdioReader io.Reader,
|
||||
stdioWriter io.Writer,
|
||||
logger slog.Logger,
|
||||
wg *sync.WaitGroup,
|
||||
args sshStdioArgs,
|
||||
) error {
|
||||
// Fresh stack per attempt so failed connections get cleaned up.
|
||||
var stack *closerStack
|
||||
defer func() {
|
||||
if stack != nil {
|
||||
stack.close(nil)
|
||||
}
|
||||
}()
|
||||
|
||||
var (
|
||||
result *stdioConnResult
|
||||
lastErr error
|
||||
)
|
||||
for attempt, delay := range stdioRetryDelays {
|
||||
if attempt > 0 {
|
||||
stack.close(nil)
|
||||
logger.Warn(ctx, "ssh tunnel establishment failed, retrying",
|
||||
slog.Error(lastErr),
|
||||
slog.F("attempt", attempt+1),
|
||||
slog.F("delay", delay),
|
||||
)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return lastErr
|
||||
case <-time.After(delay):
|
||||
}
|
||||
}
|
||||
stack = newCloserStack(ctx, logger, quartz.NewReal())
|
||||
result, lastErr = r.sshEstablishStdioConn(ctx, inv, client, wsClient, logger, stack, args)
|
||||
if lastErr == nil || !isRetryableError(lastErr) {
|
||||
break
|
||||
}
|
||||
}
|
||||
if lastErr != nil {
|
||||
return lastErr
|
||||
}
|
||||
|
||||
// Tunnel established. Do not retry past this point.
|
||||
|
||||
found, err := verifyContainer(ctx, client, result.workspaceAgent.ID, args.containerName, inv.Stderr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !found {
|
||||
return nil
|
||||
}
|
||||
|
||||
stopPolling := tryPollWorkspaceAutostop(ctx, client, result.workspace)
|
||||
defer stopPolling()
|
||||
|
||||
usageAppName := getUsageAppName(args.usageApp)
|
||||
if usageAppName != "" {
|
||||
closeUsage := client.UpdateWorkspaceUsageWithBodyContext(ctx, result.workspace.ID, codersdk.PostWorkspaceUsageRequest{
|
||||
AgentID: result.workspaceAgent.ID,
|
||||
AppName: usageAppName,
|
||||
})
|
||||
defer closeUsage()
|
||||
}
|
||||
|
||||
if result.coderConnectConn != nil {
|
||||
defer cancel()
|
||||
|
||||
if args.networkInfoDir != "" {
|
||||
if err := writeCoderConnectNetInfo(ctx, args.networkInfoDir); err != nil {
|
||||
logger.Error(ctx, "failed to write coder connect net info file", slog.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
agentssh.Bicopy(ctx, result.coderConnectConn, &StdioRwc{
|
||||
Reader: stdioReader,
|
||||
Writer: stdioWriter,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
copier := newRawSSHCopier(logger, result.rawSSH, stdioReader, stdioWriter)
|
||||
if err := stack.push("rawSSHCopier", copier); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var errCh <-chan error
|
||||
if args.networkInfoDir != "" {
|
||||
var err error
|
||||
errCh, err = setStatsCallback(ctx, result.agentConn, logger, args.networkInfoDir, args.networkInfoInterval)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
watchAndClose(ctx, func() error {
|
||||
stack.close(xerrors.New("watchAndClose"))
|
||||
return nil
|
||||
}, logger, client, result.workspace, errCh)
|
||||
}()
|
||||
copier.copy(wg)
|
||||
return nil
|
||||
}
|
||||
|
||||
// verifyContainer checks that the named container exists on the
|
||||
// agent. Returns (true, nil) if found or name is empty, (false,
|
||||
// nil) if not found (prints diagnostic), or (false, err) on API
|
||||
// failure.
|
||||
func verifyContainer(
|
||||
ctx context.Context,
|
||||
client *codersdk.Client,
|
||||
agentID uuid.UUID,
|
||||
containerName string,
|
||||
stderr io.Writer,
|
||||
) (bool, error) {
|
||||
if containerName == "" {
|
||||
return true, nil
|
||||
}
|
||||
cts, err := client.WorkspaceAgentListContainers(ctx, agentID, nil)
|
||||
if err != nil {
|
||||
return false, xerrors.Errorf("list containers: %w", err)
|
||||
}
|
||||
if len(cts.Containers) == 0 {
|
||||
cliui.Info(stderr, "No containers found!")
|
||||
return false, nil
|
||||
}
|
||||
for _, c := range cts.Containers {
|
||||
if c.Match(containerName) {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
availableContainers := make([]string, len(cts.Containers))
|
||||
for i, c := range cts.Containers {
|
||||
availableContainers[i] = c.FriendlyName
|
||||
}
|
||||
cliui.Errorf(stderr, "Container not found: %q\nAvailable containers: %v", containerName, availableContainers)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// findWorkspaceAndAgentByHostname parses the hostname from the commandline and finds the workspace and agent it
|
||||
// corresponds to, taking into account any name prefixes or suffixes configured (e.g. myworkspace.coder, or
|
||||
// vscode-coder--myusername--myworkspace).
|
||||
@@ -1583,24 +1826,6 @@ func testOrDefaultDialer(ctx context.Context) coderConnectDialer {
|
||||
return dialer
|
||||
}
|
||||
|
||||
func runCoderConnectStdio(ctx context.Context, addr string, stdin io.Reader, stdout io.Writer, stack *closerStack) error {
|
||||
dialer := testOrDefaultDialer(ctx)
|
||||
conn, err := dialer.DialContext(ctx, "tcp", addr)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("dial coder connect host: %w", err)
|
||||
}
|
||||
if err := stack.push("tcp conn", conn); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
agentssh.Bicopy(ctx, conn, &StdioRwc{
|
||||
Reader: stdin,
|
||||
Writer: stdout,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type StdioRwc struct {
|
||||
io.Reader
|
||||
io.Writer
|
||||
|
||||
@@ -1,11 +1,16 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -19,6 +24,7 @@ import (
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/agent/agentssh"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
"github.com/coder/quartz"
|
||||
@@ -254,9 +260,19 @@ func TestCoderConnectStdio(t *testing.T) {
|
||||
|
||||
stdioDone := make(chan struct{})
|
||||
go func() {
|
||||
err = runCoderConnectStdio(ctx, ln.Addr().String(), clientOutput, serverInput, stack)
|
||||
assert.NoError(t, err)
|
||||
close(stdioDone)
|
||||
defer close(stdioDone)
|
||||
conn, dialErr := net.Dial("tcp", ln.Addr().String())
|
||||
if !assert.NoError(t, dialErr) {
|
||||
return
|
||||
}
|
||||
pushErr := stack.push("tcp conn", conn)
|
||||
if !assert.NoError(t, pushErr) {
|
||||
return
|
||||
}
|
||||
agentssh.Bicopy(ctx, conn, &StdioRwc{
|
||||
Reader: clientOutput,
|
||||
Writer: serverInput,
|
||||
})
|
||||
}()
|
||||
|
||||
conn, channels, requests, err := ssh.NewClientConn(&testutil.ReaderWriterConn{
|
||||
@@ -448,3 +464,213 @@ func Test_getWorkspaceAgent(t *testing.T) {
|
||||
assert.Contains(t, err.Error(), "available agents: [clark krypton zod]")
|
||||
})
|
||||
}
|
||||
|
||||
func TestIsRetryableError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "Nil",
|
||||
err: nil,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "ContextCanceled",
|
||||
err: context.Canceled,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "ContextDeadlineExceeded",
|
||||
err: context.DeadlineExceeded,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "WrappedContextCanceled",
|
||||
err: xerrors.Errorf("wrapped: %w", context.Canceled),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "DNSError",
|
||||
err: &net.DNSError{
|
||||
Err: "no such host",
|
||||
Name: "example.com",
|
||||
IsNotFound: true,
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "OpError",
|
||||
err: &net.OpError{
|
||||
Op: "dial",
|
||||
Net: "tcp",
|
||||
Err: &os.SyscallError{},
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "WrappedDNSError",
|
||||
err: xerrors.Errorf("connect failed: %w", &net.DNSError{
|
||||
Err: "no such host",
|
||||
Name: "example.com",
|
||||
}),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "SDKError500",
|
||||
err: codersdk.NewTestError(http.StatusInternalServerError, "GET", "/api"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "SDKError502",
|
||||
err: codersdk.NewTestError(http.StatusBadGateway, "GET", "/api"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "SDKError503",
|
||||
err: codersdk.NewTestError(http.StatusServiceUnavailable, "GET", "/api"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "SDKError401",
|
||||
err: codersdk.NewTestError(http.StatusUnauthorized, "GET", "/api"),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "SDKError403",
|
||||
err: codersdk.NewTestError(http.StatusForbidden, "GET", "/api"),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "SDKError404",
|
||||
err: codersdk.NewTestError(http.StatusNotFound, "GET", "/api"),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "GenericError",
|
||||
err: xerrors.New("something went wrong"),
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
assert.Equal(t, tt.expected, isRetryableError(tt.err))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyContainer(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
agentID := uuid.New()
|
||||
|
||||
newContainerClient := func(t *testing.T, statusCode int, resp codersdk.WorkspaceAgentListContainersResponse) *codersdk.Client {
|
||||
t.Helper()
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(statusCode)
|
||||
_ = json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
serverURL, err := url.Parse(srv.URL)
|
||||
require.NoError(t, err)
|
||||
return codersdk.New(serverURL)
|
||||
}
|
||||
|
||||
t.Run("EmptyName", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
// Empty name short-circuits before making any HTTP calls.
|
||||
serverURL, err := url.Parse(fakeServerURL)
|
||||
require.NoError(t, err)
|
||||
client := codersdk.New(serverURL)
|
||||
|
||||
var stderr bytes.Buffer
|
||||
found, err := verifyContainer(ctx, client, agentID, "", &stderr)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, found)
|
||||
})
|
||||
|
||||
t.Run("FoundByFriendlyName", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
client := newContainerClient(t, http.StatusOK, codersdk.WorkspaceAgentListContainersResponse{
|
||||
Containers: []codersdk.WorkspaceAgentContainer{
|
||||
{ID: "abc123", FriendlyName: "my-container"},
|
||||
},
|
||||
})
|
||||
|
||||
var stderr bytes.Buffer
|
||||
found, err := verifyContainer(ctx, client, agentID, "my-container", &stderr)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, found)
|
||||
})
|
||||
|
||||
t.Run("FoundByID", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
client := newContainerClient(t, http.StatusOK, codersdk.WorkspaceAgentListContainersResponse{
|
||||
Containers: []codersdk.WorkspaceAgentContainer{
|
||||
{ID: "abc123", FriendlyName: "my-container"},
|
||||
},
|
||||
})
|
||||
|
||||
var stderr bytes.Buffer
|
||||
found, err := verifyContainer(ctx, client, agentID, "abc123", &stderr)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, found)
|
||||
})
|
||||
|
||||
t.Run("NotFound", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
client := newContainerClient(t, http.StatusOK, codersdk.WorkspaceAgentListContainersResponse{
|
||||
Containers: []codersdk.WorkspaceAgentContainer{
|
||||
{ID: "abc123", FriendlyName: "other-container"},
|
||||
},
|
||||
})
|
||||
|
||||
var stderr bytes.Buffer
|
||||
found, err := verifyContainer(ctx, client, agentID, "missing", &stderr)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, found)
|
||||
assert.Contains(t, stderr.String(), "Container not found")
|
||||
assert.Contains(t, stderr.String(), "other-container")
|
||||
})
|
||||
|
||||
t.Run("NoContainers", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
client := newContainerClient(t, http.StatusOK, codersdk.WorkspaceAgentListContainersResponse{
|
||||
Containers: []codersdk.WorkspaceAgentContainer{},
|
||||
})
|
||||
|
||||
var stderr bytes.Buffer
|
||||
found, err := verifyContainer(ctx, client, agentID, "anything", &stderr)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, found)
|
||||
assert.Contains(t, stderr.String(), "No containers found")
|
||||
})
|
||||
|
||||
t.Run("APIError", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
client := newContainerClient(t, http.StatusInternalServerError, codersdk.WorkspaceAgentListContainersResponse{})
|
||||
|
||||
var stderr bytes.Buffer
|
||||
_, err := verifyContainer(ctx, client, agentID, "my-container", &stderr)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "list containers")
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user