Compare commits

...

18 Commits

Author SHA1 Message Date
Sas Swart
9764926f92 remove defunct test file 2025-10-30 14:26:47 +00:00
Sas Swart
10d4e42fc1 remove defunct files 2025-10-30 13:37:00 +00:00
Sas Swart
217ddf46c4 fix an incomplete refactor 2025-10-30 13:35:49 +00:00
Sas Swart
0d3d493eae fix an incomplete refactor 2025-10-30 13:28:38 +00:00
Sas Swart
89b060e245 hide functions that do not need to be public 2025-10-30 13:19:00 +00:00
Sas Swart
820d53b66a streamline agentsocket server initialization 2025-10-30 12:55:55 +00:00
Sas Swart
f550028052 Move unit statuses to the appropriate package 2025-10-30 12:23:54 +00:00
Sas Swart
e6873c8d61 rename dependency_tracker.go to manager.go 2025-10-30 12:21:07 +00:00
Sas Swart
8c0bfcb570 Improve agentsocket rpc naming and documentation 2025-10-30 12:17:27 +00:00
Sas Swart
c322b92ab0 remove agent socket auth for now 2025-10-30 12:02:48 +00:00
Sas Swart
216a5ac562 document initSocketServer and tweak its log levels 2025-10-30 11:49:54 +00:00
Sas Swart
86447126d5 make the agent socket path configurable 2025-10-30 11:45:12 +00:00
Sas Swart
55c5b707fb Rename unit.DependencyTracker to unit.Manager 2025-10-30 11:33:20 +00:00
Sas Swart
4616c82f3c switch agent socket to drpc. factor components and add tests 2025-10-30 09:01:17 +00:00
Sas Swart
9ca30e28d6 add a prototype cli command that uses the agent socket 2025-10-28 08:27:25 +00:00
Sas Swart
34c1370090 fix agent socket tests 2025-10-28 06:30:29 +00:00
Sas Swart
851c4f907c add a socket to the agent for local IPC 2025-10-28 06:26:49 +00:00
Sas Swart
e3dfe45f35 LLM generated implementation of unit status change communication 2025-10-27 11:10:22 +00:00
27 changed files with 4725 additions and 0 deletions

View File

@@ -642,6 +642,7 @@ AIBRIDGED_MOCKS := \
GEN_FILES := \
tailnet/proto/tailnet.pb.go \
agent/proto/agent.pb.go \
agent/agentsocket/proto/agentsocket.pb.go \
provisionersdk/proto/provisioner.pb.go \
provisionerd/proto/provisionerd.pb.go \
vpn/vpn.pb.go \
@@ -800,6 +801,14 @@ agent/proto/agent.pb.go: agent/proto/agent.proto
--go-drpc_opt=paths=source_relative \
./agent/proto/agent.proto
agent/agentsocket/proto/agentsocket.pb.go: agent/agentsocket/proto/agentsocket.proto
protoc \
--go_out=. \
--go_opt=paths=source_relative \
--go-drpc_out=. \
--go-drpc_opt=paths=source_relative \
./agent/agentsocket/proto/agentsocket.proto
provisionersdk/proto/provisioner.pb.go: provisionersdk/proto/provisioner.proto
protoc \
--go_out=. \

View File

@@ -40,6 +40,7 @@ import (
"github.com/coder/coder/v2/agent/agentcontainers"
"github.com/coder/coder/v2/agent/agentexec"
"github.com/coder/coder/v2/agent/agentscripts"
"github.com/coder/coder/v2/agent/agentsocket"
"github.com/coder/coder/v2/agent/agentssh"
"github.com/coder/coder/v2/agent/proto"
"github.com/coder/coder/v2/agent/proto/resourcesmonitor"
@@ -91,6 +92,7 @@ type Options struct {
Devcontainers bool
DevcontainerAPIOptions []agentcontainers.Option // Enable Devcontainers for these to be effective.
Clock quartz.Clock
SocketPath string // Path for the agent socket server
}
type Client interface {
@@ -190,6 +192,7 @@ func New(options Options) Agent {
devcontainers: options.Devcontainers,
containerAPIOptions: options.DevcontainerAPIOptions,
socketPath: options.SocketPath,
}
// Initially, we have a closed channel, reflecting the fact that we are not initially connected.
// Each time we connect we replace the channel (while holding the closeMutex) with a new one
@@ -271,6 +274,9 @@ type agent struct {
devcontainers bool
containerAPIOptions []agentcontainers.Option
containerAPI *agentcontainers.API
socketPath string
socketServer *agentsocket.Server
}
func (a *agent) TailnetConn() *tailnet.Conn {
@@ -350,9 +356,35 @@ func (a *agent) init() {
s.ExperimentalContainers = a.devcontainers
},
)
a.initSocketServer()
go a.runLoop()
}
// initSocketServer initializes server that allows direct communication with a workspace agent using IPC.
func (a *agent) initSocketServer() {
if a.socketPath == "" {
a.logger.Info(a.hardCtx, "socket server disabled (no path configured)")
return
}
server, err := agentsocket.NewServer(a.socketPath, a.logger.Named("socket"))
if err != nil {
a.logger.Warn(a.hardCtx, "failed to create socket server", slog.Error(err))
return
}
err = server.Start()
if err != nil {
a.logger.Warn(a.hardCtx, "failed to start socket server", slog.Error(err))
return
}
a.socketServer = server
a.logger.Debug(a.hardCtx, "socket server started", slog.F("path", a.socketPath))
}
// runLoop attempts to start the agent in a retry loop.
// Coder may be offline temporarily, a connection issue
// may be happening, but regardless after the intermittent
@@ -1920,6 +1952,13 @@ func (a *agent) Close() error {
lifecycleState = codersdk.WorkspaceAgentLifecycleShutdownError
}
}
if a.socketServer != nil {
if err := a.socketServer.Stop(); err != nil {
a.logger.Error(a.hardCtx, "socket server close", slog.Error(err))
}
}
a.setLifecycle(lifecycleState)
err = a.scriptRunner.Close()

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,88 @@
syntax = "proto3";
option go_package = "github.com/coder/coder/v2/agent/agentsocket/proto";
package coder.agentsocket.v1;
import "google/protobuf/timestamp.proto";
message PingRequest {}
message PingResponse {
string message = 1;
google.protobuf.Timestamp timestamp = 2;
}
message SyncStartRequest {
string unit = 1;
}
message SyncStartResponse {
bool success = 1;
string message = 2;
}
message SyncWantRequest {
string unit = 1;
string depends_on = 2;
}
message SyncWantResponse {
bool success = 1;
string message = 2;
}
message SyncCompleteRequest {
string unit = 1;
}
message SyncCompleteResponse {
bool success = 1;
string message = 2;
}
message SyncReadyRequest {
string unit = 1;
}
message SyncReadyResponse {
bool success = 1;
string message = 2;
}
message SyncStatusRequest {
string unit = 1;
bool recursive = 2;
}
message DependencyInfo {
string depends_on = 1;
string required_status = 2;
string current_status = 3;
bool is_satisfied = 4;
}
message SyncStatusResponse {
bool success = 1;
string message = 2;
string unit = 3;
string status = 4;
bool is_ready = 5;
repeated DependencyInfo dependencies = 6;
string dot = 7;
}
// AgentSocket provides direct access to the agent over local IPC.
service AgentSocket {
// Ping the agent to check if it is alive.
rpc Ping(PingRequest) returns (PingResponse);
// Report the start of a unit.
rpc SyncStart(SyncStartRequest) returns (SyncStartResponse);
// Declare a dependency between units.
rpc SyncWant(SyncWantRequest) returns (SyncWantResponse);
// Report the completion of a unit.
rpc SyncComplete(SyncCompleteRequest) returns (SyncCompleteResponse);
// Request whether a unit is ready to be started. That is, all dependencies are satisfied.
rpc SyncReady(SyncReadyRequest) returns (SyncReadyResponse);
// Get the status of a unit and list its dependencies.
rpc SyncStatus(SyncStatusRequest) returns (SyncStatusResponse);
}

View File

@@ -0,0 +1,311 @@
// Code generated by protoc-gen-go-drpc. DO NOT EDIT.
// protoc-gen-go-drpc version: v0.0.34
// source: agent/agentsocket/proto/agentsocket.proto
package proto
import (
context "context"
errors "errors"
protojson "google.golang.org/protobuf/encoding/protojson"
proto "google.golang.org/protobuf/proto"
drpc "storj.io/drpc"
drpcerr "storj.io/drpc/drpcerr"
)
type drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto struct{}
func (drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto) Marshal(msg drpc.Message) ([]byte, error) {
return proto.Marshal(msg.(proto.Message))
}
func (drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto) MarshalAppend(buf []byte, msg drpc.Message) ([]byte, error) {
return proto.MarshalOptions{}.MarshalAppend(buf, msg.(proto.Message))
}
func (drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto) Unmarshal(buf []byte, msg drpc.Message) error {
return proto.Unmarshal(buf, msg.(proto.Message))
}
func (drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto) JSONMarshal(msg drpc.Message) ([]byte, error) {
return protojson.Marshal(msg.(proto.Message))
}
func (drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto) JSONUnmarshal(buf []byte, msg drpc.Message) error {
return protojson.Unmarshal(buf, msg.(proto.Message))
}
type DRPCAgentSocketClient interface {
DRPCConn() drpc.Conn
Ping(ctx context.Context, in *PingRequest) (*PingResponse, error)
SyncStart(ctx context.Context, in *SyncStartRequest) (*SyncStartResponse, error)
SyncWant(ctx context.Context, in *SyncWantRequest) (*SyncWantResponse, error)
SyncComplete(ctx context.Context, in *SyncCompleteRequest) (*SyncCompleteResponse, error)
SyncReady(ctx context.Context, in *SyncReadyRequest) (*SyncReadyResponse, error)
SyncStatus(ctx context.Context, in *SyncStatusRequest) (*SyncStatusResponse, error)
}
type drpcAgentSocketClient struct {
cc drpc.Conn
}
func NewDRPCAgentSocketClient(cc drpc.Conn) DRPCAgentSocketClient {
return &drpcAgentSocketClient{cc}
}
func (c *drpcAgentSocketClient) DRPCConn() drpc.Conn { return c.cc }
func (c *drpcAgentSocketClient) Ping(ctx context.Context, in *PingRequest) (*PingResponse, error) {
out := new(PingResponse)
err := c.cc.Invoke(ctx, "/coder.agentsocket.v1.AgentSocket/Ping", drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{}, in, out)
if err != nil {
return nil, err
}
return out, nil
}
func (c *drpcAgentSocketClient) SyncStart(ctx context.Context, in *SyncStartRequest) (*SyncStartResponse, error) {
out := new(SyncStartResponse)
err := c.cc.Invoke(ctx, "/coder.agentsocket.v1.AgentSocket/SyncStart", drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{}, in, out)
if err != nil {
return nil, err
}
return out, nil
}
func (c *drpcAgentSocketClient) SyncWant(ctx context.Context, in *SyncWantRequest) (*SyncWantResponse, error) {
out := new(SyncWantResponse)
err := c.cc.Invoke(ctx, "/coder.agentsocket.v1.AgentSocket/SyncWant", drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{}, in, out)
if err != nil {
return nil, err
}
return out, nil
}
func (c *drpcAgentSocketClient) SyncComplete(ctx context.Context, in *SyncCompleteRequest) (*SyncCompleteResponse, error) {
out := new(SyncCompleteResponse)
err := c.cc.Invoke(ctx, "/coder.agentsocket.v1.AgentSocket/SyncComplete", drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{}, in, out)
if err != nil {
return nil, err
}
return out, nil
}
func (c *drpcAgentSocketClient) SyncReady(ctx context.Context, in *SyncReadyRequest) (*SyncReadyResponse, error) {
out := new(SyncReadyResponse)
err := c.cc.Invoke(ctx, "/coder.agentsocket.v1.AgentSocket/SyncReady", drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{}, in, out)
if err != nil {
return nil, err
}
return out, nil
}
func (c *drpcAgentSocketClient) SyncStatus(ctx context.Context, in *SyncStatusRequest) (*SyncStatusResponse, error) {
out := new(SyncStatusResponse)
err := c.cc.Invoke(ctx, "/coder.agentsocket.v1.AgentSocket/SyncStatus", drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{}, in, out)
if err != nil {
return nil, err
}
return out, nil
}
type DRPCAgentSocketServer interface {
Ping(context.Context, *PingRequest) (*PingResponse, error)
SyncStart(context.Context, *SyncStartRequest) (*SyncStartResponse, error)
SyncWant(context.Context, *SyncWantRequest) (*SyncWantResponse, error)
SyncComplete(context.Context, *SyncCompleteRequest) (*SyncCompleteResponse, error)
SyncReady(context.Context, *SyncReadyRequest) (*SyncReadyResponse, error)
SyncStatus(context.Context, *SyncStatusRequest) (*SyncStatusResponse, error)
}
type DRPCAgentSocketUnimplementedServer struct{}
func (s *DRPCAgentSocketUnimplementedServer) Ping(context.Context, *PingRequest) (*PingResponse, error) {
return nil, drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented)
}
func (s *DRPCAgentSocketUnimplementedServer) SyncStart(context.Context, *SyncStartRequest) (*SyncStartResponse, error) {
return nil, drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented)
}
func (s *DRPCAgentSocketUnimplementedServer) SyncWant(context.Context, *SyncWantRequest) (*SyncWantResponse, error) {
return nil, drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented)
}
func (s *DRPCAgentSocketUnimplementedServer) SyncComplete(context.Context, *SyncCompleteRequest) (*SyncCompleteResponse, error) {
return nil, drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented)
}
func (s *DRPCAgentSocketUnimplementedServer) SyncReady(context.Context, *SyncReadyRequest) (*SyncReadyResponse, error) {
return nil, drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented)
}
func (s *DRPCAgentSocketUnimplementedServer) SyncStatus(context.Context, *SyncStatusRequest) (*SyncStatusResponse, error) {
return nil, drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented)
}
type DRPCAgentSocketDescription struct{}
func (DRPCAgentSocketDescription) NumMethods() int { return 6 }
func (DRPCAgentSocketDescription) Method(n int) (string, drpc.Encoding, drpc.Receiver, interface{}, bool) {
switch n {
case 0:
return "/coder.agentsocket.v1.AgentSocket/Ping", drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{},
func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) {
return srv.(DRPCAgentSocketServer).
Ping(
ctx,
in1.(*PingRequest),
)
}, DRPCAgentSocketServer.Ping, true
case 1:
return "/coder.agentsocket.v1.AgentSocket/SyncStart", drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{},
func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) {
return srv.(DRPCAgentSocketServer).
SyncStart(
ctx,
in1.(*SyncStartRequest),
)
}, DRPCAgentSocketServer.SyncStart, true
case 2:
return "/coder.agentsocket.v1.AgentSocket/SyncWant", drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{},
func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) {
return srv.(DRPCAgentSocketServer).
SyncWant(
ctx,
in1.(*SyncWantRequest),
)
}, DRPCAgentSocketServer.SyncWant, true
case 3:
return "/coder.agentsocket.v1.AgentSocket/SyncComplete", drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{},
func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) {
return srv.(DRPCAgentSocketServer).
SyncComplete(
ctx,
in1.(*SyncCompleteRequest),
)
}, DRPCAgentSocketServer.SyncComplete, true
case 4:
return "/coder.agentsocket.v1.AgentSocket/SyncReady", drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{},
func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) {
return srv.(DRPCAgentSocketServer).
SyncReady(
ctx,
in1.(*SyncReadyRequest),
)
}, DRPCAgentSocketServer.SyncReady, true
case 5:
return "/coder.agentsocket.v1.AgentSocket/SyncStatus", drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{},
func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) {
return srv.(DRPCAgentSocketServer).
SyncStatus(
ctx,
in1.(*SyncStatusRequest),
)
}, DRPCAgentSocketServer.SyncStatus, true
default:
return "", nil, nil, nil, false
}
}
func DRPCRegisterAgentSocket(mux drpc.Mux, impl DRPCAgentSocketServer) error {
return mux.Register(impl, DRPCAgentSocketDescription{})
}
type DRPCAgentSocket_PingStream interface {
drpc.Stream
SendAndClose(*PingResponse) error
}
type drpcAgentSocket_PingStream struct {
drpc.Stream
}
func (x *drpcAgentSocket_PingStream) SendAndClose(m *PingResponse) error {
if err := x.MsgSend(m, drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{}); err != nil {
return err
}
return x.CloseSend()
}
type DRPCAgentSocket_SyncStartStream interface {
drpc.Stream
SendAndClose(*SyncStartResponse) error
}
type drpcAgentSocket_SyncStartStream struct {
drpc.Stream
}
func (x *drpcAgentSocket_SyncStartStream) SendAndClose(m *SyncStartResponse) error {
if err := x.MsgSend(m, drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{}); err != nil {
return err
}
return x.CloseSend()
}
type DRPCAgentSocket_SyncWantStream interface {
drpc.Stream
SendAndClose(*SyncWantResponse) error
}
type drpcAgentSocket_SyncWantStream struct {
drpc.Stream
}
func (x *drpcAgentSocket_SyncWantStream) SendAndClose(m *SyncWantResponse) error {
if err := x.MsgSend(m, drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{}); err != nil {
return err
}
return x.CloseSend()
}
type DRPCAgentSocket_SyncCompleteStream interface {
drpc.Stream
SendAndClose(*SyncCompleteResponse) error
}
type drpcAgentSocket_SyncCompleteStream struct {
drpc.Stream
}
func (x *drpcAgentSocket_SyncCompleteStream) SendAndClose(m *SyncCompleteResponse) error {
if err := x.MsgSend(m, drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{}); err != nil {
return err
}
return x.CloseSend()
}
type DRPCAgentSocket_SyncReadyStream interface {
drpc.Stream
SendAndClose(*SyncReadyResponse) error
}
type drpcAgentSocket_SyncReadyStream struct {
drpc.Stream
}
func (x *drpcAgentSocket_SyncReadyStream) SendAndClose(m *SyncReadyResponse) error {
if err := x.MsgSend(m, drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{}); err != nil {
return err
}
return x.CloseSend()
}
type DRPCAgentSocket_SyncStatusStream interface {
drpc.Stream
SendAndClose(*SyncStatusResponse) error
}
type drpcAgentSocket_SyncStatusStream struct {
drpc.Stream
}
func (x *drpcAgentSocket_SyncStatusStream) SendAndClose(m *SyncStatusResponse) error {
if err := x.MsgSend(m, drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{}); err != nil {
return err
}
return x.CloseSend()
}

