Compare commits

...

3 Commits

Author SHA1 Message Date
Sas Swart 34c1370090 fix agent socket tests 2025-10-28 06:30:29 +00:00
Sas Swart 851c4f907c add a socket to the agent for local IPC 2025-10-28 06:26:49 +00:00
Sas Swart e3dfe45f35 LLM generated implementation of unit status change communication 2025-10-27 11:10:22 +00:00
13 changed files with 2506 additions and 0 deletions
+74
View File
@@ -40,6 +40,7 @@ import (
"github.com/coder/coder/v2/agent/agentcontainers"
"github.com/coder/coder/v2/agent/agentexec"
"github.com/coder/coder/v2/agent/agentscripts"
"github.com/coder/coder/v2/agent/agentsocket"
"github.com/coder/coder/v2/agent/agentssh"
"github.com/coder/coder/v2/agent/proto"
"github.com/coder/coder/v2/agent/proto/resourcesmonitor"
@@ -91,6 +92,7 @@ type Options struct {
Devcontainers bool
DevcontainerAPIOptions []agentcontainers.Option // Enable Devcontainers for these to be effective.
Clock quartz.Clock
SocketPath string // Path for the agent socket server
}
type Client interface {
@@ -190,6 +192,7 @@ func New(options Options) Agent {
devcontainers: options.Devcontainers,
containerAPIOptions: options.DevcontainerAPIOptions,
socketPath: options.SocketPath,
}
// Initially, we have a closed channel, reflecting the fact that we are not initially connected.
// Each time we connect we replace the channel (while holding the closeMutex) with a new one
@@ -271,6 +274,10 @@ type agent struct {
devcontainers bool
containerAPIOptions []agentcontainers.Option
containerAPI *agentcontainers.API
// Socket server for CLI communication
socketPath string
socketServer *agentsocket.Server
}
func (a *agent) TailnetConn() *tailnet.Conn {
@@ -350,9 +357,69 @@ func (a *agent) init() {
s.ExperimentalContainers = a.devcontainers
},
)
// Initialize socket server for CLI communication
a.initSocketServer()
go a.runLoop()
}
// initSocketServer initializes the socket server for CLI communication
func (a *agent) initSocketServer() {
// Get socket path from options or environment
socketPath := a.getSocketPath()
if socketPath == "" {
a.logger.Debug(a.hardCtx, "socket server disabled (no path configured)")
return
}
// Create socket server
server := agentsocket.NewServer(agentsocket.Config{
Path: socketPath,
Logger: a.logger.Named("socket"),
})
// Register default handlers
handlerCtx := agentsocket.CreateHandlerContext(
"", // Agent ID will be set when manifest is available
buildinfo.Version(),
"starting",
time.Now(),
a.logger,
)
agentsocket.RegisterDefaultHandlers(server, handlerCtx)
// Start the server
if err := server.Start(); err != nil {
a.logger.Warn(a.hardCtx, "failed to start socket server", slog.Error(err))
return
}
a.socketServer = server
a.logger.Info(a.hardCtx, "socket server started", slog.F("path", socketPath))
}
// getSocketPath returns the socket path from options or environment
func (a *agent) getSocketPath() string {
// Check if socket path is explicitly configured
if a.getSocketPathFromOptions() != "" {
return a.getSocketPathFromOptions()
}
// Check environment variable
if path := os.Getenv("CODER_AGENT_SOCKET_PATH"); path != "" {
return path
}
// Return empty to disable socket server
return ""
}
// getSocketPathFromOptions returns the socket path from agent options
func (a *agent) getSocketPathFromOptions() string {
return a.socketPath
}
// runLoop attempts to start the agent in a retry loop.
// Coder may be offline temporarily, a connection issue
// may be happening, but regardless after the intermittent
@@ -1931,6 +1998,13 @@ func (a *agent) Close() error {
a.logger.Error(a.hardCtx, "container API close", slog.Error(err))
}
// Close socket server
if a.socketServer != nil {
if err := a.socketServer.Stop(); err != nil {
a.logger.Error(a.hardCtx, "socket server close", slog.Error(err))
}
}
// Wait for the graceful shutdown to complete, but don't wait forever so
// that we don't break user expectations.
go func() {
+214
View File
@@ -0,0 +1,214 @@
# Agent Socket API
The Agent Socket API provides a local communication channel between CLI commands running within a workspace and the Coder agent process. This enables new CLI commands to interact directly with the agent without going through the control plane.
## Overview
The socket server runs within the agent process and listens on a Unix domain socket (or named pipe on Windows). CLI commands can connect to this socket to query agent information, check health status, and perform other operations.
## Architecture
### Socket Server
- **Location**: `agent/agentsocket/`
- **Protocol**: JSON-RPC 2.0 over Unix domain socket
- **Platform Support**: Linux, macOS, Windows 10+ (build 17063+)
- **Authentication**: Pluggable middleware (no-auth by default)
### Client Library
- **Location**: `codersdk/agentsdk/socket_client.go`
- **Auto-discovery**: Automatically finds socket path
- **Type-safe**: Go client with proper error handling
## Socket Path Discovery
The socket path is determined in the following order:
1. **Environment Variable**: `CODER_AGENT_SOCKET_PATH`
2. **XDG Runtime Directory**: `$XDG_RUNTIME_DIR/coder-agent.sock`
3. **User Temp Directory**: `/tmp/coder-agent-{uid}.sock`
4. **Fallback**: `/tmp/coder-agent.sock`
## Protocol
### Request Format
```json
{
"version": "1.0",
"method": "ping",
"id": "request-123",
"params": {}
}
```
### Response Format
```json
{
"version": "1.0",
"id": "request-123",
"result": {
"message": "pong",
"timestamp": "2024-01-01T00:00:00Z"
}
}
```
### Error Format
```json
{
"version": "1.0",
"id": "request-123",
"error": {
"code": -32601,
"message": "Method not found",
"data": "nonexistent"
}
}
```
## Available Methods
### Core Methods
- `ping` - Health check with timestamp
- `health` - Agent status and uptime
- `agent.info` - Detailed agent information
- `methods.list` - List available methods
### Example Usage
```go
// Create client
client, err := agentsdk.NewSocketClient(agentsdk.SocketConfig{})
if err != nil {
log.Fatal(err)
}
defer client.Close()
// Ping the agent
pingResp, err := client.Ping(ctx)
if err != nil {
log.Fatal(err)
}
fmt.Printf("Agent responded: %s\n", pingResp.Message)
// Get agent info
info, err := client.AgentInfo(ctx)
if err != nil {
log.Fatal(err)
}
fmt.Printf("Agent ID: %s, Version: %s\n", info.ID, info.Version)
```
## Adding New Handlers
### Server Side
```go
// Register a new handler
server.RegisterHandler("custom.method", func(ctx Context, req *Request) (*Response, error) {
// Handle the request
result := map[string]string{"status": "ok"}
return NewResponse(req.ID, result)
})
```
### Client Side
```go
// Add method to client
func (c *SocketClient) CustomMethod(ctx context.Context) (*CustomResponse, error) {
req := &Request{
Version: "1.0",
Method: "custom.method",
ID: generateRequestID(),
}
resp, err := c.sendRequest(ctx, req)
if err != nil {
return nil, err
}
if resp.Error != nil {
return nil, fmt.Errorf("custom method error: %s", resp.Error.Message)
}
var result CustomResponse
if err := json.Unmarshal(resp.Result, &result); err != nil {
return nil, fmt.Errorf("unmarshal response: %w", err)
}
return &result, nil
}
```
## Authentication
The socket server supports pluggable authentication middleware. By default, no authentication is performed (suitable for local-only communication).
### Custom Authentication
```go
type CustomAuthMiddleware struct {
// Add auth fields
}
func (m *CustomAuthMiddleware) Authenticate(ctx context.Context, conn net.Conn) (context.Context, error) {
// Implement authentication logic
// Return context with auth info or error
return ctx, nil
}
// Use in server config
server := agentsocket.NewServer(agentsocket.Config{
Path: socketPath,
Logger: logger,
AuthMiddleware: &CustomAuthMiddleware{},
})
```
## Configuration
### Agent Options
```go
options := agent.Options{
// ... other options
SocketPath: "/custom/path/agent.sock", // Optional, uses auto-discovery if empty
}
```
### Environment Variables
- `CODER_AGENT_SOCKET_PATH` - Override socket path
- `XDG_RUNTIME_DIR` - Used for socket path discovery
## Error Codes
| Code | Description |
|------|-------------|
| -32700 | Parse error |
| -32600 | Invalid request |
| -32601 | Method not found |
| -32602 | Invalid params |
| -32603 | Internal error |
## Platform Support
### Unix-like Systems (Linux, macOS)
- Uses Unix domain sockets
- Socket file permissions: 600 (owner read/write only)
- Auto-cleanup on shutdown
### Windows
- Uses Unix domain sockets (Windows 10 build 17063+)
- Falls back to named pipes if needed
- Simplified permission handling
## Security Considerations
1. **Local Only**: Socket is only accessible from within the workspace
2. **File Permissions**: Socket file is restricted to owner only
3. **No Network Access**: Unix domain sockets don't traverse network
4. **Authentication Ready**: Middleware pattern allows future auth implementation
## Future Extensibility
The design supports:
- **Protocol Versioning**: Request includes version field
- **Multiple Transports**: Interface-based design allows TCP/WebSocket later
- **Auth Plugins**: Middleware pattern for various auth methods
- **Custom Handlers**: Simple registration pattern for new commands
+23
View File
@@ -0,0 +1,23 @@
package agentsocket
import (
"context"
"net"
)
// AuthMiddleware defines the interface for authentication middleware
type AuthMiddleware interface {
// Authenticate authenticates a connection and returns a context with auth info
Authenticate(ctx context.Context, conn net.Conn) (context.Context, error)
}
// NoAuthMiddleware is a no-op authentication middleware
type NoAuthMiddleware struct{}
// Authenticate implements AuthMiddleware but performs no authentication
func (*NoAuthMiddleware) Authenticate(ctx context.Context, conn net.Conn) (context.Context, error) {
return ctx, nil
}
// Ensure NoAuthMiddleware implements AuthMiddleware
var _ AuthMiddleware = (*NoAuthMiddleware)(nil)
+108
View File
@@ -0,0 +1,108 @@
package agentsocket
import (
"time"
"cdr.dev/slog"
)
// AgentInfo represents information about the agent
type AgentInfo struct {
ID string `json:"id"`
Version string `json:"version"`
Status string `json:"status"`
StartedAt time.Time `json:"started_at"`
Uptime string `json:"uptime"`
}
// PingResponse represents a ping response
type PingResponse struct {
Message string `json:"message"`
Timestamp time.Time `json:"timestamp"`
}
// HealthResponse represents a health check response
type HealthResponse struct {
Status string `json:"status"`
Timestamp time.Time `json:"timestamp"`
Uptime string `json:"uptime"`
}
// HandlerContext provides context for handlers
type HandlerContext struct {
AgentID string
Version string
Status string
StartedAt time.Time
Logger slog.Logger
}
// NewHandlers creates the default set of handlers
func NewHandlers(handlerCtx HandlerContext) map[string]Handler {
handlers := make(map[string]Handler)
// Ping handler
handlers["ping"] = func(_ Context, req *Request) (*Response, error) {
resp := PingResponse{
Message: "pong",
Timestamp: time.Now(),
}
return NewResponse(req.ID, resp)
}
// Health check handler
handlers["health"] = func(_ Context, req *Request) (*Response, error) {
uptime := time.Since(handlerCtx.StartedAt)
resp := HealthResponse{
Status: handlerCtx.Status,
Timestamp: time.Now(),
Uptime: uptime.String(),
}
return NewResponse(req.ID, resp)
}
// Agent info handler
handlers["agent.info"] = func(_ Context, req *Request) (*Response, error) {
uptime := time.Since(handlerCtx.StartedAt)
resp := AgentInfo{
ID: handlerCtx.AgentID,
Version: handlerCtx.Version,
Status: handlerCtx.Status,
StartedAt: handlerCtx.StartedAt,
Uptime: uptime.String(),
}
return NewResponse(req.ID, resp)
}
// List methods handler
handlers["methods.list"] = func(_ Context, req *Request) (*Response, error) {
methods := []string{
"ping",
"health",
"agent.info",
"methods.list",
}
return NewResponse(req.ID, methods)
}
return handlers
}
// RegisterDefaultHandlers registers the default set of handlers with a server
func RegisterDefaultHandlers(server *Server, ctx HandlerContext) {
handlers := NewHandlers(ctx)
for method, handler := range handlers {
server.RegisterHandler(method, handler)
}
}
// CreateHandlerContext creates a handler context from agent information
func CreateHandlerContext(agentID, version, status string, startedAt time.Time, logger slog.Logger) HandlerContext {
return HandlerContext{
AgentID: agentID,
Version: version,
Status: status,
StartedAt: startedAt,
Logger: logger,
}
}
+83
View File
@@ -0,0 +1,83 @@
package agentsocket
import (
"encoding/json"
"golang.org/x/xerrors"
)
// Protocol version for the agent socket API
const ProtocolVersion = "1.0"
// Request represents an incoming request to the agent socket
type Request struct {
Version string `json:"version"`
Method string `json:"method"`
ID string `json:"id,omitempty"`
Params json.RawMessage `json:"params,omitempty"`
}
// Response represents a response from the agent socket
type Response struct {
Version string `json:"version"`
ID string `json:"id,omitempty"`
Result json.RawMessage `json:"result,omitempty"`
Error *Error `json:"error,omitempty"`
}
// Error represents an error in the response
type Error struct {
Code int `json:"code"`
Message string `json:"message"`
Data any `json:"data,omitempty"`
}
// Standard error codes
const (
ErrCodeParseError = -32700
ErrCodeInvalidRequest = -32600
ErrCodeMethodNotFound = -32601
ErrCodeInvalidParams = -32602
ErrCodeInternalError = -32603
)
// NewError creates a new error response
func NewError(code int, message string, data any) *Error {
return &Error{
Code: code,
Message: message,
Data: data,
}
}
// NewResponse creates a successful response
func NewResponse(id string, result any) (*Response, error) {
resultBytes, err := json.Marshal(result)
if err != nil {
return nil, xerrors.Errorf("marshal result: %w", err)
}
return &Response{
Version: ProtocolVersion,
ID: id,
Result: resultBytes,
}, nil
}
// NewErrorResponse creates an error response
func NewErrorResponse(id string, err *Error) *Response {
return &Response{
Version: ProtocolVersion,
ID: id,
Error: err,
}
}
// Handler represents a function that can handle a request
type Handler func(ctx Context, req *Request) (*Response, error)
// Context provides context for request handling
type Context struct {
// Additional context can be added here in the future
// For now, this is a placeholder for future auth context, etc.
}
+266
View File
@@ -0,0 +1,266 @@
package agentsocket
import (
"context"
"encoding/json"
"io"
"net"
"sync"
"time"
"golang.org/x/xerrors"
"cdr.dev/slog"
)
// Server represents the agent socket server
type Server struct {
logger slog.Logger
path string
listener net.Listener
handlers map[string]Handler
authMiddleware AuthMiddleware
mu sync.RWMutex
ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
}
// Config holds configuration for the socket server
type Config struct {
Path string
Logger slog.Logger
AuthMiddleware AuthMiddleware
}
// NewServer creates a new agent socket server
func NewServer(config Config) *Server {
ctx, cancel := context.WithCancel(context.Background())
server := &Server{
logger: config.Logger.Named("agentsocket"),
path: config.Path,
handlers: make(map[string]Handler),
authMiddleware: config.AuthMiddleware,
ctx: ctx,
cancel: cancel,
}
// Set default auth middleware if none provided
if server.authMiddleware == nil {
server.authMiddleware = &NoAuthMiddleware{}
}
return server
}
// Start starts the socket server
func (s *Server) Start() error {
s.mu.Lock()
defer s.mu.Unlock()
if s.listener != nil {
return xerrors.New("server already started")
}
// Get socket path
path := s.path
if path == "" {
var err error
path, err = getDefaultSocketPath()
if err != nil {
return xerrors.Errorf("get default socket path: %w", err)
}
}
// Check if socket is available
if !isSocketAvailable(path) {
return xerrors.Errorf("socket path %s is not available", path)
}
// Create socket listener
listener, err := createSocket(s.ctx, path)
if err != nil {
return xerrors.Errorf("create socket: %w", err)
}
s.listener = listener
s.path = path
s.logger.Info(s.ctx, "agent socket server started", slog.F("path", path))
// Start accepting connections
s.wg.Add(1)
go s.acceptConnections()
return nil
}
// Stop stops the socket server
func (s *Server) Stop() error {
s.mu.Lock()
defer s.mu.Unlock()
if s.listener == nil {
return nil
}
s.logger.Info(s.ctx, "stopping agent socket server")
// Cancel context to stop accepting new connections
s.cancel()
// Close listener
if err := s.listener.Close(); err != nil {
s.logger.Warn(s.ctx, "error closing socket listener", slog.Error(err))
}
// Wait for all connections to finish
s.wg.Wait()
// Clean up socket file
if err := cleanupSocket(s.path); err != nil {
s.logger.Warn(s.ctx, "error cleaning up socket file", slog.Error(err))
}
s.listener = nil
s.logger.Info(s.ctx, "agent socket server stopped")
return nil
}
// RegisterHandler registers a handler for a method
func (s *Server) RegisterHandler(method string, handler Handler) {
s.mu.Lock()
defer s.mu.Unlock()
s.handlers[method] = handler
}
// GetPath returns the socket path
func (s *Server) GetPath() string {
s.mu.RLock()
defer s.mu.RUnlock()
return s.path
}
// acceptConnections accepts incoming connections
func (s *Server) acceptConnections() {
defer s.wg.Done()
for {
select {
case <-s.ctx.Done():
return
default:
}
conn, err := s.listener.Accept()
if err != nil {
select {
case <-s.ctx.Done():
return
default:
s.logger.Warn(s.ctx, "error accepting connection", slog.Error(err))
continue
}
}
// Handle connection in a goroutine
s.wg.Add(1)
go func() {
defer s.wg.Done()
s.handleConnection(conn)
}()
}
}
// handleConnection handles a single connection
func (s *Server) handleConnection(conn net.Conn) {
defer conn.Close()
// Authenticate connection first to get context
ctx, err := s.authMiddleware.Authenticate(s.ctx, conn)
if err != nil {
s.logger.Warn(s.ctx, "authentication failed", slog.Error(err))
return
}
// Set connection deadline
if err := conn.SetDeadline(time.Now().Add(30 * time.Second)); err != nil {
s.logger.Warn(ctx, "failed to set connection deadline", slog.Error(err))
}
s.logger.Debug(ctx, "new connection accepted", slog.F("remote_addr", conn.RemoteAddr()))
// Handle requests
decoder := json.NewDecoder(conn)
encoder := json.NewEncoder(conn)
for {
select {
case <-ctx.Done():
return
default:
}
// Set read deadline
if err := conn.SetReadDeadline(time.Now().Add(30 * time.Second)); err != nil {
s.logger.Warn(ctx, "failed to set read deadline", slog.Error(err))
}
var req Request
if err := decoder.Decode(&req); err != nil {
if err == io.EOF {
s.logger.Debug(ctx, "connection closed by client")
return
}
s.logger.Warn(ctx, "error decoding request", slog.Error(err))
// Send error response
resp := NewErrorResponse("", NewError(ErrCodeParseError, "Parse error", err.Error()))
encoder.Encode(resp)
return
}
// Handle request
resp := s.handleRequest(ctx, &req)
// Send response
if err := encoder.Encode(resp); err != nil {
s.logger.Warn(ctx, "error sending response", slog.Error(err))
return
}
}
}
// handleRequest handles a single request
func (s *Server) handleRequest(ctx context.Context, req *Request) *Response {
// Validate request
if req.Version != ProtocolVersion {
return NewErrorResponse(req.ID, NewError(ErrCodeInvalidRequest, "Unsupported version", req.Version))
}
if req.Method == "" {
return NewErrorResponse(req.ID, NewError(ErrCodeInvalidRequest, "Missing method", nil))
}
// Get handler
s.mu.RLock()
handler, exists := s.handlers[req.Method]
s.mu.RUnlock()
if !exists {
return NewErrorResponse(req.ID, NewError(ErrCodeMethodNotFound, "Method not found", req.Method))
}
// Call handler
type requestIDKey struct{}
ctx = context.WithValue(ctx, requestIDKey{}, req.ID)
resp, err := handler(Context{}, req)
if err != nil {
s.logger.Warn(ctx, "handler execution failed", slog.Error(err))
return NewErrorResponse(req.ID, NewError(ErrCodeInternalError, "Internal error", err.Error()))
}
return resp
}
+250
View File
@@ -0,0 +1,250 @@
package agentsocket
import (
"encoding/json"
"fmt"
"net"
"os"
"path/filepath"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"
"cdr.dev/slog"
)
func TestServer_StartStop(t *testing.T) {
t.Parallel()
// Create temporary socket path
tmpDir := t.TempDir()
socketPath := filepath.Join(tmpDir, "test.sock")
// Create server
server := NewServer(Config{
Path: socketPath,
Logger: slog.Make().Leveled(slog.LevelDebug),
})
// Register a test handler
server.RegisterHandler("test", func(ctx Context, req *Request) (*Response, error) {
return NewResponse(req.ID, map[string]string{"message": "test response"})
})
// Start server
err := server.Start()
require.NoError(t, err)
defer server.Stop()
// Verify socket file exists
_, err = os.Stat(socketPath)
require.NoError(t, err)
// Test connection
conn, err := net.Dial("unix", socketPath)
require.NoError(t, err)
defer conn.Close()
// Send test request
req := Request{
Version: "1.0",
Method: "test",
ID: "test-1",
}
err = json.NewEncoder(conn).Encode(req)
require.NoError(t, err)
// Read response
var resp Response
err = json.NewDecoder(conn).Decode(&resp)
require.NoError(t, err)
assert.Equal(t, "test-1", resp.ID)
assert.Nil(t, resp.Error)
assert.NotNil(t, resp.Result)
// Verify response content
var result map[string]string
err = json.Unmarshal(resp.Result, &result)
require.NoError(t, err)
assert.Equal(t, "test response", result["message"])
}
func TestServer_ErrorHandling(t *testing.T) {
t.Parallel()
// Create temporary socket path
tmpDir := t.TempDir()
socketPath := filepath.Join(tmpDir, "test.sock")
// Create server
server := NewServer(Config{
Path: socketPath,
Logger: slog.Make().Leveled(slog.LevelDebug),
})
// Start server
err := server.Start()
require.NoError(t, err)
defer server.Stop()
// Test connection
conn, err := net.Dial("unix", socketPath)
require.NoError(t, err)
defer conn.Close()
// Send request for non-existent method
req := Request{
Version: "1.0",
Method: "nonexistent",
ID: "test-1",
}
err = json.NewEncoder(conn).Encode(req)
require.NoError(t, err)
// Read response
var resp Response
err = json.NewDecoder(conn).Decode(&resp)
require.NoError(t, err)
assert.Equal(t, "test-1", resp.ID)
assert.NotNil(t, resp.Error)
assert.Equal(t, ErrCodeMethodNotFound, resp.Error.Code)
assert.Equal(t, "Method not found", resp.Error.Message)
}
func TestServer_DefaultHandlers(t *testing.T) {
t.Parallel()
// Create temporary socket path
tmpDir := t.TempDir()
socketPath := filepath.Join(tmpDir, "test.sock")
// Create server
server := NewServer(Config{
Path: socketPath,
Logger: slog.Make().Leveled(slog.LevelDebug),
})
// Register default handlers
handlerCtx := CreateHandlerContext(
"test-agent-id",
"1.0.0",
"ready",
time.Now().Add(-time.Hour),
slog.Make(),
)
RegisterDefaultHandlers(server, handlerCtx)
// Start server
err := server.Start()
require.NoError(t, err)
defer server.Stop()
// Test ping
conn, err := net.Dial("unix", socketPath)
require.NoError(t, err)
defer conn.Close()
req := Request{
Version: "1.0",
Method: "ping",
ID: "ping-1",
}
err = json.NewEncoder(conn).Encode(req)
require.NoError(t, err)
var resp Response
err = json.NewDecoder(conn).Decode(&resp)
require.NoError(t, err)
assert.Equal(t, "ping-1", resp.ID)
assert.Nil(t, resp.Error)
var pingResp PingResponse
err = json.Unmarshal(resp.Result, &pingResp)
require.NoError(t, err)
assert.Equal(t, "pong", pingResp.Message)
}
func TestServer_ConcurrentConnections(t *testing.T) {
t.Parallel()
// Create temporary socket path
tmpDir := t.TempDir()
socketPath := filepath.Join(tmpDir, "test.sock")
// Create server
server := NewServer(Config{
Path: socketPath,
Logger: slog.Make().Leveled(slog.LevelDebug),
})
// Register a test handler
server.RegisterHandler("test", func(ctx Context, req *Request) (*Response, error) {
time.Sleep(10 * time.Millisecond) // Simulate some work
return NewResponse(req.ID, map[string]string{"message": "test response"})
})
// Start server
err := server.Start()
require.NoError(t, err)
defer server.Stop()
// Test multiple concurrent connections
const numConnections = 5
results := make(chan error, numConnections)
for i := 0; i < numConnections; i++ {
go func(i int) {
conn, err := net.Dial("unix", socketPath)
if err != nil {
results <- err
return
}
defer conn.Close()
req := Request{
Version: "1.0",
Method: "test",
ID: fmt.Sprintf("test-%d", i),
}
err = json.NewEncoder(conn).Encode(req)
if err != nil {
results <- err
return
}
var resp Response
err = json.NewDecoder(conn).Decode(&resp)
if err != nil {
results <- err
return
}
if resp.Error != nil {
results <- xerrors.Errorf("server error: %s", resp.Error.Message)
return
}
results <- nil
}(i)
}
// Wait for all connections to complete
for i := 0; i < numConnections; i++ {
select {
case err := <-results:
require.NoError(t, err)
case <-time.After(5 * time.Second):
t.Fatal("timeout waiting for concurrent connections")
}
}
}
+106
View File
@@ -0,0 +1,106 @@
//go:build !windows
package agentsocket
import (
"context"
"fmt"
"net"
"os"
"path/filepath"
"syscall"
"time"
"golang.org/x/xerrors"
)
// createSocket creates a Unix domain socket listener
func createSocket(ctx context.Context, path string) (net.Listener, error) {
// Remove existing socket file if it exists
if err := os.Remove(path); err != nil && !os.IsNotExist(err) {
return nil, xerrors.Errorf("remove existing socket: %w", err)
}
// Create parent directory if it doesn't exist
parentDir := filepath.Dir(path)
if err := os.MkdirAll(parentDir, 0o700); err != nil {
return nil, xerrors.Errorf("create socket directory: %w", err)
}
// Create Unix domain socket listener
listener, err := net.Listen("unix", path)
if err != nil {
return nil, xerrors.Errorf("listen on unix socket: %w", err)
}
// Set socket permissions to be accessible only by the current user
if err := os.Chmod(path, 0o600); err != nil {
listener.Close()
return nil, xerrors.Errorf("set socket permissions: %w", err)
}
return listener, nil
}
// getDefaultSocketPath returns the default socket path for Unix-like systems
func getDefaultSocketPath() (string, error) {
// Try XDG_RUNTIME_DIR first
if runtimeDir := os.Getenv("XDG_RUNTIME_DIR"); runtimeDir != "" {
return filepath.Join(runtimeDir, "coder-agent.sock"), nil
}
// Fall back to /tmp with user-specific path
uid := os.Getuid()
return filepath.Join("/tmp", fmt.Sprintf("coder-agent-%d.sock", uid)), nil
}
// cleanupSocket removes the socket file
func cleanupSocket(path string) error {
return os.Remove(path)
}
// isSocketAvailable checks if a socket path is available for use
func isSocketAvailable(path string) bool {
// Check if file exists
if _, err := os.Stat(path); os.IsNotExist(err) {
return true
}
// Try to connect to see if it's actually listening
conn, err := net.Dial("unix", path)
if err != nil {
// If we can't connect, the socket is not in use
return true
}
conn.Close()
return false
}
// getSocketInfo returns information about the socket file
func getSocketInfo(path string) (*SocketInfo, error) {
stat, err := os.Stat(path)
if err != nil {
return nil, err
}
sys, ok := stat.Sys().(*syscall.Stat_t)
if !ok {
return nil, xerrors.New("unable to get stat_t from file info")
}
return &SocketInfo{
Path: path,
UID: int(sys.Uid),
GID: int(sys.Gid),
Mode: stat.Mode(),
ModTime: stat.ModTime(),
}, nil
}
// SocketInfo contains information about a socket file
type SocketInfo struct {
Path string
UID int
GID int
Mode os.FileMode
ModTime time.Time
}
+99
View File
@@ -0,0 +1,99 @@
//go:build windows
package agentsocket
import (
"context"
"fmt"
"net"
"os"
"path/filepath"
"strconv"
"time"
)
// createSocket creates a Unix domain socket listener on Windows
// Falls back to named pipe if Unix sockets are not supported
func createSocket(ctx context.Context, path string) (net.Listener, error) {
// Try Unix domain socket first (Windows 10 build 17063+)
listener, err := net.Listen("unix", path)
if err == nil {
return listener, nil
}
// Fall back to named pipe
pipePath := `\\.\pipe\coder-agent`
return net.Listen("tcp", pipePath)
}
// getDefaultSocketPath returns the default socket path for Windows
func getDefaultSocketPath() (string, error) {
// Try to use a temporary directory
tempDir := os.TempDir()
if tempDir == "" {
tempDir = "C:\\temp"
}
// Create a user-specific subdirectory
uid := os.Getuid()
userDir := filepath.Join(tempDir, "coder-agent", strconv.Itoa(uid))
if err := os.MkdirAll(userDir, 0o700); err != nil {
return "", fmt.Errorf("create user directory: %w", err)
}
return filepath.Join(userDir, "agent.sock"), nil
}
// cleanupSocket removes the socket file
func cleanupSocket(path string) error {
return os.Remove(path)
}
// isSocketAvailable checks if a socket path is available for use
func isSocketAvailable(path string) bool {
// Check if file exists
if _, err := os.Stat(path); os.IsNotExist(err) {
return true
}
// Try to connect to see if it's actually listening
conn, err := net.Dial("unix", path)
if err != nil {
// If we can't connect, the socket is not in use
return true
}
conn.Close()
return false
}
// getSocketInfo returns information about the socket file
func getSocketInfo(path string) (*SocketInfo, error) {
stat, err := os.Stat(path)
if err != nil {
return nil, err
}
// On Windows, we'll use a simplified approach for now
// In a real implementation, you'd get the security descriptor
return &SocketInfo{
Path: path,
UID: 0, // Simplified for now
GID: 0, // Simplified for now
Mode: stat.Mode(),
ModTime: stat.ModTime(),
Owner: "unknown",
Group: "unknown",
}, nil
}
// SocketInfo contains information about a socket file
type SocketInfo struct {
Path string
UID int
GID int
Mode os.FileMode
ModTime time.Time
Owner string // Windows SID string
Group string // Windows SID string
}
+227
View File
@@ -0,0 +1,227 @@
package unit
import (
"sync"
"golang.org/x/xerrors"
)
// ErrConsumerNotFound is returned when a consumer ID is not registered.
var ErrConsumerNotFound = xerrors.New("consumer not found")
// ErrCannotUpdateOtherConsumer is returned when attempting to update another consumer's status.
var ErrCannotUpdateOtherConsumer = xerrors.New("cannot update other consumer's status")
// dependencyVertex represents a vertex in the dependency graph that is associated with a consumer.
type dependencyVertex[ConsumerID comparable] struct {
ID ConsumerID
}
// Dependency represents a dependency relationship between consumers.
type Dependency[StatusType, ConsumerID comparable] struct {
Consumer ConsumerID
DependsOn ConsumerID
RequiredStatus StatusType
CurrentStatus StatusType
IsSatisfied bool
}
// DependencyTracker provides reactive dependency tracking over a Graph.
// It manages consumer registration, dependency relationships, and status updates
// with automatic recalculation of readiness when dependencies are satisfied.
type DependencyTracker[StatusType, ConsumerID comparable] struct {
mu sync.RWMutex
// The underlying graph that stores dependency relationships
graph *Graph[StatusType, *dependencyVertex[ConsumerID]]
// Track current status of each consumer
consumerStatus map[ConsumerID]StatusType
// Track readiness state (cached to avoid repeated graph traversal)
consumerReadiness map[ConsumerID]bool
// Track which consumers are registered
registeredConsumers map[ConsumerID]bool
// Store vertex instances for each consumer to ensure consistent references
consumerVertices map[ConsumerID]*dependencyVertex[ConsumerID]
}
// NewDependencyTracker creates a new DependencyTracker instance.
func NewDependencyTracker[StatusType, ConsumerID comparable]() *DependencyTracker[StatusType, ConsumerID] {
return &DependencyTracker[StatusType, ConsumerID]{
graph: &Graph[StatusType, *dependencyVertex[ConsumerID]]{},
consumerStatus: make(map[ConsumerID]StatusType),
consumerReadiness: make(map[ConsumerID]bool),
registeredConsumers: make(map[ConsumerID]bool),
consumerVertices: make(map[ConsumerID]*dependencyVertex[ConsumerID]),
}
}
// Register registers a new consumer as a vertex in the dependency graph.
func (dt *DependencyTracker[StatusType, ConsumerID]) Register(id ConsumerID) error {
dt.mu.Lock()
defer dt.mu.Unlock()
if dt.registeredConsumers[id] {
return xerrors.Errorf("consumer %v is already registered", id)
}
// Create and store the vertex for this consumer
vertex := &dependencyVertex[ConsumerID]{ID: id}
dt.consumerVertices[id] = vertex
dt.registeredConsumers[id] = true
dt.consumerReadiness[id] = true // New consumers start as ready (no dependencies)
return nil
}
// AddDependency adds a dependency relationship between consumers.
// The consumer depends on the dependsOn consumer reaching the requiredStatus.
func (dt *DependencyTracker[StatusType, ConsumerID]) AddDependency(consumer ConsumerID, dependsOn ConsumerID, requiredStatus StatusType) error {
dt.mu.Lock()
defer dt.mu.Unlock()
if !dt.registeredConsumers[consumer] {
return xerrors.Errorf("consumer %v is not registered", consumer)
}
if !dt.registeredConsumers[dependsOn] {
return xerrors.Errorf("consumer %v is not registered", dependsOn)
}
// Get the stored vertices for both consumers
consumerVertex := dt.consumerVertices[consumer]
dependsOnVertex := dt.consumerVertices[dependsOn]
// Add the dependency edge to the graph
// The edge goes from consumer to dependsOn, representing the dependency
err := dt.graph.AddEdge(consumerVertex, dependsOnVertex, requiredStatus)
if err != nil {
return xerrors.Errorf("failed to add dependency: %w", err)
}
// Recalculate readiness for the consumer since it now has a dependency
dt.recalculateReadinessUnsafe(consumer)
return nil
}
// UpdateStatus updates a consumer's status and recalculates readiness for affected dependents.
func (dt *DependencyTracker[StatusType, ConsumerID]) UpdateStatus(consumer ConsumerID, newStatus StatusType) error {
dt.mu.Lock()
defer dt.mu.Unlock()
if !dt.registeredConsumers[consumer] {
return ErrConsumerNotFound
}
// Update the consumer's status
dt.consumerStatus[consumer] = newStatus
// Get all consumers that depend on this one (reverse adjacent vertices)
consumerVertex := dt.consumerVertices[consumer]
dependentEdges := dt.graph.GetReverseAdjacentVertices(consumerVertex)
// Recalculate readiness for all dependents
for _, edge := range dependentEdges {
dt.recalculateReadinessUnsafe(edge.From.ID)
}
return nil
}
// IsReady checks if all dependencies for a consumer are satisfied.
func (dt *DependencyTracker[StatusType, ConsumerID]) IsReady(consumer ConsumerID) (bool, error) {
dt.mu.RLock()
defer dt.mu.RUnlock()
if !dt.registeredConsumers[consumer] {
return false, ErrConsumerNotFound
}
return dt.consumerReadiness[consumer], nil
}
// GetUnmetDependencies returns a list of unsatisfied dependencies for a consumer.
func (dt *DependencyTracker[StatusType, ConsumerID]) GetUnmetDependencies(consumer ConsumerID) ([]Dependency[StatusType, ConsumerID], error) {
dt.mu.RLock()
defer dt.mu.RUnlock()
if !dt.registeredConsumers[consumer] {
return nil, ErrConsumerNotFound
}
consumerVertex := dt.consumerVertices[consumer]
forwardEdges := dt.graph.GetForwardAdjacentVertices(consumerVertex)
var unmetDependencies []Dependency[StatusType, ConsumerID]
for _, edge := range forwardEdges {
dependsOnConsumer := edge.To.ID
requiredStatus := edge.Edge
currentStatus, exists := dt.consumerStatus[dependsOnConsumer]
if !exists {
// If the dependency consumer has no status, it's not satisfied
var zeroStatus StatusType
unmetDependencies = append(unmetDependencies, Dependency[StatusType, ConsumerID]{
Consumer: consumer,
DependsOn: dependsOnConsumer,
RequiredStatus: requiredStatus,
CurrentStatus: zeroStatus, // Zero value
IsSatisfied: false,
})
} else {
isSatisfied := currentStatus == requiredStatus
if !isSatisfied {
unmetDependencies = append(unmetDependencies, Dependency[StatusType, ConsumerID]{
Consumer: consumer,
DependsOn: dependsOnConsumer,
RequiredStatus: requiredStatus,
CurrentStatus: currentStatus,
IsSatisfied: false,
})
}
}
}
return unmetDependencies, nil
}
// recalculateReadinessUnsafe recalculates the readiness state for a consumer.
// This method assumes the caller holds the write lock.
func (dt *DependencyTracker[StatusType, ConsumerID]) recalculateReadinessUnsafe(consumer ConsumerID) {
consumerVertex := dt.consumerVertices[consumer]
forwardEdges := dt.graph.GetForwardAdjacentVertices(consumerVertex)
// If there are no dependencies, the consumer is ready
if len(forwardEdges) == 0 {
dt.consumerReadiness[consumer] = true
return
}
// Check if all dependencies are satisfied
allSatisfied := true
for _, edge := range forwardEdges {
dependsOnConsumer := edge.To.ID
requiredStatus := edge.Edge
currentStatus, exists := dt.consumerStatus[dependsOnConsumer]
if !exists || currentStatus != requiredStatus {
allSatisfied = false
break
}
}
dt.consumerReadiness[consumer] = allSatisfied
}
// GetGraph returns the underlying graph for visualization and debugging.
// This should be used carefully as it exposes the internal graph structure.
func (dt *DependencyTracker[StatusType, ConsumerID]) GetGraph() *Graph[StatusType, *dependencyVertex[ConsumerID]] {
return dt.graph
}
// ExportDOT exports the dependency graph to DOT format for visualization.
func (dt *DependencyTracker[StatusType, ConsumerID]) ExportDOT(name string) (string, error) {
return dt.graph.ToDOT(name)
}
+692
View File
@@ -0,0 +1,692 @@
package unit_test
import (
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/agent/unit"
)
type testStatus string
const (
statusInitialized testStatus = "initialized"
statusStarted testStatus = "started"
statusRunning testStatus = "running"
statusCompleted testStatus = "completed"
)
type testConsumerID string
const (
consumerA testConsumerID = "serviceA"
consumerB testConsumerID = "serviceB"
consumerC testConsumerID = "serviceC"
consumerD testConsumerID = "serviceD"
)
func TestDependencyTracker_Register(t *testing.T) {
t.Parallel()
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
t.Run("RegisterNewConsumer", func(t *testing.T) {
t.Parallel()
err := tracker.Register(consumerA)
require.NoError(t, err)
// Consumer should be ready initially (no dependencies)
ready, err := tracker.IsReady(consumerA)
require.NoError(t, err)
assert.True(t, ready)
})
t.Run("RegisterDuplicateConsumer", func(t *testing.T) {
t.Parallel()
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
err := tracker.Register(consumerA)
require.NoError(t, err)
err = tracker.Register(consumerA)
require.Error(t, err)
assert.Contains(t, err.Error(), "already registered")
})
t.Run("RegisterMultipleConsumers", func(t *testing.T) {
t.Parallel()
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
consumers := []testConsumerID{consumerA, consumerB, consumerC}
for _, consumer := range consumers {
err := tracker.Register(consumer)
require.NoError(t, err)
}
// All should be ready initially
for _, consumer := range consumers {
ready, err := tracker.IsReady(consumer)
require.NoError(t, err)
assert.True(t, ready)
}
})
}
func TestDependencyTracker_AddDependency(t *testing.T) {
t.Parallel()
t.Run("AddDependencyBetweenRegisteredConsumers", func(t *testing.T) {
t.Parallel()
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
err := tracker.Register(consumerA)
require.NoError(t, err)
err = tracker.Register(consumerB)
require.NoError(t, err)
// A depends on B being "running"
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
require.NoError(t, err)
// A should no longer be ready (depends on B)
ready, err := tracker.IsReady(consumerA)
require.NoError(t, err)
assert.False(t, ready)
// B should still be ready (no dependencies)
ready, err = tracker.IsReady(consumerB)
require.NoError(t, err)
assert.True(t, ready)
})
t.Run("AddDependencyWithUnregisteredConsumer", func(t *testing.T) {
t.Parallel()
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
err := tracker.Register(consumerA)
require.NoError(t, err)
// Try to add dependency to unregistered consumer
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
require.Error(t, err)
assert.Contains(t, err.Error(), "not registered")
})
t.Run("AddDependencyFromUnregisteredConsumer", func(t *testing.T) {
t.Parallel()
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
err := tracker.Register(consumerB)
require.NoError(t, err)
// Try to add dependency from unregistered consumer
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
require.Error(t, err)
assert.Contains(t, err.Error(), "not registered")
})
}
func TestDependencyTracker_UpdateStatus(t *testing.T) {
t.Parallel()
t.Run("UpdateStatusTriggersReadinessRecalculation", func(t *testing.T) {
t.Parallel()
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
err := tracker.Register(consumerA)
require.NoError(t, err)
err = tracker.Register(consumerB)
require.NoError(t, err)
// A depends on B being "running"
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
require.NoError(t, err)
// Initially A is not ready
ready, err := tracker.IsReady(consumerA)
require.NoError(t, err)
assert.False(t, ready)
// Update B to "running" - A should become ready
err = tracker.UpdateStatus(consumerB, statusRunning)
require.NoError(t, err)
ready, err = tracker.IsReady(consumerA)
require.NoError(t, err)
assert.True(t, ready)
})
t.Run("UpdateStatusWithUnregisteredConsumer", func(t *testing.T) {
t.Parallel()
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
err := tracker.UpdateStatus(consumerA, statusRunning)
require.Error(t, err)
assert.Equal(t, unit.ErrConsumerNotFound, err)
})
t.Run("LinearChainDependencies", func(t *testing.T) {
t.Parallel()
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
// Register all consumers
consumers := []testConsumerID{consumerA, consumerB, consumerC}
for _, consumer := range consumers {
err := tracker.Register(consumer)
require.NoError(t, err)
}
// Create chain: A depends on B being "started", B depends on C being "completed"
err := tracker.AddDependency(consumerA, consumerB, statusStarted)
require.NoError(t, err)
err = tracker.AddDependency(consumerB, consumerC, statusCompleted)
require.NoError(t, err)
// Initially only C is ready
ready, err := tracker.IsReady(consumerC)
require.NoError(t, err)
assert.True(t, ready)
ready, err = tracker.IsReady(consumerB)
require.NoError(t, err)
assert.False(t, ready)
ready, err = tracker.IsReady(consumerA)
require.NoError(t, err)
assert.False(t, ready)
// Update C to "completed" - B should become ready
err = tracker.UpdateStatus(consumerC, statusCompleted)
require.NoError(t, err)
ready, err = tracker.IsReady(consumerB)
require.NoError(t, err)
assert.True(t, ready)
ready, err = tracker.IsReady(consumerA)
require.NoError(t, err)
assert.False(t, ready)
// Update B to "started" - A should become ready
err = tracker.UpdateStatus(consumerB, statusStarted)
require.NoError(t, err)
ready, err = tracker.IsReady(consumerA)
require.NoError(t, err)
assert.True(t, ready)
})
}
func TestDependencyTracker_GetUnmetDependencies(t *testing.T) {
t.Parallel()
t.Run("GetUnmetDependenciesForConsumerWithNoDependencies", func(t *testing.T) {
t.Parallel()
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
err := tracker.Register(consumerA)
require.NoError(t, err)
unmet, err := tracker.GetUnmetDependencies(consumerA)
require.NoError(t, err)
assert.Empty(t, unmet)
})
t.Run("GetUnmetDependenciesForConsumerWithUnsatisfiedDependencies", func(t *testing.T) {
t.Parallel()
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
err := tracker.Register(consumerA)
require.NoError(t, err)
err = tracker.Register(consumerB)
require.NoError(t, err)
// A depends on B being "running"
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
require.NoError(t, err)
unmet, err := tracker.GetUnmetDependencies(consumerA)
require.NoError(t, err)
require.Len(t, unmet, 1)
assert.Equal(t, consumerA, unmet[0].Consumer)
assert.Equal(t, consumerB, unmet[0].DependsOn)
assert.Equal(t, statusRunning, unmet[0].RequiredStatus)
assert.False(t, unmet[0].IsSatisfied)
})
t.Run("GetUnmetDependenciesForConsumerWithSatisfiedDependencies", func(t *testing.T) {
t.Parallel()
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
err := tracker.Register(consumerA)
require.NoError(t, err)
err = tracker.Register(consumerB)
require.NoError(t, err)
// A depends on B being "running"
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
require.NoError(t, err)
// Update B to "running"
err = tracker.UpdateStatus(consumerB, statusRunning)
require.NoError(t, err)
unmet, err := tracker.GetUnmetDependencies(consumerA)
require.NoError(t, err)
assert.Empty(t, unmet)
})
t.Run("GetUnmetDependenciesForUnregisteredConsumer", func(t *testing.T) {
t.Parallel()
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
unmet, err := tracker.GetUnmetDependencies(consumerA)
require.Error(t, err)
assert.Equal(t, unit.ErrConsumerNotFound, err)
assert.Nil(t, unmet)
})
}
func TestDependencyTracker_ConcurrentOperations(t *testing.T) {
t.Parallel()
t.Run("ConcurrentStatusUpdates", func(t *testing.T) {
t.Parallel()
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
// Register consumers
consumers := []testConsumerID{consumerA, consumerB, consumerC, consumerD}
for _, consumer := range consumers {
err := tracker.Register(consumer)
require.NoError(t, err)
}
// Create dependencies: A depends on B, B depends on C, C depends on D
err := tracker.AddDependency(consumerA, consumerB, statusRunning)
require.NoError(t, err)
err = tracker.AddDependency(consumerB, consumerC, statusStarted)
require.NoError(t, err)
err = tracker.AddDependency(consumerC, consumerD, statusCompleted)
require.NoError(t, err)
var wg sync.WaitGroup
const numGoroutines = 10
// Launch goroutines that update statuses
errors := make([]error, numGoroutines)
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func(goroutineID int) {
defer wg.Done()
// Update D to completed (should make C ready)
err := tracker.UpdateStatus(consumerD, statusCompleted)
if err != nil {
errors[goroutineID] = err
return
}
// Update C to started (should make B ready)
err = tracker.UpdateStatus(consumerC, statusStarted)
if err != nil {
errors[goroutineID] = err
return
}
// Update B to running (should make A ready)
err = tracker.UpdateStatus(consumerB, statusRunning)
if err != nil {
errors[goroutineID] = err
return
}
}(i)
}
wg.Wait()
// Check for any errors in goroutines
for i, err := range errors {
require.NoError(t, err, "goroutine %d had error", i)
}
// All consumers should be ready after the updates
for _, consumer := range consumers {
ready, err := tracker.IsReady(consumer)
require.NoError(t, err)
assert.True(t, ready)
}
})
t.Run("ConcurrentReadinessChecks", func(t *testing.T) {
t.Parallel()
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
// Register consumers
err := tracker.Register(consumerA)
require.NoError(t, err)
err = tracker.Register(consumerB)
require.NoError(t, err)
// A depends on B being "running"
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
require.NoError(t, err)
var wg sync.WaitGroup
const numGoroutines = 20
// Launch goroutines that check readiness
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func(goroutineID int) {
defer wg.Done()
// Check readiness multiple times
for j := 0; j < 10; j++ {
ready, err := tracker.IsReady(consumerA)
require.NoError(t, err)
// Initially should be false, then true after B is updated
_ = ready
ready, err = tracker.IsReady(consumerB)
require.NoError(t, err)
// B should always be ready (no dependencies)
assert.True(t, ready)
}
}(i)
}
// Update B to "running" in the middle of readiness checks
err = tracker.UpdateStatus(consumerB, statusRunning)
require.NoError(t, err)
wg.Wait()
})
}
func TestDependencyTracker_MultipleDependencies(t *testing.T) {
t.Parallel()
t.Run("ConsumerWithMultipleDependencies", func(t *testing.T) {
t.Parallel()
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
// Register all consumers
consumers := []testConsumerID{consumerA, consumerB, consumerC, consumerD}
for _, consumer := range consumers {
err := tracker.Register(consumer)
require.NoError(t, err)
}
// A depends on B being "running" AND C being "started"
err := tracker.AddDependency(consumerA, consumerB, statusRunning)
require.NoError(t, err)
err = tracker.AddDependency(consumerA, consumerC, statusStarted)
require.NoError(t, err)
// A should not be ready (depends on both B and C)
ready, err := tracker.IsReady(consumerA)
require.NoError(t, err)
assert.False(t, ready)
// Update B to "running" - A should still not be ready (needs C too)
err = tracker.UpdateStatus(consumerB, statusRunning)
require.NoError(t, err)
ready, err = tracker.IsReady(consumerA)
require.NoError(t, err)
assert.False(t, ready)
// Update C to "started" - A should now be ready
err = tracker.UpdateStatus(consumerC, statusStarted)
require.NoError(t, err)
ready, err = tracker.IsReady(consumerA)
require.NoError(t, err)
assert.True(t, ready)
})
t.Run("ComplexDependencyChain", func(t *testing.T) {
t.Parallel()
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
// Register all consumers
consumers := []testConsumerID{consumerA, consumerB, consumerC, consumerD}
for _, consumer := range consumers {
err := tracker.Register(consumer)
require.NoError(t, err)
}
// Create complex dependency graph:
// A depends on B being "running" AND C being "started"
// B depends on D being "completed"
// C depends on D being "completed"
err := tracker.AddDependency(consumerA, consumerB, statusRunning)
require.NoError(t, err)
err = tracker.AddDependency(consumerA, consumerC, statusStarted)
require.NoError(t, err)
err = tracker.AddDependency(consumerB, consumerD, statusCompleted)
require.NoError(t, err)
err = tracker.AddDependency(consumerC, consumerD, statusCompleted)
require.NoError(t, err)
// Initially only D is ready
ready, err := tracker.IsReady(consumerD)
require.NoError(t, err)
assert.True(t, ready)
ready, err = tracker.IsReady(consumerB)
require.NoError(t, err)
assert.False(t, ready)
ready, err = tracker.IsReady(consumerC)
require.NoError(t, err)
assert.False(t, ready)
ready, err = tracker.IsReady(consumerA)
require.NoError(t, err)
assert.False(t, ready)
// Update D to "completed" - B and C should become ready
err = tracker.UpdateStatus(consumerD, statusCompleted)
require.NoError(t, err)
ready, err = tracker.IsReady(consumerB)
require.NoError(t, err)
assert.True(t, ready)
ready, err = tracker.IsReady(consumerC)
require.NoError(t, err)
assert.True(t, ready)
ready, err = tracker.IsReady(consumerA)
require.NoError(t, err)
assert.False(t, ready)
// Update B to "running" - A should still not be ready (needs C)
err = tracker.UpdateStatus(consumerB, statusRunning)
require.NoError(t, err)
ready, err = tracker.IsReady(consumerA)
require.NoError(t, err)
assert.False(t, ready)
// Update C to "started" - A should now be ready
err = tracker.UpdateStatus(consumerC, statusStarted)
require.NoError(t, err)
ready, err = tracker.IsReady(consumerA)
require.NoError(t, err)
assert.True(t, ready)
})
t.Run("DifferentStatusTypes", func(t *testing.T) {
t.Parallel()
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
// Register consumers
err := tracker.Register(consumerA)
require.NoError(t, err)
err = tracker.Register(consumerB)
require.NoError(t, err)
err = tracker.Register(consumerC)
require.NoError(t, err)
// A depends on B being "running" AND C being "completed"
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
require.NoError(t, err)
err = tracker.AddDependency(consumerA, consumerC, statusCompleted)
require.NoError(t, err)
// Update B to "running" but not C - A should not be ready
err = tracker.UpdateStatus(consumerB, statusRunning)
require.NoError(t, err)
ready, err := tracker.IsReady(consumerA)
require.NoError(t, err)
assert.False(t, ready)
// Update C to "completed" - A should now be ready
err = tracker.UpdateStatus(consumerC, statusCompleted)
require.NoError(t, err)
ready, err = tracker.IsReady(consumerA)
require.NoError(t, err)
assert.True(t, ready)
})
}
func TestDependencyTracker_ErrorCases(t *testing.T) {
t.Parallel()
t.Run("UpdateStatusWithUnregisteredConsumer", func(t *testing.T) {
t.Parallel()
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
err := tracker.UpdateStatus(consumerA, statusRunning)
require.Error(t, err)
assert.Equal(t, unit.ErrConsumerNotFound, err)
})
t.Run("IsReadyWithUnregisteredConsumer", func(t *testing.T) {
t.Parallel()
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
ready, err := tracker.IsReady(consumerA)
require.Error(t, err)
assert.Equal(t, unit.ErrConsumerNotFound, err)
assert.False(t, ready)
})
t.Run("GetUnmetDependenciesWithUnregisteredConsumer", func(t *testing.T) {
t.Parallel()
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
unmet, err := tracker.GetUnmetDependencies(consumerA)
require.Error(t, err)
assert.Equal(t, unit.ErrConsumerNotFound, err)
assert.Nil(t, unmet)
})
t.Run("AddDependencyWithUnregisteredConsumers", func(t *testing.T) {
t.Parallel()
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
// Try to add dependency with unregistered consumers
err := tracker.AddDependency(consumerA, consumerB, statusRunning)
require.Error(t, err)
assert.Contains(t, err.Error(), "not registered")
})
t.Run("CyclicDependencyDetection", func(t *testing.T) {
t.Parallel()
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
// Register consumers
err := tracker.Register(consumerA)
require.NoError(t, err)
err = tracker.Register(consumerB)
require.NoError(t, err)
// A depends on B
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
require.NoError(t, err)
// Try to make B depend on A (creates cycle)
err = tracker.AddDependency(consumerB, consumerA, statusStarted)
require.Error(t, err)
assert.Contains(t, err.Error(), "would create a cycle")
})
}
func TestDependencyTracker_ToDOT(t *testing.T) {
t.Parallel()
t.Run("ExportSimpleGraph", func(t *testing.T) {
t.Parallel()
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
// Register consumers
err := tracker.Register(consumerA)
require.NoError(t, err)
err = tracker.Register(consumerB)
require.NoError(t, err)
// Add dependency
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
require.NoError(t, err)
dot, err := tracker.ExportDOT("test")
require.NoError(t, err)
assert.NotEmpty(t, dot)
assert.Contains(t, dot, "digraph")
})
t.Run("ExportComplexGraph", func(t *testing.T) {
t.Parallel()
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
// Register all consumers
consumers := []testConsumerID{consumerA, consumerB, consumerC, consumerD}
for _, consumer := range consumers {
err := tracker.Register(consumer)
require.NoError(t, err)
}
// Create complex dependency graph
// A depends on B and C, B depends on D, C depends on D
err := tracker.AddDependency(consumerA, consumerB, statusRunning)
require.NoError(t, err)
err = tracker.AddDependency(consumerA, consumerC, statusStarted)
require.NoError(t, err)
err = tracker.AddDependency(consumerB, consumerD, statusCompleted)
require.NoError(t, err)
err = tracker.AddDependency(consumerC, consumerD, statusCompleted)
require.NoError(t, err)
dot, err := tracker.ExportDOT("complex")
require.NoError(t, err)
assert.NotEmpty(t, dot)
assert.Contains(t, dot, "digraph")
})
}
+254
View File
@@ -0,0 +1,254 @@
package agentsdk
import (
"context"
"encoding/json"
"fmt"
"net"
"os"
"path/filepath"
"time"
"golang.org/x/xerrors"
)
// SocketClient provides a client for communicating with the agent socket
type SocketClient struct {
conn net.Conn
}
// SocketConfig holds configuration for the socket client
type SocketConfig struct {
Path string // Socket path (optional, will auto-discover if not set)
}
// NewSocketClient creates a new socket client
func NewSocketClient(config SocketConfig) (*SocketClient, error) {
path := config.Path
if path == "" {
var err error
path, err = discoverSocketPath()
if err != nil {
return nil, xerrors.Errorf("discover socket path: %w", err)
}
}
conn, err := net.Dial("unix", path)
if err != nil {
return nil, xerrors.Errorf("connect to socket: %w", err)
}
return &SocketClient{
conn: conn,
}, nil
}
// Close closes the socket connection
func (c *SocketClient) Close() error {
return c.conn.Close()
}
// Ping sends a ping request to the agent
func (c *SocketClient) Ping(ctx context.Context) (*PingResponse, error) {
req := &Request{
Version: "1.0",
Method: "ping",
ID: generateRequestID(),
}
resp, err := c.sendRequest(ctx, req)
if err != nil {
return nil, err
}
if resp.Error != nil {
return nil, xerrors.Errorf("ping error: %s", resp.Error.Message)
}
var pingResp PingResponse
if err := json.Unmarshal(resp.Result, &pingResp); err != nil {
return nil, xerrors.Errorf("unmarshal ping response: %w", err)
}
return &pingResp, nil
}
// Health sends a health check request to the agent
func (c *SocketClient) Health(ctx context.Context) (*HealthResponse, error) {
req := &Request{
Version: "1.0",
Method: "health",
ID: generateRequestID(),
}
resp, err := c.sendRequest(ctx, req)
if err != nil {
return nil, err
}
if resp.Error != nil {
return nil, xerrors.Errorf("health error: %s", resp.Error.Message)
}
var healthResp HealthResponse
if err := json.Unmarshal(resp.Result, &healthResp); err != nil {
return nil, xerrors.Errorf("unmarshal health response: %w", err)
}
return &healthResp, nil
}
// AgentInfo sends an agent info request
func (c *SocketClient) AgentInfo(ctx context.Context) (*AgentInfo, error) {
req := &Request{
Version: "1.0",
Method: "agent.info",
ID: generateRequestID(),
}
resp, err := c.sendRequest(ctx, req)
if err != nil {
return nil, err
}
if resp.Error != nil {
return nil, xerrors.Errorf("agent info error: %s", resp.Error.Message)
}
var agentInfo AgentInfo
if err := json.Unmarshal(resp.Result, &agentInfo); err != nil {
return nil, xerrors.Errorf("unmarshal agent info response: %w", err)
}
return &agentInfo, nil
}
// ListMethods lists available methods
func (c *SocketClient) ListMethods(ctx context.Context) ([]string, error) {
req := &Request{
Version: "1.0",
Method: "methods.list",
ID: generateRequestID(),
}
resp, err := c.sendRequest(ctx, req)
if err != nil {
return nil, err
}
if resp.Error != nil {
return nil, xerrors.Errorf("list methods error: %s", resp.Error.Message)
}
var methods []string
if err := json.Unmarshal(resp.Result, &methods); err != nil {
return nil, xerrors.Errorf("unmarshal methods response: %w", err)
}
return methods, nil
}
// sendRequest sends a request and returns the response
func (c *SocketClient) sendRequest(_ context.Context, req *Request) (*Response, error) {
// Set write deadline
if err := c.conn.SetWriteDeadline(time.Now().Add(30 * time.Second)); err != nil {
return nil, xerrors.Errorf("set write deadline: %w", err)
}
// Send request
if err := json.NewEncoder(c.conn).Encode(req); err != nil {
return nil, xerrors.Errorf("send request: %w", err)
}
// Set read deadline
if err := c.conn.SetReadDeadline(time.Now().Add(30 * time.Second)); err != nil {
return nil, xerrors.Errorf("set read deadline: %w", err)
}
// Read response
var resp Response
if err := json.NewDecoder(c.conn).Decode(&resp); err != nil {
return nil, xerrors.Errorf("read response: %w", err)
}
return &resp, nil
}
// discoverSocketPath discovers the agent socket path
func discoverSocketPath() (string, error) {
// Check environment variable first
if path := os.Getenv("CODER_AGENT_SOCKET_PATH"); path != "" {
return path, nil
}
// Try common socket paths
paths := []string{
// XDG runtime directory
filepath.Join(os.Getenv("XDG_RUNTIME_DIR"), "coder-agent.sock"),
// User-specific temp directory
filepath.Join(os.TempDir(), fmt.Sprintf("coder-agent-%d.sock", os.Getuid())),
// Fallback temp directory
filepath.Join(os.TempDir(), "coder-agent.sock"),
}
for _, path := range paths {
if path == "" {
continue
}
if _, err := os.Stat(path); err == nil {
return path, nil
}
}
return "", xerrors.New("agent socket not found")
}
// generateRequestID generates a unique request ID
func generateRequestID() string {
return fmt.Sprintf("%d", time.Now().UnixNano())
}
// Request represents a socket request
type Request struct {
Version string `json:"version"`
Method string `json:"method"`
ID string `json:"id,omitempty"`
Params json.RawMessage `json:"params,omitempty"`
}
// Response represents a socket response
type Response struct {
Version string `json:"version"`
ID string `json:"id,omitempty"`
Result json.RawMessage `json:"result,omitempty"`
Error *Error `json:"error,omitempty"`
}
// Error represents a socket error
type Error struct {
Code int `json:"code"`
Message string `json:"message"`
Data any `json:"data,omitempty"`
}
// PingResponse represents a ping response
type PingResponse struct {
Message string `json:"message"`
Timestamp time.Time `json:"timestamp"`
}
// HealthResponse represents a health check response
type HealthResponse struct {
Status string `json:"status"`
Timestamp time.Time `json:"timestamp"`
Uptime string `json:"uptime"`
}
// AgentInfo represents agent information
type AgentInfo struct {
ID string `json:"id"`
Version string `json:"version"`
Status string `json:"status"`
StartedAt time.Time `json:"started_at"`
Uptime string `json:"uptime"`
}
+110
View File
@@ -0,0 +1,110 @@
package agentsdk_test
import (
"context"
"path/filepath"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"cdr.dev/slog"
"github.com/coder/coder/v2/agent/agentsocket"
"github.com/coder/coder/v2/codersdk/agentsdk"
)
func TestSocketClient_Integration(t *testing.T) {
// Create temporary socket path
tmpDir := t.TempDir()
socketPath := filepath.Join(tmpDir, "test.sock")
// Set environment variable for socket discovery
t.Setenv("CODER_AGENT_SOCKET_PATH", socketPath)
// Start a real socket server
server := startSocketServer(t, socketPath)
defer server.Stop()
// Create client
client, err := agentsdk.NewSocketClient(agentsdk.SocketConfig{})
require.NoError(t, err)
defer client.Close()
// Test ping
ctx := context.Background()
pingResp, err := client.Ping(ctx)
require.NoError(t, err)
assert.Equal(t, "pong", pingResp.Message)
assert.False(t, pingResp.Timestamp.IsZero())
// Test health
healthResp, err := client.Health(ctx)
require.NoError(t, err)
assert.Equal(t, "ready", healthResp.Status)
assert.NotEmpty(t, healthResp.Uptime)
// Test agent info
agentInfo, err := client.AgentInfo(ctx)
require.NoError(t, err)
assert.Equal(t, "test-agent", agentInfo.ID)
assert.Equal(t, "1.0.0", agentInfo.Version)
assert.Equal(t, "ready", agentInfo.Status)
// Test list methods
methods, err := client.ListMethods(ctx)
require.NoError(t, err)
assert.Contains(t, methods, "ping")
assert.Contains(t, methods, "health")
assert.Contains(t, methods, "agent.info")
}
func TestSocketClient_Discovery(t *testing.T) {
// Test with explicit path
tmpDir := t.TempDir()
socketPath := filepath.Join(tmpDir, "test.sock")
server := startSocketServer(t, socketPath)
defer server.Stop()
client, err := agentsdk.NewSocketClient(agentsdk.SocketConfig{Path: socketPath})
require.NoError(t, err)
defer client.Close()
ctx := context.Background()
_, err = client.Ping(ctx)
require.NoError(t, err)
}
func TestSocketClient_ErrorHandling(t *testing.T) {
// Test with non-existent socket
client, err := agentsdk.NewSocketClient(agentsdk.SocketConfig{Path: "/nonexistent/socket"})
assert.Error(t, err)
assert.Nil(t, client)
}
// startSocketServer starts a real socket server for testing
func startSocketServer(t *testing.T, path string) *agentsocket.Server {
// Create server
server := agentsocket.NewServer(agentsocket.Config{
Path: path,
Logger: slog.Make().Leveled(slog.LevelDebug),
})
// Register default handlers with test data
handlerCtx := agentsocket.CreateHandlerContext(
"test-agent",
"1.0.0",
"ready",
time.Now().Add(-time.Hour),
slog.Make(),
)
agentsocket.RegisterDefaultHandlers(server, handlerCtx)
// Start server
err := server.Start()
require.NoError(t, err)
return server
}