Compare commits
18 Commits
fix/useref
...
jjs/intern
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9764926f92 | ||
|
|
10d4e42fc1 | ||
|
|
217ddf46c4 | ||
|
|
0d3d493eae | ||
|
|
89b060e245 | ||
|
|
820d53b66a | ||
|
|
f550028052 | ||
|
|
e6873c8d61 | ||
|
|
8c0bfcb570 | ||
|
|
c322b92ab0 | ||
|
|
216a5ac562 | ||
|
|
86447126d5 | ||
|
|
55c5b707fb | ||
|
|
4616c82f3c | ||
|
|
9ca30e28d6 | ||
|
|
34c1370090 | ||
|
|
851c4f907c | ||
|
|
e3dfe45f35 |
9
Makefile
9
Makefile
@@ -642,6 +642,7 @@ AIBRIDGED_MOCKS := \
|
||||
GEN_FILES := \
|
||||
tailnet/proto/tailnet.pb.go \
|
||||
agent/proto/agent.pb.go \
|
||||
agent/agentsocket/proto/agentsocket.pb.go \
|
||||
provisionersdk/proto/provisioner.pb.go \
|
||||
provisionerd/proto/provisionerd.pb.go \
|
||||
vpn/vpn.pb.go \
|
||||
@@ -800,6 +801,14 @@ agent/proto/agent.pb.go: agent/proto/agent.proto
|
||||
--go-drpc_opt=paths=source_relative \
|
||||
./agent/proto/agent.proto
|
||||
|
||||
agent/agentsocket/proto/agentsocket.pb.go: agent/agentsocket/proto/agentsocket.proto
|
||||
protoc \
|
||||
--go_out=. \
|
||||
--go_opt=paths=source_relative \
|
||||
--go-drpc_out=. \
|
||||
--go-drpc_opt=paths=source_relative \
|
||||
./agent/agentsocket/proto/agentsocket.proto
|
||||
|
||||
provisionersdk/proto/provisioner.pb.go: provisionersdk/proto/provisioner.proto
|
||||
protoc \
|
||||
--go_out=. \
|
||||
|
||||
@@ -40,6 +40,7 @@ import (
|
||||
"github.com/coder/coder/v2/agent/agentcontainers"
|
||||
"github.com/coder/coder/v2/agent/agentexec"
|
||||
"github.com/coder/coder/v2/agent/agentscripts"
|
||||
"github.com/coder/coder/v2/agent/agentsocket"
|
||||
"github.com/coder/coder/v2/agent/agentssh"
|
||||
"github.com/coder/coder/v2/agent/proto"
|
||||
"github.com/coder/coder/v2/agent/proto/resourcesmonitor"
|
||||
@@ -91,6 +92,7 @@ type Options struct {
|
||||
Devcontainers bool
|
||||
DevcontainerAPIOptions []agentcontainers.Option // Enable Devcontainers for these to be effective.
|
||||
Clock quartz.Clock
|
||||
SocketPath string // Path for the agent socket server
|
||||
}
|
||||
|
||||
type Client interface {
|
||||
@@ -190,6 +192,7 @@ func New(options Options) Agent {
|
||||
|
||||
devcontainers: options.Devcontainers,
|
||||
containerAPIOptions: options.DevcontainerAPIOptions,
|
||||
socketPath: options.SocketPath,
|
||||
}
|
||||
// Initially, we have a closed channel, reflecting the fact that we are not initially connected.
|
||||
// Each time we connect we replace the channel (while holding the closeMutex) with a new one
|
||||
@@ -271,6 +274,9 @@ type agent struct {
|
||||
devcontainers bool
|
||||
containerAPIOptions []agentcontainers.Option
|
||||
containerAPI *agentcontainers.API
|
||||
|
||||
socketPath string
|
||||
socketServer *agentsocket.Server
|
||||
}
|
||||
|
||||
func (a *agent) TailnetConn() *tailnet.Conn {
|
||||
@@ -350,9 +356,35 @@ func (a *agent) init() {
|
||||
s.ExperimentalContainers = a.devcontainers
|
||||
},
|
||||
)
|
||||
|
||||
a.initSocketServer()
|
||||
|
||||
go a.runLoop()
|
||||
}
|
||||
|
||||
// initSocketServer initializes server that allows direct communication with a workspace agent using IPC.
|
||||
func (a *agent) initSocketServer() {
|
||||
if a.socketPath == "" {
|
||||
a.logger.Info(a.hardCtx, "socket server disabled (no path configured)")
|
||||
return
|
||||
}
|
||||
|
||||
server, err := agentsocket.NewServer(a.socketPath, a.logger.Named("socket"))
|
||||
if err != nil {
|
||||
a.logger.Warn(a.hardCtx, "failed to create socket server", slog.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
err = server.Start()
|
||||
if err != nil {
|
||||
a.logger.Warn(a.hardCtx, "failed to start socket server", slog.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
a.socketServer = server
|
||||
a.logger.Debug(a.hardCtx, "socket server started", slog.F("path", a.socketPath))
|
||||
}
|
||||
|
||||
// runLoop attempts to start the agent in a retry loop.
|
||||
// Coder may be offline temporarily, a connection issue
|
||||
// may be happening, but regardless after the intermittent
|
||||
@@ -1920,6 +1952,13 @@ func (a *agent) Close() error {
|
||||
lifecycleState = codersdk.WorkspaceAgentLifecycleShutdownError
|
||||
}
|
||||
}
|
||||
|
||||
if a.socketServer != nil {
|
||||
if err := a.socketServer.Stop(); err != nil {
|
||||
a.logger.Error(a.hardCtx, "socket server close", slog.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
a.setLifecycle(lifecycleState)
|
||||
|
||||
err = a.scriptRunner.Close()
|
||||
|
||||
1105
agent/agentsocket/proto/agentsocket.pb.go
Normal file
1105
agent/agentsocket/proto/agentsocket.pb.go
Normal file
File diff suppressed because it is too large
Load Diff
88
agent/agentsocket/proto/agentsocket.proto
Normal file
88
agent/agentsocket/proto/agentsocket.proto
Normal file
@@ -0,0 +1,88 @@
|
||||
syntax = "proto3";
|
||||
option go_package = "github.com/coder/coder/v2/agent/agentsocket/proto";
|
||||
|
||||
package coder.agentsocket.v1;
|
||||
|
||||
import "google/protobuf/timestamp.proto";
|
||||
|
||||
message PingRequest {}
|
||||
|
||||
message PingResponse {
|
||||
string message = 1;
|
||||
google.protobuf.Timestamp timestamp = 2;
|
||||
}
|
||||
|
||||
message SyncStartRequest {
|
||||
string unit = 1;
|
||||
}
|
||||
|
||||
message SyncStartResponse {
|
||||
bool success = 1;
|
||||
string message = 2;
|
||||
}
|
||||
|
||||
message SyncWantRequest {
|
||||
string unit = 1;
|
||||
string depends_on = 2;
|
||||
}
|
||||
|
||||
message SyncWantResponse {
|
||||
bool success = 1;
|
||||
string message = 2;
|
||||
}
|
||||
|
||||
message SyncCompleteRequest {
|
||||
string unit = 1;
|
||||
}
|
||||
|
||||
message SyncCompleteResponse {
|
||||
bool success = 1;
|
||||
string message = 2;
|
||||
}
|
||||
|
||||
message SyncReadyRequest {
|
||||
string unit = 1;
|
||||
}
|
||||
|
||||
message SyncReadyResponse {
|
||||
bool success = 1;
|
||||
string message = 2;
|
||||
}
|
||||
|
||||
message SyncStatusRequest {
|
||||
string unit = 1;
|
||||
bool recursive = 2;
|
||||
}
|
||||
|
||||
message DependencyInfo {
|
||||
string depends_on = 1;
|
||||
string required_status = 2;
|
||||
string current_status = 3;
|
||||
bool is_satisfied = 4;
|
||||
}
|
||||
|
||||
message SyncStatusResponse {
|
||||
bool success = 1;
|
||||
string message = 2;
|
||||
string unit = 3;
|
||||
string status = 4;
|
||||
bool is_ready = 5;
|
||||
repeated DependencyInfo dependencies = 6;
|
||||
string dot = 7;
|
||||
}
|
||||
|
||||
// AgentSocket provides direct access to the agent over local IPC.
|
||||
service AgentSocket {
|
||||
// Ping the agent to check if it is alive.
|
||||
rpc Ping(PingRequest) returns (PingResponse);
|
||||
// Report the start of a unit.
|
||||
rpc SyncStart(SyncStartRequest) returns (SyncStartResponse);
|
||||
// Declare a dependency between units.
|
||||
rpc SyncWant(SyncWantRequest) returns (SyncWantResponse);
|
||||
// Report the completion of a unit.
|
||||
rpc SyncComplete(SyncCompleteRequest) returns (SyncCompleteResponse);
|
||||
// Request whether a unit is ready to be started. That is, all dependencies are satisfied.
|
||||
rpc SyncReady(SyncReadyRequest) returns (SyncReadyResponse);
|
||||
// Get the status of a unit and list its dependencies.
|
||||
rpc SyncStatus(SyncStatusRequest) returns (SyncStatusResponse);
|
||||
}
|
||||
311
agent/agentsocket/proto/agentsocket_drpc.pb.go
Normal file
311
agent/agentsocket/proto/agentsocket_drpc.pb.go
Normal file
@@ -0,0 +1,311 @@
|
||||
// Code generated by protoc-gen-go-drpc. DO NOT EDIT.
|
||||
// protoc-gen-go-drpc version: v0.0.34
|
||||
// source: agent/agentsocket/proto/agentsocket.proto
|
||||
|
||||
package proto
|
||||
|
||||
import (
|
||||
context "context"
|
||||
errors "errors"
|
||||
protojson "google.golang.org/protobuf/encoding/protojson"
|
||||
proto "google.golang.org/protobuf/proto"
|
||||
drpc "storj.io/drpc"
|
||||
drpcerr "storj.io/drpc/drpcerr"
|
||||
)
|
||||
|
||||
type drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto struct{}
|
||||
|
||||
func (drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto) Marshal(msg drpc.Message) ([]byte, error) {
|
||||
return proto.Marshal(msg.(proto.Message))
|
||||
}
|
||||
|
||||
func (drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto) MarshalAppend(buf []byte, msg drpc.Message) ([]byte, error) {
|
||||
return proto.MarshalOptions{}.MarshalAppend(buf, msg.(proto.Message))
|
||||
}
|
||||
|
||||
func (drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto) Unmarshal(buf []byte, msg drpc.Message) error {
|
||||
return proto.Unmarshal(buf, msg.(proto.Message))
|
||||
}
|
||||
|
||||
func (drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto) JSONMarshal(msg drpc.Message) ([]byte, error) {
|
||||
return protojson.Marshal(msg.(proto.Message))
|
||||
}
|
||||
|
||||
func (drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto) JSONUnmarshal(buf []byte, msg drpc.Message) error {
|
||||
return protojson.Unmarshal(buf, msg.(proto.Message))
|
||||
}
|
||||
|
||||
type DRPCAgentSocketClient interface {
|
||||
DRPCConn() drpc.Conn
|
||||
|
||||
Ping(ctx context.Context, in *PingRequest) (*PingResponse, error)
|
||||
SyncStart(ctx context.Context, in *SyncStartRequest) (*SyncStartResponse, error)
|
||||
SyncWant(ctx context.Context, in *SyncWantRequest) (*SyncWantResponse, error)
|
||||
SyncComplete(ctx context.Context, in *SyncCompleteRequest) (*SyncCompleteResponse, error)
|
||||
SyncReady(ctx context.Context, in *SyncReadyRequest) (*SyncReadyResponse, error)
|
||||
SyncStatus(ctx context.Context, in *SyncStatusRequest) (*SyncStatusResponse, error)
|
||||
}
|
||||
|
||||
type drpcAgentSocketClient struct {
|
||||
cc drpc.Conn
|
||||
}
|
||||
|
||||
func NewDRPCAgentSocketClient(cc drpc.Conn) DRPCAgentSocketClient {
|
||||
return &drpcAgentSocketClient{cc}
|
||||
}
|
||||
|
||||
func (c *drpcAgentSocketClient) DRPCConn() drpc.Conn { return c.cc }
|
||||
|
||||
func (c *drpcAgentSocketClient) Ping(ctx context.Context, in *PingRequest) (*PingResponse, error) {
|
||||
out := new(PingResponse)
|
||||
err := c.cc.Invoke(ctx, "/coder.agentsocket.v1.AgentSocket/Ping", drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{}, in, out)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *drpcAgentSocketClient) SyncStart(ctx context.Context, in *SyncStartRequest) (*SyncStartResponse, error) {
|
||||
out := new(SyncStartResponse)
|
||||
err := c.cc.Invoke(ctx, "/coder.agentsocket.v1.AgentSocket/SyncStart", drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{}, in, out)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *drpcAgentSocketClient) SyncWant(ctx context.Context, in *SyncWantRequest) (*SyncWantResponse, error) {
|
||||
out := new(SyncWantResponse)
|
||||
err := c.cc.Invoke(ctx, "/coder.agentsocket.v1.AgentSocket/SyncWant", drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{}, in, out)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *drpcAgentSocketClient) SyncComplete(ctx context.Context, in *SyncCompleteRequest) (*SyncCompleteResponse, error) {
|
||||
out := new(SyncCompleteResponse)
|
||||
err := c.cc.Invoke(ctx, "/coder.agentsocket.v1.AgentSocket/SyncComplete", drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{}, in, out)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *drpcAgentSocketClient) SyncReady(ctx context.Context, in *SyncReadyRequest) (*SyncReadyResponse, error) {
|
||||
out := new(SyncReadyResponse)
|
||||
err := c.cc.Invoke(ctx, "/coder.agentsocket.v1.AgentSocket/SyncReady", drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{}, in, out)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *drpcAgentSocketClient) SyncStatus(ctx context.Context, in *SyncStatusRequest) (*SyncStatusResponse, error) {
|
||||
out := new(SyncStatusResponse)
|
||||
err := c.cc.Invoke(ctx, "/coder.agentsocket.v1.AgentSocket/SyncStatus", drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{}, in, out)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
type DRPCAgentSocketServer interface {
|
||||
Ping(context.Context, *PingRequest) (*PingResponse, error)
|
||||
SyncStart(context.Context, *SyncStartRequest) (*SyncStartResponse, error)
|
||||
SyncWant(context.Context, *SyncWantRequest) (*SyncWantResponse, error)
|
||||
SyncComplete(context.Context, *SyncCompleteRequest) (*SyncCompleteResponse, error)
|
||||
SyncReady(context.Context, *SyncReadyRequest) (*SyncReadyResponse, error)
|
||||
SyncStatus(context.Context, *SyncStatusRequest) (*SyncStatusResponse, error)
|
||||
}
|
||||
|
||||
type DRPCAgentSocketUnimplementedServer struct{}
|
||||
|
||||
func (s *DRPCAgentSocketUnimplementedServer) Ping(context.Context, *PingRequest) (*PingResponse, error) {
|
||||
return nil, drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented)
|
||||
}
|
||||
|
||||
func (s *DRPCAgentSocketUnimplementedServer) SyncStart(context.Context, *SyncStartRequest) (*SyncStartResponse, error) {
|
||||
return nil, drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented)
|
||||
}
|
||||
|
||||
func (s *DRPCAgentSocketUnimplementedServer) SyncWant(context.Context, *SyncWantRequest) (*SyncWantResponse, error) {
|
||||
return nil, drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented)
|
||||
}
|
||||
|
||||
func (s *DRPCAgentSocketUnimplementedServer) SyncComplete(context.Context, *SyncCompleteRequest) (*SyncCompleteResponse, error) {
|
||||
return nil, drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented)
|
||||
}
|
||||
|
||||
func (s *DRPCAgentSocketUnimplementedServer) SyncReady(context.Context, *SyncReadyRequest) (*SyncReadyResponse, error) {
|
||||
return nil, drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented)
|
||||
}
|
||||
|
||||
func (s *DRPCAgentSocketUnimplementedServer) SyncStatus(context.Context, *SyncStatusRequest) (*SyncStatusResponse, error) {
|
||||
return nil, drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented)
|
||||
}
|
||||
|
||||
type DRPCAgentSocketDescription struct{}
|
||||
|
||||
func (DRPCAgentSocketDescription) NumMethods() int { return 6 }
|
||||
|
||||
func (DRPCAgentSocketDescription) Method(n int) (string, drpc.Encoding, drpc.Receiver, interface{}, bool) {
|
||||
switch n {
|
||||
case 0:
|
||||
return "/coder.agentsocket.v1.AgentSocket/Ping", drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{},
|
||||
func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) {
|
||||
return srv.(DRPCAgentSocketServer).
|
||||
Ping(
|
||||
ctx,
|
||||
in1.(*PingRequest),
|
||||
)
|
||||
}, DRPCAgentSocketServer.Ping, true
|
||||
case 1:
|
||||
return "/coder.agentsocket.v1.AgentSocket/SyncStart", drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{},
|
||||
func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) {
|
||||
return srv.(DRPCAgentSocketServer).
|
||||
SyncStart(
|
||||
ctx,
|
||||
in1.(*SyncStartRequest),
|
||||
)
|
||||
}, DRPCAgentSocketServer.SyncStart, true
|
||||
case 2:
|
||||
return "/coder.agentsocket.v1.AgentSocket/SyncWant", drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{},
|
||||
func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) {
|
||||
return srv.(DRPCAgentSocketServer).
|
||||
SyncWant(
|
||||
ctx,
|
||||
in1.(*SyncWantRequest),
|
||||
)
|
||||
}, DRPCAgentSocketServer.SyncWant, true
|
||||
case 3:
|
||||
return "/coder.agentsocket.v1.AgentSocket/SyncComplete", drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{},
|
||||
func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) {
|
||||
return srv.(DRPCAgentSocketServer).
|
||||
SyncComplete(
|
||||
ctx,
|
||||
in1.(*SyncCompleteRequest),
|
||||
)
|
||||
}, DRPCAgentSocketServer.SyncComplete, true
|
||||
case 4:
|
||||
return "/coder.agentsocket.v1.AgentSocket/SyncReady", drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{},
|
||||
func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) {
|
||||
return srv.(DRPCAgentSocketServer).
|
||||
SyncReady(
|
||||
ctx,
|
||||
in1.(*SyncReadyRequest),
|
||||
)
|
||||
}, DRPCAgentSocketServer.SyncReady, true
|
||||
case 5:
|
||||
return "/coder.agentsocket.v1.AgentSocket/SyncStatus", drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{},
|
||||
func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) {
|
||||
return srv.(DRPCAgentSocketServer).
|
||||
SyncStatus(
|
||||
ctx,
|
||||
in1.(*SyncStatusRequest),
|
||||
)
|
||||
}, DRPCAgentSocketServer.SyncStatus, true
|
||||
default:
|
||||
return "", nil, nil, nil, false
|
||||
}
|
||||
}
|
||||
|
||||
func DRPCRegisterAgentSocket(mux drpc.Mux, impl DRPCAgentSocketServer) error {
|
||||
return mux.Register(impl, DRPCAgentSocketDescription{})
|
||||
}
|
||||
|
||||
type DRPCAgentSocket_PingStream interface {
|
||||
drpc.Stream
|
||||
SendAndClose(*PingResponse) error
|
||||
}
|
||||
|
||||
type drpcAgentSocket_PingStream struct {
|
||||
drpc.Stream
|
||||
}
|
||||
|
||||
func (x *drpcAgentSocket_PingStream) SendAndClose(m *PingResponse) error {
|
||||
if err := x.MsgSend(m, drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{}); err != nil {
|
||||
return err
|
||||
}
|
||||
return x.CloseSend()
|
||||
}
|
||||
|
||||
type DRPCAgentSocket_SyncStartStream interface {
|
||||
drpc.Stream
|
||||
SendAndClose(*SyncStartResponse) error
|
||||
}
|
||||
|
||||
type drpcAgentSocket_SyncStartStream struct {
|
||||
drpc.Stream
|
||||
}
|
||||
|
||||
func (x *drpcAgentSocket_SyncStartStream) SendAndClose(m *SyncStartResponse) error {
|
||||
if err := x.MsgSend(m, drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{}); err != nil {
|
||||
return err
|
||||
}
|
||||
return x.CloseSend()
|
||||
}
|
||||
|
||||
type DRPCAgentSocket_SyncWantStream interface {
|
||||
drpc.Stream
|
||||
SendAndClose(*SyncWantResponse) error
|
||||
}
|
||||
|
||||
type drpcAgentSocket_SyncWantStream struct {
|
||||
drpc.Stream
|
||||
}
|
||||
|
||||
func (x *drpcAgentSocket_SyncWantStream) SendAndClose(m *SyncWantResponse) error {
|
||||
if err := x.MsgSend(m, drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{}); err != nil {
|
||||
return err
|
||||
}
|
||||
return x.CloseSend()
|
||||
}
|
||||
|
||||
type DRPCAgentSocket_SyncCompleteStream interface {
|
||||
drpc.Stream
|
||||
SendAndClose(*SyncCompleteResponse) error
|
||||
}
|
||||
|
||||
type drpcAgentSocket_SyncCompleteStream struct {
|
||||
drpc.Stream
|
||||
}
|
||||
|
||||
func (x *drpcAgentSocket_SyncCompleteStream) SendAndClose(m *SyncCompleteResponse) error {
|
||||
if err := x.MsgSend(m, drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{}); err != nil {
|
||||
return err
|
||||
}
|
||||
return x.CloseSend()
|
||||
}
|
||||
|
||||
type DRPCAgentSocket_SyncReadyStream interface {
|
||||
drpc.Stream
|
||||
SendAndClose(*SyncReadyResponse) error
|
||||
}
|
||||
|
||||
type drpcAgentSocket_SyncReadyStream struct {
|
||||
drpc.Stream
|
||||
}
|
||||
|
||||
func (x *drpcAgentSocket_SyncReadyStream) SendAndClose(m *SyncReadyResponse) error {
|
||||
if err := x.MsgSend(m, drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{}); err != nil {
|
||||
return err
|
||||
}
|
||||
return x.CloseSend()
|
||||
}
|
||||
|
||||
type DRPCAgentSocket_SyncStatusStream interface {
|
||||
drpc.Stream
|
||||
SendAndClose(*SyncStatusResponse) error
|
||||
}
|
||||
|
||||
type drpcAgentSocket_SyncStatusStream struct {
|
||||
drpc.Stream
|
||||
}
|
||||
|
||||
func (x *drpcAgentSocket_SyncStatusStream) SendAndClose(m *SyncStatusResponse) error {
|
||||
if err := x.MsgSend(m, drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{}); err != nil {
|
||||
return err
|
||||
}
|
||||
return x.CloseSend()
|
||||
}
|
||||
17
agent/agentsocket/proto/version.go
Normal file
17
agent/agentsocket/proto/version.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package proto
|
||||
|
||||
import "github.com/coder/coder/v2/apiversion"
|
||||
|
||||
// Version history:
|
||||
//
|
||||
// API v1.0:
|
||||
// - Initial release
|
||||
// - Ping
|
||||
// - Sync operations: SyncStart, SyncWant, SyncComplete, SyncWait, SyncStatus
|
||||
|
||||
const (
|
||||
CurrentMajor = 1
|
||||
CurrentMinor = 0
|
||||
)
|
||||
|
||||
var CurrentVersion = apiversion.New(CurrentMajor, CurrentMinor)
|
||||
185
agent/agentsocket/server.go
Normal file
185
agent/agentsocket/server.go
Normal file
@@ -0,0 +1,185 @@
|
||||
package agentsocket
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/hashicorp/yamux"
|
||||
"storj.io/drpc/drpcmux"
|
||||
"storj.io/drpc/drpcserver"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/v2/agent/agentsocket/proto"
|
||||
"github.com/coder/coder/v2/agent/unit"
|
||||
"github.com/coder/coder/v2/codersdk/drpcsdk"
|
||||
)
|
||||
|
||||
// Server provides access to the DRPCAgentSocketService via a Unix domain socket.
|
||||
// Do not invoke Server{} directly. Use NewServer() instead.
|
||||
type Server struct {
|
||||
logger slog.Logger
|
||||
path string
|
||||
listener net.Listener
|
||||
mu sync.RWMutex
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
drpcServer *drpcserver.Server
|
||||
service *DRPCAgentSocketService
|
||||
}
|
||||
|
||||
func NewServer(path string, logger slog.Logger) (*Server, error) {
|
||||
logger = logger.Named("agentsocket")
|
||||
server := &Server{
|
||||
logger: logger,
|
||||
path: path,
|
||||
service: &DRPCAgentSocketService{
|
||||
logger: logger,
|
||||
unitManager: unit.NewManager[string, string](),
|
||||
},
|
||||
}
|
||||
|
||||
mux := drpcmux.New()
|
||||
err := proto.DRPCRegisterAgentSocket(mux, server.service)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("failed to register drpc service: %w", err)
|
||||
}
|
||||
|
||||
server.drpcServer = drpcserver.NewWithOptions(mux, drpcserver.Options{
|
||||
Manager: drpcsdk.DefaultDRPCOptions(nil),
|
||||
Log: func(err error) {
|
||||
if errors.Is(err, context.Canceled) ||
|
||||
errors.Is(err, context.DeadlineExceeded) {
|
||||
return
|
||||
}
|
||||
logger.Debug(context.Background(), "drpc server error", slog.Error(err))
|
||||
},
|
||||
})
|
||||
|
||||
return server, nil
|
||||
}
|
||||
|
||||
var ErrServerAlreadyStarted = xerrors.New("server already started")
|
||||
|
||||
func (s *Server) Start() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.listener != nil {
|
||||
return ErrServerAlreadyStarted
|
||||
}
|
||||
|
||||
// This context is canceled by s.Stop() when the server is stopped.
|
||||
// canceling it will close all connections.
|
||||
s.ctx, s.cancel = context.WithCancel(context.Background())
|
||||
|
||||
if s.path == "" {
|
||||
var err error
|
||||
s.path, err = getDefaultSocketPath()
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get default socket path: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
listener, err := createSocket(s.path)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create socket: %w", err)
|
||||
}
|
||||
|
||||
s.listener = listener
|
||||
|
||||
s.logger.Info(s.ctx, "agent socket server started", slog.F("path", s.path))
|
||||
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
defer s.wg.Done()
|
||||
s.acceptConnections()
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) Stop() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.listener == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
s.logger.Info(s.ctx, "stopping agent socket server")
|
||||
|
||||
s.cancel()
|
||||
|
||||
if err := s.listener.Close(); err != nil {
|
||||
s.logger.Warn(s.ctx, "error closing socket listener", slog.Error(err))
|
||||
}
|
||||
|
||||
// Wait for all connections to finish
|
||||
s.wg.Wait()
|
||||
|
||||
if err := cleanupSocket(s.path); err != nil {
|
||||
s.logger.Warn(s.ctx, "error cleaning up socket file", slog.Error(err))
|
||||
}
|
||||
|
||||
s.listener = nil
|
||||
s.logger.Info(s.ctx, "agent socket server stopped")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) acceptConnections() {
|
||||
for {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
conn, err := s.listener.Accept()
|
||||
if err != nil {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
default:
|
||||
s.logger.Warn(s.ctx, "error accepting connection", slog.Error(err))
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
defer s.wg.Done()
|
||||
s.handleConnection(conn)
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleConnection(conn net.Conn) {
|
||||
defer conn.Close()
|
||||
|
||||
if err := conn.SetDeadline(time.Now().Add(30 * time.Second)); err != nil {
|
||||
s.logger.Warn(s.ctx, "failed to set connection deadline", slog.Error(err))
|
||||
}
|
||||
|
||||
s.logger.Debug(s.ctx, "new connection accepted", slog.F("remote_addr", conn.RemoteAddr()))
|
||||
|
||||
config := yamux.DefaultConfig()
|
||||
config.Logger = nil
|
||||
session, err := yamux.Server(conn, config)
|
||||
if err != nil {
|
||||
s.logger.Warn(s.ctx, "failed to create yamux session", slog.Error(err))
|
||||
return
|
||||
}
|
||||
defer session.Close()
|
||||
|
||||
err = s.drpcServer.Serve(s.ctx, session)
|
||||
if err != nil {
|
||||
s.logger.Debug(s.ctx, "drpc server finished", slog.Error(err))
|
||||
}
|
||||
}
|
||||
48
agent/agentsocket/server_test.go
Normal file
48
agent/agentsocket/server_test.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package agentsocket_test
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/v2/agent/agentsocket"
|
||||
)
|
||||
|
||||
func TestServer(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("StartStop", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketPath := filepath.Join(t.TempDir(), "test.sock")
|
||||
logger := slog.Make().Leveled(slog.LevelDebug)
|
||||
server, err := agentsocket.NewServer(socketPath, logger)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, server.Start())
|
||||
require.NoError(t, server.Stop())
|
||||
})
|
||||
|
||||
t.Run("AlreadyStarted", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketPath := filepath.Join(t.TempDir(), "test.sock")
|
||||
logger := slog.Make().Leveled(slog.LevelDebug)
|
||||
server, err := agentsocket.NewServer(socketPath, logger)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, server.Start())
|
||||
require.ErrorIs(t, server.Start(), agentsocket.ErrServerAlreadyStarted)
|
||||
})
|
||||
|
||||
t.Run("AutoSocketPath", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketPath := filepath.Join(t.TempDir(), "test.sock")
|
||||
logger := slog.Make().Leveled(slog.LevelDebug)
|
||||
server, err := agentsocket.NewServer(socketPath, logger)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, server.Start())
|
||||
require.NoError(t, server.Stop())
|
||||
})
|
||||
}
|
||||
262
agent/agentsocket/service.go
Normal file
262
agent/agentsocket/service.go
Normal file
@@ -0,0 +1,262 @@
|
||||
package agentsocket
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/v2/agent/agentsocket/proto"
|
||||
"github.com/coder/coder/v2/agent/unit"
|
||||
)
|
||||
|
||||
var _ proto.DRPCAgentSocketServer = (*DRPCAgentSocketService)(nil)
|
||||
|
||||
type DRPCAgentSocketService struct {
|
||||
mu sync.RWMutex
|
||||
unitManager *unit.Manager[string, string]
|
||||
logger slog.Logger
|
||||
}
|
||||
|
||||
func (*DRPCAgentSocketService) Ping(_ context.Context, _ *proto.PingRequest) (*proto.PingResponse, error) {
|
||||
return &proto.PingResponse{
|
||||
Message: "pong",
|
||||
Timestamp: timestamppb.New(time.Now()),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *DRPCAgentSocketService) SyncStart(_ context.Context, req *proto.SyncStartRequest) (*proto.SyncStartResponse, error) {
|
||||
if s.unitManager == nil {
|
||||
return &proto.SyncStartResponse{
|
||||
Success: false,
|
||||
Message: "dependency tracker not available",
|
||||
}, nil
|
||||
}
|
||||
|
||||
if req.Unit == "" {
|
||||
return &proto.SyncStartResponse{
|
||||
Success: false,
|
||||
Message: "Unit name is required",
|
||||
}, nil
|
||||
}
|
||||
|
||||
if err := s.unitManager.Register(req.Unit); err != nil {
|
||||
// If already registered, that's okay - we can still update status
|
||||
if !errors.Is(err, unit.ErrConsumerAlreadyRegistered) {
|
||||
return &proto.SyncStartResponse{
|
||||
Success: false,
|
||||
Message: "Failed to register unit: " + err.Error(),
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
isReady, err := s.unitManager.IsReady(req.Unit)
|
||||
if err != nil {
|
||||
return &proto.SyncStartResponse{
|
||||
Success: false,
|
||||
Message: "Failed to check readiness: " + err.Error(),
|
||||
}, nil
|
||||
}
|
||||
if !isReady {
|
||||
return &proto.SyncStartResponse{
|
||||
Success: false,
|
||||
Message: "Unit is not ready",
|
||||
}, nil
|
||||
}
|
||||
|
||||
if err := s.unitManager.UpdateStatus(req.Unit, unit.StatusStarted); err != nil {
|
||||
return &proto.SyncStartResponse{
|
||||
Success: false,
|
||||
Message: "Failed to update status: " + err.Error(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
return &proto.SyncStartResponse{
|
||||
Success: true,
|
||||
Message: "Unit " + req.Unit + " started successfully",
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *DRPCAgentSocketService) SyncWant(_ context.Context, req *proto.SyncWantRequest) (*proto.SyncWantResponse, error) {
|
||||
if s.unitManager == nil {
|
||||
return &proto.SyncWantResponse{
|
||||
Success: false,
|
||||
Message: "unit manager not available",
|
||||
}, nil
|
||||
}
|
||||
|
||||
if req.Unit == "" || req.DependsOn == "" {
|
||||
return &proto.SyncWantResponse{
|
||||
Success: false,
|
||||
Message: "unit and depends_on are required",
|
||||
}, nil
|
||||
}
|
||||
|
||||
if err := s.unitManager.Register(req.Unit); err != nil {
|
||||
if !errors.Is(err, unit.ErrConsumerAlreadyRegistered) {
|
||||
return &proto.SyncWantResponse{
|
||||
Success: false,
|
||||
Message: "failed to register unit: " + err.Error(),
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.unitManager.Register(req.DependsOn); err != nil {
|
||||
if !errors.Is(err, unit.ErrConsumerAlreadyRegistered) {
|
||||
return &proto.SyncWantResponse{
|
||||
Success: false,
|
||||
Message: "failed to register dependency unit: " + err.Error(),
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.unitManager.AddDependency(req.Unit, req.DependsOn, unit.StatusComplete); err != nil {
|
||||
return &proto.SyncWantResponse{
|
||||
Success: false,
|
||||
Message: "failed to add dependency: " + err.Error(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
return &proto.SyncWantResponse{
|
||||
Success: true,
|
||||
Message: "Unit " + req.Unit + " now depends on " + req.DependsOn,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *DRPCAgentSocketService) SyncComplete(_ context.Context, req *proto.SyncCompleteRequest) (*proto.SyncCompleteResponse, error) {
|
||||
if s.unitManager == nil {
|
||||
return &proto.SyncCompleteResponse{
|
||||
Success: false,
|
||||
Message: "unit manager not available",
|
||||
}, nil
|
||||
}
|
||||
|
||||
if req.Unit == "" {
|
||||
return &proto.SyncCompleteResponse{
|
||||
Success: false,
|
||||
Message: "unit name is required",
|
||||
}, nil
|
||||
}
|
||||
|
||||
if err := s.unitManager.UpdateStatus(req.Unit, unit.StatusComplete); err != nil {
|
||||
return &proto.SyncCompleteResponse{
|
||||
Success: false,
|
||||
Message: "failed to update status: " + err.Error(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
return &proto.SyncCompleteResponse{
|
||||
Success: true,
|
||||
Message: "unit " + req.Unit + " completed successfully",
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *DRPCAgentSocketService) SyncReady(_ context.Context, req *proto.SyncReadyRequest) (*proto.SyncReadyResponse, error) {
|
||||
if s.unitManager == nil {
|
||||
return &proto.SyncReadyResponse{
|
||||
Success: false,
|
||||
Message: "unit manager not available",
|
||||
}, nil
|
||||
}
|
||||
|
||||
if req.Unit == "" {
|
||||
return &proto.SyncReadyResponse{
|
||||
Success: false,
|
||||
Message: "unit name is required",
|
||||
}, nil
|
||||
}
|
||||
|
||||
isReady, err := s.unitManager.IsReady(req.Unit)
|
||||
if err != nil {
|
||||
return &proto.SyncReadyResponse{
|
||||
Success: false,
|
||||
Message: "failed to check readiness: " + err.Error(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
if !isReady {
|
||||
return &proto.SyncReadyResponse{
|
||||
Success: false,
|
||||
Message: unit.ErrDependenciesNotSatisfied.Error(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
return &proto.SyncReadyResponse{
|
||||
Success: true,
|
||||
Message: "unit " + req.Unit + " dependencies are satisfied",
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *DRPCAgentSocketService) SyncStatus(_ context.Context, req *proto.SyncStatusRequest) (*proto.SyncStatusResponse, error) {
|
||||
if s.unitManager == nil {
|
||||
return &proto.SyncStatusResponse{
|
||||
Success: false,
|
||||
Message: "unit manager not available",
|
||||
}, nil
|
||||
}
|
||||
|
||||
if req.Unit == "" {
|
||||
return &proto.SyncStatusResponse{
|
||||
Success: false,
|
||||
Message: "unit name is required",
|
||||
}, nil
|
||||
}
|
||||
|
||||
status, err := s.unitManager.GetStatus(req.Unit)
|
||||
if err != nil {
|
||||
return &proto.SyncStatusResponse{
|
||||
Success: false,
|
||||
Message: "failed to get unit status: " + err.Error(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
isReady, err := s.unitManager.IsReady(req.Unit)
|
||||
if err != nil {
|
||||
return &proto.SyncStatusResponse{
|
||||
Success: false,
|
||||
Message: "failed to check readiness: " + err.Error(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
dependencies, err := s.unitManager.GetAllDependencies(req.Unit)
|
||||
if err != nil {
|
||||
return &proto.SyncStatusResponse{
|
||||
Success: false,
|
||||
Message: "failed to get dependencies: " + err.Error(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
var depInfos []*proto.DependencyInfo
|
||||
for _, dep := range dependencies {
|
||||
depInfos = append(depInfos, &proto.DependencyInfo{
|
||||
DependsOn: dep.DependsOn,
|
||||
RequiredStatus: dep.RequiredStatus,
|
||||
CurrentStatus: dep.CurrentStatus,
|
||||
IsSatisfied: dep.IsSatisfied,
|
||||
})
|
||||
}
|
||||
|
||||
var dotStr string
|
||||
if req.Recursive {
|
||||
dotStr, err = s.unitManager.ExportDOT("dependency_graph")
|
||||
if err != nil {
|
||||
return &proto.SyncStatusResponse{
|
||||
Success: false,
|
||||
Message: "failed to export DOT: " + err.Error(),
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
return &proto.SyncStatusResponse{
|
||||
Success: true,
|
||||
Message: "unit status retrieved successfully",
|
||||
Unit: req.Unit,
|
||||
Status: status,
|
||||
IsReady: isReady,
|
||||
Dependencies: depInfos,
|
||||
Dot: dotStr,
|
||||
}, nil
|
||||
}
|
||||
311
agent/agentsocket/service_test.go
Normal file
311
agent/agentsocket/service_test.go
Normal file
@@ -0,0 +1,311 @@
|
||||
package agentsocket_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/v2/agent/agentsocket"
|
||||
"github.com/coder/coder/v2/agent/unit"
|
||||
"github.com/coder/coder/v2/codersdk/agentsdk"
|
||||
)
|
||||
|
||||
func TestDRPCAgentSocketService(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("Ping", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketPath := filepath.Join(t.TempDir(), "test.sock")
|
||||
|
||||
server, err := agentsocket.NewServer(
|
||||
socketPath,
|
||||
slog.Make().Leveled(slog.LevelDebug),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = server.Start()
|
||||
require.NoError(t, err)
|
||||
defer server.Stop()
|
||||
|
||||
client, err := agentsdk.NewSocketClient(agentsdk.SocketConfig{
|
||||
Path: socketPath,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer client.Close()
|
||||
|
||||
response, err := client.Ping(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "pong", response.Message)
|
||||
})
|
||||
|
||||
t.Run("SyncStart", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("NewUnit", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
socketPath := filepath.Join(t.TempDir(), "test.sock")
|
||||
|
||||
server, err := agentsocket.NewServer(
|
||||
socketPath,
|
||||
slog.Make().Leveled(slog.LevelDebug),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
err = server.Start()
|
||||
require.NoError(t, err)
|
||||
defer server.Stop()
|
||||
|
||||
client, err := agentsdk.NewSocketClient(agentsdk.SocketConfig{
|
||||
Path: socketPath,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer client.Close()
|
||||
|
||||
err = client.SyncStart(context.Background(), "test-unit")
|
||||
require.NoError(t, err)
|
||||
|
||||
status, err := client.SyncStatus(context.Background(), "test-unit", false)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "started", status.Status)
|
||||
})
|
||||
|
||||
t.Run("UnitAlreadyStarted", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketPath := filepath.Join(t.TempDir(), "test.sock")
|
||||
|
||||
server, err := agentsocket.NewServer(
|
||||
socketPath,
|
||||
slog.Make().Leveled(slog.LevelDebug),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
err = server.Start()
|
||||
require.NoError(t, err)
|
||||
defer server.Stop()
|
||||
|
||||
client, err := agentsdk.NewSocketClient(agentsdk.SocketConfig{
|
||||
Path: socketPath,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer client.Close()
|
||||
|
||||
err = client.SyncStart(context.Background(), "test-unit")
|
||||
require.NoError(t, err)
|
||||
|
||||
// First Start
|
||||
status, err := client.SyncStatus(context.Background(), "test-unit", false)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "started", status.Status)
|
||||
|
||||
status, err = client.SyncStatus(context.Background(), "test-unit", false)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "started", status.Status)
|
||||
|
||||
// Second Start
|
||||
err = client.SyncStart(context.Background(), "test-unit")
|
||||
require.ErrorContains(t, err, unit.ErrSameStatusAlreadySet.Error())
|
||||
|
||||
status, err = client.SyncStatus(context.Background(), "test-unit", false)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "started", status.Status)
|
||||
})
|
||||
|
||||
t.Run("UnitAlreadyCompleted", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketPath := filepath.Join(t.TempDir(), "test.sock")
|
||||
|
||||
server, err := agentsocket.NewServer(
|
||||
socketPath,
|
||||
slog.Make().Leveled(slog.LevelDebug),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
err = server.Start()
|
||||
require.NoError(t, err)
|
||||
defer server.Stop()
|
||||
|
||||
client, err := agentsdk.NewSocketClient(agentsdk.SocketConfig{
|
||||
Path: socketPath,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer client.Close()
|
||||
|
||||
// First start
|
||||
err = client.SyncStart(context.Background(), "test-unit")
|
||||
require.NoError(t, err)
|
||||
|
||||
status, err := client.SyncStatus(context.Background(), "test-unit", false)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "started", status.Status)
|
||||
|
||||
// Complete the unit
|
||||
err = client.SyncComplete(context.Background(), "test-unit")
|
||||
require.NoError(t, err)
|
||||
|
||||
status, err = client.SyncStatus(context.Background(), "test-unit", false)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "completed", status.Status)
|
||||
|
||||
// Second start
|
||||
err = client.SyncStart(context.Background(), "test-unit")
|
||||
require.NoError(t, err)
|
||||
|
||||
status, err = client.SyncStatus(context.Background(), "test-unit", false)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "started", status.Status)
|
||||
})
|
||||
|
||||
t.Run("UnitNotReady", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketPath := filepath.Join(t.TempDir(), "test.sock")
|
||||
|
||||
server, err := agentsocket.NewServer(
|
||||
socketPath,
|
||||
slog.Make().Leveled(slog.LevelDebug),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
err = server.Start()
|
||||
require.NoError(t, err)
|
||||
defer server.Stop()
|
||||
|
||||
client, err := agentsdk.NewSocketClient(agentsdk.SocketConfig{
|
||||
Path: socketPath,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer client.Close()
|
||||
|
||||
client.SyncWant(context.Background(), "test-unit", "dependency-unit")
|
||||
require.NoError(t, err)
|
||||
|
||||
err = client.SyncStart(context.Background(), "test-unit")
|
||||
require.ErrorContains(t, err, "Unit is not ready")
|
||||
|
||||
status, err := client.SyncStatus(context.Background(), "test-unit", false)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "", status.Status)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("SyncWant", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("NewUnits", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketPath := filepath.Join(t.TempDir(), "test.sock")
|
||||
|
||||
server, err := agentsocket.NewServer(
|
||||
socketPath,
|
||||
slog.Make().Leveled(slog.LevelDebug),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
err = server.Start()
|
||||
require.NoError(t, err)
|
||||
defer server.Stop()
|
||||
|
||||
client, err := agentsdk.NewSocketClient(agentsdk.SocketConfig{
|
||||
Path: socketPath,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer client.Close()
|
||||
|
||||
// If units are not registered, they are registered automatically
|
||||
err = client.SyncWant(context.Background(), "test-unit", "dependency-unit")
|
||||
require.NoError(t, err)
|
||||
|
||||
status, err := client.SyncStatus(context.Background(), "test-unit", false)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "dependency-unit", status.Dependencies[0].DependsOn)
|
||||
require.Equal(t, "completed", status.Dependencies[0].RequiredStatus)
|
||||
})
|
||||
|
||||
t.Run("DependencyAlreadyRegistered", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketPath := filepath.Join(t.TempDir(), "test.sock")
|
||||
|
||||
server, err := agentsocket.NewServer(
|
||||
socketPath,
|
||||
slog.Make().Leveled(slog.LevelDebug),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
err = server.Start()
|
||||
require.NoError(t, err)
|
||||
defer server.Stop()
|
||||
|
||||
client, err := agentsdk.NewSocketClient(agentsdk.SocketConfig{
|
||||
Path: socketPath,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer client.Close()
|
||||
|
||||
// Start the dependency unit
|
||||
err = client.SyncStart(context.Background(), "dependency-unit")
|
||||
require.NoError(t, err)
|
||||
|
||||
status, err := client.SyncStatus(context.Background(), "dependency-unit", false)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "started", status.Status)
|
||||
|
||||
// Add the dependency after the dependency unit has already started
|
||||
err = client.SyncWant(context.Background(), "test-unit", "dependency-unit")
|
||||
|
||||
// Dependencies can be added even if the dependency unit has already started
|
||||
require.NoError(t, err)
|
||||
|
||||
// The dependency is now reflected in the test unit's status
|
||||
status, err = client.SyncStatus(context.Background(), "test-unit", false)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "dependency-unit", status.Dependencies[0].DependsOn)
|
||||
require.Equal(t, "completed", status.Dependencies[0].RequiredStatus)
|
||||
})
|
||||
|
||||
t.Run("DependencyAddedAfterDependentStarted", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketPath := filepath.Join(t.TempDir(), "test.sock")
|
||||
|
||||
server, err := agentsocket.NewServer(
|
||||
socketPath,
|
||||
slog.Make().Leveled(slog.LevelDebug),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
err = server.Start()
|
||||
require.NoError(t, err)
|
||||
defer server.Stop()
|
||||
|
||||
client, err := agentsdk.NewSocketClient(agentsdk.SocketConfig{
|
||||
Path: socketPath,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer client.Close()
|
||||
|
||||
// Start the dependent unit
|
||||
err = client.SyncStart(context.Background(), "test-unit")
|
||||
require.NoError(t, err)
|
||||
|
||||
status, err := client.SyncStatus(context.Background(), "test-unit", false)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "started", status.Status)
|
||||
|
||||
// Add the dependency after the dependency unit has already started
|
||||
err = client.SyncWant(context.Background(), "test-unit", "dependency-unit")
|
||||
|
||||
// Dependencies can be added even if the dependent unit has already started.
|
||||
// The dependency applies the next time a unit is started. The current status is not updated.
|
||||
// This is to allow flexible dependency management. It does mean that users of this API should
|
||||
// take care to add dependencies before they start their dependent units.
|
||||
require.NoError(t, err)
|
||||
|
||||
// The dependency is now reflected in the test unit's status
|
||||
status, err = client.SyncStatus(context.Background(), "test-unit", false)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "dependency-unit", status.Dependencies[0].DependsOn)
|
||||
require.Equal(t, "completed", status.Dependencies[0].RequiredStatus)
|
||||
})
|
||||
})
|
||||
}
|
||||
76
agent/agentsocket/socket_unix.go
Normal file
76
agent/agentsocket/socket_unix.go
Normal file
@@ -0,0 +1,76 @@
|
||||
//go:build !windows
|
||||
|
||||
package agentsocket
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
// createSocket creates a Unix domain socket listener
|
||||
func createSocket(path string) (net.Listener, error) {
|
||||
if !isSocketAvailable(path) {
|
||||
return nil, xerrors.Errorf("socket path %s is not available", path)
|
||||
}
|
||||
|
||||
if err := os.Remove(path); err != nil && !os.IsNotExist(err) {
|
||||
return nil, xerrors.Errorf("remove existing socket: %w", err)
|
||||
}
|
||||
|
||||
// Create parent directory if it doesn't exist
|
||||
parentDir := filepath.Dir(path)
|
||||
if err := os.MkdirAll(parentDir, 0o700); err != nil {
|
||||
return nil, xerrors.Errorf("create socket directory: %w", err)
|
||||
}
|
||||
|
||||
listener, err := net.Listen("unix", path)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("listen on unix socket: %w", err)
|
||||
}
|
||||
|
||||
if err := os.Chmod(path, 0o600); err != nil {
|
||||
_ = listener.Close()
|
||||
return nil, xerrors.Errorf("set socket permissions: %w", err)
|
||||
}
|
||||
return listener, nil
|
||||
}
|
||||
|
||||
// getDefaultSocketPath returns the default socket path for Unix-like systems
|
||||
func getDefaultSocketPath() (string, error) {
|
||||
// Try XDG_RUNTIME_DIR first
|
||||
if runtimeDir := os.Getenv("XDG_RUNTIME_DIR"); runtimeDir != "" {
|
||||
return filepath.Join(runtimeDir, "coder-agent.sock"), nil
|
||||
}
|
||||
|
||||
// Fall back to /tmp with user-specific path
|
||||
uid := os.Getuid()
|
||||
return filepath.Join("/tmp", fmt.Sprintf("coder-agent-%d.sock", uid)), nil
|
||||
}
|
||||
|
||||
// CleanupSocket removes the socket file
|
||||
func cleanupSocket(path string) error {
|
||||
return os.Remove(path)
|
||||
}
|
||||
|
||||
// isSocketAvailable checks if a socket path is available for use
|
||||
func isSocketAvailable(path string) bool {
|
||||
// Check if file exists
|
||||
if _, err := os.Stat(path); os.IsNotExist(err) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Try to connect to see if it's actually listening
|
||||
conn, err := net.Dial("unix", path)
|
||||
if err != nil {
|
||||
// If we can't connect, the socket is not in use
|
||||
// Socket is available for use
|
||||
return true
|
||||
}
|
||||
_ = conn.Close()
|
||||
// Socket is in use
|
||||
return false
|
||||
}
|
||||
111
agent/agentsocket/socket_windows.go
Normal file
111
agent/agentsocket/socket_windows.go
Normal file
@@ -0,0 +1,111 @@
|
||||
//go:build windows
|
||||
|
||||
package agentsocket
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"cdr.dev/slog"
|
||||
)
|
||||
|
||||
// createSocket creates a Unix domain socket listener on Windows
|
||||
// Falls back to named pipe if Unix sockets are not supported
|
||||
func CreateSocket(path string) (net.Listener, error) {
|
||||
// Try Unix domain socket first (Windows 10 build 17063+)
|
||||
listener, err := net.Listen("unix", path)
|
||||
if err == nil {
|
||||
return listener, nil
|
||||
}
|
||||
|
||||
// Fall back to named pipe
|
||||
pipePath := `\\.\pipe\coder-agent`
|
||||
listener, err = net.Listen("tcp", pipePath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return listener, nil
|
||||
}
|
||||
|
||||
// getDefaultSocketPath returns the default socket path for Windows
|
||||
func GetDefaultSocketPath() (string, error) {
|
||||
// Try to use a temporary directory
|
||||
tempDir := os.TempDir()
|
||||
if tempDir == "" {
|
||||
tempDir = "C:\\temp"
|
||||
}
|
||||
|
||||
// Create a user-specific subdirectory
|
||||
uid := os.Getuid()
|
||||
userDir := filepath.Join(tempDir, "coder-agent", strconv.Itoa(uid))
|
||||
|
||||
if err := os.MkdirAll(userDir, 0o700); err != nil {
|
||||
return "", fmt.Errorf("create user directory: %w", err)
|
||||
}
|
||||
|
||||
return filepath.Join(userDir, "agent.sock"), nil
|
||||
}
|
||||
|
||||
// cleanupSocket removes the socket file
|
||||
func CleanupSocket(path string) error {
|
||||
return os.Remove(path)
|
||||
}
|
||||
|
||||
// isSocketAvailable checks if a socket path is available for use
|
||||
func IsSocketAvailable(path string, logger slog.Logger) bool {
|
||||
logger.Debug(context.Background(), "Checking socket availability on Windows", slog.F("path", path))
|
||||
|
||||
// Check if file exists
|
||||
if _, err := os.Stat(path); os.IsNotExist(err) {
|
||||
logger.Debug(context.Background(), "Socket file does not exist, path is available", slog.F("path", path))
|
||||
return true
|
||||
}
|
||||
logger.Debug(context.Background(), "Socket file exists, checking if it's listening", slog.F("path", path))
|
||||
|
||||
// Try to connect to see if it's actually listening
|
||||
conn, err := net.Dial("unix", path)
|
||||
if err != nil {
|
||||
// If we can't connect, the socket is not in use
|
||||
logger.Debug(context.Background(), "Cannot connect to socket, path is available", slog.F("path", path), slog.Error(err))
|
||||
return true
|
||||
}
|
||||
_ = conn.Close()
|
||||
logger.Debug(context.Background(), "Socket is listening, path is not available", slog.F("path", path))
|
||||
return false
|
||||
}
|
||||
|
||||
// getSocketInfo returns information about the socket file
|
||||
func GetSocketInfo(path string) (*SocketInfo, error) {
|
||||
stat, err := os.Stat(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// On Windows, we'll use a simplified approach for now
|
||||
// In a real implementation, you'd get the security descriptor
|
||||
return &SocketInfo{
|
||||
Path: path,
|
||||
UID: 0, // Simplified for now
|
||||
GID: 0, // Simplified for now
|
||||
Mode: stat.Mode(),
|
||||
ModTime: stat.ModTime(),
|
||||
Owner: "unknown",
|
||||
Group: "unknown",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SocketInfo contains information about a socket file
|
||||
type SocketInfo struct {
|
||||
Path string
|
||||
UID int
|
||||
GID int
|
||||
Mode os.FileMode
|
||||
ModTime time.Time
|
||||
Owner string // Windows SID string
|
||||
Group string // Windows SID string
|
||||
}
|
||||
307
agent/unit/manager.go
Normal file
307
agent/unit/manager.go
Normal file
@@ -0,0 +1,307 @@
|
||||
package unit
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
// ErrConsumerNotFound is returned when a consumer ID is not registered.
|
||||
var ErrConsumerNotFound = xerrors.New("consumer not found")
|
||||
|
||||
// ErrConsumerAlreadyRegistered is returned when a consumer ID is already registered.
|
||||
var ErrConsumerAlreadyRegistered = xerrors.New("consumer already registered")
|
||||
|
||||
// ErrCannotUpdateOtherConsumer is returned when attempting to update another consumer's status.
|
||||
var ErrCannotUpdateOtherConsumer = xerrors.New("cannot update other consumer's status")
|
||||
|
||||
// ErrDependenciesNotSatisfied is returned when a consumer's dependencies are not satisfied.
|
||||
var ErrDependenciesNotSatisfied = xerrors.New("unit dependencies not satisfied")
|
||||
|
||||
// ErrSameStatusAlreadySet is returned when attempting to set the same status as the current status.
|
||||
var ErrSameStatusAlreadySet = xerrors.New("same status already set")
|
||||
|
||||
// Status constants for dependency tracking
|
||||
const (
|
||||
StatusStarted = "started"
|
||||
StatusComplete = "completed"
|
||||
)
|
||||
|
||||
// dependencyVertex represents a vertex in the dependency graph that is associated with a consumer.
|
||||
type dependencyVertex[ConsumerID comparable] struct {
|
||||
ID ConsumerID
|
||||
}
|
||||
|
||||
// Dependency represents a dependency relationship between consumers.
|
||||
type Dependency[StatusType, ConsumerID comparable] struct {
|
||||
Consumer ConsumerID
|
||||
DependsOn ConsumerID
|
||||
RequiredStatus StatusType
|
||||
CurrentStatus StatusType
|
||||
IsSatisfied bool
|
||||
}
|
||||
|
||||
// Manager provides reactive dependency tracking over a Graph.
|
||||
// It manages consumer registration, dependency relationships, and status updates
|
||||
// with automatic recalculation of readiness when dependencies are satisfied.
|
||||
type Manager[StatusType, ConsumerID comparable] struct {
|
||||
mu sync.RWMutex
|
||||
|
||||
// The underlying graph that stores dependency relationships
|
||||
graph *Graph[StatusType, *dependencyVertex[ConsumerID]]
|
||||
|
||||
// Track current status of each consumer
|
||||
consumerStatus map[ConsumerID]StatusType
|
||||
|
||||
// Track readiness state (cached to avoid repeated graph traversal)
|
||||
consumerReadiness map[ConsumerID]bool
|
||||
|
||||
// Track which consumers are registered
|
||||
registeredConsumers map[ConsumerID]bool
|
||||
|
||||
// Store vertex instances for each consumer to ensure consistent references
|
||||
consumerVertices map[ConsumerID]*dependencyVertex[ConsumerID]
|
||||
}
|
||||
|
||||
// NewManager creates a new Manager instance.
|
||||
func NewManager[StatusType, ConsumerID comparable]() *Manager[StatusType, ConsumerID] {
|
||||
return &Manager[StatusType, ConsumerID]{
|
||||
graph: &Graph[StatusType, *dependencyVertex[ConsumerID]]{},
|
||||
consumerStatus: make(map[ConsumerID]StatusType),
|
||||
consumerReadiness: make(map[ConsumerID]bool),
|
||||
registeredConsumers: make(map[ConsumerID]bool),
|
||||
consumerVertices: make(map[ConsumerID]*dependencyVertex[ConsumerID]),
|
||||
}
|
||||
}
|
||||
|
||||
// Register registers a new consumer as a vertex in the dependency graph.
|
||||
func (dt *Manager[StatusType, ConsumerID]) Register(id ConsumerID) error {
|
||||
dt.mu.Lock()
|
||||
defer dt.mu.Unlock()
|
||||
|
||||
if dt.registeredConsumers[id] {
|
||||
return ErrConsumerAlreadyRegistered
|
||||
}
|
||||
|
||||
// Create and store the vertex for this consumer
|
||||
vertex := &dependencyVertex[ConsumerID]{ID: id}
|
||||
dt.consumerVertices[id] = vertex
|
||||
dt.registeredConsumers[id] = true
|
||||
dt.consumerReadiness[id] = true // New consumers start as ready (no dependencies)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddDependency adds a dependency relationship between consumers.
|
||||
// The consumer depends on the dependsOn consumer reaching the requiredStatus.
|
||||
func (dt *Manager[StatusType, ConsumerID]) AddDependency(consumer ConsumerID, dependsOn ConsumerID, requiredStatus StatusType) error {
|
||||
dt.mu.Lock()
|
||||
defer dt.mu.Unlock()
|
||||
|
||||
if !dt.registeredConsumers[consumer] {
|
||||
return xerrors.Errorf("consumer %v is not registered", consumer)
|
||||
}
|
||||
if !dt.registeredConsumers[dependsOn] {
|
||||
return xerrors.Errorf("consumer %v is not registered", dependsOn)
|
||||
}
|
||||
|
||||
// Get the stored vertices for both consumers
|
||||
consumerVertex := dt.consumerVertices[consumer]
|
||||
dependsOnVertex := dt.consumerVertices[dependsOn]
|
||||
|
||||
// Add the dependency edge to the graph
|
||||
// The edge goes from consumer to dependsOn, representing the dependency
|
||||
err := dt.graph.AddEdge(consumerVertex, dependsOnVertex, requiredStatus)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("failed to add dependency: %w", err)
|
||||
}
|
||||
|
||||
// Recalculate readiness for the consumer since it now has a dependency
|
||||
dt.recalculateReadinessUnsafe(consumer)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateStatus updates a consumer's status and recalculates readiness for affected dependents.
|
||||
func (dt *Manager[StatusType, ConsumerID]) UpdateStatus(consumer ConsumerID, newStatus StatusType) error {
|
||||
dt.mu.Lock()
|
||||
defer dt.mu.Unlock()
|
||||
|
||||
if !dt.registeredConsumers[consumer] {
|
||||
return ErrConsumerNotFound
|
||||
}
|
||||
|
||||
// Update the consumer's status
|
||||
if dt.consumerStatus[consumer] == newStatus {
|
||||
return ErrSameStatusAlreadySet
|
||||
}
|
||||
dt.consumerStatus[consumer] = newStatus
|
||||
|
||||
// Get all consumers that depend on this one (reverse adjacent vertices)
|
||||
consumerVertex := dt.consumerVertices[consumer]
|
||||
dependentEdges := dt.graph.GetReverseAdjacentVertices(consumerVertex)
|
||||
|
||||
// Recalculate readiness for all dependents
|
||||
for _, edge := range dependentEdges {
|
||||
dt.recalculateReadinessUnsafe(edge.From.ID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsReady checks if all dependencies for a consumer are satisfied.
|
||||
func (dt *Manager[StatusType, ConsumerID]) IsReady(consumer ConsumerID) (bool, error) {
|
||||
dt.mu.RLock()
|
||||
defer dt.mu.RUnlock()
|
||||
|
||||
if !dt.registeredConsumers[consumer] {
|
||||
return false, ErrConsumerNotFound
|
||||
}
|
||||
|
||||
return dt.consumerReadiness[consumer], nil
|
||||
}
|
||||
|
||||
// GetUnmetDependencies returns a list of unsatisfied dependencies for a consumer.
|
||||
func (dt *Manager[StatusType, ConsumerID]) GetUnmetDependencies(consumer ConsumerID) ([]Dependency[StatusType, ConsumerID], error) {
|
||||
dt.mu.RLock()
|
||||
defer dt.mu.RUnlock()
|
||||
|
||||
if !dt.registeredConsumers[consumer] {
|
||||
return nil, ErrConsumerNotFound
|
||||
}
|
||||
|
||||
consumerVertex := dt.consumerVertices[consumer]
|
||||
forwardEdges := dt.graph.GetForwardAdjacentVertices(consumerVertex)
|
||||
|
||||
var unmetDependencies []Dependency[StatusType, ConsumerID]
|
||||
|
||||
for _, edge := range forwardEdges {
|
||||
dependsOnConsumer := edge.To.ID
|
||||
requiredStatus := edge.Edge
|
||||
currentStatus, exists := dt.consumerStatus[dependsOnConsumer]
|
||||
if !exists {
|
||||
// If the dependency consumer has no status, it's not satisfied
|
||||
var zeroStatus StatusType
|
||||
unmetDependencies = append(unmetDependencies, Dependency[StatusType, ConsumerID]{
|
||||
Consumer: consumer,
|
||||
DependsOn: dependsOnConsumer,
|
||||
RequiredStatus: requiredStatus,
|
||||
CurrentStatus: zeroStatus, // Zero value
|
||||
IsSatisfied: false,
|
||||
})
|
||||
} else {
|
||||
isSatisfied := currentStatus == requiredStatus
|
||||
if !isSatisfied {
|
||||
unmetDependencies = append(unmetDependencies, Dependency[StatusType, ConsumerID]{
|
||||
Consumer: consumer,
|
||||
DependsOn: dependsOnConsumer,
|
||||
RequiredStatus: requiredStatus,
|
||||
CurrentStatus: currentStatus,
|
||||
IsSatisfied: false,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return unmetDependencies, nil
|
||||
}
|
||||
|
||||
// recalculateReadinessUnsafe recalculates the readiness state for a consumer.
|
||||
// This method assumes the caller holds the write lock.
|
||||
func (dt *Manager[StatusType, ConsumerID]) recalculateReadinessUnsafe(consumer ConsumerID) {
|
||||
consumerVertex := dt.consumerVertices[consumer]
|
||||
forwardEdges := dt.graph.GetForwardAdjacentVertices(consumerVertex)
|
||||
|
||||
// If there are no dependencies, the consumer is ready
|
||||
if len(forwardEdges) == 0 {
|
||||
dt.consumerReadiness[consumer] = true
|
||||
return
|
||||
}
|
||||
|
||||
// Check if all dependencies are satisfied
|
||||
allSatisfied := true
|
||||
for _, edge := range forwardEdges {
|
||||
dependsOnConsumer := edge.To.ID
|
||||
requiredStatus := edge.Edge
|
||||
currentStatus, exists := dt.consumerStatus[dependsOnConsumer]
|
||||
if !exists || currentStatus != requiredStatus {
|
||||
allSatisfied = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
dt.consumerReadiness[consumer] = allSatisfied
|
||||
}
|
||||
|
||||
// GetGraph returns the underlying graph for visualization and debugging.
|
||||
// This should be used carefully as it exposes the internal graph structure.
|
||||
func (dt *Manager[StatusType, ConsumerID]) GetGraph() *Graph[StatusType, *dependencyVertex[ConsumerID]] {
|
||||
return dt.graph
|
||||
}
|
||||
|
||||
// GetStatus returns the current status of a consumer.
|
||||
func (dt *Manager[StatusType, ConsumerID]) GetStatus(consumer ConsumerID) (StatusType, error) {
|
||||
dt.mu.RLock()
|
||||
defer dt.mu.RUnlock()
|
||||
|
||||
if !dt.registeredConsumers[consumer] {
|
||||
var zeroStatus StatusType
|
||||
return zeroStatus, ErrConsumerNotFound
|
||||
}
|
||||
|
||||
status, exists := dt.consumerStatus[consumer]
|
||||
if !exists {
|
||||
var zeroStatus StatusType
|
||||
return zeroStatus, nil
|
||||
}
|
||||
|
||||
return status, nil
|
||||
}
|
||||
|
||||
// GetAllDependencies returns all dependencies for a consumer, both satisfied and unsatisfied.
|
||||
func (dt *Manager[StatusType, ConsumerID]) GetAllDependencies(consumer ConsumerID) ([]Dependency[StatusType, ConsumerID], error) {
|
||||
dt.mu.RLock()
|
||||
defer dt.mu.RUnlock()
|
||||
|
||||
if !dt.registeredConsumers[consumer] {
|
||||
return nil, ErrConsumerNotFound
|
||||
}
|
||||
|
||||
consumerVertex := dt.consumerVertices[consumer]
|
||||
forwardEdges := dt.graph.GetForwardAdjacentVertices(consumerVertex)
|
||||
|
||||
var allDependencies []Dependency[StatusType, ConsumerID]
|
||||
|
||||
for _, edge := range forwardEdges {
|
||||
dependsOnConsumer := edge.To.ID
|
||||
requiredStatus := edge.Edge
|
||||
currentStatus, exists := dt.consumerStatus[dependsOnConsumer]
|
||||
if !exists {
|
||||
// If the dependency consumer has no status, it's not satisfied
|
||||
var zeroStatus StatusType
|
||||
allDependencies = append(allDependencies, Dependency[StatusType, ConsumerID]{
|
||||
Consumer: consumer,
|
||||
DependsOn: dependsOnConsumer,
|
||||
RequiredStatus: requiredStatus,
|
||||
CurrentStatus: zeroStatus, // Zero value
|
||||
IsSatisfied: false,
|
||||
})
|
||||
} else {
|
||||
isSatisfied := currentStatus == requiredStatus
|
||||
allDependencies = append(allDependencies, Dependency[StatusType, ConsumerID]{
|
||||
Consumer: consumer,
|
||||
DependsOn: dependsOnConsumer,
|
||||
RequiredStatus: requiredStatus,
|
||||
CurrentStatus: currentStatus,
|
||||
IsSatisfied: isSatisfied,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return allDependencies, nil
|
||||
}
|
||||
|
||||
// ExportDOT exports the dependency graph to DOT format for visualization.
|
||||
func (dt *Manager[StatusType, ConsumerID]) ExportDOT(name string) (string, error) {
|
||||
return dt.graph.ToDOT(name)
|
||||
}
|
||||
691
agent/unit/manager_test.go
Normal file
691
agent/unit/manager_test.go
Normal file
@@ -0,0 +1,691 @@
|
||||
package unit_test
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/agent/unit"
|
||||
)
|
||||
|
||||
type testStatus string
|
||||
|
||||
const (
|
||||
statusStarted testStatus = "started"
|
||||
statusRunning testStatus = "running"
|
||||
statusCompleted testStatus = "completed"
|
||||
)
|
||||
|
||||
type testConsumerID string
|
||||
|
||||
const (
|
||||
consumerA testConsumerID = "serviceA"
|
||||
consumerB testConsumerID = "serviceB"
|
||||
consumerC testConsumerID = "serviceC"
|
||||
consumerD testConsumerID = "serviceD"
|
||||
)
|
||||
|
||||
func TestDependencyTracker_Register(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewManager[testStatus, testConsumerID]()
|
||||
|
||||
t.Run("RegisterNewConsumer", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
err := tracker.Register(consumerA)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Consumer should be ready initially (no dependencies)
|
||||
ready, err := tracker.IsReady(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ready)
|
||||
})
|
||||
|
||||
t.Run("RegisterDuplicateConsumer", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewManager[testStatus, testConsumerID]()
|
||||
err := tracker.Register(consumerA)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = tracker.Register(consumerA)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "already registered")
|
||||
})
|
||||
|
||||
t.Run("RegisterMultipleConsumers", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewManager[testStatus, testConsumerID]()
|
||||
|
||||
consumers := []testConsumerID{consumerA, consumerB, consumerC}
|
||||
for _, consumer := range consumers {
|
||||
err := tracker.Register(consumer)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// All should be ready initially
|
||||
for _, consumer := range consumers {
|
||||
ready, err := tracker.IsReady(consumer)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ready)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestDependencyTracker_AddDependency(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("AddDependencyBetweenRegisteredConsumers", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewManager[testStatus, testConsumerID]()
|
||||
err := tracker.Register(consumerA)
|
||||
require.NoError(t, err)
|
||||
err = tracker.Register(consumerB)
|
||||
require.NoError(t, err)
|
||||
|
||||
// A depends on B being "running"
|
||||
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
|
||||
// A should no longer be ready (depends on B)
|
||||
ready, err := tracker.IsReady(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, ready)
|
||||
|
||||
// B should still be ready (no dependencies)
|
||||
ready, err = tracker.IsReady(consumerB)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ready)
|
||||
})
|
||||
|
||||
t.Run("AddDependencyWithUnregisteredConsumer", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewManager[testStatus, testConsumerID]()
|
||||
err := tracker.Register(consumerA)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to add dependency to unregistered consumer
|
||||
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "not registered")
|
||||
})
|
||||
|
||||
t.Run("AddDependencyFromUnregisteredConsumer", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewManager[testStatus, testConsumerID]()
|
||||
err := tracker.Register(consumerB)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to add dependency from unregistered consumer
|
||||
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "not registered")
|
||||
})
|
||||
}
|
||||
|
||||
func TestDependencyTracker_UpdateStatus(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("UpdateStatusTriggersReadinessRecalculation", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewManager[testStatus, testConsumerID]()
|
||||
err := tracker.Register(consumerA)
|
||||
require.NoError(t, err)
|
||||
err = tracker.Register(consumerB)
|
||||
require.NoError(t, err)
|
||||
|
||||
// A depends on B being "running"
|
||||
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Initially A is not ready
|
||||
ready, err := tracker.IsReady(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, ready)
|
||||
|
||||
// Update B to "running" - A should become ready
|
||||
err = tracker.UpdateStatus(consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
|
||||
ready, err = tracker.IsReady(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ready)
|
||||
})
|
||||
|
||||
t.Run("UpdateStatusWithUnregisteredConsumer", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewManager[testStatus, testConsumerID]()
|
||||
|
||||
err := tracker.UpdateStatus(consumerA, statusRunning)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, unit.ErrConsumerNotFound, err)
|
||||
})
|
||||
|
||||
t.Run("LinearChainDependencies", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewManager[testStatus, testConsumerID]()
|
||||
|
||||
// Register all consumers
|
||||
consumers := []testConsumerID{consumerA, consumerB, consumerC}
|
||||
for _, consumer := range consumers {
|
||||
err := tracker.Register(consumer)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Create chain: A depends on B being "started", B depends on C being "completed"
|
||||
err := tracker.AddDependency(consumerA, consumerB, statusStarted)
|
||||
require.NoError(t, err)
|
||||
err = tracker.AddDependency(consumerB, consumerC, statusCompleted)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Initially only C is ready
|
||||
ready, err := tracker.IsReady(consumerC)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ready)
|
||||
|
||||
ready, err = tracker.IsReady(consumerB)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, ready)
|
||||
|
||||
ready, err = tracker.IsReady(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, ready)
|
||||
|
||||
// Update C to "completed" - B should become ready
|
||||
err = tracker.UpdateStatus(consumerC, statusCompleted)
|
||||
require.NoError(t, err)
|
||||
|
||||
ready, err = tracker.IsReady(consumerB)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ready)
|
||||
|
||||
ready, err = tracker.IsReady(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, ready)
|
||||
|
||||
// Update B to "started" - A should become ready
|
||||
err = tracker.UpdateStatus(consumerB, statusStarted)
|
||||
require.NoError(t, err)
|
||||
|
||||
ready, err = tracker.IsReady(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ready)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDependencyTracker_GetUnmetDependencies(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("GetUnmetDependenciesForConsumerWithNoDependencies", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewManager[testStatus, testConsumerID]()
|
||||
err := tracker.Register(consumerA)
|
||||
require.NoError(t, err)
|
||||
|
||||
unmet, err := tracker.GetUnmetDependencies(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, unmet)
|
||||
})
|
||||
|
||||
t.Run("GetUnmetDependenciesForConsumerWithUnsatisfiedDependencies", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewManager[testStatus, testConsumerID]()
|
||||
err := tracker.Register(consumerA)
|
||||
require.NoError(t, err)
|
||||
err = tracker.Register(consumerB)
|
||||
require.NoError(t, err)
|
||||
|
||||
// A depends on B being "running"
|
||||
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
|
||||
unmet, err := tracker.GetUnmetDependencies(consumerA)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, unmet, 1)
|
||||
|
||||
assert.Equal(t, consumerA, unmet[0].Consumer)
|
||||
assert.Equal(t, consumerB, unmet[0].DependsOn)
|
||||
assert.Equal(t, statusRunning, unmet[0].RequiredStatus)
|
||||
assert.False(t, unmet[0].IsSatisfied)
|
||||
})
|
||||
|
||||
t.Run("GetUnmetDependenciesForConsumerWithSatisfiedDependencies", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewManager[testStatus, testConsumerID]()
|
||||
err := tracker.Register(consumerA)
|
||||
require.NoError(t, err)
|
||||
err = tracker.Register(consumerB)
|
||||
require.NoError(t, err)
|
||||
|
||||
// A depends on B being "running"
|
||||
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Update B to "running"
|
||||
err = tracker.UpdateStatus(consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
|
||||
unmet, err := tracker.GetUnmetDependencies(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, unmet)
|
||||
})
|
||||
|
||||
t.Run("GetUnmetDependenciesForUnregisteredConsumer", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewManager[testStatus, testConsumerID]()
|
||||
|
||||
unmet, err := tracker.GetUnmetDependencies(consumerA)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, unit.ErrConsumerNotFound, err)
|
||||
assert.Nil(t, unmet)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDependencyTracker_ConcurrentOperations(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("ConcurrentStatusUpdates", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewManager[testStatus, testConsumerID]()
|
||||
|
||||
// Register consumers
|
||||
consumers := []testConsumerID{consumerA, consumerB, consumerC, consumerD}
|
||||
for _, consumer := range consumers {
|
||||
err := tracker.Register(consumer)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Create dependencies: A depends on B, B depends on C, C depends on D
|
||||
err := tracker.AddDependency(consumerA, consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
err = tracker.AddDependency(consumerB, consumerC, statusStarted)
|
||||
require.NoError(t, err)
|
||||
err = tracker.AddDependency(consumerC, consumerD, statusCompleted)
|
||||
require.NoError(t, err)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
const numGoroutines = 10
|
||||
|
||||
// Launch goroutines that update statuses
|
||||
errors := make([]error, numGoroutines)
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(goroutineID int) {
|
||||
defer wg.Done()
|
||||
|
||||
// Update D to completed (should make C ready)
|
||||
err := tracker.UpdateStatus(consumerD, statusCompleted)
|
||||
if err != nil {
|
||||
errors[goroutineID] = err
|
||||
return
|
||||
}
|
||||
|
||||
// Update C to started (should make B ready)
|
||||
err = tracker.UpdateStatus(consumerC, statusStarted)
|
||||
if err != nil {
|
||||
errors[goroutineID] = err
|
||||
return
|
||||
}
|
||||
|
||||
// Update B to running (should make A ready)
|
||||
err = tracker.UpdateStatus(consumerB, statusRunning)
|
||||
if err != nil {
|
||||
errors[goroutineID] = err
|
||||
return
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Check for any errors in goroutines
|
||||
for i, err := range errors {
|
||||
require.NoError(t, err, "goroutine %d had error", i)
|
||||
}
|
||||
|
||||
// All consumers should be ready after the updates
|
||||
for _, consumer := range consumers {
|
||||
ready, err := tracker.IsReady(consumer)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ready)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ConcurrentReadinessChecks", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewManager[testStatus, testConsumerID]()
|
||||
|
||||
// Register consumers
|
||||
err := tracker.Register(consumerA)
|
||||
require.NoError(t, err)
|
||||
err = tracker.Register(consumerB)
|
||||
require.NoError(t, err)
|
||||
|
||||
// A depends on B being "running"
|
||||
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
const numGoroutines = 20
|
||||
|
||||
// Launch goroutines that check readiness
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(goroutineID int) {
|
||||
defer wg.Done()
|
||||
|
||||
// Check readiness multiple times
|
||||
for j := 0; j < 10; j++ {
|
||||
ready, err := tracker.IsReady(consumerA)
|
||||
require.NoError(t, err)
|
||||
// Initially should be false, then true after B is updated
|
||||
_ = ready
|
||||
|
||||
ready, err = tracker.IsReady(consumerB)
|
||||
require.NoError(t, err)
|
||||
// B should always be ready (no dependencies)
|
||||
assert.True(t, ready)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Update B to "running" in the middle of readiness checks
|
||||
err = tracker.UpdateStatus(consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
|
||||
wg.Wait()
|
||||
})
|
||||
}
|
||||
|
||||
func TestDependencyTracker_MultipleDependencies(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("ConsumerWithMultipleDependencies", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewManager[testStatus, testConsumerID]()
|
||||
|
||||
// Register all consumers
|
||||
consumers := []testConsumerID{consumerA, consumerB, consumerC, consumerD}
|
||||
for _, consumer := range consumers {
|
||||
err := tracker.Register(consumer)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// A depends on B being "running" AND C being "started"
|
||||
err := tracker.AddDependency(consumerA, consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
err = tracker.AddDependency(consumerA, consumerC, statusStarted)
|
||||
require.NoError(t, err)
|
||||
|
||||
// A should not be ready (depends on both B and C)
|
||||
ready, err := tracker.IsReady(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, ready)
|
||||
|
||||
// Update B to "running" - A should still not be ready (needs C too)
|
||||
err = tracker.UpdateStatus(consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
|
||||
ready, err = tracker.IsReady(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, ready)
|
||||
|
||||
// Update C to "started" - A should now be ready
|
||||
err = tracker.UpdateStatus(consumerC, statusStarted)
|
||||
require.NoError(t, err)
|
||||
|
||||
ready, err = tracker.IsReady(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ready)
|
||||
})
|
||||
|
||||
t.Run("ComplexDependencyChain", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewManager[testStatus, testConsumerID]()
|
||||
|
||||
// Register all consumers
|
||||
consumers := []testConsumerID{consumerA, consumerB, consumerC, consumerD}
|
||||
for _, consumer := range consumers {
|
||||
err := tracker.Register(consumer)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Create complex dependency graph:
|
||||
// A depends on B being "running" AND C being "started"
|
||||
// B depends on D being "completed"
|
||||
// C depends on D being "completed"
|
||||
err := tracker.AddDependency(consumerA, consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
err = tracker.AddDependency(consumerA, consumerC, statusStarted)
|
||||
require.NoError(t, err)
|
||||
err = tracker.AddDependency(consumerB, consumerD, statusCompleted)
|
||||
require.NoError(t, err)
|
||||
err = tracker.AddDependency(consumerC, consumerD, statusCompleted)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Initially only D is ready
|
||||
ready, err := tracker.IsReady(consumerD)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ready)
|
||||
|
||||
ready, err = tracker.IsReady(consumerB)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, ready)
|
||||
|
||||
ready, err = tracker.IsReady(consumerC)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, ready)
|
||||
|
||||
ready, err = tracker.IsReady(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, ready)
|
||||
|
||||
// Update D to "completed" - B and C should become ready
|
||||
err = tracker.UpdateStatus(consumerD, statusCompleted)
|
||||
require.NoError(t, err)
|
||||
|
||||
ready, err = tracker.IsReady(consumerB)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ready)
|
||||
|
||||
ready, err = tracker.IsReady(consumerC)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ready)
|
||||
|
||||
ready, err = tracker.IsReady(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, ready)
|
||||
|
||||
// Update B to "running" - A should still not be ready (needs C)
|
||||
err = tracker.UpdateStatus(consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
|
||||
ready, err = tracker.IsReady(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, ready)
|
||||
|
||||
// Update C to "started" - A should now be ready
|
||||
err = tracker.UpdateStatus(consumerC, statusStarted)
|
||||
require.NoError(t, err)
|
||||
|
||||
ready, err = tracker.IsReady(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ready)
|
||||
})
|
||||
|
||||
t.Run("DifferentStatusTypes", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewManager[testStatus, testConsumerID]()
|
||||
|
||||
// Register consumers
|
||||
err := tracker.Register(consumerA)
|
||||
require.NoError(t, err)
|
||||
err = tracker.Register(consumerB)
|
||||
require.NoError(t, err)
|
||||
err = tracker.Register(consumerC)
|
||||
require.NoError(t, err)
|
||||
|
||||
// A depends on B being "running" AND C being "completed"
|
||||
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
err = tracker.AddDependency(consumerA, consumerC, statusCompleted)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Update B to "running" but not C - A should not be ready
|
||||
err = tracker.UpdateStatus(consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
|
||||
ready, err := tracker.IsReady(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, ready)
|
||||
|
||||
// Update C to "completed" - A should now be ready
|
||||
err = tracker.UpdateStatus(consumerC, statusCompleted)
|
||||
require.NoError(t, err)
|
||||
|
||||
ready, err = tracker.IsReady(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ready)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDependencyTracker_ErrorCases(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("UpdateStatusWithUnregisteredConsumer", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewManager[testStatus, testConsumerID]()
|
||||
|
||||
err := tracker.UpdateStatus(consumerA, statusRunning)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, unit.ErrConsumerNotFound, err)
|
||||
})
|
||||
|
||||
t.Run("IsReadyWithUnregisteredConsumer", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewManager[testStatus, testConsumerID]()
|
||||
|
||||
ready, err := tracker.IsReady(consumerA)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, unit.ErrConsumerNotFound, err)
|
||||
assert.False(t, ready)
|
||||
})
|
||||
|
||||
t.Run("GetUnmetDependenciesWithUnregisteredConsumer", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewManager[testStatus, testConsumerID]()
|
||||
|
||||
unmet, err := tracker.GetUnmetDependencies(consumerA)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, unit.ErrConsumerNotFound, err)
|
||||
assert.Nil(t, unmet)
|
||||
})
|
||||
|
||||
t.Run("AddDependencyWithUnregisteredConsumers", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewManager[testStatus, testConsumerID]()
|
||||
|
||||
// Try to add dependency with unregistered consumers
|
||||
err := tracker.AddDependency(consumerA, consumerB, statusRunning)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "not registered")
|
||||
})
|
||||
|
||||
t.Run("CyclicDependencyDetection", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewManager[testStatus, testConsumerID]()
|
||||
|
||||
// Register consumers
|
||||
err := tracker.Register(consumerA)
|
||||
require.NoError(t, err)
|
||||
err = tracker.Register(consumerB)
|
||||
require.NoError(t, err)
|
||||
|
||||
// A depends on B
|
||||
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to make B depend on A (creates cycle)
|
||||
err = tracker.AddDependency(consumerB, consumerA, statusStarted)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "would create a cycle")
|
||||
})
|
||||
}
|
||||
|
||||
func TestDependencyTracker_ToDOT(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("ExportSimpleGraph", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewManager[testStatus, testConsumerID]()
|
||||
|
||||
// Register consumers
|
||||
err := tracker.Register(consumerA)
|
||||
require.NoError(t, err)
|
||||
err = tracker.Register(consumerB)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add dependency
|
||||
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
|
||||
dot, err := tracker.ExportDOT("test")
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, dot)
|
||||
assert.Contains(t, dot, "digraph")
|
||||
})
|
||||
|
||||
t.Run("ExportComplexGraph", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewManager[testStatus, testConsumerID]()
|
||||
|
||||
// Register all consumers
|
||||
consumers := []testConsumerID{consumerA, consumerB, consumerC, consumerD}
|
||||
for _, consumer := range consumers {
|
||||
err := tracker.Register(consumer)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Create complex dependency graph
|
||||
// A depends on B and C, B depends on D, C depends on D
|
||||
err := tracker.AddDependency(consumerA, consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
err = tracker.AddDependency(consumerA, consumerC, statusStarted)
|
||||
require.NoError(t, err)
|
||||
err = tracker.AddDependency(consumerB, consumerD, statusCompleted)
|
||||
require.NoError(t, err)
|
||||
err = tracker.AddDependency(consumerC, consumerD, statusCompleted)
|
||||
require.NoError(t, err)
|
||||
|
||||
dot, err := tracker.ExportDOT("complex")
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, dot)
|
||||
assert.Contains(t, dot, "digraph")
|
||||
})
|
||||
}
|
||||
@@ -56,6 +56,7 @@ func workspaceAgent() *serpent.Command {
|
||||
devcontainers bool
|
||||
devcontainerProjectDiscovery bool
|
||||
devcontainerDiscoveryAutostart bool
|
||||
socketPath string
|
||||
)
|
||||
agentAuth := &AgentAuth{}
|
||||
cmd := &serpent.Command{
|
||||
@@ -297,6 +298,7 @@ func workspaceAgent() *serpent.Command {
|
||||
agentcontainers.WithProjectDiscovery(devcontainerProjectDiscovery),
|
||||
agentcontainers.WithDiscoveryAutostart(devcontainerDiscoveryAutostart),
|
||||
},
|
||||
SocketPath: socketPath,
|
||||
})
|
||||
|
||||
promHandler := agent.PrometheusMetricsHandler(prometheusRegistry, logger)
|
||||
@@ -449,6 +451,12 @@ func workspaceAgent() *serpent.Command {
|
||||
Description: "Allow the agent to autostart devcontainer projects it discovers based on their configuration.",
|
||||
Value: serpent.BoolOf(&devcontainerDiscoveryAutostart),
|
||||
},
|
||||
{
|
||||
Flag: "socket-path",
|
||||
Env: "CODER_AGENT_SOCKET_PATH",
|
||||
Description: "Specify the path for the agent socket.",
|
||||
Value: serpent.StringOf(&socketPath),
|
||||
},
|
||||
}
|
||||
agentAuth.AttachOptions(cmd, false)
|
||||
return cmd
|
||||
|
||||
@@ -144,6 +144,7 @@ func (r *RootCmd) AGPLExperimental() []*serpent.Command {
|
||||
r.mcpCommand(),
|
||||
r.promptExample(),
|
||||
r.rptyCommand(),
|
||||
r.syncCommand(),
|
||||
r.tasksCommand(),
|
||||
r.boundary(),
|
||||
}
|
||||
|
||||
25
cli/sync.go
Normal file
25
cli/sync.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
|
||||
func (r *RootCmd) syncCommand() *serpent.Command {
|
||||
cmd := &serpent.Command{
|
||||
Use: "sync",
|
||||
Short: "Synchronize with the local agent socket",
|
||||
Long: "Commands for interacting with the local Coder agent via socket communication.",
|
||||
Handler: func(i *serpent.Invocation) error {
|
||||
return i.Command.HelpHandler(i)
|
||||
},
|
||||
Children: []*serpent.Command{
|
||||
r.syncPing(),
|
||||
r.syncStart(),
|
||||
r.syncWant(),
|
||||
r.syncComplete(),
|
||||
r.syncWait(),
|
||||
r.syncStatus(),
|
||||
},
|
||||
}
|
||||
return cmd
|
||||
}
|
||||
50
cli/sync_complete.go
Normal file
50
cli/sync_complete.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/serpent"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk/agentsdk"
|
||||
)
|
||||
|
||||
func (r *RootCmd) syncComplete() *serpent.Command {
|
||||
return &serpent.Command{
|
||||
Use: "complete <unit>",
|
||||
Short: "Mark a unit as complete in the dependency graph",
|
||||
Long: "Set a unit's status to complete in the dependency graph.",
|
||||
Handler: func(i *serpent.Invocation) error {
|
||||
ctx := context.Background()
|
||||
|
||||
if len(i.Args) != 1 {
|
||||
return xerrors.New("exactly one unit name is required")
|
||||
}
|
||||
unit := i.Args[0]
|
||||
|
||||
// Show initial message
|
||||
fmt.Printf("Completing unit '%s'...\n", unit)
|
||||
|
||||
// Connect to agent socket
|
||||
client, err := agentsdk.NewSocketClient(agentsdk.SocketConfig{
|
||||
Path: "/tmp/coder.sock",
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("connect to agent socket: %w", err)
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
// Complete the unit
|
||||
if err := client.SyncComplete(ctx, unit); err != nil {
|
||||
return xerrors.Errorf("complete unit failed: %w", err)
|
||||
}
|
||||
|
||||
// Display success message
|
||||
fmt.Printf("Unit '%s' completed successfully\n", unit)
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
53
cli/sync_ping.go
Normal file
53
cli/sync_ping.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/serpent"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk/agentsdk"
|
||||
)
|
||||
|
||||
func (r *RootCmd) syncPing() *serpent.Command {
|
||||
return &serpent.Command{
|
||||
Use: "ping",
|
||||
Short: "Ping the local agent socket",
|
||||
Long: "Test connectivity to the local Coder agent via socket communication.",
|
||||
Handler: func(i *serpent.Invocation) error {
|
||||
ctx := context.Background()
|
||||
|
||||
// Show initial message
|
||||
fmt.Println("Pinging agent socket...")
|
||||
|
||||
// Connect to agent socket
|
||||
client, err := agentsdk.NewSocketClient(agentsdk.SocketConfig{
|
||||
Path: "/tmp/coder.sock",
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("connect to agent socket: %w", err)
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
// Measure round-trip time
|
||||
start := time.Now()
|
||||
resp, err := client.Ping(ctx)
|
||||
duration := time.Since(start)
|
||||
|
||||
if err != nil {
|
||||
return xerrors.Errorf("ping failed: %w", err)
|
||||
}
|
||||
|
||||
// Display results
|
||||
fmt.Printf("Response: %s\n", resp.Message)
|
||||
fmt.Printf("Timestamp: %s\n", resp.Timestamp.Format(time.RFC3339))
|
||||
fmt.Printf("Round-trip time: %s\n", duration.Round(time.Microsecond))
|
||||
fmt.Println("Status: healthy")
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
122
cli/sync_start.go
Normal file
122
cli/sync_start.go
Normal file
@@ -0,0 +1,122 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/serpent"
|
||||
|
||||
"github.com/coder/coder/v2/agent/unit"
|
||||
"github.com/coder/coder/v2/codersdk/agentsdk"
|
||||
)
|
||||
|
||||
const (
|
||||
// SyncPollInterval is the interval between dependency checks for sync start
|
||||
SyncPollInterval = 1 * time.Second
|
||||
)
|
||||
|
||||
func (r *RootCmd) syncStart() *serpent.Command {
|
||||
var timeout time.Duration
|
||||
|
||||
cmd := &serpent.Command{
|
||||
Use: "start <unit>",
|
||||
Short: "Start a unit in the dependency graph",
|
||||
Long: "Register a unit in the dependency graph and set its status to started. Waits for all dependencies to be satisfied before marking as started.",
|
||||
Handler: func(i *serpent.Invocation) error {
|
||||
ctx := context.Background()
|
||||
|
||||
if len(i.Args) != 1 {
|
||||
return xerrors.New("exactly one unit name is required")
|
||||
}
|
||||
unitName := i.Args[0]
|
||||
|
||||
// Set up context with timeout if specified
|
||||
if timeout > 0 {
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
// Show initial message
|
||||
fmt.Printf("Starting unit '%s'...\n", unitName)
|
||||
|
||||
// Connect to agent socket
|
||||
client, err := agentsdk.NewSocketClient(agentsdk.SocketConfig{
|
||||
Path: "/tmp/coder.sock",
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("connect to agent socket: %w", err)
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
// Check if dependencies are satisfied first
|
||||
err = client.SyncReady(ctx, unitName)
|
||||
if err != nil {
|
||||
// Check if it's a "not ready" error (expected if dependencies exist)
|
||||
if xerrors.Is(err, unit.ErrDependenciesNotSatisfied) {
|
||||
// Dependencies exist but aren't satisfied, start polling
|
||||
fmt.Printf("Waiting for dependencies of unit '%s' to be satisfied...\n", unitName)
|
||||
|
||||
// Poll until dependencies are satisfied
|
||||
ticker := time.NewTicker(SyncPollInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
pollLoop:
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if ctx.Err() == context.DeadlineExceeded {
|
||||
return xerrors.Errorf("timeout waiting for dependencies of unit '%s'", unitName)
|
||||
}
|
||||
return ctx.Err()
|
||||
case <-ticker.C:
|
||||
// Check if dependencies are satisfied
|
||||
err := client.SyncReady(ctx, unitName)
|
||||
if err == nil {
|
||||
// Dependencies are satisfied
|
||||
fmt.Printf("Dependencies satisfied, marking unit '%s' as started\n", unitName)
|
||||
break pollLoop
|
||||
}
|
||||
|
||||
// Check if it's still a "not ready" error (expected while waiting)
|
||||
if xerrors.Is(err, unit.ErrDependenciesNotSatisfied) {
|
||||
// Still waiting, continue polling
|
||||
continue
|
||||
}
|
||||
|
||||
// Some other error occurred
|
||||
return xerrors.Errorf("error checking dependencies: %w", err)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Some other error occurred
|
||||
return xerrors.Errorf("error checking dependencies: %w", err)
|
||||
}
|
||||
} else {
|
||||
// No dependencies or already satisfied
|
||||
fmt.Printf("Dependencies satisfied, marking unit '%s' as started\n", unitName)
|
||||
}
|
||||
|
||||
// Start the unit
|
||||
if err := client.SyncStart(ctx, unitName); err != nil {
|
||||
return xerrors.Errorf("start unit failed: %w", err)
|
||||
}
|
||||
|
||||
// Display success message
|
||||
fmt.Printf("Unit '%s' started successfully\n", unitName)
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Options = append(cmd.Options, serpent.Option{
|
||||
Flag: "timeout",
|
||||
Description: "Maximum time to wait for dependencies (e.g., 30s, 5m). No timeout by default.",
|
||||
Value: serpent.DurationOf(&timeout),
|
||||
})
|
||||
|
||||
return cmd
|
||||
}
|
||||
134
cli/sync_status.go
Normal file
134
cli/sync_status.go
Normal file
@@ -0,0 +1,134 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/serpent"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk/agentsdk"
|
||||
)
|
||||
|
||||
type outputFormat string
|
||||
|
||||
const (
|
||||
outputFormatHuman outputFormat = "human"
|
||||
outputFormatJSON outputFormat = "json"
|
||||
outputFormatDOT outputFormat = "dot"
|
||||
)
|
||||
|
||||
func (r *RootCmd) syncStatus() *serpent.Command {
|
||||
var (
|
||||
output string
|
||||
recursive bool
|
||||
)
|
||||
|
||||
cmd := &serpent.Command{
|
||||
Use: "status <unit>",
|
||||
Short: "Show the status of a unit and its dependencies",
|
||||
Long: "Display the current status of a unit and information about its dependencies. Supports multiple output formats.",
|
||||
Handler: func(i *serpent.Invocation) error {
|
||||
ctx := context.Background()
|
||||
|
||||
if len(i.Args) != 1 {
|
||||
return xerrors.New("exactly one unit name is required")
|
||||
}
|
||||
unit := i.Args[0]
|
||||
|
||||
// Connect to agent socket
|
||||
client, err := agentsdk.NewSocketClient(agentsdk.SocketConfig{
|
||||
Path: "/tmp/coder.sock",
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("connect to agent socket: %w", err)
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
// Get status information
|
||||
statusResp, err := client.SyncStatus(ctx, unit, recursive)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get status failed: %w", err)
|
||||
}
|
||||
|
||||
// Output based on format
|
||||
switch outputFormat(output) {
|
||||
case outputFormatJSON:
|
||||
return outputJSON(statusResp)
|
||||
case outputFormatDOT:
|
||||
return outputDOT(statusResp)
|
||||
default: // outputFormatHuman
|
||||
return outputHuman(statusResp)
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Options = append(cmd.Options,
|
||||
serpent.Option{
|
||||
Flag: "output",
|
||||
FlagShorthand: "o",
|
||||
Description: "Output format: human, json, or dot.",
|
||||
Value: serpent.EnumOf(&output, "human", "json", "dot"),
|
||||
},
|
||||
serpent.Option{
|
||||
Flag: "recursive",
|
||||
FlagShorthand: "r",
|
||||
Description: "Show transitive dependencies and include DOT graph.",
|
||||
Value: serpent.BoolOf(&recursive),
|
||||
},
|
||||
)
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
func outputJSON(statusResp *agentsdk.SyncStatusResponse) error {
|
||||
encoder := json.NewEncoder(os.Stdout)
|
||||
encoder.SetIndent("", " ")
|
||||
return encoder.Encode(statusResp)
|
||||
}
|
||||
|
||||
func outputDOT(statusResp *agentsdk.SyncStatusResponse) error {
|
||||
if statusResp.DOT == "" {
|
||||
return xerrors.New("DOT output requires --recursive flag")
|
||||
}
|
||||
fmt.Println(statusResp.DOT)
|
||||
return nil
|
||||
}
|
||||
|
||||
func outputHuman(statusResp *agentsdk.SyncStatusResponse) error {
|
||||
// Unit status
|
||||
fmt.Printf("Unit: %s\n", statusResp.Unit)
|
||||
fmt.Printf("Status: %s\n", statusResp.Status)
|
||||
fmt.Printf("Ready: %t\n", statusResp.IsReady)
|
||||
fmt.Println()
|
||||
|
||||
// Dependencies
|
||||
if len(statusResp.Dependencies) == 0 {
|
||||
fmt.Println("No dependencies")
|
||||
return nil
|
||||
}
|
||||
|
||||
fmt.Println("Dependencies:")
|
||||
fmt.Println(strings.Repeat("-", 80))
|
||||
fmt.Printf("%-20s %-15s %-15s %-10s\n", "Depends On", "Required", "Current", "Satisfied")
|
||||
fmt.Println(strings.Repeat("-", 80))
|
||||
|
||||
for _, dep := range statusResp.Dependencies {
|
||||
satisfied := "✓"
|
||||
if !dep.IsSatisfied {
|
||||
satisfied = "✗"
|
||||
}
|
||||
fmt.Printf("%-20s %-15s %-15s %-10s\n",
|
||||
dep.DependsOn,
|
||||
dep.RequiredStatus,
|
||||
dep.CurrentStatus,
|
||||
satisfied,
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
359
cli/sync_test.go
Normal file
359
cli/sync_test.go
Normal file
@@ -0,0 +1,359 @@
|
||||
package cli_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/agent/unit"
|
||||
"github.com/coder/coder/v2/cli/clitest"
|
||||
)
|
||||
|
||||
// mockAgentSocketServer simulates the agent socket server for testing
|
||||
type mockAgentSocketServer struct {
|
||||
listener net.Listener
|
||||
handlers map[string]func(string) (string, error)
|
||||
}
|
||||
|
||||
func newMockAgentSocketServer(t *testing.T, socketPath string) *mockAgentSocketServer {
|
||||
listener, err := net.Listen("unix", socketPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
server := &mockAgentSocketServer{
|
||||
listener: listener,
|
||||
handlers: make(map[string]func(string) (string, error)),
|
||||
}
|
||||
|
||||
// Set up default handlers
|
||||
server.handlers["sync.wait"] = func(unitName string) (string, error) {
|
||||
// Always return dependencies not satisfied to trigger polling
|
||||
return "", unit.ErrDependenciesNotSatisfied
|
||||
}
|
||||
|
||||
server.handlers["sync.start"] = func(unitName string) (string, error) {
|
||||
return "Unit " + unitName + " started successfully", nil
|
||||
}
|
||||
|
||||
go server.serve(t)
|
||||
return server
|
||||
}
|
||||
|
||||
func (s *mockAgentSocketServer) serve(t *testing.T) {
|
||||
for {
|
||||
conn, err := s.listener.Accept()
|
||||
if err != nil {
|
||||
if !errors.Is(err, net.ErrClosed) {
|
||||
t.Logf("Accept error: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
go s.handleConnection(t, conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *mockAgentSocketServer) handleConnection(t *testing.T, conn net.Conn) {
|
||||
defer conn.Close()
|
||||
|
||||
// Simple JSON-RPC-like protocol simulation
|
||||
buf := make([]byte, 1024)
|
||||
n, err := conn.Read(buf)
|
||||
if err != nil {
|
||||
t.Logf("Read error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
request := string(buf[:n])
|
||||
|
||||
// Parse method from request (simplified)
|
||||
var method string
|
||||
if strings.Contains(request, "sync.wait") {
|
||||
method = "sync.wait"
|
||||
} else if strings.Contains(request, "sync.start") {
|
||||
method = "sync.start"
|
||||
}
|
||||
|
||||
handler, exists := s.handlers[method]
|
||||
if !exists {
|
||||
response := `{"error": {"code": -32601, "message": "Method not found"}}`
|
||||
_, _ = conn.Write([]byte(response))
|
||||
return
|
||||
}
|
||||
|
||||
// Extract unit name from request (simplified)
|
||||
unitName := "test-unit"
|
||||
if strings.Contains(request, "test-unit") {
|
||||
unitName = "test-unit"
|
||||
}
|
||||
|
||||
message, err := handler(unitName)
|
||||
if err != nil {
|
||||
response := fmt.Sprintf(`{"error": {"code": -32603, "message": %q}}`, err.Error())
|
||||
_, _ = conn.Write([]byte(response))
|
||||
return
|
||||
}
|
||||
|
||||
response := fmt.Sprintf(`{"result": {"success": true, "message": %q}}`, message)
|
||||
_, _ = conn.Write([]byte(response))
|
||||
}
|
||||
|
||||
func (s *mockAgentSocketServer) setHandler(method string, handler func(string) (string, error)) {
|
||||
s.handlers[method] = handler
|
||||
}
|
||||
|
||||
func (s *mockAgentSocketServer) close() {
|
||||
_ = s.listener.Close()
|
||||
}
|
||||
|
||||
func TestSyncStartTimeout(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create a unique temporary socket file
|
||||
socketPath := fmt.Sprintf("/tmp/coder-test-%d.sock", time.Now().UnixNano())
|
||||
// Remove existing socket if it exists
|
||||
_ = os.Remove(socketPath)
|
||||
defer func() { _ = os.Remove(socketPath) }()
|
||||
|
||||
// Start mock server
|
||||
server := newMockAgentSocketServer(t, socketPath)
|
||||
defer server.close()
|
||||
|
||||
// Test with a short timeout
|
||||
inv, _ := clitest.New(t, "exp", "sync", "start", "test-unit", "--timeout", "100ms")
|
||||
|
||||
// Override the socket path for this test
|
||||
inv.Args = append(inv.Args, "--agent-socket", socketPath)
|
||||
|
||||
start := time.Now()
|
||||
err := inv.Run()
|
||||
duration := time.Since(start)
|
||||
|
||||
// Should timeout after approximately 100ms
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "timeout waiting for dependencies of unit 'test-unit'")
|
||||
|
||||
// Should timeout within a reasonable range (100ms + some buffer for test execution)
|
||||
assert.True(t, duration >= 100*time.Millisecond, "Duration should be at least 100ms, got %v", duration)
|
||||
assert.True(t, duration < 2*time.Second, "Duration should be less than 2s, got %v", duration)
|
||||
}
|
||||
|
||||
func TestSyncWaitTimeout(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create a unique temporary socket file
|
||||
socketPath := fmt.Sprintf("/tmp/coder-test-%d.sock", time.Now().UnixNano())
|
||||
// Remove existing socket if it exists
|
||||
_ = os.Remove(socketPath)
|
||||
defer func() { _ = os.Remove(socketPath) }()
|
||||
|
||||
// Start mock server
|
||||
server := newMockAgentSocketServer(t, socketPath)
|
||||
defer server.close()
|
||||
|
||||
// Test with a short timeout
|
||||
inv, _ := clitest.New(t, "exp", "sync", "wait", "test-unit", "--timeout", "100ms")
|
||||
|
||||
// Override the socket path for this test
|
||||
inv.Args = append(inv.Args, "--agent-socket", socketPath)
|
||||
|
||||
start := time.Now()
|
||||
err := inv.Run()
|
||||
duration := time.Since(start)
|
||||
|
||||
// Should timeout after approximately 100ms
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "timeout waiting for dependencies of unit 'test-unit'")
|
||||
|
||||
// Should timeout within a reasonable range (100ms + some buffer for test execution)
|
||||
assert.True(t, duration >= 100*time.Millisecond, "Duration should be at least 100ms, got %v", duration)
|
||||
assert.True(t, duration < 2*time.Second, "Duration should be less than 2s, got %v", duration)
|
||||
}
|
||||
|
||||
func TestSyncStartNoTimeout(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create a unique temporary socket file
|
||||
socketPath := fmt.Sprintf("/tmp/coder-test-%d.sock", time.Now().UnixNano())
|
||||
// Remove existing socket if it exists
|
||||
_ = os.Remove(socketPath)
|
||||
defer func() { _ = os.Remove(socketPath) }()
|
||||
|
||||
// Start mock server
|
||||
server := newMockAgentSocketServer(t, socketPath)
|
||||
defer server.close()
|
||||
|
||||
// Set up handler that will eventually succeed
|
||||
callCount := 0
|
||||
server.setHandler("sync.wait", func(unitName string) (string, error) {
|
||||
callCount++
|
||||
if callCount >= 3 {
|
||||
// After 3 calls, dependencies are satisfied
|
||||
return "Dependencies satisfied", nil
|
||||
}
|
||||
return "", unit.ErrDependenciesNotSatisfied
|
||||
})
|
||||
|
||||
// Test without timeout - should eventually succeed
|
||||
inv, _ := clitest.New(t, "exp", "sync", "start", "test-unit")
|
||||
|
||||
// Override the socket path for this test
|
||||
inv.Args = append(inv.Args, "--agent-socket", socketPath)
|
||||
|
||||
start := time.Now()
|
||||
err := inv.Run()
|
||||
duration := time.Since(start)
|
||||
|
||||
// Should succeed after a few polling cycles
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Should take at least 2 seconds (2 polling cycles at 1s interval)
|
||||
assert.True(t, duration >= 2*time.Second, "Duration should be at least 2s, got %v", duration)
|
||||
assert.True(t, callCount >= 3, "Should have made at least 3 calls, got %d", callCount)
|
||||
}
|
||||
|
||||
func TestSyncWaitNoTimeout(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create a unique temporary socket file
|
||||
socketPath := fmt.Sprintf("/tmp/coder-test-%d.sock", time.Now().UnixNano())
|
||||
// Remove existing socket if it exists
|
||||
_ = os.Remove(socketPath)
|
||||
defer func() { _ = os.Remove(socketPath) }()
|
||||
|
||||
// Start mock server
|
||||
server := newMockAgentSocketServer(t, socketPath)
|
||||
defer server.close()
|
||||
|
||||
// Set up handler that will eventually succeed
|
||||
callCount := 0
|
||||
server.setHandler("sync.wait", func(unitName string) (string, error) {
|
||||
callCount++
|
||||
if callCount >= 3 {
|
||||
// After 3 calls, dependencies are satisfied
|
||||
return "Dependencies satisfied", nil
|
||||
}
|
||||
return "", unit.ErrDependenciesNotSatisfied
|
||||
})
|
||||
|
||||
// Test without timeout - should eventually succeed
|
||||
inv, _ := clitest.New(t, "exp", "sync", "wait", "test-unit")
|
||||
|
||||
// Override the socket path for this test
|
||||
inv.Args = append(inv.Args, "--agent-socket", socketPath)
|
||||
|
||||
start := time.Now()
|
||||
err := inv.Run()
|
||||
duration := time.Since(start)
|
||||
|
||||
// Should succeed after a few polling cycles
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Should take at least 2 seconds (2 polling cycles at 1s interval)
|
||||
assert.True(t, duration >= 2*time.Second, "Duration should be at least 2s, got %v", duration)
|
||||
assert.True(t, callCount >= 3, "Should have made at least 3 calls, got %d", callCount)
|
||||
}
|
||||
|
||||
func TestSyncStartTimeoutWithDifferentValues(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
timeout string
|
||||
expected time.Duration
|
||||
}{
|
||||
{"50ms", "50ms", 50 * time.Millisecond},
|
||||
{"200ms", "200ms", 200 * time.Millisecond},
|
||||
{"500ms", "500ms", 500 * time.Millisecond},
|
||||
{"1s", "1s", 1 * time.Second},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create a unique temporary socket file
|
||||
socketPath := fmt.Sprintf("/tmp/coder-test-%d.sock", time.Now().UnixNano())
|
||||
// Remove existing socket if it exists
|
||||
_ = os.Remove(socketPath)
|
||||
defer func() { _ = os.Remove(socketPath) }()
|
||||
|
||||
// Start mock server
|
||||
server := newMockAgentSocketServer(t, socketPath)
|
||||
defer server.close()
|
||||
|
||||
// Test with specified timeout
|
||||
inv, _ := clitest.New(t, "exp", "sync", "start", "test-unit", "--timeout", tc.timeout)
|
||||
|
||||
// Override the socket path for this test
|
||||
inv.Args = append(inv.Args, "--agent-socket", socketPath)
|
||||
|
||||
start := time.Now()
|
||||
err := inv.Run()
|
||||
duration := time.Since(start)
|
||||
|
||||
// Should timeout after approximately the specified duration
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "timeout waiting for dependencies of unit 'test-unit'")
|
||||
|
||||
// Should timeout within a reasonable range
|
||||
assert.True(t, duration >= tc.expected, "Duration should be at least %v, got %v", tc.expected, duration)
|
||||
assert.True(t, duration < tc.expected+2*time.Second, "Duration should be less than %v, got %v", tc.expected+2*time.Second, duration)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSyncWaitTimeoutWithDifferentValues(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
timeout string
|
||||
expected time.Duration
|
||||
}{
|
||||
{"50ms", "50ms", 50 * time.Millisecond},
|
||||
{"200ms", "200ms", 200 * time.Millisecond},
|
||||
{"500ms", "500ms", 500 * time.Millisecond},
|
||||
{"1s", "1s", 1 * time.Second},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create a unique temporary socket file
|
||||
socketPath := fmt.Sprintf("/tmp/coder-test-%d.sock", time.Now().UnixNano())
|
||||
// Remove existing socket if it exists
|
||||
_ = os.Remove(socketPath)
|
||||
defer func() { _ = os.Remove(socketPath) }()
|
||||
|
||||
// Start mock server
|
||||
server := newMockAgentSocketServer(t, socketPath)
|
||||
defer server.close()
|
||||
|
||||
// Test with specified timeout
|
||||
inv, _ := clitest.New(t, "exp", "sync", "wait", "test-unit", "--timeout", tc.timeout)
|
||||
|
||||
// Override the socket path for this test
|
||||
inv.Args = append(inv.Args, "--agent-socket", socketPath)
|
||||
|
||||
start := time.Now()
|
||||
err := inv.Run()
|
||||
duration := time.Since(start)
|
||||
|
||||
// Should timeout after approximately the specified duration
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "timeout waiting for dependencies of unit 'test-unit'")
|
||||
|
||||
// Should timeout within a reasonable range
|
||||
assert.True(t, duration >= tc.expected, "Duration should be at least %v, got %v", tc.expected, duration)
|
||||
assert.True(t, duration < tc.expected+2*time.Second, "Duration should be less than %v, got %v", tc.expected+2*time.Second, duration)
|
||||
})
|
||||
}
|
||||
}
|
||||
95
cli/sync_wait.go
Normal file
95
cli/sync_wait.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/serpent"
|
||||
|
||||
"github.com/coder/coder/v2/agent/unit"
|
||||
"github.com/coder/coder/v2/codersdk/agentsdk"
|
||||
)
|
||||
|
||||
const (
|
||||
// PollInterval is the interval between dependency checks
|
||||
PollInterval = 1 * time.Second
|
||||
)
|
||||
|
||||
func (r *RootCmd) syncWait() *serpent.Command {
|
||||
var timeout time.Duration
|
||||
|
||||
cmd := &serpent.Command{
|
||||
Use: "wait <unit>",
|
||||
Short: "Wait for a unit's dependencies to be satisfied",
|
||||
Long: "Poll until all dependencies for a unit are met. Exits when dependencies are satisfied or timeout is reached.",
|
||||
Handler: func(i *serpent.Invocation) error {
|
||||
ctx := context.Background()
|
||||
|
||||
if len(i.Args) != 1 {
|
||||
return xerrors.New("exactly one unit name is required")
|
||||
}
|
||||
unitName := i.Args[0]
|
||||
|
||||
// Set up context with timeout if specified
|
||||
if timeout > 0 {
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
// Show initial message
|
||||
fmt.Printf("Waiting for dependencies of unit '%s' to be satisfied...\n", unitName)
|
||||
|
||||
// Connect to agent socket
|
||||
client, err := agentsdk.NewSocketClient(agentsdk.SocketConfig{
|
||||
Path: "/tmp/coder.sock",
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("connect to agent socket: %w", err)
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
// Poll until dependencies are satisfied
|
||||
ticker := time.NewTicker(PollInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if ctx.Err() == context.DeadlineExceeded {
|
||||
return xerrors.Errorf("timeout waiting for dependencies of unit '%s'", unitName)
|
||||
}
|
||||
return ctx.Err()
|
||||
case <-ticker.C:
|
||||
// Check if dependencies are satisfied
|
||||
err := client.SyncReady(ctx, unitName)
|
||||
if err == nil {
|
||||
// Dependencies are satisfied
|
||||
fmt.Printf("Dependencies for unit '%s' are now satisfied\n", unitName)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if it's a "not ready" error (expected while waiting)
|
||||
if xerrors.Is(err, unit.ErrDependenciesNotSatisfied) {
|
||||
// Still waiting, continue polling
|
||||
continue
|
||||
}
|
||||
|
||||
// Some other error occurred
|
||||
return xerrors.Errorf("error checking dependencies: %w", err)
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Options = append(cmd.Options, serpent.Option{
|
||||
Flag: "timeout",
|
||||
Description: "Maximum time to wait for dependencies (e.g., 30s, 5m). No timeout by default.",
|
||||
Value: serpent.DurationOf(&timeout),
|
||||
})
|
||||
|
||||
return cmd
|
||||
}
|
||||
51
cli/sync_want.go
Normal file
51
cli/sync_want.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/serpent"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk/agentsdk"
|
||||
)
|
||||
|
||||
func (r *RootCmd) syncWant() *serpent.Command {
|
||||
return &serpent.Command{
|
||||
Use: "want <unit> <depends-on>",
|
||||
Short: "Declare a dependency between units",
|
||||
Long: "Declare that a unit depends on another unit reaching complete status.",
|
||||
Handler: func(i *serpent.Invocation) error {
|
||||
ctx := context.Background()
|
||||
|
||||
if len(i.Args) != 2 {
|
||||
return xerrors.New("exactly two arguments are required: unit and depends-on")
|
||||
}
|
||||
unit := i.Args[0]
|
||||
dependsOn := i.Args[1]
|
||||
|
||||
// Show initial message
|
||||
fmt.Printf("Declaring dependency: '%s' depends on '%s'...\n", unit, dependsOn)
|
||||
|
||||
// Connect to agent socket
|
||||
client, err := agentsdk.NewSocketClient(agentsdk.SocketConfig{
|
||||
Path: "/tmp/coder.sock",
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("connect to agent socket: %w", err)
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
// Declare the dependency
|
||||
if err := client.SyncWant(ctx, unit, dependsOn); err != nil {
|
||||
return xerrors.Errorf("declare dependency failed: %w", err)
|
||||
}
|
||||
|
||||
// Display success message
|
||||
fmt.Printf("Dependency declared: '%s' now depends on '%s'\n", unit, dependsOn)
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
8
cli/sync_want_test.go
Normal file
8
cli/sync_want_test.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package cli_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSyncWant(t *testing.T) {
|
||||
}
|
||||
3
cli/testdata/coder_agent_--help.golden
vendored
3
cli/testdata/coder_agent_--help.golden
vendored
@@ -67,6 +67,9 @@ OPTIONS:
|
||||
--script-data-dir string, $CODER_AGENT_SCRIPT_DATA_DIR (default: /tmp)
|
||||
Specify the location for storing script data.
|
||||
|
||||
--socket-path string, $CODER_AGENT_SOCKET_PATH
|
||||
Specify the path for the agent socket.
|
||||
|
||||
--ssh-max-timeout duration, $CODER_AGENT_SSH_MAX_TIMEOUT (default: 72h)
|
||||
Specify the max timeout for a SSH connection, it is advisable to set
|
||||
it to a minimum of 60s, but no more than 72h.
|
||||
|
||||
256
codersdk/agentsdk/socket_client.go
Normal file
256
codersdk/agentsdk/socket_client.go
Normal file
@@ -0,0 +1,256 @@
|
||||
package agentsdk
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/hashicorp/yamux"
|
||||
"storj.io/drpc"
|
||||
|
||||
"github.com/coder/coder/v2/agent/agentsocket/proto"
|
||||
"github.com/coder/coder/v2/agent/unit"
|
||||
"github.com/coder/coder/v2/codersdk/drpcsdk"
|
||||
)
|
||||
|
||||
// SocketClient provides a client for communicating with the agent socket
|
||||
type SocketClient struct {
|
||||
client proto.DRPCAgentSocketClient
|
||||
conn drpc.Conn
|
||||
}
|
||||
|
||||
// SocketConfig holds configuration for the socket client
|
||||
type SocketConfig struct {
|
||||
Path string // Socket path (optional, will auto-discover if not set)
|
||||
}
|
||||
|
||||
// NewSocketClient creates a new socket client
|
||||
func NewSocketClient(config SocketConfig) (*SocketClient, error) {
|
||||
path := config.Path
|
||||
if path == "" {
|
||||
var err error
|
||||
path, err = discoverSocketPath()
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("discover socket path: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
conn, err := net.Dial("unix", path)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("connect to socket: %w", err)
|
||||
}
|
||||
|
||||
// Create yamux session for multiplexing
|
||||
configYamux := yamux.DefaultConfig()
|
||||
configYamux.Logger = nil // Disable yamux logging
|
||||
session, err := yamux.Client(conn, configYamux)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return nil, xerrors.Errorf("create yamux client: %w", err)
|
||||
}
|
||||
|
||||
// Create drpc connection using the multiplexed connection
|
||||
drpcConn := drpcsdk.MultiplexedConn(session)
|
||||
|
||||
// Create drpc client
|
||||
client := proto.NewDRPCAgentSocketClient(drpcConn)
|
||||
|
||||
return &SocketClient{
|
||||
client: client,
|
||||
conn: drpcConn,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Close closes the socket connection
|
||||
func (c *SocketClient) Close() error {
|
||||
return c.conn.Close()
|
||||
}
|
||||
|
||||
// Ping sends a ping request to the agent
|
||||
func (c *SocketClient) Ping(ctx context.Context) (*PingResponse, error) {
|
||||
resp, err := c.client.Ping(ctx, &proto.PingRequest{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &PingResponse{
|
||||
Message: resp.Message,
|
||||
Timestamp: resp.Timestamp.AsTime(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SyncStart starts a unit in the dependency graph
|
||||
func (c *SocketClient) SyncStart(ctx context.Context, unit string) error {
|
||||
resp, err := c.client.SyncStart(ctx, &proto.SyncStartRequest{
|
||||
Unit: unit,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !resp.Success {
|
||||
return xerrors.Errorf("sync start failed: %s", resp.Message)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SyncWant declares a dependency between units
|
||||
func (c *SocketClient) SyncWant(ctx context.Context, unit, dependsOn string) error {
|
||||
resp, err := c.client.SyncWant(ctx, &proto.SyncWantRequest{
|
||||
Unit: unit,
|
||||
DependsOn: dependsOn,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !resp.Success {
|
||||
return xerrors.Errorf("sync want failed: %s", resp.Message)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SyncComplete marks a unit as complete in the dependency graph
|
||||
func (c *SocketClient) SyncComplete(ctx context.Context, unit string) error {
|
||||
resp, err := c.client.SyncComplete(ctx, &proto.SyncCompleteRequest{
|
||||
Unit: unit,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !resp.Success {
|
||||
return xerrors.Errorf("sync complete failed: %s", resp.Message)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SyncReady requests whether a unit is ready to be started. That is, all dependencies are satisfied.
|
||||
func (c *SocketClient) SyncReady(ctx context.Context, unitName string) error {
|
||||
resp, err := c.client.SyncReady(ctx, &proto.SyncReadyRequest{
|
||||
Unit: unitName,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !resp.Success {
|
||||
// Check if this is a dependencies not satisfied error
|
||||
if resp.Message == unit.ErrDependenciesNotSatisfied.Error() {
|
||||
return unit.ErrDependenciesNotSatisfied
|
||||
}
|
||||
return xerrors.Errorf("sync ready failed: %s", resp.Message)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SyncStatus gets the status of a unit and its dependencies
|
||||
func (c *SocketClient) SyncStatus(ctx context.Context, unit string, recursive bool) (*SyncStatusResponse, error) {
|
||||
resp, err := c.client.SyncStatus(ctx, &proto.SyncStatusRequest{
|
||||
Unit: unit,
|
||||
Recursive: recursive,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !resp.Success {
|
||||
return nil, xerrors.Errorf("sync status failed: %s", resp.Message)
|
||||
}
|
||||
|
||||
// Convert dependencies
|
||||
var dependencies []DependencyInfo
|
||||
for _, dep := range resp.Dependencies {
|
||||
dependencies = append(dependencies, DependencyInfo{
|
||||
DependsOn: dep.DependsOn,
|
||||
RequiredStatus: dep.RequiredStatus,
|
||||
CurrentStatus: dep.CurrentStatus,
|
||||
IsSatisfied: dep.IsSatisfied,
|
||||
})
|
||||
}
|
||||
|
||||
return &SyncStatusResponse{
|
||||
Success: resp.Success,
|
||||
Message: resp.Message,
|
||||
Unit: resp.Unit,
|
||||
Status: resp.Status,
|
||||
IsReady: resp.IsReady,
|
||||
Dependencies: dependencies,
|
||||
DOT: resp.Dot,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// discoverSocketPath discovers the agent socket path
|
||||
func discoverSocketPath() (string, error) {
|
||||
// Check environment variable first
|
||||
if path := os.Getenv("CODER_AGENT_SOCKET_PATH"); path != "" {
|
||||
return path, nil
|
||||
}
|
||||
|
||||
// Try common socket paths
|
||||
paths := []string{
|
||||
// XDG runtime directory
|
||||
filepath.Join(os.Getenv("XDG_RUNTIME_DIR"), "coder-agent.sock"),
|
||||
// User-specific temp directory
|
||||
filepath.Join(os.TempDir(), fmt.Sprintf("coder-agent-%d.sock", os.Getuid())),
|
||||
// Fallback temp directory
|
||||
filepath.Join(os.TempDir(), "coder-agent.sock"),
|
||||
}
|
||||
|
||||
for _, path := range paths {
|
||||
if path == "" {
|
||||
continue
|
||||
}
|
||||
if _, err := os.Stat(path); err == nil {
|
||||
return path, nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", xerrors.New("agent socket not found")
|
||||
}
|
||||
|
||||
// Response types for backward compatibility
|
||||
type PingResponse struct {
|
||||
Message string `json:"message"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
}
|
||||
|
||||
type HealthResponse struct {
|
||||
Status string `json:"status"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Uptime string `json:"uptime"`
|
||||
}
|
||||
|
||||
type AgentInfo struct {
|
||||
ID string `json:"id"`
|
||||
Version string `json:"version"`
|
||||
Status string `json:"status"`
|
||||
StartedAt time.Time `json:"started_at"`
|
||||
Uptime string `json:"uptime"`
|
||||
}
|
||||
|
||||
type SyncStatusResponse struct {
|
||||
Success bool `json:"success"`
|
||||
Message string `json:"message"`
|
||||
Unit string `json:"unit"`
|
||||
Status string `json:"status"`
|
||||
IsReady bool `json:"is_ready"`
|
||||
Dependencies []DependencyInfo `json:"dependencies"`
|
||||
DOT string `json:"dot,omitempty"`
|
||||
}
|
||||
|
||||
type DependencyInfo struct {
|
||||
DependsOn string `json:"depends_on"`
|
||||
RequiredStatus string `json:"required_status"`
|
||||
CurrentStatus string `json:"current_status"`
|
||||
IsSatisfied bool `json:"is_satisfied"`
|
||||
}
|
||||
Reference in New Issue
Block a user