View File

@@ -0,0 +1,17 @@
package proto
import "github.com/coder/coder/v2/apiversion"
// Version history:
//
// API v1.0:
// - Initial release
// - Ping
// - Sync operations: SyncStart, SyncWant, SyncComplete, SyncWait, SyncStatus
const (
CurrentMajor = 1
CurrentMinor = 0
)
var CurrentVersion = apiversion.New(CurrentMajor, CurrentMinor)

185
agent/agentsocket/server.go Normal file
View File

@@ -0,0 +1,185 @@
package agentsocket
import (
"context"
"errors"
"net"
"sync"
"time"
"golang.org/x/xerrors"
"github.com/hashicorp/yamux"
"storj.io/drpc/drpcmux"
"storj.io/drpc/drpcserver"
"cdr.dev/slog"
"github.com/coder/coder/v2/agent/agentsocket/proto"
"github.com/coder/coder/v2/agent/unit"
"github.com/coder/coder/v2/codersdk/drpcsdk"
)
// Server provides access to the DRPCAgentSocketService via a Unix domain socket.
// Do not invoke Server{} directly. Use NewServer() instead.
type Server struct {
logger slog.Logger
path string
listener net.Listener
mu sync.RWMutex
ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
drpcServer *drpcserver.Server
service *DRPCAgentSocketService
}
func NewServer(path string, logger slog.Logger) (*Server, error) {
logger = logger.Named("agentsocket")
server := &Server{
logger: logger,
path: path,
service: &DRPCAgentSocketService{
logger: logger,
unitManager: unit.NewManager[string, string](),
},
}
mux := drpcmux.New()
err := proto.DRPCRegisterAgentSocket(mux, server.service)
if err != nil {
return nil, xerrors.Errorf("failed to register drpc service: %w", err)
}
server.drpcServer = drpcserver.NewWithOptions(mux, drpcserver.Options{
Manager: drpcsdk.DefaultDRPCOptions(nil),
Log: func(err error) {
if errors.Is(err, context.Canceled) ||
errors.Is(err, context.DeadlineExceeded) {
return
}
logger.Debug(context.Background(), "drpc server error", slog.Error(err))
},
})
return server, nil
}
var ErrServerAlreadyStarted = xerrors.New("server already started")
func (s *Server) Start() error {
s.mu.Lock()
defer s.mu.Unlock()
if s.listener != nil {
return ErrServerAlreadyStarted
}
// This context is canceled by s.Stop() when the server is stopped.
// canceling it will close all connections.
s.ctx, s.cancel = context.WithCancel(context.Background())
if s.path == "" {
var err error
s.path, err = getDefaultSocketPath()
if err != nil {
return xerrors.Errorf("get default socket path: %w", err)
}
}
listener, err := createSocket(s.path)
if err != nil {
return xerrors.Errorf("create socket: %w", err)
}
s.listener = listener
s.logger.Info(s.ctx, "agent socket server started", slog.F("path", s.path))
s.wg.Add(1)
go func() {
defer s.wg.Done()
s.acceptConnections()
}()
return nil
}
func (s *Server) Stop() error {
s.mu.Lock()
defer s.mu.Unlock()
if s.listener == nil {
return nil
}
s.logger.Info(s.ctx, "stopping agent socket server")
s.cancel()
if err := s.listener.Close(); err != nil {
s.logger.Warn(s.ctx, "error closing socket listener", slog.Error(err))
}
// Wait for all connections to finish
s.wg.Wait()
if err := cleanupSocket(s.path); err != nil {
s.logger.Warn(s.ctx, "error cleaning up socket file", slog.Error(err))
}
s.listener = nil
s.logger.Info(s.ctx, "agent socket server stopped")
return nil
}
func (s *Server) acceptConnections() {
for {
select {
case <-s.ctx.Done():
return
default:
}
conn, err := s.listener.Accept()
if err != nil {
select {
case <-s.ctx.Done():
return
default:
s.logger.Warn(s.ctx, "error accepting connection", slog.Error(err))
continue
}
}
s.wg.Add(1)
go func() {
defer s.wg.Done()
s.handleConnection(conn)
}()
}
}
func (s *Server) handleConnection(conn net.Conn) {
defer conn.Close()
if err := conn.SetDeadline(time.Now().Add(30 * time.Second)); err != nil {
s.logger.Warn(s.ctx, "failed to set connection deadline", slog.Error(err))
}
s.logger.Debug(s.ctx, "new connection accepted", slog.F("remote_addr", conn.RemoteAddr()))
config := yamux.DefaultConfig()
config.Logger = nil
session, err := yamux.Server(conn, config)
if err != nil {
s.logger.Warn(s.ctx, "failed to create yamux session", slog.Error(err))
return
}
defer session.Close()
err = s.drpcServer.Serve(s.ctx, session)
if err != nil {
s.logger.Debug(s.ctx, "drpc server finished", slog.Error(err))
}
}

View File

@@ -0,0 +1,48 @@
package agentsocket_test
import (
"path/filepath"
"testing"
"github.com/stretchr/testify/require"
"cdr.dev/slog"
"github.com/coder/coder/v2/agent/agentsocket"
)
func TestServer(t *testing.T) {
t.Parallel()
t.Run("StartStop", func(t *testing.T) {
t.Parallel()
socketPath := filepath.Join(t.TempDir(), "test.sock")
logger := slog.Make().Leveled(slog.LevelDebug)
server, err := agentsocket.NewServer(socketPath, logger)
require.NoError(t, err)
require.NoError(t, server.Start())
require.NoError(t, server.Stop())
})
t.Run("AlreadyStarted", func(t *testing.T) {
t.Parallel()
socketPath := filepath.Join(t.TempDir(), "test.sock")
logger := slog.Make().Leveled(slog.LevelDebug)
server, err := agentsocket.NewServer(socketPath, logger)
require.NoError(t, err)
require.NoError(t, server.Start())
require.ErrorIs(t, server.Start(), agentsocket.ErrServerAlreadyStarted)
})
t.Run("AutoSocketPath", func(t *testing.T) {
t.Parallel()
socketPath := filepath.Join(t.TempDir(), "test.sock")
logger := slog.Make().Leveled(slog.LevelDebug)
server, err := agentsocket.NewServer(socketPath, logger)
require.NoError(t, err)
require.NoError(t, server.Start())
require.NoError(t, server.Stop())
})
}

View File

@@ -0,0 +1,262 @@
package agentsocket
import (
"context"
"errors"
"sync"
"time"
"google.golang.org/protobuf/types/known/timestamppb"
"cdr.dev/slog"
"github.com/coder/coder/v2/agent/agentsocket/proto"
"github.com/coder/coder/v2/agent/unit"
)
var _ proto.DRPCAgentSocketServer = (*DRPCAgentSocketService)(nil)
type DRPCAgentSocketService struct {
mu sync.RWMutex
unitManager *unit.Manager[string, string]
logger slog.Logger
}
func (*DRPCAgentSocketService) Ping(_ context.Context, _ *proto.PingRequest) (*proto.PingResponse, error) {
return &proto.PingResponse{
Message: "pong",
Timestamp: timestamppb.New(time.Now()),
}, nil
}
func (s *DRPCAgentSocketService) SyncStart(_ context.Context, req *proto.SyncStartRequest) (*proto.SyncStartResponse, error) {
if s.unitManager == nil {
return &proto.SyncStartResponse{
Success: false,
Message: "dependency tracker not available",
}, nil
}
if req.Unit == "" {
return &proto.SyncStartResponse{
Success: false,
Message: "Unit name is required",
}, nil
}
if err := s.unitManager.Register(req.Unit); err != nil {
// If already registered, that's okay - we can still update status
if !errors.Is(err, unit.ErrConsumerAlreadyRegistered) {
return &proto.SyncStartResponse{
Success: false,
Message: "Failed to register unit: " + err.Error(),
}, nil
}
}
isReady, err := s.unitManager.IsReady(req.Unit)
if err != nil {
return &proto.SyncStartResponse{
Success: false,
Message: "Failed to check readiness: " + err.Error(),
}, nil
}
if !isReady {
return &proto.SyncStartResponse{
Success: false,
Message: "Unit is not ready",
}, nil
}
if err := s.unitManager.UpdateStatus(req.Unit, unit.StatusStarted); err != nil {
return &proto.SyncStartResponse{
Success: false,
Message: "Failed to update status: " + err.Error(),
}, nil
}
return &proto.SyncStartResponse{
Success: true,
Message: "Unit " + req.Unit + " started successfully",
}, nil
}
func (s *DRPCAgentSocketService) SyncWant(_ context.Context, req *proto.SyncWantRequest) (*proto.SyncWantResponse, error) {
if s.unitManager == nil {
return &proto.SyncWantResponse{
Success: false,
Message: "unit manager not available",
}, nil
}
if req.Unit == "" || req.DependsOn == "" {
return &proto.SyncWantResponse{
Success: false,
Message: "unit and depends_on are required",
}, nil
}
if err := s.unitManager.Register(req.Unit); err != nil {
if !errors.Is(err, unit.ErrConsumerAlreadyRegistered) {
return &proto.SyncWantResponse{
Success: false,
Message: "failed to register unit: " + err.Error(),
}, nil
}
}
if err := s.unitManager.Register(req.DependsOn); err != nil {
if !errors.Is(err, unit.ErrConsumerAlreadyRegistered) {
return &proto.SyncWantResponse{
Success: false,
Message: "failed to register dependency unit: " + err.Error(),
}, nil
}
}
if err := s.unitManager.AddDependency(req.Unit, req.DependsOn, unit.StatusComplete); err != nil {
return &proto.SyncWantResponse{
Success: false,
Message: "failed to add dependency: " + err.Error(),
}, nil
}
return &proto.SyncWantResponse{
Success: true,
Message: "Unit " + req.Unit + " now depends on " + req.DependsOn,
}, nil
}
func (s *DRPCAgentSocketService) SyncComplete(_ context.Context, req *proto.SyncCompleteRequest) (*proto.SyncCompleteResponse, error) {
if s.unitManager == nil {
return &proto.SyncCompleteResponse{
Success: false,
Message: "unit manager not available",
}, nil
}
if req.Unit == "" {
return &proto.SyncCompleteResponse{
Success: false,
Message: "unit name is required",
}, nil
}
if err := s.unitManager.UpdateStatus(req.Unit, unit.StatusComplete); err != nil {
return &proto.SyncCompleteResponse{
Success: false,
Message: "failed to update status: " + err.Error(),
}, nil
}
return &proto.SyncCompleteResponse{
Success: true,
Message: "unit " + req.Unit + " completed successfully",
}, nil
}
func (s *DRPCAgentSocketService) SyncReady(_ context.Context, req *proto.SyncReadyRequest) (*proto.SyncReadyResponse, error) {
if s.unitManager == nil {
return &proto.SyncReadyResponse{
Success: false,
Message: "unit manager not available",
}, nil
}
if req.Unit == "" {
return &proto.SyncReadyResponse{
Success: false,
Message: "unit name is required",
}, nil
}
isReady, err := s.unitManager.IsReady(req.Unit)
if err != nil {
return &proto.SyncReadyResponse{
Success: false,
Message: "failed to check readiness: " + err.Error(),
}, nil
}
if !isReady {
return &proto.SyncReadyResponse{
Success: false,
Message: unit.ErrDependenciesNotSatisfied.Error(),
}, nil
}
return &proto.SyncReadyResponse{
Success: true,
Message: "unit " + req.Unit + " dependencies are satisfied",
}, nil
}
func (s *DRPCAgentSocketService) SyncStatus(_ context.Context, req *proto.SyncStatusRequest) (*proto.SyncStatusResponse, error) {
if s.unitManager == nil {
return &proto.SyncStatusResponse{
Success: false,
Message: "unit manager not available",
}, nil
}
if req.Unit == "" {
return &proto.SyncStatusResponse{
Success: false,
Message: "unit name is required",
}, nil
}
status, err := s.unitManager.GetStatus(req.Unit)
if err != nil {
return &proto.SyncStatusResponse{
Success: false,
Message: "failed to get unit status: " + err.Error(),
}, nil
}
isReady, err := s.unitManager.IsReady(req.Unit)
if err != nil {
return &proto.SyncStatusResponse{
Success: false,
Message: "failed to check readiness: " + err.Error(),
}, nil
}
dependencies, err := s.unitManager.GetAllDependencies(req.Unit)
if err != nil {
return &proto.SyncStatusResponse{
Success: false,
Message: "failed to get dependencies: " + err.Error(),
}, nil
}
var depInfos []*proto.DependencyInfo
for _, dep := range dependencies {
depInfos = append(depInfos, &proto.DependencyInfo{
DependsOn: dep.DependsOn,
RequiredStatus: dep.RequiredStatus,
CurrentStatus: dep.CurrentStatus,
IsSatisfied: dep.IsSatisfied,
})
}
var dotStr string
if req.Recursive {
dotStr, err = s.unitManager.ExportDOT("dependency_graph")
if err != nil {
return &proto.SyncStatusResponse{
Success: false,
Message: "failed to export DOT: " + err.Error(),
}, nil
}
}
return &proto.SyncStatusResponse{
Success: true,
Message: "unit status retrieved successfully",
Unit: req.Unit,
Status: status,
IsReady: isReady,
Dependencies: depInfos,
Dot: dotStr,
}, nil
}

View File

@@ -0,0 +1,311 @@
package agentsocket_test
import (
"context"
"path/filepath"
"testing"
"github.com/stretchr/testify/require"
"cdr.dev/slog"
"github.com/coder/coder/v2/agent/agentsocket"
"github.com/coder/coder/v2/agent/unit"
"github.com/coder/coder/v2/codersdk/agentsdk"
)
func TestDRPCAgentSocketService(t *testing.T) {
t.Parallel()
t.Run("Ping", func(t *testing.T) {
t.Parallel()
socketPath := filepath.Join(t.TempDir(), "test.sock")
server, err := agentsocket.NewServer(
socketPath,
slog.Make().Leveled(slog.LevelDebug),
)
require.NoError(t, err)
err = server.Start()
require.NoError(t, err)
defer server.Stop()
client, err := agentsdk.NewSocketClient(agentsdk.SocketConfig{
Path: socketPath,
})
require.NoError(t, err)
defer client.Close()
response, err := client.Ping(context.Background())
require.NoError(t, err)
require.Equal(t, "pong", response.Message)
})
t.Run("SyncStart", func(t *testing.T) {
t.Parallel()
t.Run("NewUnit", func(t *testing.T) {
t.Parallel()
socketPath := filepath.Join(t.TempDir(), "test.sock")
server, err := agentsocket.NewServer(
socketPath,
slog.Make().Leveled(slog.LevelDebug),
)
require.NoError(t, err)
err = server.Start()
require.NoError(t, err)
defer server.Stop()
client, err := agentsdk.NewSocketClient(agentsdk.SocketConfig{
Path: socketPath,
})
require.NoError(t, err)
defer client.Close()
err = client.SyncStart(context.Background(), "test-unit")
require.NoError(t, err)
status, err := client.SyncStatus(context.Background(), "test-unit", false)
require.NoError(t, err)
require.Equal(t, "started", status.Status)
})
t.Run("UnitAlreadyStarted", func(t *testing.T) {
t.Parallel()
socketPath := filepath.Join(t.TempDir(), "test.sock")
server, err := agentsocket.NewServer(
socketPath,
slog.Make().Leveled(slog.LevelDebug),
)
require.NoError(t, err)
err = server.Start()
require.NoError(t, err)
defer server.Stop()
client, err := agentsdk.NewSocketClient(agentsdk.SocketConfig{
Path: socketPath,
})
require.NoError(t, err)
defer client.Close()
err = client.SyncStart(context.Background(), "test-unit")
require.NoError(t, err)
// First Start
status, err := client.SyncStatus(context.Background(), "test-unit", false)
require.NoError(t, err)
require.Equal(t, "started", status.Status)
status, err = client.SyncStatus(context.Background(), "test-unit", false)
require.NoError(t, err)
require.Equal(t, "started", status.Status)
// Second Start
err = client.SyncStart(context.Background(), "test-unit")
require.ErrorContains(t, err, unit.ErrSameStatusAlreadySet.Error())
status, err = client.SyncStatus(context.Background(), "test-unit", false)
require.NoError(t, err)
require.Equal(t, "started", status.Status)
})
t.Run("UnitAlreadyCompleted", func(t *testing.T) {
t.Parallel()
socketPath := filepath.Join(t.TempDir(), "test.sock")
server, err := agentsocket.NewServer(
socketPath,
slog.Make().Leveled(slog.LevelDebug),
)
require.NoError(t, err)
err = server.Start()
require.NoError(t, err)
defer server.Stop()
client, err := agentsdk.NewSocketClient(agentsdk.SocketConfig{
Path: socketPath,
})
require.NoError(t, err)
defer client.Close()
// First start
err = client.SyncStart(context.Background(), "test-unit")
require.NoError(t, err)
status, err := client.SyncStatus(context.Background(), "test-unit", false)
require.NoError(t, err)
require.Equal(t, "started", status.Status)
// Complete the unit
err = client.SyncComplete(context.Background(), "test-unit")
require.NoError(t, err)
status, err = client.SyncStatus(context.Background(), "test-unit", false)
require.NoError(t, err)
require.Equal(t, "completed", status.Status)
// Second start
err = client.SyncStart(context.Background(), "test-unit")
require.NoError(t, err)
status, err = client.SyncStatus(context.Background(), "test-unit", false)
require.NoError(t, err)
require.Equal(t, "started", status.Status)
})
t.Run("UnitNotReady", func(t *testing.T) {
t.Parallel()
socketPath := filepath.Join(t.TempDir(), "test.sock")
server, err := agentsocket.NewServer(
socketPath,
slog.Make().Leveled(slog.LevelDebug),
)
require.NoError(t, err)
err = server.Start()
require.NoError(t, err)
defer server.Stop()
client, err := agentsdk.NewSocketClient(agentsdk.SocketConfig{
Path: socketPath,
})
require.NoError(t, err)
defer client.Close()
client.SyncWant(context.Background(), "test-unit", "dependency-unit")
require.NoError(t, err)
err = client.SyncStart(context.Background(), "test-unit")
require.ErrorContains(t, err, "Unit is not ready")
status, err := client.SyncStatus(context.Background(), "test-unit", false)
require.NoError(t, err)
require.Equal(t, "", status.Status)
})
})
t.Run("SyncWant", func(t *testing.T) {
t.Parallel()
t.Run("NewUnits", func(t *testing.T) {
t.Parallel()
socketPath := filepath.Join(t.TempDir(), "test.sock")
server, err := agentsocket.NewServer(
socketPath,
slog.Make().Leveled(slog.LevelDebug),
)
require.NoError(t, err)
err = server.Start()
require.NoError(t, err)
defer server.Stop()
client, err := agentsdk.NewSocketClient(agentsdk.SocketConfig{
Path: socketPath,
})
require.NoError(t, err)
defer client.Close()
// If units are not registered, they are registered automatically
err = client.SyncWant(context.Background(), "test-unit", "dependency-unit")
require.NoError(t, err)
status, err := client.SyncStatus(context.Background(), "test-unit", false)
require.NoError(t, err)
require.Equal(t, "dependency-unit", status.Dependencies[0].DependsOn)
require.Equal(t, "completed", status.Dependencies[0].RequiredStatus)
})
t.Run("DependencyAlreadyRegistered", func(t *testing.T) {
t.Parallel()
socketPath := filepath.Join(t.TempDir(), "test.sock")
server, err := agentsocket.NewServer(
socketPath,
slog.Make().Leveled(slog.LevelDebug),
)
require.NoError(t, err)
err = server.Start()
require.NoError(t, err)
defer server.Stop()
client, err := agentsdk.NewSocketClient(agentsdk.SocketConfig{
Path: socketPath,
})
require.NoError(t, err)
defer client.Close()
// Start the dependency unit
err = client.SyncStart(context.Background(), "dependency-unit")
require.NoError(t, err)
status, err := client.SyncStatus(context.Background(), "dependency-unit", false)
require.NoError(t, err)
require.Equal(t, "started", status.Status)
// Add the dependency after the dependency unit has already started
err = client.SyncWant(context.Background(), "test-unit", "dependency-unit")
// Dependencies can be added even if the dependency unit has already started
require.NoError(t, err)
// The dependency is now reflected in the test unit's status
status, err = client.SyncStatus(context.Background(), "test-unit", false)
require.NoError(t, err)
require.Equal(t, "dependency-unit", status.Dependencies[0].DependsOn)
require.Equal(t, "completed", status.Dependencies[0].RequiredStatus)
})
t.Run("DependencyAddedAfterDependentStarted", func(t *testing.T) {
t.Parallel()
socketPath := filepath.Join(t.TempDir(), "test.sock")
server, err := agentsocket.NewServer(
socketPath,
slog.Make().Leveled(slog.LevelDebug),
)
require.NoError(t, err)
err = server.Start()
require.NoError(t, err)
defer server.Stop()
client, err := agentsdk.NewSocketClient(agentsdk.SocketConfig{
Path: socketPath,
})
require.NoError(t, err)
defer client.Close()
// Start the dependent unit
err = client.SyncStart(context.Background(), "test-unit")
require.NoError(t, err)
status, err := client.SyncStatus(context.Background(), "test-unit", false)
require.NoError(t, err)
require.Equal(t, "started", status.Status)
// Add the dependency after the dependency unit has already started
err = client.SyncWant(context.Background(), "test-unit", "dependency-unit")
// Dependencies can be added even if the dependent unit has already started.
// The dependency applies the next time a unit is started. The current status is not updated.
// This is to allow flexible dependency management. It does mean that users of this API should
// take care to add dependencies before they start their dependent units.
require.NoError(t, err)
// The dependency is now reflected in the test unit's status
status, err = client.SyncStatus(context.Background(), "test-unit", false)
require.NoError(t, err)
require.Equal(t, "dependency-unit", status.Dependencies[0].DependsOn)
require.Equal(t, "completed", status.Dependencies[0].RequiredStatus)
})
})
}

View File

@@ -0,0 +1,76 @@
//go:build !windows
package agentsocket
import (
"fmt"
"net"
"os"
"path/filepath"
"golang.org/x/xerrors"
)
// createSocket creates a Unix domain socket listener
func createSocket(path string) (net.Listener, error) {
if !isSocketAvailable(path) {
return nil, xerrors.Errorf("socket path %s is not available", path)
}
if err := os.Remove(path); err != nil && !os.IsNotExist(err) {
return nil, xerrors.Errorf("remove existing socket: %w", err)
}
// Create parent directory if it doesn't exist
parentDir := filepath.Dir(path)
if err := os.MkdirAll(parentDir, 0o700); err != nil {
return nil, xerrors.Errorf("create socket directory: %w", err)
}
listener, err := net.Listen("unix", path)
if err != nil {
return nil, xerrors.Errorf("listen on unix socket: %w", err)
}
if err := os.Chmod(path, 0o600); err != nil {
_ = listener.Close()
return nil, xerrors.Errorf("set socket permissions: %w", err)
}
return listener, nil
}
// getDefaultSocketPath returns the default socket path for Unix-like systems
func getDefaultSocketPath() (string, error) {
// Try XDG_RUNTIME_DIR first
if runtimeDir := os.Getenv("XDG_RUNTIME_DIR"); runtimeDir != "" {
return filepath.Join(runtimeDir, "coder-agent.sock"), nil
}
// Fall back to /tmp with user-specific path
uid := os.Getuid()
return filepath.Join("/tmp", fmt.Sprintf("coder-agent-%d.sock", uid)), nil
}
// CleanupSocket removes the socket file
func cleanupSocket(path string) error {
return os.Remove(path)
}
// isSocketAvailable checks if a socket path is available for use
func isSocketAvailable(path string) bool {
// Check if file exists
if _, err := os.Stat(path); os.IsNotExist(err) {
return true
}
// Try to connect to see if it's actually listening
conn, err := net.Dial("unix", path)
if err != nil {
// If we can't connect, the socket is not in use
// Socket is available for use
return true
}
_ = conn.Close()
// Socket is in use
return false
}

View File

@@ -0,0 +1,111 @@
//go:build windows
package agentsocket
import (
"context"
"fmt"
"net"
"os"
"path/filepath"
"strconv"
"time"
"cdr.dev/slog"
)
// createSocket creates a Unix domain socket listener on Windows
// Falls back to named pipe if Unix sockets are not supported
func CreateSocket(path string) (net.Listener, error) {
// Try Unix domain socket first (Windows 10 build 17063+)
listener, err := net.Listen("unix", path)
if err == nil {
return listener, nil
}
// Fall back to named pipe
pipePath := `\\.\pipe\coder-agent`
listener, err = net.Listen("tcp", pipePath)
if err != nil {
return nil, err
}
return listener, nil
}
// getDefaultSocketPath returns the default socket path for Windows
func GetDefaultSocketPath() (string, error) {
// Try to use a temporary directory
tempDir := os.TempDir()
if tempDir == "" {
tempDir = "C:\\temp"
}
// Create a user-specific subdirectory
uid := os.Getuid()
userDir := filepath.Join(tempDir, "coder-agent", strconv.Itoa(uid))
if err := os.MkdirAll(userDir, 0o700); err != nil {
return "", fmt.Errorf("create user directory: %w", err)
}
return filepath.Join(userDir, "agent.sock"), nil
}
// cleanupSocket removes the socket file
func CleanupSocket(path string) error {
return os.Remove(path)
}
// isSocketAvailable checks if a socket path is available for use
func IsSocketAvailable(path string, logger slog.Logger) bool {
logger.Debug(context.Background(), "Checking socket availability on Windows", slog.F("path", path))
// Check if file exists
if _, err := os.Stat(path); os.IsNotExist(err) {
logger.Debug(context.Background(), "Socket file does not exist, path is available", slog.F("path", path))
return true
}
logger.Debug(context.Background(), "Socket file exists, checking if it's listening", slog.F("path", path))
// Try to connect to see if it's actually listening
conn, err := net.Dial("unix", path)
if err != nil {
// If we can't connect, the socket is not in use
logger.Debug(context.Background(), "Cannot connect to socket, path is available", slog.F("path", path), slog.Error(err))
return true
}
_ = conn.Close()
logger.Debug(context.Background(), "Socket is listening, path is not available", slog.F("path", path))
return false
}
// getSocketInfo returns information about the socket file
func GetSocketInfo(path string) (*SocketInfo, error) {
stat, err := os.Stat(path)
if err != nil {
return nil, err
}
// On Windows, we'll use a simplified approach for now
// In a real implementation, you'd get the security descriptor
return &SocketInfo{
Path: path,
UID: 0, // Simplified for now
GID: 0, // Simplified for now
Mode: stat.Mode(),
ModTime: stat.ModTime(),
Owner: "unknown",
Group: "unknown",
}, nil
}
// SocketInfo contains information about a socket file
type SocketInfo struct {
Path string
UID int
GID int
Mode os.FileMode
ModTime time.Time
Owner string // Windows SID string
Group string // Windows SID string
}

307
agent/unit/manager.go Normal file
View File

@@ -0,0 +1,307 @@
package unit
import (
"sync"
"golang.org/x/xerrors"
)
// ErrConsumerNotFound is returned when a consumer ID is not registered.
var ErrConsumerNotFound = xerrors.New("consumer not found")
// ErrConsumerAlreadyRegistered is returned when a consumer ID is already registered.
var ErrConsumerAlreadyRegistered = xerrors.New("consumer already registered")
// ErrCannotUpdateOtherConsumer is returned when attempting to update another consumer's status.
var ErrCannotUpdateOtherConsumer = xerrors.New("cannot update other consumer's status")
// ErrDependenciesNotSatisfied is returned when a consumer's dependencies are not satisfied.
var ErrDependenciesNotSatisfied = xerrors.New("unit dependencies not satisfied")
// ErrSameStatusAlreadySet is returned when attempting to set the same status as the current status.
var ErrSameStatusAlreadySet = xerrors.New("same status already set")
// Status constants for dependency tracking
const (
StatusStarted = "started"
StatusComplete = "completed"
)
// dependencyVertex represents a vertex in the dependency graph that is associated with a consumer.
type dependencyVertex[ConsumerID comparable] struct {
ID ConsumerID
}
// Dependency represents a dependency relationship between consumers.
type Dependency[StatusType, ConsumerID comparable] struct {
Consumer ConsumerID
DependsOn ConsumerID
RequiredStatus StatusType
CurrentStatus StatusType
IsSatisfied bool
}
// Manager provides reactive dependency tracking over a Graph.
// It manages consumer registration, dependency relationships, and status updates
// with automatic recalculation of readiness when dependencies are satisfied.
type Manager[StatusType, ConsumerID comparable] struct {
mu sync.RWMutex
// The underlying graph that stores dependency relationships
graph *Graph[StatusType, *dependencyVertex[ConsumerID]]
// Track current status of each consumer
consumerStatus map[ConsumerID]StatusType
// Track readiness state (cached to avoid repeated graph traversal)
consumerReadiness map[ConsumerID]bool
// Track which consumers are registered
registeredConsumers map[ConsumerID]bool
// Store vertex instances for each consumer to ensure consistent references
consumerVertices map[ConsumerID]*dependencyVertex[ConsumerID]
}
// NewManager creates a new Manager instance.
func NewManager[StatusType, ConsumerID comparable]() *Manager[StatusType, ConsumerID] {
return &Manager[StatusType, ConsumerID]{
graph: &Graph[StatusType, *dependencyVertex[ConsumerID]]{},
consumerStatus: make(map[ConsumerID]StatusType),
consumerReadiness: make(map[ConsumerID]bool),
registeredConsumers: make(map[ConsumerID]bool),
consumerVertices: make(map[ConsumerID]*dependencyVertex[ConsumerID]),
}
}
// Register registers a new consumer as a vertex in the dependency graph.
func (dt *Manager[StatusType, ConsumerID]) Register(id ConsumerID) error {
dt.mu.Lock()
defer dt.mu.Unlock()
if dt.registeredConsumers[id] {
return ErrConsumerAlreadyRegistered
}
// Create and store the vertex for this consumer
vertex := &dependencyVertex[ConsumerID]{ID: id}
dt.consumerVertices[id] = vertex
dt.registeredConsumers[id] = true
dt.consumerReadiness[id] = true // New consumers start as ready (no dependencies)
return nil
}
// AddDependency adds a dependency relationship between consumers.
// The consumer depends on the dependsOn consumer reaching the requiredStatus.
func (dt *Manager[StatusType, ConsumerID]) AddDependency(consumer ConsumerID, dependsOn ConsumerID, requiredStatus StatusType) error {
dt.mu.Lock()
defer dt.mu.Unlock()
if !dt.registeredConsumers[consumer] {
return xerrors.Errorf("consumer %v is not registered", consumer)
}
if !dt.registeredConsumers[dependsOn] {
return xerrors.Errorf("consumer %v is not registered", dependsOn)
}
// Get the stored vertices for both consumers
consumerVertex := dt.consumerVertices[consumer]
dependsOnVertex := dt.consumerVertices[dependsOn]
// Add the dependency edge to the graph
// The edge goes from consumer to dependsOn, representing the dependency
err := dt.graph.AddEdge(consumerVertex, dependsOnVertex, requiredStatus)
if err != nil {
return xerrors.Errorf("failed to add dependency: %w", err)
}
// Recalculate readiness for the consumer since it now has a dependency
dt.recalculateReadinessUnsafe(consumer)
return nil
}
// UpdateStatus updates a consumer's status and recalculates readiness for affected dependents.
func (dt *Manager[StatusType, ConsumerID]) UpdateStatus(consumer ConsumerID, newStatus StatusType) error {
dt.mu.Lock()
defer dt.mu.Unlock()
if !dt.registeredConsumers[consumer] {
return ErrConsumerNotFound
}
// Update the consumer's status
if dt.consumerStatus[consumer] == newStatus {
return ErrSameStatusAlreadySet
}
dt.consumerStatus[consumer] = newStatus
// Get all consumers that depend on this one (reverse adjacent vertices)
consumerVertex := dt.consumerVertices[consumer]
dependentEdges := dt.graph.GetReverseAdjacentVertices(consumerVertex)
// Recalculate readiness for all dependents
for _, edge := range dependentEdges {
dt.recalculateReadinessUnsafe(edge.From.ID)
}
return nil
}
// IsReady checks if all dependencies for a consumer are satisfied.
func (dt *Manager[StatusType, ConsumerID]) IsReady(consumer ConsumerID) (bool, error) {
dt.mu.RLock()
defer dt.mu.RUnlock()
if !dt.registeredConsumers[consumer] {
return false, ErrConsumerNotFound
}
return dt.consumerReadiness[consumer], nil
}
// GetUnmetDependencies returns a list of unsatisfied dependencies for a consumer.
func (dt *Manager[StatusType, ConsumerID]) GetUnmetDependencies(consumer ConsumerID) ([]Dependency[StatusType, ConsumerID], error) {
dt.mu.RLock()
defer dt.mu.RUnlock()
if !dt.registeredConsumers[consumer] {
return nil, ErrConsumerNotFound
}
consumerVertex := dt.consumerVertices[consumer]
forwardEdges := dt.graph.GetForwardAdjacentVertices(consumerVertex)
var unmetDependencies []Dependency[StatusType, ConsumerID]
for _, edge := range forwardEdges {
dependsOnConsumer := edge.To.ID
requiredStatus := edge.Edge
currentStatus, exists := dt.consumerStatus[dependsOnConsumer]
if !exists {
// If the dependency consumer has no status, it's not satisfied
var zeroStatus StatusType
unmetDependencies = append(unmetDependencies, Dependency[StatusType, ConsumerID]{
Consumer: consumer,
DependsOn: dependsOnConsumer,
RequiredStatus: requiredStatus,
CurrentStatus: zeroStatus, // Zero value
IsSatisfied: false,
})
} else {
isSatisfied := currentStatus == requiredStatus
if !isSatisfied {
unmetDependencies = append(unmetDependencies, Dependency[StatusType, ConsumerID]{
Consumer: consumer,
DependsOn: dependsOnConsumer,
RequiredStatus: requiredStatus,
CurrentStatus: currentStatus,
IsSatisfied: false,
})
}
}
}
return unmetDependencies, nil
}
// recalculateReadinessUnsafe recalculates the readiness state for a consumer.
// This method assumes the caller holds the write lock.
func (dt *Manager[StatusType, ConsumerID]) recalculateReadinessUnsafe(consumer ConsumerID) {
consumerVertex := dt.consumerVertices[consumer]
forwardEdges := dt.graph.GetForwardAdjacentVertices(consumerVertex)
// If there are no dependencies, the consumer is ready
if len(forwardEdges) == 0 {
dt.consumerReadiness[consumer] = true
return
}
// Check if all dependencies are satisfied
allSatisfied := true
for _, edge := range forwardEdges {
dependsOnConsumer := edge.To.ID
requiredStatus := edge.Edge
currentStatus, exists := dt.consumerStatus[dependsOnConsumer]
if !exists || currentStatus != requiredStatus {
allSatisfied = false
break
}
}
dt.consumerReadiness[consumer] = allSatisfied
}
// GetGraph returns the underlying graph for visualization and debugging.
// This should be used carefully as it exposes the internal graph structure.
func (dt *Manager[StatusType, ConsumerID]) GetGraph() *Graph[StatusType, *dependencyVertex[ConsumerID]] {
return dt.graph
}
// GetStatus returns the current status of a consumer.
func (dt *Manager[StatusType, ConsumerID]) GetStatus(consumer ConsumerID) (StatusType, error) {
dt.mu.RLock()
defer dt.mu.RUnlock()
if !dt.registeredConsumers[consumer] {
var zeroStatus StatusType
return zeroStatus, ErrConsumerNotFound
}
status, exists := dt.consumerStatus[consumer]
if !exists {
var zeroStatus StatusType
return zeroStatus, nil
}
return status, nil
}
// GetAllDependencies returns all dependencies for a consumer, both satisfied and unsatisfied.
func (dt *Manager[StatusType, ConsumerID]) GetAllDependencies(consumer ConsumerID) ([]Dependency[StatusType, ConsumerID], error) {
dt.mu.RLock()
defer dt.mu.RUnlock()
if !dt.registeredConsumers[consumer] {
return nil, ErrConsumerNotFound
}
consumerVertex := dt.consumerVertices[consumer]
forwardEdges := dt.graph.GetForwardAdjacentVertices(consumerVertex)
var allDependencies []Dependency[StatusType, ConsumerID]
for _, edge := range forwardEdges {
dependsOnConsumer := edge.To.ID
requiredStatus := edge.Edge
currentStatus, exists := dt.consumerStatus[dependsOnConsumer]
if !exists {
// If the dependency consumer has no status, it's not satisfied
var zeroStatus StatusType
allDependencies = append(allDependencies, Dependency[StatusType, ConsumerID]{
Consumer: consumer,
DependsOn: dependsOnConsumer,
RequiredStatus: requiredStatus,
CurrentStatus: zeroStatus, // Zero value
IsSatisfied: false,
})
} else {
isSatisfied := currentStatus == requiredStatus
allDependencies = append(allDependencies, Dependency[StatusType, ConsumerID]{
Consumer: consumer,
DependsOn: dependsOnConsumer,
RequiredStatus: requiredStatus,
CurrentStatus: currentStatus,
IsSatisfied: isSatisfied,
})
}
}
return allDependencies, nil
}
// ExportDOT exports the dependency graph to DOT format for visualization.
func (dt *Manager[StatusType, ConsumerID]) ExportDOT(name string) (string, error) {
return dt.graph.ToDOT(name)
}

691
agent/unit/manager_test.go Normal file
View File

@@ -0,0 +1,691 @@
package unit_test
import (
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/agent/unit"
)
type testStatus string
const (
statusStarted testStatus = "started"
statusRunning testStatus = "running"
statusCompleted testStatus = "completed"
)
type testConsumerID string
const (
consumerA testConsumerID = "serviceA"
consumerB testConsumerID = "serviceB"
consumerC testConsumerID = "serviceC"
consumerD testConsumerID = "serviceD"
)
func TestDependencyTracker_Register(t *testing.T) {
t.Parallel()
tracker := unit.NewManager[testStatus, testConsumerID]()
t.Run("RegisterNewConsumer", func(t *testing.T) {
t.Parallel()
err := tracker.Register(consumerA)
require.NoError(t, err)
// Consumer should be ready initially (no dependencies)
ready, err := tracker.IsReady(consumerA)
require.NoError(t, err)
assert.True(t, ready)
})
t.Run("RegisterDuplicateConsumer", func(t *testing.T) {
t.Parallel()
tracker := unit.NewManager[testStatus, testConsumerID]()
err := tracker.Register(consumerA)
require.NoError(t, err)
err = tracker.Register(consumerA)
require.Error(t, err)
assert.Contains(t, err.Error(), "already registered")
})
t.Run("RegisterMultipleConsumers", func(t *testing.T) {
t.Parallel()
tracker := unit.NewManager[testStatus, testConsumerID]()
consumers := []testConsumerID{consumerA, consumerB, consumerC}
for _, consumer := range consumers {
err := tracker.Register(consumer)
require.NoError(t, err)
}
// All should be ready initially
for _, consumer := range consumers {
ready, err := tracker.IsReady(consumer)
require.NoError(t, err)
assert.True(t, ready)
}
})
}
func TestDependencyTracker_AddDependency(t *testing.T) {
t.Parallel()
t.Run("AddDependencyBetweenRegisteredConsumers", func(t *testing.T) {
t.Parallel()
tracker := unit.NewManager[testStatus, testConsumerID]()
err := tracker.Register(consumerA)
require.NoError(t, err)
err = tracker.Register(consumerB)
require.NoError(t, err)
// A depends on B being "running"
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
require.NoError(t, err)
// A should no longer be ready (depends on B)
ready, err := tracker.IsReady(consumerA)
require.NoError(t, err)
assert.False(t, ready)
// B should still be ready (no dependencies)
ready, err = tracker.IsReady(consumerB)
require.NoError(t, err)
assert.True(t, ready)
})
t.Run("AddDependencyWithUnregisteredConsumer", func(t *testing.T) {
t.Parallel()
tracker := unit.NewManager[testStatus, testConsumerID]()
err := tracker.Register(consumerA)
require.NoError(t, err)
// Try to add dependency to unregistered consumer
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
require.Error(t, err)
assert.Contains(t, err.Error(), "not registered")
})
t.Run("AddDependencyFromUnregisteredConsumer", func(t *testing.T) {
t.Parallel()
tracker := unit.NewManager[testStatus, testConsumerID]()
err := tracker.Register(consumerB)
require.NoError(t, err)
// Try to add dependency from unregistered consumer
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
require.Error(t, err)
assert.Contains(t, err.Error(), "not registered")
})
}
func TestDependencyTracker_UpdateStatus(t *testing.T) {
t.Parallel()
t.Run("UpdateStatusTriggersReadinessRecalculation", func(t *testing.T) {
t.Parallel()
tracker := unit.NewManager[testStatus, testConsumerID]()
err := tracker.Register(consumerA)
require.NoError(t, err)
err = tracker.Register(consumerB)
require.NoError(t, err)
// A depends on B being "running"
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
require.NoError(t, err)
// Initially A is not ready
ready, err := tracker.IsReady(consumerA)
require.NoError(t, err)
assert.False(t, ready)
// Update B to "running" - A should become ready
err = tracker.UpdateStatus(consumerB, statusRunning)
require.NoError(t, err)
ready, err = tracker.IsReady(consumerA)
require.NoError(t, err)
assert.True(t, ready)
})
t.Run("UpdateStatusWithUnregisteredConsumer", func(t *testing.T) {
t.Parallel()
tracker := unit.NewManager[testStatus, testConsumerID]()
err := tracker.UpdateStatus(consumerA, statusRunning)
require.Error(t, err)
assert.Equal(t, unit.ErrConsumerNotFound, err)
})
t.Run("LinearChainDependencies", func(t *testing.T) {
t.Parallel()
tracker := unit.NewManager[testStatus, testConsumerID]()
// Register all consumers
consumers := []testConsumerID{consumerA, consumerB, consumerC}
for _, consumer := range consumers {
err := tracker.Register(consumer)
require.NoError(t, err)
}
// Create chain: A depends on B being "started", B depends on C being "completed"
err := tracker.AddDependency(consumerA, consumerB, statusStarted)
require.NoError(t, err)
err = tracker.AddDependency(consumerB, consumerC, statusCompleted)
require.NoError(t, err)
// Initially only C is ready
ready, err := tracker.IsReady(consumerC)
require.NoError(t, err)
assert.True(t, ready)
ready, err = tracker.IsReady(consumerB)
require.NoError(t, err)
assert.False(t, ready)
ready, err = tracker.IsReady(consumerA)
require.NoError(t, err)
assert.False(t, ready)
// Update C to "completed" - B should become ready
err = tracker.UpdateStatus(consumerC, statusCompleted)
require.NoError(t, err)
ready, err = tracker.IsReady(consumerB)
require.NoError(t, err)
assert.True(t, ready)
ready, err = tracker.IsReady(consumerA)
require.NoError(t, err)
assert.False(t, ready)
// Update B to "started" - A should become ready
err = tracker.UpdateStatus(consumerB, statusStarted)
require.NoError(t, err)
ready, err = tracker.IsReady(consumerA)
require.NoError(t, err)
assert.True(t, ready)
})
}
func TestDependencyTracker_GetUnmetDependencies(t *testing.T) {
t.Parallel()
t.Run("GetUnmetDependenciesForConsumerWithNoDependencies", func(t *testing.T) {
t.Parallel()
tracker := unit.NewManager[testStatus, testConsumerID]()
err := tracker.Register(consumerA)
require.NoError(t, err)
unmet, err := tracker.GetUnmetDependencies(consumerA)
require.NoError(t, err)
assert.Empty(t, unmet)
})
t.Run("GetUnmetDependenciesForConsumerWithUnsatisfiedDependencies", func(t *testing.T) {
t.Parallel()
tracker := unit.NewManager[testStatus, testConsumerID]()
err := tracker.Register(consumerA)
require.NoError(t, err)
err = tracker.Register(consumerB)
require.NoError(t, err)
// A depends on B being "running"
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
require.NoError(t, err)
unmet, err := tracker.GetUnmetDependencies(consumerA)
require.NoError(t, err)
require.Len(t, unmet, 1)
assert.Equal(t, consumerA, unmet[0].Consumer)
assert.Equal(t, consumerB, unmet[0].DependsOn)
assert.Equal(t, statusRunning, unmet[0].RequiredStatus)
assert.False(t, unmet[0].IsSatisfied)
})
t.Run("GetUnmetDependenciesForConsumerWithSatisfiedDependencies", func(t *testing.T) {
t.Parallel()
tracker := unit.NewManager[testStatus, testConsumerID]()
err := tracker.Register(consumerA)
require.NoError(t, err)
err = tracker.Register(consumerB)
require.NoError(t, err)
// A depends on B being "running"
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
require.NoError(t, err)
// Update B to "running"
err = tracker.UpdateStatus(consumerB, statusRunning)
require.NoError(t, err)
unmet, err := tracker.GetUnmetDependencies(consumerA)
require.NoError(t, err)
assert.Empty(t, unmet)
})
t.Run("GetUnmetDependenciesForUnregisteredConsumer", func(t *testing.T) {
t.Parallel()
tracker := unit.NewManager[testStatus, testConsumerID]()
unmet, err := tracker.GetUnmetDependencies(consumerA)
require.Error(t, err)
assert.Equal(t, unit.ErrConsumerNotFound, err)
assert.Nil(t, unmet)
})
}
func TestDependencyTracker_ConcurrentOperations(t *testing.T) {
t.Parallel()
t.Run("ConcurrentStatusUpdates", func(t *testing.T) {
t.Parallel()
tracker := unit.NewManager[testStatus, testConsumerID]()
// Register consumers
consumers := []testConsumerID{consumerA, consumerB, consumerC, consumerD}
for _, consumer := range consumers {
err := tracker.Register(consumer)
require.NoError(t, err)
}
// Create dependencies: A depends on B, B depends on C, C depends on D
err := tracker.AddDependency(consumerA, consumerB, statusRunning)
require.NoError(t, err)
err = tracker.AddDependency(consumerB, consumerC, statusStarted)
require.NoError(t, err)
err = tracker.AddDependency(consumerC, consumerD, statusCompleted)
require.NoError(t, err)
var wg sync.WaitGroup
const numGoroutines = 10
// Launch goroutines that update statuses
errors := make([]error, numGoroutines)
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func(goroutineID int) {
defer wg.Done()
// Update D to completed (should make C ready)
err := tracker.UpdateStatus(consumerD, statusCompleted)
if err != nil {
errors[goroutineID] = err
return
}
// Update C to started (should make B ready)
err = tracker.UpdateStatus(consumerC, statusStarted)
if err != nil {
errors[goroutineID] = err
return
}
// Update B to running (should make A ready)
err = tracker.UpdateStatus(consumerB, statusRunning)
if err != nil {
errors[goroutineID] = err
return
}
}(i)
}
wg.Wait()
// Check for any errors in goroutines
for i, err := range errors {
require.NoError(t, err, "goroutine %d had error", i)
}
// All consumers should be ready after the updates
for _, consumer := range consumers {
ready, err := tracker.IsReady(consumer)
require.NoError(t, err)
assert.True(t, ready)
}
})
t.Run("ConcurrentReadinessChecks", func(t *testing.T) {
t.Parallel()
tracker := unit.NewManager[testStatus, testConsumerID]()
// Register consumers
err := tracker.Register(consumerA)
require.NoError(t, err)
err = tracker.Register(consumerB)
require.NoError(t, err)
// A depends on B being "running"
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
require.NoError(t, err)
var wg sync.WaitGroup
const numGoroutines = 20
// Launch goroutines that check readiness
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func(goroutineID int) {
defer wg.Done()
// Check readiness multiple times
for j := 0; j < 10; j++ {
ready, err := tracker.IsReady(consumerA)
require.NoError(t, err)
// Initially should be false, then true after B is updated
_ = ready
ready, err = tracker.IsReady(consumerB)
require.NoError(t, err)
// B should always be ready (no dependencies)
assert.True(t, ready)
}
}(i)
}
// Update B to "running" in the middle of readiness checks
err = tracker.UpdateStatus(consumerB, statusRunning)
require.NoError(t, err)
wg.Wait()
})
}
func TestDependencyTracker_MultipleDependencies(t *testing.T) {
t.Parallel()
t.Run("ConsumerWithMultipleDependencies", func(t *testing.T) {
t.Parallel()
tracker := unit.NewManager[testStatus, testConsumerID]()
// Register all consumers
consumers := []testConsumerID{consumerA, consumerB, consumerC, consumerD}
for _, consumer := range consumers {
err := tracker.Register(consumer)
require.NoError(t, err)
}
// A depends on B being "running" AND C being "started"
err := tracker.AddDependency(consumerA, consumerB, statusRunning)
require.NoError(t, err)
err = tracker.AddDependency(consumerA, consumerC, statusStarted)
require.NoError(t, err)
// A should not be ready (depends on both B and C)
ready, err := tracker.IsReady(consumerA)
require.NoError(t, err)
assert.False(t, ready)
// Update B to "running" - A should still not be ready (needs C too)
err = tracker.UpdateStatus(consumerB, statusRunning)
require.NoError(t, err)
ready, err = tracker.IsReady(consumerA)
require.NoError(t, err)
assert.False(t, ready)
// Update C to "started" - A should now be ready
err = tracker.UpdateStatus(consumerC, statusStarted)
require.NoError(t, err)
ready, err = tracker.IsReady(consumerA)
require.NoError(t, err)
assert.True(t, ready)
})
t.Run("ComplexDependencyChain", func(t *testing.T) {
t.Parallel()
tracker := unit.NewManager[testStatus, testConsumerID]()
// Register all consumers
consumers := []testConsumerID{consumerA, consumerB, consumerC, consumerD}
for _, consumer := range consumers {
err := tracker.Register(consumer)
require.NoError(t, err)
}
// Create complex dependency graph:
// A depends on B being "running" AND C being "started"
// B depends on D being "completed"
// C depends on D being "completed"
err := tracker.AddDependency(consumerA, consumerB, statusRunning)
require.NoError(t, err)
err = tracker.AddDependency(consumerA, consumerC, statusStarted)
require.NoError(t, err)
err = tracker.AddDependency(consumerB, consumerD, statusCompleted)
require.NoError(t, err)
err = tracker.AddDependency(consumerC, consumerD, statusCompleted)
require.NoError(t, err)
// Initially only D is ready
ready, err := tracker.IsReady(consumerD)
require.NoError(t, err)
assert.True(t, ready)
ready, err = tracker.IsReady(consumerB)
require.NoError(t, err)
assert.False(t, ready)
ready, err = tracker.IsReady(consumerC)
require.NoError(t, err)
assert.False(t, ready)
ready, err = tracker.IsReady(consumerA)
require.NoError(t, err)
assert.False(t, ready)
// Update D to "completed" - B and C should become ready
err = tracker.UpdateStatus(consumerD, statusCompleted)
require.NoError(t, err)
ready, err = tracker.IsReady(consumerB)
require.NoError(t, err)
assert.True(t, ready)
ready, err = tracker.IsReady(consumerC)
require.NoError(t, err)
assert.True(t, ready)
ready, err = tracker.IsReady(consumerA)
require.NoError(t, err)
assert.False(t, ready)
// Update B to "running" - A should still not be ready (needs C)
err = tracker.UpdateStatus(consumerB, statusRunning)
require.NoError(t, err)
ready, err = tracker.IsReady(consumerA)
require.NoError(t, err)
assert.False(t, ready)
// Update C to "started" - A should now be ready
err = tracker.UpdateStatus(consumerC, statusStarted)
require.NoError(t, err)
ready, err = tracker.IsReady(consumerA)
require.NoError(t, err)
assert.True(t, ready)
})
t.Run("DifferentStatusTypes", func(t *testing.T) {
t.Parallel()
tracker := unit.NewManager[testStatus, testConsumerID]()
// Register consumers
err := tracker.Register(consumerA)
require.NoError(t, err)
err = tracker.Register(consumerB)
require.NoError(t, err)
err = tracker.Register(consumerC)
require.NoError(t, err)
// A depends on B being "running" AND C being "completed"
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
require.NoError(t, err)
err = tracker.AddDependency(consumerA, consumerC, statusCompleted)
require.NoError(t, err)
// Update B to "running" but not C - A should not be ready
err = tracker.UpdateStatus(consumerB, statusRunning)
require.NoError(t, err)
ready, err := tracker.IsReady(consumerA)
require.NoError(t, err)
assert.False(t, ready)
// Update C to "completed" - A should now be ready
err = tracker.UpdateStatus(consumerC, statusCompleted)
require.NoError(t, err)
ready, err = tracker.IsReady(consumerA)
require.NoError(t, err)
assert.True(t, ready)
})
}
func TestDependencyTracker_ErrorCases(t *testing.T) {
t.Parallel()
t.Run("UpdateStatusWithUnregisteredConsumer", func(t *testing.T) {
t.Parallel()
tracker := unit.NewManager[testStatus, testConsumerID]()
err := tracker.UpdateStatus(consumerA, statusRunning)
require.Error(t, err)
assert.Equal(t, unit.ErrConsumerNotFound, err)
})
t.Run("IsReadyWithUnregisteredConsumer", func(t *testing.T) {
t.Parallel()
tracker := unit.NewManager[testStatus, testConsumerID]()
ready, err := tracker.IsReady(consumerA)
require.Error(t, err)
assert.Equal(t, unit.ErrConsumerNotFound, err)
assert.False(t, ready)
})
t.Run("GetUnmetDependenciesWithUnregisteredConsumer", func(t *testing.T) {
t.Parallel()
tracker := unit.NewManager[testStatus, testConsumerID]()
unmet, err := tracker.GetUnmetDependencies(consumerA)
require.Error(t, err)
assert.Equal(t, unit.ErrConsumerNotFound, err)
assert.Nil(t, unmet)
})
t.Run("AddDependencyWithUnregisteredConsumers", func(t *testing.T) {
t.Parallel()
tracker := unit.NewManager[testStatus, testConsumerID]()
// Try to add dependency with unregistered consumers
err := tracker.AddDependency(consumerA, consumerB, statusRunning)
require.Error(t, err)
assert.Contains(t, err.Error(), "not registered")
})
t.Run("CyclicDependencyDetection", func(t *testing.T) {
t.Parallel()
tracker := unit.NewManager[testStatus, testConsumerID]()
// Register consumers
err := tracker.Register(consumerA)
require.NoError(t, err)
err = tracker.Register(consumerB)
require.NoError(t, err)
// A depends on B
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
require.NoError(t, err)
// Try to make B depend on A (creates cycle)
err = tracker.AddDependency(consumerB, consumerA, statusStarted)
require.Error(t, err)
assert.Contains(t, err.Error(), "would create a cycle")
})
}
func TestDependencyTracker_ToDOT(t *testing.T) {
t.Parallel()
t.Run("ExportSimpleGraph", func(t *testing.T) {
t.Parallel()
tracker := unit.NewManager[testStatus, testConsumerID]()
// Register consumers
err := tracker.Register(consumerA)
require.NoError(t, err)
err = tracker.Register(consumerB)
require.NoError(t, err)
// Add dependency
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
require.NoError(t, err)
dot, err := tracker.ExportDOT("test")
require.NoError(t, err)
assert.NotEmpty(t, dot)
assert.Contains(t, dot, "digraph")
})
t.Run("ExportComplexGraph", func(t *testing.T) {
t.Parallel()
tracker := unit.NewManager[testStatus, testConsumerID]()
// Register all consumers
consumers := []testConsumerID{consumerA, consumerB, consumerC, consumerD}
for _, consumer := range consumers {
err := tracker.Register(consumer)
require.NoError(t, err)
}
// Create complex dependency graph
// A depends on B and C, B depends on D, C depends on D
err := tracker.AddDependency(consumerA, consumerB, statusRunning)
require.NoError(t, err)
err = tracker.AddDependency(consumerA, consumerC, statusStarted)
require.NoError(t, err)
err = tracker.AddDependency(consumerB, consumerD, statusCompleted)
require.NoError(t, err)
err = tracker.AddDependency(consumerC, consumerD, statusCompleted)
require.NoError(t, err)
dot, err := tracker.ExportDOT("complex")
require.NoError(t, err)
assert.NotEmpty(t, dot)
assert.Contains(t, dot, "digraph")
})
}

View File

@@ -56,6 +56,7 @@ func workspaceAgent() *serpent.Command {
devcontainers bool
devcontainerProjectDiscovery bool
devcontainerDiscoveryAutostart bool
socketPath string
)
agentAuth := &AgentAuth{}
cmd := &serpent.Command{
@@ -297,6 +298,7 @@ func workspaceAgent() *serpent.Command {
agentcontainers.WithProjectDiscovery(devcontainerProjectDiscovery),
agentcontainers.WithDiscoveryAutostart(devcontainerDiscoveryAutostart),
},
SocketPath: socketPath,
})
promHandler := agent.PrometheusMetricsHandler(prometheusRegistry, logger)
@@ -449,6 +451,12 @@ func workspaceAgent() *serpent.Command {
Description: "Allow the agent to autostart devcontainer projects it discovers based on their configuration.",
Value: serpent.BoolOf(&devcontainerDiscoveryAutostart),
},
{
Flag: "socket-path",
Env: "CODER_AGENT_SOCKET_PATH",
Description: "Specify the path for the agent socket.",
Value: serpent.StringOf(&socketPath),
},
}
agentAuth.AttachOptions(cmd, false)
return cmd

View File

@@ -144,6 +144,7 @@ func (r *RootCmd) AGPLExperimental() []*serpent.Command {
r.mcpCommand(),
r.promptExample(),
r.rptyCommand(),
r.syncCommand(),
r.tasksCommand(),
r.boundary(),
}

25
cli/sync.go Normal file
View File

@@ -0,0 +1,25 @@
package cli
import (
"github.com/coder/serpent"
)
func (r *RootCmd) syncCommand() *serpent.Command {
cmd := &serpent.Command{
Use: "sync",
Short: "Synchronize with the local agent socket",
Long: "Commands for interacting with the local Coder agent via socket communication.",
Handler: func(i *serpent.Invocation) error {
return i.Command.HelpHandler(i)
},
Children: []*serpent.Command{
r.syncPing(),
r.syncStart(),
r.syncWant(),
r.syncComplete(),
r.syncWait(),
r.syncStatus(),
},
}
return cmd
}

50
cli/sync_complete.go Normal file
View File

@@ -0,0 +1,50 @@
package cli
import (
"context"
"fmt"
"golang.org/x/xerrors"
"github.com/coder/serpent"
"github.com/coder/coder/v2/codersdk/agentsdk"
)
func (r *RootCmd) syncComplete() *serpent.Command {
return &serpent.Command{
Use: "complete <unit>",
Short: "Mark a unit as complete in the dependency graph",
Long: "Set a unit's status to complete in the dependency graph.",
Handler: func(i *serpent.Invocation) error {
ctx := context.Background()
if len(i.Args) != 1 {
return xerrors.New("exactly one unit name is required")
}
unit := i.Args[0]
// Show initial message
fmt.Printf("Completing unit '%s'...\n", unit)
// Connect to agent socket
client, err := agentsdk.NewSocketClient(agentsdk.SocketConfig{
Path: "/tmp/coder.sock",
})
if err != nil {
return xerrors.Errorf("connect to agent socket: %w", err)
}
defer client.Close()
// Complete the unit
if err := client.SyncComplete(ctx, unit); err != nil {
return xerrors.Errorf("complete unit failed: %w", err)
}
// Display success message
fmt.Printf("Unit '%s' completed successfully\n", unit)
return nil
},
}
}

53
cli/sync_ping.go Normal file
View File

@@ -0,0 +1,53 @@
package cli
import (
"context"
"fmt"
"time"
"golang.org/x/xerrors"
"github.com/coder/serpent"
"github.com/coder/coder/v2/codersdk/agentsdk"
)
func (r *RootCmd) syncPing() *serpent.Command {
return &serpent.Command{
Use: "ping",
Short: "Ping the local agent socket",
Long: "Test connectivity to the local Coder agent via socket communication.",
Handler: func(i *serpent.Invocation) error {
ctx := context.Background()
// Show initial message
fmt.Println("Pinging agent socket...")
// Connect to agent socket
client, err := agentsdk.NewSocketClient(agentsdk.SocketConfig{
Path: "/tmp/coder.sock",
})
if err != nil {
return xerrors.Errorf("connect to agent socket: %w", err)
}
defer client.Close()
// Measure round-trip time
start := time.Now()
resp, err := client.Ping(ctx)
duration := time.Since(start)
if err != nil {
return xerrors.Errorf("ping failed: %w", err)
}
// Display results
fmt.Printf("Response: %s\n", resp.Message)
fmt.Printf("Timestamp: %s\n", resp.Timestamp.Format(time.RFC3339))
fmt.Printf("Round-trip time: %s\n", duration.Round(time.Microsecond))
fmt.Println("Status: healthy")
return nil
},
}
}

122
cli/sync_start.go Normal file
View File

@@ -0,0 +1,122 @@
package cli
import (
"context"
"fmt"
"time"
"golang.org/x/xerrors"
"github.com/coder/serpent"
"github.com/coder/coder/v2/agent/unit"
"github.com/coder/coder/v2/codersdk/agentsdk"
)
const (
// SyncPollInterval is the interval between dependency checks for sync start
SyncPollInterval = 1 * time.Second
)
func (r *RootCmd) syncStart() *serpent.Command {
var timeout time.Duration
cmd := &serpent.Command{
Use: "start <unit>",
Short: "Start a unit in the dependency graph",
Long: "Register a unit in the dependency graph and set its status to started. Waits for all dependencies to be satisfied before marking as started.",
Handler: func(i *serpent.Invocation) error {
ctx := context.Background()
if len(i.Args) != 1 {
return xerrors.New("exactly one unit name is required")
}
unitName := i.Args[0]
// Set up context with timeout if specified
if timeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, timeout)
defer cancel()
}
// Show initial message
fmt.Printf("Starting unit '%s'...\n", unitName)
// Connect to agent socket
client, err := agentsdk.NewSocketClient(agentsdk.SocketConfig{
Path: "/tmp/coder.sock",
})
if err != nil {
return xerrors.Errorf("connect to agent socket: %w", err)
}
defer client.Close()
// Check if dependencies are satisfied first
err = client.SyncReady(ctx, unitName)
if err != nil {
// Check if it's a "not ready" error (expected if dependencies exist)
if xerrors.Is(err, unit.ErrDependenciesNotSatisfied) {
// Dependencies exist but aren't satisfied, start polling
fmt.Printf("Waiting for dependencies of unit '%s' to be satisfied...\n", unitName)
// Poll until dependencies are satisfied
ticker := time.NewTicker(SyncPollInterval)
defer ticker.Stop()
pollLoop:
for {
select {
case <-ctx.Done():
if ctx.Err() == context.DeadlineExceeded {
return xerrors.Errorf("timeout waiting for dependencies of unit '%s'", unitName)
}
return ctx.Err()
case <-ticker.C:
// Check if dependencies are satisfied
err := client.SyncReady(ctx, unitName)
if err == nil {
// Dependencies are satisfied
fmt.Printf("Dependencies satisfied, marking unit '%s' as started\n", unitName)
break pollLoop
}
// Check if it's still a "not ready" error (expected while waiting)
if xerrors.Is(err, unit.ErrDependenciesNotSatisfied) {
// Still waiting, continue polling
continue
}
// Some other error occurred
return xerrors.Errorf("error checking dependencies: %w", err)
}
}
} else {
// Some other error occurred
return xerrors.Errorf("error checking dependencies: %w", err)
}
} else {
// No dependencies or already satisfied
fmt.Printf("Dependencies satisfied, marking unit '%s' as started\n", unitName)
}
// Start the unit
if err := client.SyncStart(ctx, unitName); err != nil {
return xerrors.Errorf("start unit failed: %w", err)
}
// Display success message
fmt.Printf("Unit '%s' started successfully\n", unitName)
return nil
},
}
cmd.Options = append(cmd.Options, serpent.Option{
Flag: "timeout",
Description: "Maximum time to wait for dependencies (e.g., 30s, 5m). No timeout by default.",
Value: serpent.DurationOf(&timeout),
})
return cmd
}

134
cli/sync_status.go Normal file
View File

@@ -0,0 +1,134 @@
package cli
import (
"context"
"encoding/json"
"fmt"
"os"
"strings"
"golang.org/x/xerrors"
"github.com/coder/serpent"
"github.com/coder/coder/v2/codersdk/agentsdk"
)
type outputFormat string
const (
outputFormatHuman outputFormat = "human"
outputFormatJSON outputFormat = "json"
outputFormatDOT outputFormat = "dot"
)
func (r *RootCmd) syncStatus() *serpent.Command {
var (
output string
recursive bool
)
cmd := &serpent.Command{
Use: "status <unit>",
Short: "Show the status of a unit and its dependencies",
Long: "Display the current status of a unit and information about its dependencies. Supports multiple output formats.",
Handler: func(i *serpent.Invocation) error {
ctx := context.Background()
if len(i.Args) != 1 {
return xerrors.New("exactly one unit name is required")
}
unit := i.Args[0]
// Connect to agent socket
client, err := agentsdk.NewSocketClient(agentsdk.SocketConfig{
Path: "/tmp/coder.sock",
})
if err != nil {
return xerrors.Errorf("connect to agent socket: %w", err)
}
defer client.Close()
// Get status information
statusResp, err := client.SyncStatus(ctx, unit, recursive)
if err != nil {
return xerrors.Errorf("get status failed: %w", err)
}
// Output based on format
switch outputFormat(output) {
case outputFormatJSON:
return outputJSON(statusResp)
case outputFormatDOT:
return outputDOT(statusResp)
default: // outputFormatHuman
return outputHuman(statusResp)
}
},
}
cmd.Options = append(cmd.Options,
serpent.Option{
Flag: "output",
FlagShorthand: "o",
Description: "Output format: human, json, or dot.",
Value: serpent.EnumOf(&output, "human", "json", "dot"),
},
serpent.Option{
Flag: "recursive",
FlagShorthand: "r",
Description: "Show transitive dependencies and include DOT graph.",
Value: serpent.BoolOf(&recursive),
},
)
return cmd
}
func outputJSON(statusResp *agentsdk.SyncStatusResponse) error {
encoder := json.NewEncoder(os.Stdout)
encoder.SetIndent("", " ")
return encoder.Encode(statusResp)
}
func outputDOT(statusResp *agentsdk.SyncStatusResponse) error {
if statusResp.DOT == "" {
return xerrors.New("DOT output requires --recursive flag")
}
fmt.Println(statusResp.DOT)
return nil
}
func outputHuman(statusResp *agentsdk.SyncStatusResponse) error {
// Unit status
fmt.Printf("Unit: %s\n", statusResp.Unit)
fmt.Printf("Status: %s\n", statusResp.Status)
fmt.Printf("Ready: %t\n", statusResp.IsReady)
fmt.Println()
// Dependencies
if len(statusResp.Dependencies) == 0 {
fmt.Println("No dependencies")
return nil
}
fmt.Println("Dependencies:")
fmt.Println(strings.Repeat("-", 80))
fmt.Printf("%-20s %-15s %-15s %-10s\n", "Depends On", "Required", "Current", "Satisfied")
fmt.Println(strings.Repeat("-", 80))
for _, dep := range statusResp.Dependencies {
satisfied := "✓"
if !dep.IsSatisfied {
satisfied = "✗"
}
fmt.Printf("%-20s %-15s %-15s %-10s\n",
dep.DependsOn,
dep.RequiredStatus,
dep.CurrentStatus,
satisfied,
)
}
return nil
}

359
cli/sync_test.go Normal file
View File

@@ -0,0 +1,359 @@
package cli_test
import (
"errors"
"fmt"
"net"
"os"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/agent/unit"
"github.com/coder/coder/v2/cli/clitest"
)
// mockAgentSocketServer simulates the agent socket server for testing
type mockAgentSocketServer struct {
listener net.Listener
handlers map[string]func(string) (string, error)
}
func newMockAgentSocketServer(t *testing.T, socketPath string) *mockAgentSocketServer {
listener, err := net.Listen("unix", socketPath)
require.NoError(t, err)
server := &mockAgentSocketServer{
listener: listener,
handlers: make(map[string]func(string) (string, error)),
}
// Set up default handlers
server.handlers["sync.wait"] = func(unitName string) (string, error) {
// Always return dependencies not satisfied to trigger polling
return "", unit.ErrDependenciesNotSatisfied
}
server.handlers["sync.start"] = func(unitName string) (string, error) {
return "Unit " + unitName + " started successfully", nil
}
go server.serve(t)
return server
}
func (s *mockAgentSocketServer) serve(t *testing.T) {
for {
conn, err := s.listener.Accept()
if err != nil {
if !errors.Is(err, net.ErrClosed) {
t.Logf("Accept error: %v", err)
}
return
}
go s.handleConnection(t, conn)
}
}
func (s *mockAgentSocketServer) handleConnection(t *testing.T, conn net.Conn) {
defer conn.Close()
// Simple JSON-RPC-like protocol simulation
buf := make([]byte, 1024)
n, err := conn.Read(buf)
if err != nil {
t.Logf("Read error: %v", err)
return
}
request := string(buf[:n])
// Parse method from request (simplified)
var method string
if strings.Contains(request, "sync.wait") {
method = "sync.wait"
} else if strings.Contains(request, "sync.start") {
method = "sync.start"
}
handler, exists := s.handlers[method]
if !exists {
response := `{"error": {"code": -32601, "message": "Method not found"}}`
_, _ = conn.Write([]byte(response))
return
}
// Extract unit name from request (simplified)
unitName := "test-unit"
if strings.Contains(request, "test-unit") {
unitName = "test-unit"
}
message, err := handler(unitName)
if err != nil {
response := fmt.Sprintf(`{"error": {"code": -32603, "message": %q}}`, err.Error())
_, _ = conn.Write([]byte(response))
return
}
response := fmt.Sprintf(`{"result": {"success": true, "message": %q}}`, message)
_, _ = conn.Write([]byte(response))
}
func (s *mockAgentSocketServer) setHandler(method string, handler func(string) (string, error)) {
s.handlers[method] = handler
}
func (s *mockAgentSocketServer) close() {
_ = s.listener.Close()
}
func TestSyncStartTimeout(t *testing.T) {
t.Parallel()
// Create a unique temporary socket file
socketPath := fmt.Sprintf("/tmp/coder-test-%d.sock", time.Now().UnixNano())
// Remove existing socket if it exists
_ = os.Remove(socketPath)
defer func() { _ = os.Remove(socketPath) }()
// Start mock server
server := newMockAgentSocketServer(t, socketPath)
defer server.close()
// Test with a short timeout
inv, _ := clitest.New(t, "exp", "sync", "start", "test-unit", "--timeout", "100ms")
// Override the socket path for this test
inv.Args = append(inv.Args, "--agent-socket", socketPath)
start := time.Now()
err := inv.Run()
duration := time.Since(start)
// Should timeout after approximately 100ms
assert.Error(t, err)
assert.Contains(t, err.Error(), "timeout waiting for dependencies of unit 'test-unit'")
// Should timeout within a reasonable range (100ms + some buffer for test execution)
assert.True(t, duration >= 100*time.Millisecond, "Duration should be at least 100ms, got %v", duration)
assert.True(t, duration < 2*time.Second, "Duration should be less than 2s, got %v", duration)
}
func TestSyncWaitTimeout(t *testing.T) {
t.Parallel()
// Create a unique temporary socket file
socketPath := fmt.Sprintf("/tmp/coder-test-%d.sock", time.Now().UnixNano())
// Remove existing socket if it exists
_ = os.Remove(socketPath)
defer func() { _ = os.Remove(socketPath) }()
// Start mock server
server := newMockAgentSocketServer(t, socketPath)
defer server.close()
// Test with a short timeout
inv, _ := clitest.New(t, "exp", "sync", "wait", "test-unit", "--timeout", "100ms")
// Override the socket path for this test
inv.Args = append(inv.Args, "--agent-socket", socketPath)
start := time.Now()
err := inv.Run()
duration := time.Since(start)
// Should timeout after approximately 100ms
assert.Error(t, err)
assert.Contains(t, err.Error(), "timeout waiting for dependencies of unit 'test-unit'")
// Should timeout within a reasonable range (100ms + some buffer for test execution)
assert.True(t, duration >= 100*time.Millisecond, "Duration should be at least 100ms, got %v", duration)
assert.True(t, duration < 2*time.Second, "Duration should be less than 2s, got %v", duration)
}
func TestSyncStartNoTimeout(t *testing.T) {
t.Parallel()
// Create a unique temporary socket file
socketPath := fmt.Sprintf("/tmp/coder-test-%d.sock", time.Now().UnixNano())
// Remove existing socket if it exists
_ = os.Remove(socketPath)
defer func() { _ = os.Remove(socketPath) }()
// Start mock server
server := newMockAgentSocketServer(t, socketPath)
defer server.close()
// Set up handler that will eventually succeed
callCount := 0
server.setHandler("sync.wait", func(unitName string) (string, error) {
callCount++
if callCount >= 3 {
// After 3 calls, dependencies are satisfied
return "Dependencies satisfied", nil
}
return "", unit.ErrDependenciesNotSatisfied
})
// Test without timeout - should eventually succeed
inv, _ := clitest.New(t, "exp", "sync", "start", "test-unit")
// Override the socket path for this test
inv.Args = append(inv.Args, "--agent-socket", socketPath)
start := time.Now()
err := inv.Run()
duration := time.Since(start)
// Should succeed after a few polling cycles
assert.NoError(t, err)
// Should take at least 2 seconds (2 polling cycles at 1s interval)
assert.True(t, duration >= 2*time.Second, "Duration should be at least 2s, got %v", duration)
assert.True(t, callCount >= 3, "Should have made at least 3 calls, got %d", callCount)
}
func TestSyncWaitNoTimeout(t *testing.T) {
t.Parallel()
// Create a unique temporary socket file
socketPath := fmt.Sprintf("/tmp/coder-test-%d.sock", time.Now().UnixNano())
// Remove existing socket if it exists
_ = os.Remove(socketPath)
defer func() { _ = os.Remove(socketPath) }()
// Start mock server
server := newMockAgentSocketServer(t, socketPath)
defer server.close()
// Set up handler that will eventually succeed
callCount := 0
server.setHandler("sync.wait", func(unitName string) (string, error) {
callCount++
if callCount >= 3 {
// After 3 calls, dependencies are satisfied
return "Dependencies satisfied", nil
}
return "", unit.ErrDependenciesNotSatisfied
})
// Test without timeout - should eventually succeed
inv, _ := clitest.New(t, "exp", "sync", "wait", "test-unit")
// Override the socket path for this test
inv.Args = append(inv.Args, "--agent-socket", socketPath)
start := time.Now()
err := inv.Run()
duration := time.Since(start)
// Should succeed after a few polling cycles
assert.NoError(t, err)
// Should take at least 2 seconds (2 polling cycles at 1s interval)
assert.True(t, duration >= 2*time.Second, "Duration should be at least 2s, got %v", duration)
assert.True(t, callCount >= 3, "Should have made at least 3 calls, got %d", callCount)
}
func TestSyncStartTimeoutWithDifferentValues(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
timeout string
expected time.Duration
}{
{"50ms", "50ms", 50 * time.Millisecond},
{"200ms", "200ms", 200 * time.Millisecond},
{"500ms", "500ms", 500 * time.Millisecond},
{"1s", "1s", 1 * time.Second},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
// Create a unique temporary socket file
socketPath := fmt.Sprintf("/tmp/coder-test-%d.sock", time.Now().UnixNano())
// Remove existing socket if it exists
_ = os.Remove(socketPath)
defer func() { _ = os.Remove(socketPath) }()
// Start mock server
server := newMockAgentSocketServer(t, socketPath)
defer server.close()
// Test with specified timeout
inv, _ := clitest.New(t, "exp", "sync", "start", "test-unit", "--timeout", tc.timeout)
// Override the socket path for this test
inv.Args = append(inv.Args, "--agent-socket", socketPath)
start := time.Now()
err := inv.Run()
duration := time.Since(start)
// Should timeout after approximately the specified duration
assert.Error(t, err)
assert.Contains(t, err.Error(), "timeout waiting for dependencies of unit 'test-unit'")
// Should timeout within a reasonable range
assert.True(t, duration >= tc.expected, "Duration should be at least %v, got %v", tc.expected, duration)
assert.True(t, duration < tc.expected+2*time.Second, "Duration should be less than %v, got %v", tc.expected+2*time.Second, duration)
})
}
}
func TestSyncWaitTimeoutWithDifferentValues(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
timeout string
expected time.Duration
}{
{"50ms", "50ms", 50 * time.Millisecond},
{"200ms", "200ms", 200 * time.Millisecond},
{"500ms", "500ms", 500 * time.Millisecond},
{"1s", "1s", 1 * time.Second},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
// Create a unique temporary socket file
socketPath := fmt.Sprintf("/tmp/coder-test-%d.sock", time.Now().UnixNano())
// Remove existing socket if it exists
_ = os.Remove(socketPath)
defer func() { _ = os.Remove(socketPath) }()
// Start mock server
server := newMockAgentSocketServer(t, socketPath)
defer server.close()
// Test with specified timeout
inv, _ := clitest.New(t, "exp", "sync", "wait", "test-unit", "--timeout", tc.timeout)
// Override the socket path for this test
inv.Args = append(inv.Args, "--agent-socket", socketPath)
start := time.Now()
err := inv.Run()
duration := time.Since(start)
// Should timeout after approximately the specified duration
assert.Error(t, err)
assert.Contains(t, err.Error(), "timeout waiting for dependencies of unit 'test-unit'")
// Should timeout within a reasonable range
assert.True(t, duration >= tc.expected, "Duration should be at least %v, got %v", tc.expected, duration)
assert.True(t, duration < tc.expected+2*time.Second, "Duration should be less than %v, got %v", tc.expected+2*time.Second, duration)
})
}
}

95
cli/sync_wait.go Normal file
View File

@@ -0,0 +1,95 @@
package cli
import (
"context"
"fmt"
"time"
"golang.org/x/xerrors"
"github.com/coder/serpent"
"github.com/coder/coder/v2/agent/unit"
"github.com/coder/coder/v2/codersdk/agentsdk"
)
const (
// PollInterval is the interval between dependency checks
PollInterval = 1 * time.Second
)
func (r *RootCmd) syncWait() *serpent.Command {
var timeout time.Duration
cmd := &serpent.Command{
Use: "wait <unit>",
Short: "Wait for a unit's dependencies to be satisfied",
Long: "Poll until all dependencies for a unit are met. Exits when dependencies are satisfied or timeout is reached.",
Handler: func(i *serpent.Invocation) error {
ctx := context.Background()
if len(i.Args) != 1 {
return xerrors.New("exactly one unit name is required")
}
unitName := i.Args[0]
// Set up context with timeout if specified
if timeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, timeout)
defer cancel()
}
// Show initial message
fmt.Printf("Waiting for dependencies of unit '%s' to be satisfied...\n", unitName)
// Connect to agent socket
client, err := agentsdk.NewSocketClient(agentsdk.SocketConfig{
Path: "/tmp/coder.sock",
})
if err != nil {
return xerrors.Errorf("connect to agent socket: %w", err)
}
defer client.Close()
// Poll until dependencies are satisfied
ticker := time.NewTicker(PollInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
if ctx.Err() == context.DeadlineExceeded {
return xerrors.Errorf("timeout waiting for dependencies of unit '%s'", unitName)
}
return ctx.Err()
case <-ticker.C:
// Check if dependencies are satisfied
err := client.SyncReady(ctx, unitName)
if err == nil {
// Dependencies are satisfied
fmt.Printf("Dependencies for unit '%s' are now satisfied\n", unitName)
return nil
}
// Check if it's a "not ready" error (expected while waiting)
if xerrors.Is(err, unit.ErrDependenciesNotSatisfied) {
// Still waiting, continue polling
continue
}
// Some other error occurred
return xerrors.Errorf("error checking dependencies: %w", err)
}
}
},
}
cmd.Options = append(cmd.Options, serpent.Option{
Flag: "timeout",
Description: "Maximum time to wait for dependencies (e.g., 30s, 5m). No timeout by default.",
Value: serpent.DurationOf(&timeout),
})
return cmd
}

51
cli/sync_want.go Normal file
View File

@@ -0,0 +1,51 @@
package cli
import (
"context"
"fmt"
"golang.org/x/xerrors"
"github.com/coder/serpent"
"github.com/coder/coder/v2/codersdk/agentsdk"
)
func (r *RootCmd) syncWant() *serpent.Command {
return &serpent.Command{
Use: "want <unit> <depends-on>",
Short: "Declare a dependency between units",
Long: "Declare that a unit depends on another unit reaching complete status.",
Handler: func(i *serpent.Invocation) error {
ctx := context.Background()
if len(i.Args) != 2 {
return xerrors.New("exactly two arguments are required: unit and depends-on")
}
unit := i.Args[0]
dependsOn := i.Args[1]
// Show initial message
fmt.Printf("Declaring dependency: '%s' depends on '%s'...\n", unit, dependsOn)
// Connect to agent socket
client, err := agentsdk.NewSocketClient(agentsdk.SocketConfig{
Path: "/tmp/coder.sock",
})
if err != nil {
return xerrors.Errorf("connect to agent socket: %w", err)
}
defer client.Close()
// Declare the dependency
if err := client.SyncWant(ctx, unit, dependsOn); err != nil {
return xerrors.Errorf("declare dependency failed: %w", err)
}
// Display success message
fmt.Printf("Dependency declared: '%s' now depends on '%s'\n", unit, dependsOn)
return nil
},
}
}

8
cli/sync_want_test.go Normal file
View File

@@ -0,0 +1,8 @@
package cli_test
import (
"testing"
)
func TestSyncWant(t *testing.T) {
}

View File

@@ -67,6 +67,9 @@ OPTIONS:
--script-data-dir string, $CODER_AGENT_SCRIPT_DATA_DIR (default: /tmp)
Specify the location for storing script data.
--socket-path string, $CODER_AGENT_SOCKET_PATH
Specify the path for the agent socket.
--ssh-max-timeout duration, $CODER_AGENT_SSH_MAX_TIMEOUT (default: 72h)
Specify the max timeout for a SSH connection, it is advisable to set
it to a minimum of 60s, but no more than 72h.

View File

@@ -0,0 +1,256 @@
package agentsdk
import (
"context"
"fmt"
"net"
"os"
"path/filepath"
"time"
"golang.org/x/xerrors"
"github.com/hashicorp/yamux"
"storj.io/drpc"
"github.com/coder/coder/v2/agent/agentsocket/proto"
"github.com/coder/coder/v2/agent/unit"
"github.com/coder/coder/v2/codersdk/drpcsdk"
)
// SocketClient provides a client for communicating with the agent socket
type SocketClient struct {
client proto.DRPCAgentSocketClient
conn drpc.Conn
}
// SocketConfig holds configuration for the socket client
type SocketConfig struct {
Path string // Socket path (optional, will auto-discover if not set)
}
// NewSocketClient creates a new socket client
func NewSocketClient(config SocketConfig) (*SocketClient, error) {
path := config.Path
if path == "" {
var err error
path, err = discoverSocketPath()
if err != nil {
return nil, xerrors.Errorf("discover socket path: %w", err)
}
}
conn, err := net.Dial("unix", path)
if err != nil {
return nil, xerrors.Errorf("connect to socket: %w", err)
}
// Create yamux session for multiplexing
configYamux := yamux.DefaultConfig()
configYamux.Logger = nil // Disable yamux logging
session, err := yamux.Client(conn, configYamux)
if err != nil {
conn.Close()
return nil, xerrors.Errorf("create yamux client: %w", err)
}
// Create drpc connection using the multiplexed connection
drpcConn := drpcsdk.MultiplexedConn(session)
// Create drpc client
client := proto.NewDRPCAgentSocketClient(drpcConn)
return &SocketClient{
client: client,
conn: drpcConn,
}, nil
}
// Close closes the socket connection
func (c *SocketClient) Close() error {
return c.conn.Close()
}
// Ping sends a ping request to the agent
func (c *SocketClient) Ping(ctx context.Context) (*PingResponse, error) {
resp, err := c.client.Ping(ctx, &proto.PingRequest{})
if err != nil {
return nil, err
}
return &PingResponse{
Message: resp.Message,
Timestamp: resp.Timestamp.AsTime(),
}, nil
}
// SyncStart starts a unit in the dependency graph
func (c *SocketClient) SyncStart(ctx context.Context, unit string) error {
resp, err := c.client.SyncStart(ctx, &proto.SyncStartRequest{
Unit: unit,
})
if err != nil {
return err
}
if !resp.Success {
return xerrors.Errorf("sync start failed: %s", resp.Message)
}
return nil
}
// SyncWant declares a dependency between units
func (c *SocketClient) SyncWant(ctx context.Context, unit, dependsOn string) error {
resp, err := c.client.SyncWant(ctx, &proto.SyncWantRequest{
Unit: unit,
DependsOn: dependsOn,
})
if err != nil {
return err
}
if !resp.Success {
return xerrors.Errorf("sync want failed: %s", resp.Message)
}
return nil
}
// SyncComplete marks a unit as complete in the dependency graph
func (c *SocketClient) SyncComplete(ctx context.Context, unit string) error {
resp, err := c.client.SyncComplete(ctx, &proto.SyncCompleteRequest{
Unit: unit,
})
if err != nil {
return err
}
if !resp.Success {
return xerrors.Errorf("sync complete failed: %s", resp.Message)
}
return nil
}
// SyncReady requests whether a unit is ready to be started. That is, all dependencies are satisfied.
func (c *SocketClient) SyncReady(ctx context.Context, unitName string) error {
resp, err := c.client.SyncReady(ctx, &proto.SyncReadyRequest{
Unit: unitName,
})
if err != nil {
return err
}
if !resp.Success {
// Check if this is a dependencies not satisfied error
if resp.Message == unit.ErrDependenciesNotSatisfied.Error() {
return unit.ErrDependenciesNotSatisfied
}
return xerrors.Errorf("sync ready failed: %s", resp.Message)
}
return nil
}
// SyncStatus gets the status of a unit and its dependencies
func (c *SocketClient) SyncStatus(ctx context.Context, unit string, recursive bool) (*SyncStatusResponse, error) {
resp, err := c.client.SyncStatus(ctx, &proto.SyncStatusRequest{
Unit: unit,
Recursive: recursive,
})
if err != nil {
return nil, err
}
if !resp.Success {
return nil, xerrors.Errorf("sync status failed: %s", resp.Message)
}
// Convert dependencies
var dependencies []DependencyInfo
for _, dep := range resp.Dependencies {
dependencies = append(dependencies, DependencyInfo{
DependsOn: dep.DependsOn,
RequiredStatus: dep.RequiredStatus,
CurrentStatus: dep.CurrentStatus,
IsSatisfied: dep.IsSatisfied,
})
}
return &SyncStatusResponse{
Success: resp.Success,
Message: resp.Message,
Unit: resp.Unit,
Status: resp.Status,
IsReady: resp.IsReady,
Dependencies: dependencies,
DOT: resp.Dot,
}, nil
}
// discoverSocketPath discovers the agent socket path
func discoverSocketPath() (string, error) {
// Check environment variable first
if path := os.Getenv("CODER_AGENT_SOCKET_PATH"); path != "" {
return path, nil
}
// Try common socket paths
paths := []string{
// XDG runtime directory
filepath.Join(os.Getenv("XDG_RUNTIME_DIR"), "coder-agent.sock"),
// User-specific temp directory
filepath.Join(os.TempDir(), fmt.Sprintf("coder-agent-%d.sock", os.Getuid())),
// Fallback temp directory
filepath.Join(os.TempDir(), "coder-agent.sock"),
}
for _, path := range paths {
if path == "" {
continue
}
if _, err := os.Stat(path); err == nil {
return path, nil
}
}
return "", xerrors.New("agent socket not found")
}
// Response types for backward compatibility
type PingResponse struct {
Message string `json:"message"`
Timestamp time.Time `json:"timestamp"`
}
type HealthResponse struct {
Status string `json:"status"`
Timestamp time.Time `json:"timestamp"`
Uptime string `json:"uptime"`
}
type AgentInfo struct {
ID string `json:"id"`
Version string `json:"version"`
Status string `json:"status"`
StartedAt time.Time `json:"started_at"`
Uptime string `json:"uptime"`
}
type SyncStatusResponse struct {
Success bool `json:"success"`
Message string `json:"message"`
Unit string `json:"unit"`
Status string `json:"status"`
IsReady bool `json:"is_ready"`
Dependencies []DependencyInfo `json:"dependencies"`
DOT string `json:"dot,omitempty"`
}
type DependencyInfo struct {
DependsOn string `json:"depends_on"`
RequiredStatus string `json:"required_status"`
CurrentStatus string `json:"current_status"`
IsSatisfied bool `json:"is_satisfied"`
}