Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 34c1370090 | |||
| 851c4f907c | |||
| e3dfe45f35 |
@@ -40,6 +40,7 @@ import (
|
||||
"github.com/coder/coder/v2/agent/agentcontainers"
|
||||
"github.com/coder/coder/v2/agent/agentexec"
|
||||
"github.com/coder/coder/v2/agent/agentscripts"
|
||||
"github.com/coder/coder/v2/agent/agentsocket"
|
||||
"github.com/coder/coder/v2/agent/agentssh"
|
||||
"github.com/coder/coder/v2/agent/proto"
|
||||
"github.com/coder/coder/v2/agent/proto/resourcesmonitor"
|
||||
@@ -91,6 +92,7 @@ type Options struct {
|
||||
Devcontainers bool
|
||||
DevcontainerAPIOptions []agentcontainers.Option // Enable Devcontainers for these to be effective.
|
||||
Clock quartz.Clock
|
||||
SocketPath string // Path for the agent socket server
|
||||
}
|
||||
|
||||
type Client interface {
|
||||
@@ -190,6 +192,7 @@ func New(options Options) Agent {
|
||||
|
||||
devcontainers: options.Devcontainers,
|
||||
containerAPIOptions: options.DevcontainerAPIOptions,
|
||||
socketPath: options.SocketPath,
|
||||
}
|
||||
// Initially, we have a closed channel, reflecting the fact that we are not initially connected.
|
||||
// Each time we connect we replace the channel (while holding the closeMutex) with a new one
|
||||
@@ -271,6 +274,10 @@ type agent struct {
|
||||
devcontainers bool
|
||||
containerAPIOptions []agentcontainers.Option
|
||||
containerAPI *agentcontainers.API
|
||||
|
||||
// Socket server for CLI communication
|
||||
socketPath string
|
||||
socketServer *agentsocket.Server
|
||||
}
|
||||
|
||||
func (a *agent) TailnetConn() *tailnet.Conn {
|
||||
@@ -350,9 +357,69 @@ func (a *agent) init() {
|
||||
s.ExperimentalContainers = a.devcontainers
|
||||
},
|
||||
)
|
||||
|
||||
// Initialize socket server for CLI communication
|
||||
a.initSocketServer()
|
||||
|
||||
go a.runLoop()
|
||||
}
|
||||
|
||||
// initSocketServer initializes the socket server for CLI communication
|
||||
func (a *agent) initSocketServer() {
|
||||
// Get socket path from options or environment
|
||||
socketPath := a.getSocketPath()
|
||||
if socketPath == "" {
|
||||
a.logger.Debug(a.hardCtx, "socket server disabled (no path configured)")
|
||||
return
|
||||
}
|
||||
|
||||
// Create socket server
|
||||
server := agentsocket.NewServer(agentsocket.Config{
|
||||
Path: socketPath,
|
||||
Logger: a.logger.Named("socket"),
|
||||
})
|
||||
|
||||
// Register default handlers
|
||||
handlerCtx := agentsocket.CreateHandlerContext(
|
||||
"", // Agent ID will be set when manifest is available
|
||||
buildinfo.Version(),
|
||||
"starting",
|
||||
time.Now(),
|
||||
a.logger,
|
||||
)
|
||||
agentsocket.RegisterDefaultHandlers(server, handlerCtx)
|
||||
|
||||
// Start the server
|
||||
if err := server.Start(); err != nil {
|
||||
a.logger.Warn(a.hardCtx, "failed to start socket server", slog.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
a.socketServer = server
|
||||
a.logger.Info(a.hardCtx, "socket server started", slog.F("path", socketPath))
|
||||
}
|
||||
|
||||
// getSocketPath returns the socket path from options or environment
|
||||
func (a *agent) getSocketPath() string {
|
||||
// Check if socket path is explicitly configured
|
||||
if a.getSocketPathFromOptions() != "" {
|
||||
return a.getSocketPathFromOptions()
|
||||
}
|
||||
|
||||
// Check environment variable
|
||||
if path := os.Getenv("CODER_AGENT_SOCKET_PATH"); path != "" {
|
||||
return path
|
||||
}
|
||||
|
||||
// Return empty to disable socket server
|
||||
return ""
|
||||
}
|
||||
|
||||
// getSocketPathFromOptions returns the socket path from agent options
|
||||
func (a *agent) getSocketPathFromOptions() string {
|
||||
return a.socketPath
|
||||
}
|
||||
|
||||
// runLoop attempts to start the agent in a retry loop.
|
||||
// Coder may be offline temporarily, a connection issue
|
||||
// may be happening, but regardless after the intermittent
|
||||
@@ -1931,6 +1998,13 @@ func (a *agent) Close() error {
|
||||
a.logger.Error(a.hardCtx, "container API close", slog.Error(err))
|
||||
}
|
||||
|
||||
// Close socket server
|
||||
if a.socketServer != nil {
|
||||
if err := a.socketServer.Stop(); err != nil {
|
||||
a.logger.Error(a.hardCtx, "socket server close", slog.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for the graceful shutdown to complete, but don't wait forever so
|
||||
// that we don't break user expectations.
|
||||
go func() {
|
||||
|
||||
@@ -0,0 +1,214 @@
|
||||
# Agent Socket API
|
||||
|
||||
The Agent Socket API provides a local communication channel between CLI commands running within a workspace and the Coder agent process. This enables new CLI commands to interact directly with the agent without going through the control plane.
|
||||
|
||||
## Overview
|
||||
|
||||
The socket server runs within the agent process and listens on a Unix domain socket (or named pipe on Windows). CLI commands can connect to this socket to query agent information, check health status, and perform other operations.
|
||||
|
||||
## Architecture
|
||||
|
||||
### Socket Server
|
||||
- **Location**: `agent/agentsocket/`
|
||||
- **Protocol**: JSON-RPC 2.0 over Unix domain socket
|
||||
- **Platform Support**: Linux, macOS, Windows 10+ (build 17063+)
|
||||
- **Authentication**: Pluggable middleware (no-auth by default)
|
||||
|
||||
### Client Library
|
||||
- **Location**: `codersdk/agentsdk/socket_client.go`
|
||||
- **Auto-discovery**: Automatically finds socket path
|
||||
- **Type-safe**: Go client with proper error handling
|
||||
|
||||
## Socket Path Discovery
|
||||
|
||||
The socket path is determined in the following order:
|
||||
|
||||
1. **Environment Variable**: `CODER_AGENT_SOCKET_PATH`
|
||||
2. **XDG Runtime Directory**: `$XDG_RUNTIME_DIR/coder-agent.sock`
|
||||
3. **User Temp Directory**: `/tmp/coder-agent-{uid}.sock`
|
||||
4. **Fallback**: `/tmp/coder-agent.sock`
|
||||
|
||||
## Protocol
|
||||
|
||||
### Request Format
|
||||
```json
|
||||
{
|
||||
"version": "1.0",
|
||||
"method": "ping",
|
||||
"id": "request-123",
|
||||
"params": {}
|
||||
}
|
||||
```
|
||||
|
||||
### Response Format
|
||||
```json
|
||||
{
|
||||
"version": "1.0",
|
||||
"id": "request-123",
|
||||
"result": {
|
||||
"message": "pong",
|
||||
"timestamp": "2024-01-01T00:00:00Z"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Error Format
|
||||
```json
|
||||
{
|
||||
"version": "1.0",
|
||||
"id": "request-123",
|
||||
"error": {
|
||||
"code": -32601,
|
||||
"message": "Method not found",
|
||||
"data": "nonexistent"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Available Methods
|
||||
|
||||
### Core Methods
|
||||
- `ping` - Health check with timestamp
|
||||
- `health` - Agent status and uptime
|
||||
- `agent.info` - Detailed agent information
|
||||
- `methods.list` - List available methods
|
||||
|
||||
### Example Usage
|
||||
|
||||
```go
|
||||
// Create client
|
||||
client, err := agentsdk.NewSocketClient(agentsdk.SocketConfig{})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
// Ping the agent
|
||||
pingResp, err := client.Ping(ctx)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
fmt.Printf("Agent responded: %s\n", pingResp.Message)
|
||||
|
||||
// Get agent info
|
||||
info, err := client.AgentInfo(ctx)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
fmt.Printf("Agent ID: %s, Version: %s\n", info.ID, info.Version)
|
||||
```
|
||||
|
||||
## Adding New Handlers
|
||||
|
||||
### Server Side
|
||||
```go
|
||||
// Register a new handler
|
||||
server.RegisterHandler("custom.method", func(ctx Context, req *Request) (*Response, error) {
|
||||
// Handle the request
|
||||
result := map[string]string{"status": "ok"}
|
||||
return NewResponse(req.ID, result)
|
||||
})
|
||||
```
|
||||
|
||||
### Client Side
|
||||
```go
|
||||
// Add method to client
|
||||
func (c *SocketClient) CustomMethod(ctx context.Context) (*CustomResponse, error) {
|
||||
req := &Request{
|
||||
Version: "1.0",
|
||||
Method: "custom.method",
|
||||
ID: generateRequestID(),
|
||||
}
|
||||
|
||||
resp, err := c.sendRequest(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.Error != nil {
|
||||
return nil, fmt.Errorf("custom method error: %s", resp.Error.Message)
|
||||
}
|
||||
|
||||
var result CustomResponse
|
||||
if err := json.Unmarshal(resp.Result, &result); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal response: %w", err)
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
```
|
||||
|
||||
## Authentication
|
||||
|
||||
The socket server supports pluggable authentication middleware. By default, no authentication is performed (suitable for local-only communication).
|
||||
|
||||
### Custom Authentication
|
||||
```go
|
||||
type CustomAuthMiddleware struct {
|
||||
// Add auth fields
|
||||
}
|
||||
|
||||
func (m *CustomAuthMiddleware) Authenticate(ctx context.Context, conn net.Conn) (context.Context, error) {
|
||||
// Implement authentication logic
|
||||
// Return context with auth info or error
|
||||
return ctx, nil
|
||||
}
|
||||
|
||||
// Use in server config
|
||||
server := agentsocket.NewServer(agentsocket.Config{
|
||||
Path: socketPath,
|
||||
Logger: logger,
|
||||
AuthMiddleware: &CustomAuthMiddleware{},
|
||||
})
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
### Agent Options
|
||||
```go
|
||||
options := agent.Options{
|
||||
// ... other options
|
||||
SocketPath: "/custom/path/agent.sock", // Optional, uses auto-discovery if empty
|
||||
}
|
||||
```
|
||||
|
||||
### Environment Variables
|
||||
- `CODER_AGENT_SOCKET_PATH` - Override socket path
|
||||
- `XDG_RUNTIME_DIR` - Used for socket path discovery
|
||||
|
||||
## Error Codes
|
||||
|
||||
| Code | Description |
|
||||
|------|-------------|
|
||||
| -32700 | Parse error |
|
||||
| -32600 | Invalid request |
|
||||
| -32601 | Method not found |
|
||||
| -32602 | Invalid params |
|
||||
| -32603 | Internal error |
|
||||
|
||||
## Platform Support
|
||||
|
||||
### Unix-like Systems (Linux, macOS)
|
||||
- Uses Unix domain sockets
|
||||
- Socket file permissions: 600 (owner read/write only)
|
||||
- Auto-cleanup on shutdown
|
||||
|
||||
### Windows
|
||||
- Uses Unix domain sockets (Windows 10 build 17063+)
|
||||
- Falls back to named pipes if needed
|
||||
- Simplified permission handling
|
||||
|
||||
## Security Considerations
|
||||
|
||||
1. **Local Only**: Socket is only accessible from within the workspace
|
||||
2. **File Permissions**: Socket file is restricted to owner only
|
||||
3. **No Network Access**: Unix domain sockets don't traverse network
|
||||
4. **Authentication Ready**: Middleware pattern allows future auth implementation
|
||||
|
||||
## Future Extensibility
|
||||
|
||||
The design supports:
|
||||
- **Protocol Versioning**: Request includes version field
|
||||
- **Multiple Transports**: Interface-based design allows TCP/WebSocket later
|
||||
- **Auth Plugins**: Middleware pattern for various auth methods
|
||||
- **Custom Handlers**: Simple registration pattern for new commands
|
||||
@@ -0,0 +1,23 @@
|
||||
package agentsocket
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
)
|
||||
|
||||
// AuthMiddleware defines the interface for authentication middleware
|
||||
type AuthMiddleware interface {
|
||||
// Authenticate authenticates a connection and returns a context with auth info
|
||||
Authenticate(ctx context.Context, conn net.Conn) (context.Context, error)
|
||||
}
|
||||
|
||||
// NoAuthMiddleware is a no-op authentication middleware
|
||||
type NoAuthMiddleware struct{}
|
||||
|
||||
// Authenticate implements AuthMiddleware but performs no authentication
|
||||
func (*NoAuthMiddleware) Authenticate(ctx context.Context, conn net.Conn) (context.Context, error) {
|
||||
return ctx, nil
|
||||
}
|
||||
|
||||
// Ensure NoAuthMiddleware implements AuthMiddleware
|
||||
var _ AuthMiddleware = (*NoAuthMiddleware)(nil)
|
||||
@@ -0,0 +1,108 @@
|
||||
package agentsocket
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"cdr.dev/slog"
|
||||
)
|
||||
|
||||
// AgentInfo represents information about the agent
|
||||
type AgentInfo struct {
|
||||
ID string `json:"id"`
|
||||
Version string `json:"version"`
|
||||
Status string `json:"status"`
|
||||
StartedAt time.Time `json:"started_at"`
|
||||
Uptime string `json:"uptime"`
|
||||
}
|
||||
|
||||
// PingResponse represents a ping response
|
||||
type PingResponse struct {
|
||||
Message string `json:"message"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
}
|
||||
|
||||
// HealthResponse represents a health check response
|
||||
type HealthResponse struct {
|
||||
Status string `json:"status"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Uptime string `json:"uptime"`
|
||||
}
|
||||
|
||||
// HandlerContext provides context for handlers
|
||||
type HandlerContext struct {
|
||||
AgentID string
|
||||
Version string
|
||||
Status string
|
||||
StartedAt time.Time
|
||||
Logger slog.Logger
|
||||
}
|
||||
|
||||
// NewHandlers creates the default set of handlers
|
||||
func NewHandlers(handlerCtx HandlerContext) map[string]Handler {
|
||||
handlers := make(map[string]Handler)
|
||||
|
||||
// Ping handler
|
||||
handlers["ping"] = func(_ Context, req *Request) (*Response, error) {
|
||||
resp := PingResponse{
|
||||
Message: "pong",
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
return NewResponse(req.ID, resp)
|
||||
}
|
||||
|
||||
// Health check handler
|
||||
handlers["health"] = func(_ Context, req *Request) (*Response, error) {
|
||||
uptime := time.Since(handlerCtx.StartedAt)
|
||||
resp := HealthResponse{
|
||||
Status: handlerCtx.Status,
|
||||
Timestamp: time.Now(),
|
||||
Uptime: uptime.String(),
|
||||
}
|
||||
return NewResponse(req.ID, resp)
|
||||
}
|
||||
|
||||
// Agent info handler
|
||||
handlers["agent.info"] = func(_ Context, req *Request) (*Response, error) {
|
||||
uptime := time.Since(handlerCtx.StartedAt)
|
||||
resp := AgentInfo{
|
||||
ID: handlerCtx.AgentID,
|
||||
Version: handlerCtx.Version,
|
||||
Status: handlerCtx.Status,
|
||||
StartedAt: handlerCtx.StartedAt,
|
||||
Uptime: uptime.String(),
|
||||
}
|
||||
return NewResponse(req.ID, resp)
|
||||
}
|
||||
|
||||
// List methods handler
|
||||
handlers["methods.list"] = func(_ Context, req *Request) (*Response, error) {
|
||||
methods := []string{
|
||||
"ping",
|
||||
"health",
|
||||
"agent.info",
|
||||
"methods.list",
|
||||
}
|
||||
return NewResponse(req.ID, methods)
|
||||
}
|
||||
|
||||
return handlers
|
||||
}
|
||||
|
||||
// RegisterDefaultHandlers registers the default set of handlers with a server
|
||||
func RegisterDefaultHandlers(server *Server, ctx HandlerContext) {
|
||||
handlers := NewHandlers(ctx)
|
||||
for method, handler := range handlers {
|
||||
server.RegisterHandler(method, handler)
|
||||
}
|
||||
}
|
||||
|
||||
// CreateHandlerContext creates a handler context from agent information
|
||||
func CreateHandlerContext(agentID, version, status string, startedAt time.Time, logger slog.Logger) HandlerContext {
|
||||
return HandlerContext{
|
||||
AgentID: agentID,
|
||||
Version: version,
|
||||
Status: status,
|
||||
StartedAt: startedAt,
|
||||
Logger: logger,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,83 @@
|
||||
package agentsocket
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
// Protocol version for the agent socket API
|
||||
const ProtocolVersion = "1.0"
|
||||
|
||||
// Request represents an incoming request to the agent socket
|
||||
type Request struct {
|
||||
Version string `json:"version"`
|
||||
Method string `json:"method"`
|
||||
ID string `json:"id,omitempty"`
|
||||
Params json.RawMessage `json:"params,omitempty"`
|
||||
}
|
||||
|
||||
// Response represents a response from the agent socket
|
||||
type Response struct {
|
||||
Version string `json:"version"`
|
||||
ID string `json:"id,omitempty"`
|
||||
Result json.RawMessage `json:"result,omitempty"`
|
||||
Error *Error `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// Error represents an error in the response
|
||||
type Error struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data any `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
// Standard error codes
|
||||
const (
|
||||
ErrCodeParseError = -32700
|
||||
ErrCodeInvalidRequest = -32600
|
||||
ErrCodeMethodNotFound = -32601
|
||||
ErrCodeInvalidParams = -32602
|
||||
ErrCodeInternalError = -32603
|
||||
)
|
||||
|
||||
// NewError creates a new error response
|
||||
func NewError(code int, message string, data any) *Error {
|
||||
return &Error{
|
||||
Code: code,
|
||||
Message: message,
|
||||
Data: data,
|
||||
}
|
||||
}
|
||||
|
||||
// NewResponse creates a successful response
|
||||
func NewResponse(id string, result any) (*Response, error) {
|
||||
resultBytes, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("marshal result: %w", err)
|
||||
}
|
||||
|
||||
return &Response{
|
||||
Version: ProtocolVersion,
|
||||
ID: id,
|
||||
Result: resultBytes,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// NewErrorResponse creates an error response
|
||||
func NewErrorResponse(id string, err *Error) *Response {
|
||||
return &Response{
|
||||
Version: ProtocolVersion,
|
||||
ID: id,
|
||||
Error: err,
|
||||
}
|
||||
}
|
||||
|
||||
// Handler represents a function that can handle a request
|
||||
type Handler func(ctx Context, req *Request) (*Response, error)
|
||||
|
||||
// Context provides context for request handling
|
||||
type Context struct {
|
||||
// Additional context can be added here in the future
|
||||
// For now, this is a placeholder for future auth context, etc.
|
||||
}
|
||||
@@ -0,0 +1,266 @@
|
||||
package agentsocket
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog"
|
||||
)
|
||||
|
||||
// Server represents the agent socket server
|
||||
type Server struct {
|
||||
logger slog.Logger
|
||||
path string
|
||||
listener net.Listener
|
||||
handlers map[string]Handler
|
||||
authMiddleware AuthMiddleware
|
||||
mu sync.RWMutex
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// Config holds configuration for the socket server
|
||||
type Config struct {
|
||||
Path string
|
||||
Logger slog.Logger
|
||||
AuthMiddleware AuthMiddleware
|
||||
}
|
||||
|
||||
// NewServer creates a new agent socket server
|
||||
func NewServer(config Config) *Server {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
server := &Server{
|
||||
logger: config.Logger.Named("agentsocket"),
|
||||
path: config.Path,
|
||||
handlers: make(map[string]Handler),
|
||||
authMiddleware: config.AuthMiddleware,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
// Set default auth middleware if none provided
|
||||
if server.authMiddleware == nil {
|
||||
server.authMiddleware = &NoAuthMiddleware{}
|
||||
}
|
||||
|
||||
return server
|
||||
}
|
||||
|
||||
// Start starts the socket server
|
||||
func (s *Server) Start() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.listener != nil {
|
||||
return xerrors.New("server already started")
|
||||
}
|
||||
|
||||
// Get socket path
|
||||
path := s.path
|
||||
if path == "" {
|
||||
var err error
|
||||
path, err = getDefaultSocketPath()
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get default socket path: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Check if socket is available
|
||||
if !isSocketAvailable(path) {
|
||||
return xerrors.Errorf("socket path %s is not available", path)
|
||||
}
|
||||
|
||||
// Create socket listener
|
||||
listener, err := createSocket(s.ctx, path)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create socket: %w", err)
|
||||
}
|
||||
|
||||
s.listener = listener
|
||||
s.path = path
|
||||
|
||||
s.logger.Info(s.ctx, "agent socket server started", slog.F("path", path))
|
||||
|
||||
// Start accepting connections
|
||||
s.wg.Add(1)
|
||||
go s.acceptConnections()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop stops the socket server
|
||||
func (s *Server) Stop() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.listener == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
s.logger.Info(s.ctx, "stopping agent socket server")
|
||||
|
||||
// Cancel context to stop accepting new connections
|
||||
s.cancel()
|
||||
|
||||
// Close listener
|
||||
if err := s.listener.Close(); err != nil {
|
||||
s.logger.Warn(s.ctx, "error closing socket listener", slog.Error(err))
|
||||
}
|
||||
|
||||
// Wait for all connections to finish
|
||||
s.wg.Wait()
|
||||
|
||||
// Clean up socket file
|
||||
if err := cleanupSocket(s.path); err != nil {
|
||||
s.logger.Warn(s.ctx, "error cleaning up socket file", slog.Error(err))
|
||||
}
|
||||
|
||||
s.listener = nil
|
||||
s.logger.Info(s.ctx, "agent socket server stopped")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RegisterHandler registers a handler for a method
|
||||
func (s *Server) RegisterHandler(method string, handler Handler) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.handlers[method] = handler
|
||||
}
|
||||
|
||||
// GetPath returns the socket path
|
||||
func (s *Server) GetPath() string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.path
|
||||
}
|
||||
|
||||
// acceptConnections accepts incoming connections
|
||||
func (s *Server) acceptConnections() {
|
||||
defer s.wg.Done()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
conn, err := s.listener.Accept()
|
||||
if err != nil {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
default:
|
||||
s.logger.Warn(s.ctx, "error accepting connection", slog.Error(err))
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Handle connection in a goroutine
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
defer s.wg.Done()
|
||||
s.handleConnection(conn)
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
// handleConnection handles a single connection
|
||||
func (s *Server) handleConnection(conn net.Conn) {
|
||||
defer conn.Close()
|
||||
|
||||
// Authenticate connection first to get context
|
||||
ctx, err := s.authMiddleware.Authenticate(s.ctx, conn)
|
||||
if err != nil {
|
||||
s.logger.Warn(s.ctx, "authentication failed", slog.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
// Set connection deadline
|
||||
if err := conn.SetDeadline(time.Now().Add(30 * time.Second)); err != nil {
|
||||
s.logger.Warn(ctx, "failed to set connection deadline", slog.Error(err))
|
||||
}
|
||||
|
||||
s.logger.Debug(ctx, "new connection accepted", slog.F("remote_addr", conn.RemoteAddr()))
|
||||
|
||||
// Handle requests
|
||||
decoder := json.NewDecoder(conn)
|
||||
encoder := json.NewEncoder(conn)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
// Set read deadline
|
||||
if err := conn.SetReadDeadline(time.Now().Add(30 * time.Second)); err != nil {
|
||||
s.logger.Warn(ctx, "failed to set read deadline", slog.Error(err))
|
||||
}
|
||||
|
||||
var req Request
|
||||
if err := decoder.Decode(&req); err != nil {
|
||||
if err == io.EOF {
|
||||
s.logger.Debug(ctx, "connection closed by client")
|
||||
return
|
||||
}
|
||||
s.logger.Warn(ctx, "error decoding request", slog.Error(err))
|
||||
|
||||
// Send error response
|
||||
resp := NewErrorResponse("", NewError(ErrCodeParseError, "Parse error", err.Error()))
|
||||
encoder.Encode(resp)
|
||||
return
|
||||
}
|
||||
|
||||
// Handle request
|
||||
resp := s.handleRequest(ctx, &req)
|
||||
|
||||
// Send response
|
||||
if err := encoder.Encode(resp); err != nil {
|
||||
s.logger.Warn(ctx, "error sending response", slog.Error(err))
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleRequest handles a single request
|
||||
func (s *Server) handleRequest(ctx context.Context, req *Request) *Response {
|
||||
// Validate request
|
||||
if req.Version != ProtocolVersion {
|
||||
return NewErrorResponse(req.ID, NewError(ErrCodeInvalidRequest, "Unsupported version", req.Version))
|
||||
}
|
||||
|
||||
if req.Method == "" {
|
||||
return NewErrorResponse(req.ID, NewError(ErrCodeInvalidRequest, "Missing method", nil))
|
||||
}
|
||||
|
||||
// Get handler
|
||||
s.mu.RLock()
|
||||
handler, exists := s.handlers[req.Method]
|
||||
s.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return NewErrorResponse(req.ID, NewError(ErrCodeMethodNotFound, "Method not found", req.Method))
|
||||
}
|
||||
|
||||
// Call handler
|
||||
type requestIDKey struct{}
|
||||
ctx = context.WithValue(ctx, requestIDKey{}, req.ID)
|
||||
resp, err := handler(Context{}, req)
|
||||
if err != nil {
|
||||
s.logger.Warn(ctx, "handler execution failed", slog.Error(err))
|
||||
return NewErrorResponse(req.ID, NewError(ErrCodeInternalError, "Internal error", err.Error()))
|
||||
}
|
||||
|
||||
return resp
|
||||
}
|
||||
@@ -0,0 +1,250 @@
|
||||
package agentsocket
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog"
|
||||
)
|
||||
|
||||
func TestServer_StartStop(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create temporary socket path
|
||||
tmpDir := t.TempDir()
|
||||
socketPath := filepath.Join(tmpDir, "test.sock")
|
||||
|
||||
// Create server
|
||||
server := NewServer(Config{
|
||||
Path: socketPath,
|
||||
Logger: slog.Make().Leveled(slog.LevelDebug),
|
||||
})
|
||||
|
||||
// Register a test handler
|
||||
server.RegisterHandler("test", func(ctx Context, req *Request) (*Response, error) {
|
||||
return NewResponse(req.ID, map[string]string{"message": "test response"})
|
||||
})
|
||||
|
||||
// Start server
|
||||
err := server.Start()
|
||||
require.NoError(t, err)
|
||||
defer server.Stop()
|
||||
|
||||
// Verify socket file exists
|
||||
_, err = os.Stat(socketPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test connection
|
||||
conn, err := net.Dial("unix", socketPath)
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
// Send test request
|
||||
req := Request{
|
||||
Version: "1.0",
|
||||
Method: "test",
|
||||
ID: "test-1",
|
||||
}
|
||||
|
||||
err = json.NewEncoder(conn).Encode(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Read response
|
||||
var resp Response
|
||||
err = json.NewDecoder(conn).Decode(&resp)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "test-1", resp.ID)
|
||||
assert.Nil(t, resp.Error)
|
||||
assert.NotNil(t, resp.Result)
|
||||
|
||||
// Verify response content
|
||||
var result map[string]string
|
||||
err = json.Unmarshal(resp.Result, &result)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "test response", result["message"])
|
||||
}
|
||||
|
||||
func TestServer_ErrorHandling(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create temporary socket path
|
||||
tmpDir := t.TempDir()
|
||||
socketPath := filepath.Join(tmpDir, "test.sock")
|
||||
|
||||
// Create server
|
||||
server := NewServer(Config{
|
||||
Path: socketPath,
|
||||
Logger: slog.Make().Leveled(slog.LevelDebug),
|
||||
})
|
||||
|
||||
// Start server
|
||||
err := server.Start()
|
||||
require.NoError(t, err)
|
||||
defer server.Stop()
|
||||
|
||||
// Test connection
|
||||
conn, err := net.Dial("unix", socketPath)
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
// Send request for non-existent method
|
||||
req := Request{
|
||||
Version: "1.0",
|
||||
Method: "nonexistent",
|
||||
ID: "test-1",
|
||||
}
|
||||
|
||||
err = json.NewEncoder(conn).Encode(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Read response
|
||||
var resp Response
|
||||
err = json.NewDecoder(conn).Decode(&resp)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "test-1", resp.ID)
|
||||
assert.NotNil(t, resp.Error)
|
||||
assert.Equal(t, ErrCodeMethodNotFound, resp.Error.Code)
|
||||
assert.Equal(t, "Method not found", resp.Error.Message)
|
||||
}
|
||||
|
||||
func TestServer_DefaultHandlers(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create temporary socket path
|
||||
tmpDir := t.TempDir()
|
||||
socketPath := filepath.Join(tmpDir, "test.sock")
|
||||
|
||||
// Create server
|
||||
server := NewServer(Config{
|
||||
Path: socketPath,
|
||||
Logger: slog.Make().Leveled(slog.LevelDebug),
|
||||
})
|
||||
|
||||
// Register default handlers
|
||||
handlerCtx := CreateHandlerContext(
|
||||
"test-agent-id",
|
||||
"1.0.0",
|
||||
"ready",
|
||||
time.Now().Add(-time.Hour),
|
||||
slog.Make(),
|
||||
)
|
||||
RegisterDefaultHandlers(server, handlerCtx)
|
||||
|
||||
// Start server
|
||||
err := server.Start()
|
||||
require.NoError(t, err)
|
||||
defer server.Stop()
|
||||
|
||||
// Test ping
|
||||
conn, err := net.Dial("unix", socketPath)
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
req := Request{
|
||||
Version: "1.0",
|
||||
Method: "ping",
|
||||
ID: "ping-1",
|
||||
}
|
||||
|
||||
err = json.NewEncoder(conn).Encode(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
var resp Response
|
||||
err = json.NewDecoder(conn).Decode(&resp)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "ping-1", resp.ID)
|
||||
assert.Nil(t, resp.Error)
|
||||
|
||||
var pingResp PingResponse
|
||||
err = json.Unmarshal(resp.Result, &pingResp)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "pong", pingResp.Message)
|
||||
}
|
||||
|
||||
func TestServer_ConcurrentConnections(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create temporary socket path
|
||||
tmpDir := t.TempDir()
|
||||
socketPath := filepath.Join(tmpDir, "test.sock")
|
||||
|
||||
// Create server
|
||||
server := NewServer(Config{
|
||||
Path: socketPath,
|
||||
Logger: slog.Make().Leveled(slog.LevelDebug),
|
||||
})
|
||||
|
||||
// Register a test handler
|
||||
server.RegisterHandler("test", func(ctx Context, req *Request) (*Response, error) {
|
||||
time.Sleep(10 * time.Millisecond) // Simulate some work
|
||||
return NewResponse(req.ID, map[string]string{"message": "test response"})
|
||||
})
|
||||
|
||||
// Start server
|
||||
err := server.Start()
|
||||
require.NoError(t, err)
|
||||
defer server.Stop()
|
||||
|
||||
// Test multiple concurrent connections
|
||||
const numConnections = 5
|
||||
results := make(chan error, numConnections)
|
||||
|
||||
for i := 0; i < numConnections; i++ {
|
||||
go func(i int) {
|
||||
conn, err := net.Dial("unix", socketPath)
|
||||
if err != nil {
|
||||
results <- err
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
req := Request{
|
||||
Version: "1.0",
|
||||
Method: "test",
|
||||
ID: fmt.Sprintf("test-%d", i),
|
||||
}
|
||||
|
||||
err = json.NewEncoder(conn).Encode(req)
|
||||
if err != nil {
|
||||
results <- err
|
||||
return
|
||||
}
|
||||
|
||||
var resp Response
|
||||
err = json.NewDecoder(conn).Decode(&resp)
|
||||
if err != nil {
|
||||
results <- err
|
||||
return
|
||||
}
|
||||
|
||||
if resp.Error != nil {
|
||||
results <- xerrors.Errorf("server error: %s", resp.Error.Message)
|
||||
return
|
||||
}
|
||||
|
||||
results <- nil
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Wait for all connections to complete
|
||||
for i := 0; i < numConnections; i++ {
|
||||
select {
|
||||
case err := <-results:
|
||||
require.NoError(t, err)
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("timeout waiting for concurrent connections")
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,106 @@
|
||||
//go:build !windows
|
||||
|
||||
package agentsocket
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
// createSocket creates a Unix domain socket listener
|
||||
func createSocket(ctx context.Context, path string) (net.Listener, error) {
|
||||
// Remove existing socket file if it exists
|
||||
if err := os.Remove(path); err != nil && !os.IsNotExist(err) {
|
||||
return nil, xerrors.Errorf("remove existing socket: %w", err)
|
||||
}
|
||||
|
||||
// Create parent directory if it doesn't exist
|
||||
parentDir := filepath.Dir(path)
|
||||
if err := os.MkdirAll(parentDir, 0o700); err != nil {
|
||||
return nil, xerrors.Errorf("create socket directory: %w", err)
|
||||
}
|
||||
|
||||
// Create Unix domain socket listener
|
||||
listener, err := net.Listen("unix", path)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("listen on unix socket: %w", err)
|
||||
}
|
||||
|
||||
// Set socket permissions to be accessible only by the current user
|
||||
if err := os.Chmod(path, 0o600); err != nil {
|
||||
listener.Close()
|
||||
return nil, xerrors.Errorf("set socket permissions: %w", err)
|
||||
}
|
||||
|
||||
return listener, nil
|
||||
}
|
||||
|
||||
// getDefaultSocketPath returns the default socket path for Unix-like systems
|
||||
func getDefaultSocketPath() (string, error) {
|
||||
// Try XDG_RUNTIME_DIR first
|
||||
if runtimeDir := os.Getenv("XDG_RUNTIME_DIR"); runtimeDir != "" {
|
||||
return filepath.Join(runtimeDir, "coder-agent.sock"), nil
|
||||
}
|
||||
|
||||
// Fall back to /tmp with user-specific path
|
||||
uid := os.Getuid()
|
||||
return filepath.Join("/tmp", fmt.Sprintf("coder-agent-%d.sock", uid)), nil
|
||||
}
|
||||
|
||||
// cleanupSocket removes the socket file
|
||||
func cleanupSocket(path string) error {
|
||||
return os.Remove(path)
|
||||
}
|
||||
|
||||
// isSocketAvailable checks if a socket path is available for use
|
||||
func isSocketAvailable(path string) bool {
|
||||
// Check if file exists
|
||||
if _, err := os.Stat(path); os.IsNotExist(err) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Try to connect to see if it's actually listening
|
||||
conn, err := net.Dial("unix", path)
|
||||
if err != nil {
|
||||
// If we can't connect, the socket is not in use
|
||||
return true
|
||||
}
|
||||
conn.Close()
|
||||
return false
|
||||
}
|
||||
|
||||
// getSocketInfo returns information about the socket file
|
||||
func getSocketInfo(path string) (*SocketInfo, error) {
|
||||
stat, err := os.Stat(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sys, ok := stat.Sys().(*syscall.Stat_t)
|
||||
if !ok {
|
||||
return nil, xerrors.New("unable to get stat_t from file info")
|
||||
}
|
||||
return &SocketInfo{
|
||||
Path: path,
|
||||
UID: int(sys.Uid),
|
||||
GID: int(sys.Gid),
|
||||
Mode: stat.Mode(),
|
||||
ModTime: stat.ModTime(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SocketInfo contains information about a socket file
|
||||
type SocketInfo struct {
|
||||
Path string
|
||||
UID int
|
||||
GID int
|
||||
Mode os.FileMode
|
||||
ModTime time.Time
|
||||
}
|
||||
@@ -0,0 +1,99 @@
|
||||
//go:build windows
|
||||
|
||||
package agentsocket
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
// createSocket creates a Unix domain socket listener on Windows
|
||||
// Falls back to named pipe if Unix sockets are not supported
|
||||
func createSocket(ctx context.Context, path string) (net.Listener, error) {
|
||||
// Try Unix domain socket first (Windows 10 build 17063+)
|
||||
listener, err := net.Listen("unix", path)
|
||||
if err == nil {
|
||||
return listener, nil
|
||||
}
|
||||
|
||||
// Fall back to named pipe
|
||||
pipePath := `\\.\pipe\coder-agent`
|
||||
return net.Listen("tcp", pipePath)
|
||||
}
|
||||
|
||||
// getDefaultSocketPath returns the default socket path for Windows
|
||||
func getDefaultSocketPath() (string, error) {
|
||||
// Try to use a temporary directory
|
||||
tempDir := os.TempDir()
|
||||
if tempDir == "" {
|
||||
tempDir = "C:\\temp"
|
||||
}
|
||||
|
||||
// Create a user-specific subdirectory
|
||||
uid := os.Getuid()
|
||||
userDir := filepath.Join(tempDir, "coder-agent", strconv.Itoa(uid))
|
||||
|
||||
if err := os.MkdirAll(userDir, 0o700); err != nil {
|
||||
return "", fmt.Errorf("create user directory: %w", err)
|
||||
}
|
||||
|
||||
return filepath.Join(userDir, "agent.sock"), nil
|
||||
}
|
||||
|
||||
// cleanupSocket removes the socket file
|
||||
func cleanupSocket(path string) error {
|
||||
return os.Remove(path)
|
||||
}
|
||||
|
||||
// isSocketAvailable checks if a socket path is available for use
|
||||
func isSocketAvailable(path string) bool {
|
||||
// Check if file exists
|
||||
if _, err := os.Stat(path); os.IsNotExist(err) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Try to connect to see if it's actually listening
|
||||
conn, err := net.Dial("unix", path)
|
||||
if err != nil {
|
||||
// If we can't connect, the socket is not in use
|
||||
return true
|
||||
}
|
||||
conn.Close()
|
||||
return false
|
||||
}
|
||||
|
||||
// getSocketInfo returns information about the socket file
|
||||
func getSocketInfo(path string) (*SocketInfo, error) {
|
||||
stat, err := os.Stat(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// On Windows, we'll use a simplified approach for now
|
||||
// In a real implementation, you'd get the security descriptor
|
||||
return &SocketInfo{
|
||||
Path: path,
|
||||
UID: 0, // Simplified for now
|
||||
GID: 0, // Simplified for now
|
||||
Mode: stat.Mode(),
|
||||
ModTime: stat.ModTime(),
|
||||
Owner: "unknown",
|
||||
Group: "unknown",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SocketInfo contains information about a socket file
|
||||
type SocketInfo struct {
|
||||
Path string
|
||||
UID int
|
||||
GID int
|
||||
Mode os.FileMode
|
||||
ModTime time.Time
|
||||
Owner string // Windows SID string
|
||||
Group string // Windows SID string
|
||||
}
|
||||
@@ -0,0 +1,227 @@
|
||||
package unit
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
// ErrConsumerNotFound is returned when a consumer ID is not registered.
|
||||
var ErrConsumerNotFound = xerrors.New("consumer not found")
|
||||
|
||||
// ErrCannotUpdateOtherConsumer is returned when attempting to update another consumer's status.
|
||||
var ErrCannotUpdateOtherConsumer = xerrors.New("cannot update other consumer's status")
|
||||
|
||||
// dependencyVertex represents a vertex in the dependency graph that is associated with a consumer.
|
||||
type dependencyVertex[ConsumerID comparable] struct {
|
||||
ID ConsumerID
|
||||
}
|
||||
|
||||
// Dependency represents a dependency relationship between consumers.
|
||||
type Dependency[StatusType, ConsumerID comparable] struct {
|
||||
Consumer ConsumerID
|
||||
DependsOn ConsumerID
|
||||
RequiredStatus StatusType
|
||||
CurrentStatus StatusType
|
||||
IsSatisfied bool
|
||||
}
|
||||
|
||||
// DependencyTracker provides reactive dependency tracking over a Graph.
|
||||
// It manages consumer registration, dependency relationships, and status updates
|
||||
// with automatic recalculation of readiness when dependencies are satisfied.
|
||||
type DependencyTracker[StatusType, ConsumerID comparable] struct {
|
||||
mu sync.RWMutex
|
||||
|
||||
// The underlying graph that stores dependency relationships
|
||||
graph *Graph[StatusType, *dependencyVertex[ConsumerID]]
|
||||
|
||||
// Track current status of each consumer
|
||||
consumerStatus map[ConsumerID]StatusType
|
||||
|
||||
// Track readiness state (cached to avoid repeated graph traversal)
|
||||
consumerReadiness map[ConsumerID]bool
|
||||
|
||||
// Track which consumers are registered
|
||||
registeredConsumers map[ConsumerID]bool
|
||||
|
||||
// Store vertex instances for each consumer to ensure consistent references
|
||||
consumerVertices map[ConsumerID]*dependencyVertex[ConsumerID]
|
||||
}
|
||||
|
||||
// NewDependencyTracker creates a new DependencyTracker instance.
|
||||
func NewDependencyTracker[StatusType, ConsumerID comparable]() *DependencyTracker[StatusType, ConsumerID] {
|
||||
return &DependencyTracker[StatusType, ConsumerID]{
|
||||
graph: &Graph[StatusType, *dependencyVertex[ConsumerID]]{},
|
||||
consumerStatus: make(map[ConsumerID]StatusType),
|
||||
consumerReadiness: make(map[ConsumerID]bool),
|
||||
registeredConsumers: make(map[ConsumerID]bool),
|
||||
consumerVertices: make(map[ConsumerID]*dependencyVertex[ConsumerID]),
|
||||
}
|
||||
}
|
||||
|
||||
// Register registers a new consumer as a vertex in the dependency graph.
|
||||
func (dt *DependencyTracker[StatusType, ConsumerID]) Register(id ConsumerID) error {
|
||||
dt.mu.Lock()
|
||||
defer dt.mu.Unlock()
|
||||
|
||||
if dt.registeredConsumers[id] {
|
||||
return xerrors.Errorf("consumer %v is already registered", id)
|
||||
}
|
||||
|
||||
// Create and store the vertex for this consumer
|
||||
vertex := &dependencyVertex[ConsumerID]{ID: id}
|
||||
dt.consumerVertices[id] = vertex
|
||||
dt.registeredConsumers[id] = true
|
||||
dt.consumerReadiness[id] = true // New consumers start as ready (no dependencies)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddDependency adds a dependency relationship between consumers.
|
||||
// The consumer depends on the dependsOn consumer reaching the requiredStatus.
|
||||
func (dt *DependencyTracker[StatusType, ConsumerID]) AddDependency(consumer ConsumerID, dependsOn ConsumerID, requiredStatus StatusType) error {
|
||||
dt.mu.Lock()
|
||||
defer dt.mu.Unlock()
|
||||
|
||||
if !dt.registeredConsumers[consumer] {
|
||||
return xerrors.Errorf("consumer %v is not registered", consumer)
|
||||
}
|
||||
if !dt.registeredConsumers[dependsOn] {
|
||||
return xerrors.Errorf("consumer %v is not registered", dependsOn)
|
||||
}
|
||||
|
||||
// Get the stored vertices for both consumers
|
||||
consumerVertex := dt.consumerVertices[consumer]
|
||||
dependsOnVertex := dt.consumerVertices[dependsOn]
|
||||
|
||||
// Add the dependency edge to the graph
|
||||
// The edge goes from consumer to dependsOn, representing the dependency
|
||||
err := dt.graph.AddEdge(consumerVertex, dependsOnVertex, requiredStatus)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("failed to add dependency: %w", err)
|
||||
}
|
||||
|
||||
// Recalculate readiness for the consumer since it now has a dependency
|
||||
dt.recalculateReadinessUnsafe(consumer)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateStatus updates a consumer's status and recalculates readiness for affected dependents.
|
||||
func (dt *DependencyTracker[StatusType, ConsumerID]) UpdateStatus(consumer ConsumerID, newStatus StatusType) error {
|
||||
dt.mu.Lock()
|
||||
defer dt.mu.Unlock()
|
||||
|
||||
if !dt.registeredConsumers[consumer] {
|
||||
return ErrConsumerNotFound
|
||||
}
|
||||
|
||||
// Update the consumer's status
|
||||
dt.consumerStatus[consumer] = newStatus
|
||||
|
||||
// Get all consumers that depend on this one (reverse adjacent vertices)
|
||||
consumerVertex := dt.consumerVertices[consumer]
|
||||
dependentEdges := dt.graph.GetReverseAdjacentVertices(consumerVertex)
|
||||
|
||||
// Recalculate readiness for all dependents
|
||||
for _, edge := range dependentEdges {
|
||||
dt.recalculateReadinessUnsafe(edge.From.ID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsReady checks if all dependencies for a consumer are satisfied.
|
||||
func (dt *DependencyTracker[StatusType, ConsumerID]) IsReady(consumer ConsumerID) (bool, error) {
|
||||
dt.mu.RLock()
|
||||
defer dt.mu.RUnlock()
|
||||
|
||||
if !dt.registeredConsumers[consumer] {
|
||||
return false, ErrConsumerNotFound
|
||||
}
|
||||
|
||||
return dt.consumerReadiness[consumer], nil
|
||||
}
|
||||
|
||||
// GetUnmetDependencies returns a list of unsatisfied dependencies for a consumer.
|
||||
func (dt *DependencyTracker[StatusType, ConsumerID]) GetUnmetDependencies(consumer ConsumerID) ([]Dependency[StatusType, ConsumerID], error) {
|
||||
dt.mu.RLock()
|
||||
defer dt.mu.RUnlock()
|
||||
|
||||
if !dt.registeredConsumers[consumer] {
|
||||
return nil, ErrConsumerNotFound
|
||||
}
|
||||
|
||||
consumerVertex := dt.consumerVertices[consumer]
|
||||
forwardEdges := dt.graph.GetForwardAdjacentVertices(consumerVertex)
|
||||
|
||||
var unmetDependencies []Dependency[StatusType, ConsumerID]
|
||||
|
||||
for _, edge := range forwardEdges {
|
||||
dependsOnConsumer := edge.To.ID
|
||||
requiredStatus := edge.Edge
|
||||
currentStatus, exists := dt.consumerStatus[dependsOnConsumer]
|
||||
if !exists {
|
||||
// If the dependency consumer has no status, it's not satisfied
|
||||
var zeroStatus StatusType
|
||||
unmetDependencies = append(unmetDependencies, Dependency[StatusType, ConsumerID]{
|
||||
Consumer: consumer,
|
||||
DependsOn: dependsOnConsumer,
|
||||
RequiredStatus: requiredStatus,
|
||||
CurrentStatus: zeroStatus, // Zero value
|
||||
IsSatisfied: false,
|
||||
})
|
||||
} else {
|
||||
isSatisfied := currentStatus == requiredStatus
|
||||
if !isSatisfied {
|
||||
unmetDependencies = append(unmetDependencies, Dependency[StatusType, ConsumerID]{
|
||||
Consumer: consumer,
|
||||
DependsOn: dependsOnConsumer,
|
||||
RequiredStatus: requiredStatus,
|
||||
CurrentStatus: currentStatus,
|
||||
IsSatisfied: false,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return unmetDependencies, nil
|
||||
}
|
||||
|
||||
// recalculateReadinessUnsafe recalculates the readiness state for a consumer.
|
||||
// This method assumes the caller holds the write lock.
|
||||
func (dt *DependencyTracker[StatusType, ConsumerID]) recalculateReadinessUnsafe(consumer ConsumerID) {
|
||||
consumerVertex := dt.consumerVertices[consumer]
|
||||
forwardEdges := dt.graph.GetForwardAdjacentVertices(consumerVertex)
|
||||
|
||||
// If there are no dependencies, the consumer is ready
|
||||
if len(forwardEdges) == 0 {
|
||||
dt.consumerReadiness[consumer] = true
|
||||
return
|
||||
}
|
||||
|
||||
// Check if all dependencies are satisfied
|
||||
allSatisfied := true
|
||||
for _, edge := range forwardEdges {
|
||||
dependsOnConsumer := edge.To.ID
|
||||
requiredStatus := edge.Edge
|
||||
currentStatus, exists := dt.consumerStatus[dependsOnConsumer]
|
||||
if !exists || currentStatus != requiredStatus {
|
||||
allSatisfied = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
dt.consumerReadiness[consumer] = allSatisfied
|
||||
}
|
||||
|
||||
// GetGraph returns the underlying graph for visualization and debugging.
|
||||
// This should be used carefully as it exposes the internal graph structure.
|
||||
func (dt *DependencyTracker[StatusType, ConsumerID]) GetGraph() *Graph[StatusType, *dependencyVertex[ConsumerID]] {
|
||||
return dt.graph
|
||||
}
|
||||
|
||||
// ExportDOT exports the dependency graph to DOT format for visualization.
|
||||
func (dt *DependencyTracker[StatusType, ConsumerID]) ExportDOT(name string) (string, error) {
|
||||
return dt.graph.ToDOT(name)
|
||||
}
|
||||
@@ -0,0 +1,692 @@
|
||||
package unit_test
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/agent/unit"
|
||||
)
|
||||
|
||||
type testStatus string
|
||||
|
||||
const (
|
||||
statusInitialized testStatus = "initialized"
|
||||
statusStarted testStatus = "started"
|
||||
statusRunning testStatus = "running"
|
||||
statusCompleted testStatus = "completed"
|
||||
)
|
||||
|
||||
type testConsumerID string
|
||||
|
||||
const (
|
||||
consumerA testConsumerID = "serviceA"
|
||||
consumerB testConsumerID = "serviceB"
|
||||
consumerC testConsumerID = "serviceC"
|
||||
consumerD testConsumerID = "serviceD"
|
||||
)
|
||||
|
||||
func TestDependencyTracker_Register(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
|
||||
|
||||
t.Run("RegisterNewConsumer", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
err := tracker.Register(consumerA)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Consumer should be ready initially (no dependencies)
|
||||
ready, err := tracker.IsReady(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ready)
|
||||
})
|
||||
|
||||
t.Run("RegisterDuplicateConsumer", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
|
||||
err := tracker.Register(consumerA)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = tracker.Register(consumerA)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "already registered")
|
||||
})
|
||||
|
||||
t.Run("RegisterMultipleConsumers", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
|
||||
|
||||
consumers := []testConsumerID{consumerA, consumerB, consumerC}
|
||||
for _, consumer := range consumers {
|
||||
err := tracker.Register(consumer)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// All should be ready initially
|
||||
for _, consumer := range consumers {
|
||||
ready, err := tracker.IsReady(consumer)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ready)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestDependencyTracker_AddDependency(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("AddDependencyBetweenRegisteredConsumers", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
|
||||
err := tracker.Register(consumerA)
|
||||
require.NoError(t, err)
|
||||
err = tracker.Register(consumerB)
|
||||
require.NoError(t, err)
|
||||
|
||||
// A depends on B being "running"
|
||||
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
|
||||
// A should no longer be ready (depends on B)
|
||||
ready, err := tracker.IsReady(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, ready)
|
||||
|
||||
// B should still be ready (no dependencies)
|
||||
ready, err = tracker.IsReady(consumerB)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ready)
|
||||
})
|
||||
|
||||
t.Run("AddDependencyWithUnregisteredConsumer", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
|
||||
err := tracker.Register(consumerA)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to add dependency to unregistered consumer
|
||||
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "not registered")
|
||||
})
|
||||
|
||||
t.Run("AddDependencyFromUnregisteredConsumer", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
|
||||
err := tracker.Register(consumerB)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to add dependency from unregistered consumer
|
||||
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "not registered")
|
||||
})
|
||||
}
|
||||
|
||||
func TestDependencyTracker_UpdateStatus(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("UpdateStatusTriggersReadinessRecalculation", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
|
||||
err := tracker.Register(consumerA)
|
||||
require.NoError(t, err)
|
||||
err = tracker.Register(consumerB)
|
||||
require.NoError(t, err)
|
||||
|
||||
// A depends on B being "running"
|
||||
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Initially A is not ready
|
||||
ready, err := tracker.IsReady(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, ready)
|
||||
|
||||
// Update B to "running" - A should become ready
|
||||
err = tracker.UpdateStatus(consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
|
||||
ready, err = tracker.IsReady(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ready)
|
||||
})
|
||||
|
||||
t.Run("UpdateStatusWithUnregisteredConsumer", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
|
||||
|
||||
err := tracker.UpdateStatus(consumerA, statusRunning)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, unit.ErrConsumerNotFound, err)
|
||||
})
|
||||
|
||||
t.Run("LinearChainDependencies", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
|
||||
|
||||
// Register all consumers
|
||||
consumers := []testConsumerID{consumerA, consumerB, consumerC}
|
||||
for _, consumer := range consumers {
|
||||
err := tracker.Register(consumer)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Create chain: A depends on B being "started", B depends on C being "completed"
|
||||
err := tracker.AddDependency(consumerA, consumerB, statusStarted)
|
||||
require.NoError(t, err)
|
||||
err = tracker.AddDependency(consumerB, consumerC, statusCompleted)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Initially only C is ready
|
||||
ready, err := tracker.IsReady(consumerC)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ready)
|
||||
|
||||
ready, err = tracker.IsReady(consumerB)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, ready)
|
||||
|
||||
ready, err = tracker.IsReady(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, ready)
|
||||
|
||||
// Update C to "completed" - B should become ready
|
||||
err = tracker.UpdateStatus(consumerC, statusCompleted)
|
||||
require.NoError(t, err)
|
||||
|
||||
ready, err = tracker.IsReady(consumerB)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ready)
|
||||
|
||||
ready, err = tracker.IsReady(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, ready)
|
||||
|
||||
// Update B to "started" - A should become ready
|
||||
err = tracker.UpdateStatus(consumerB, statusStarted)
|
||||
require.NoError(t, err)
|
||||
|
||||
ready, err = tracker.IsReady(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ready)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDependencyTracker_GetUnmetDependencies(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("GetUnmetDependenciesForConsumerWithNoDependencies", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
|
||||
err := tracker.Register(consumerA)
|
||||
require.NoError(t, err)
|
||||
|
||||
unmet, err := tracker.GetUnmetDependencies(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, unmet)
|
||||
})
|
||||
|
||||
t.Run("GetUnmetDependenciesForConsumerWithUnsatisfiedDependencies", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
|
||||
err := tracker.Register(consumerA)
|
||||
require.NoError(t, err)
|
||||
err = tracker.Register(consumerB)
|
||||
require.NoError(t, err)
|
||||
|
||||
// A depends on B being "running"
|
||||
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
|
||||
unmet, err := tracker.GetUnmetDependencies(consumerA)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, unmet, 1)
|
||||
|
||||
assert.Equal(t, consumerA, unmet[0].Consumer)
|
||||
assert.Equal(t, consumerB, unmet[0].DependsOn)
|
||||
assert.Equal(t, statusRunning, unmet[0].RequiredStatus)
|
||||
assert.False(t, unmet[0].IsSatisfied)
|
||||
})
|
||||
|
||||
t.Run("GetUnmetDependenciesForConsumerWithSatisfiedDependencies", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
|
||||
err := tracker.Register(consumerA)
|
||||
require.NoError(t, err)
|
||||
err = tracker.Register(consumerB)
|
||||
require.NoError(t, err)
|
||||
|
||||
// A depends on B being "running"
|
||||
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Update B to "running"
|
||||
err = tracker.UpdateStatus(consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
|
||||
unmet, err := tracker.GetUnmetDependencies(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, unmet)
|
||||
})
|
||||
|
||||
t.Run("GetUnmetDependenciesForUnregisteredConsumer", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
|
||||
|
||||
unmet, err := tracker.GetUnmetDependencies(consumerA)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, unit.ErrConsumerNotFound, err)
|
||||
assert.Nil(t, unmet)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDependencyTracker_ConcurrentOperations(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("ConcurrentStatusUpdates", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
|
||||
|
||||
// Register consumers
|
||||
consumers := []testConsumerID{consumerA, consumerB, consumerC, consumerD}
|
||||
for _, consumer := range consumers {
|
||||
err := tracker.Register(consumer)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Create dependencies: A depends on B, B depends on C, C depends on D
|
||||
err := tracker.AddDependency(consumerA, consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
err = tracker.AddDependency(consumerB, consumerC, statusStarted)
|
||||
require.NoError(t, err)
|
||||
err = tracker.AddDependency(consumerC, consumerD, statusCompleted)
|
||||
require.NoError(t, err)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
const numGoroutines = 10
|
||||
|
||||
// Launch goroutines that update statuses
|
||||
errors := make([]error, numGoroutines)
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(goroutineID int) {
|
||||
defer wg.Done()
|
||||
|
||||
// Update D to completed (should make C ready)
|
||||
err := tracker.UpdateStatus(consumerD, statusCompleted)
|
||||
if err != nil {
|
||||
errors[goroutineID] = err
|
||||
return
|
||||
}
|
||||
|
||||
// Update C to started (should make B ready)
|
||||
err = tracker.UpdateStatus(consumerC, statusStarted)
|
||||
if err != nil {
|
||||
errors[goroutineID] = err
|
||||
return
|
||||
}
|
||||
|
||||
// Update B to running (should make A ready)
|
||||
err = tracker.UpdateStatus(consumerB, statusRunning)
|
||||
if err != nil {
|
||||
errors[goroutineID] = err
|
||||
return
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Check for any errors in goroutines
|
||||
for i, err := range errors {
|
||||
require.NoError(t, err, "goroutine %d had error", i)
|
||||
}
|
||||
|
||||
// All consumers should be ready after the updates
|
||||
for _, consumer := range consumers {
|
||||
ready, err := tracker.IsReady(consumer)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ready)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ConcurrentReadinessChecks", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
|
||||
|
||||
// Register consumers
|
||||
err := tracker.Register(consumerA)
|
||||
require.NoError(t, err)
|
||||
err = tracker.Register(consumerB)
|
||||
require.NoError(t, err)
|
||||
|
||||
// A depends on B being "running"
|
||||
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
const numGoroutines = 20
|
||||
|
||||
// Launch goroutines that check readiness
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(goroutineID int) {
|
||||
defer wg.Done()
|
||||
|
||||
// Check readiness multiple times
|
||||
for j := 0; j < 10; j++ {
|
||||
ready, err := tracker.IsReady(consumerA)
|
||||
require.NoError(t, err)
|
||||
// Initially should be false, then true after B is updated
|
||||
_ = ready
|
||||
|
||||
ready, err = tracker.IsReady(consumerB)
|
||||
require.NoError(t, err)
|
||||
// B should always be ready (no dependencies)
|
||||
assert.True(t, ready)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Update B to "running" in the middle of readiness checks
|
||||
err = tracker.UpdateStatus(consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
|
||||
wg.Wait()
|
||||
})
|
||||
}
|
||||
|
||||
func TestDependencyTracker_MultipleDependencies(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("ConsumerWithMultipleDependencies", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
|
||||
|
||||
// Register all consumers
|
||||
consumers := []testConsumerID{consumerA, consumerB, consumerC, consumerD}
|
||||
for _, consumer := range consumers {
|
||||
err := tracker.Register(consumer)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// A depends on B being "running" AND C being "started"
|
||||
err := tracker.AddDependency(consumerA, consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
err = tracker.AddDependency(consumerA, consumerC, statusStarted)
|
||||
require.NoError(t, err)
|
||||
|
||||
// A should not be ready (depends on both B and C)
|
||||
ready, err := tracker.IsReady(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, ready)
|
||||
|
||||
// Update B to "running" - A should still not be ready (needs C too)
|
||||
err = tracker.UpdateStatus(consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
|
||||
ready, err = tracker.IsReady(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, ready)
|
||||
|
||||
// Update C to "started" - A should now be ready
|
||||
err = tracker.UpdateStatus(consumerC, statusStarted)
|
||||
require.NoError(t, err)
|
||||
|
||||
ready, err = tracker.IsReady(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ready)
|
||||
})
|
||||
|
||||
t.Run("ComplexDependencyChain", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
|
||||
|
||||
// Register all consumers
|
||||
consumers := []testConsumerID{consumerA, consumerB, consumerC, consumerD}
|
||||
for _, consumer := range consumers {
|
||||
err := tracker.Register(consumer)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Create complex dependency graph:
|
||||
// A depends on B being "running" AND C being "started"
|
||||
// B depends on D being "completed"
|
||||
// C depends on D being "completed"
|
||||
err := tracker.AddDependency(consumerA, consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
err = tracker.AddDependency(consumerA, consumerC, statusStarted)
|
||||
require.NoError(t, err)
|
||||
err = tracker.AddDependency(consumerB, consumerD, statusCompleted)
|
||||
require.NoError(t, err)
|
||||
err = tracker.AddDependency(consumerC, consumerD, statusCompleted)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Initially only D is ready
|
||||
ready, err := tracker.IsReady(consumerD)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ready)
|
||||
|
||||
ready, err = tracker.IsReady(consumerB)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, ready)
|
||||
|
||||
ready, err = tracker.IsReady(consumerC)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, ready)
|
||||
|
||||
ready, err = tracker.IsReady(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, ready)
|
||||
|
||||
// Update D to "completed" - B and C should become ready
|
||||
err = tracker.UpdateStatus(consumerD, statusCompleted)
|
||||
require.NoError(t, err)
|
||||
|
||||
ready, err = tracker.IsReady(consumerB)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ready)
|
||||
|
||||
ready, err = tracker.IsReady(consumerC)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ready)
|
||||
|
||||
ready, err = tracker.IsReady(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, ready)
|
||||
|
||||
// Update B to "running" - A should still not be ready (needs C)
|
||||
err = tracker.UpdateStatus(consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
|
||||
ready, err = tracker.IsReady(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, ready)
|
||||
|
||||
// Update C to "started" - A should now be ready
|
||||
err = tracker.UpdateStatus(consumerC, statusStarted)
|
||||
require.NoError(t, err)
|
||||
|
||||
ready, err = tracker.IsReady(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ready)
|
||||
})
|
||||
|
||||
t.Run("DifferentStatusTypes", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
|
||||
|
||||
// Register consumers
|
||||
err := tracker.Register(consumerA)
|
||||
require.NoError(t, err)
|
||||
err = tracker.Register(consumerB)
|
||||
require.NoError(t, err)
|
||||
err = tracker.Register(consumerC)
|
||||
require.NoError(t, err)
|
||||
|
||||
// A depends on B being "running" AND C being "completed"
|
||||
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
err = tracker.AddDependency(consumerA, consumerC, statusCompleted)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Update B to "running" but not C - A should not be ready
|
||||
err = tracker.UpdateStatus(consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
|
||||
ready, err := tracker.IsReady(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, ready)
|
||||
|
||||
// Update C to "completed" - A should now be ready
|
||||
err = tracker.UpdateStatus(consumerC, statusCompleted)
|
||||
require.NoError(t, err)
|
||||
|
||||
ready, err = tracker.IsReady(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ready)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDependencyTracker_ErrorCases(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("UpdateStatusWithUnregisteredConsumer", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
|
||||
|
||||
err := tracker.UpdateStatus(consumerA, statusRunning)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, unit.ErrConsumerNotFound, err)
|
||||
})
|
||||
|
||||
t.Run("IsReadyWithUnregisteredConsumer", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
|
||||
|
||||
ready, err := tracker.IsReady(consumerA)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, unit.ErrConsumerNotFound, err)
|
||||
assert.False(t, ready)
|
||||
})
|
||||
|
||||
t.Run("GetUnmetDependenciesWithUnregisteredConsumer", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
|
||||
|
||||
unmet, err := tracker.GetUnmetDependencies(consumerA)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, unit.ErrConsumerNotFound, err)
|
||||
assert.Nil(t, unmet)
|
||||
})
|
||||
|
||||
t.Run("AddDependencyWithUnregisteredConsumers", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
|
||||
|
||||
// Try to add dependency with unregistered consumers
|
||||
err := tracker.AddDependency(consumerA, consumerB, statusRunning)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "not registered")
|
||||
})
|
||||
|
||||
t.Run("CyclicDependencyDetection", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
|
||||
|
||||
// Register consumers
|
||||
err := tracker.Register(consumerA)
|
||||
require.NoError(t, err)
|
||||
err = tracker.Register(consumerB)
|
||||
require.NoError(t, err)
|
||||
|
||||
// A depends on B
|
||||
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to make B depend on A (creates cycle)
|
||||
err = tracker.AddDependency(consumerB, consumerA, statusStarted)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "would create a cycle")
|
||||
})
|
||||
}
|
||||
|
||||
func TestDependencyTracker_ToDOT(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("ExportSimpleGraph", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
|
||||
|
||||
// Register consumers
|
||||
err := tracker.Register(consumerA)
|
||||
require.NoError(t, err)
|
||||
err = tracker.Register(consumerB)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add dependency
|
||||
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
|
||||
dot, err := tracker.ExportDOT("test")
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, dot)
|
||||
assert.Contains(t, dot, "digraph")
|
||||
})
|
||||
|
||||
t.Run("ExportComplexGraph", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewDependencyTracker[testStatus, testConsumerID]()
|
||||
|
||||
// Register all consumers
|
||||
consumers := []testConsumerID{consumerA, consumerB, consumerC, consumerD}
|
||||
for _, consumer := range consumers {
|
||||
err := tracker.Register(consumer)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Create complex dependency graph
|
||||
// A depends on B and C, B depends on D, C depends on D
|
||||
err := tracker.AddDependency(consumerA, consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
err = tracker.AddDependency(consumerA, consumerC, statusStarted)
|
||||
require.NoError(t, err)
|
||||
err = tracker.AddDependency(consumerB, consumerD, statusCompleted)
|
||||
require.NoError(t, err)
|
||||
err = tracker.AddDependency(consumerC, consumerD, statusCompleted)
|
||||
require.NoError(t, err)
|
||||
|
||||
dot, err := tracker.ExportDOT("complex")
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, dot)
|
||||
assert.Contains(t, dot, "digraph")
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,254 @@
|
||||
package agentsdk
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
// SocketClient provides a client for communicating with the agent socket
|
||||
type SocketClient struct {
|
||||
conn net.Conn
|
||||
}
|
||||
|
||||
// SocketConfig holds configuration for the socket client
|
||||
type SocketConfig struct {
|
||||
Path string // Socket path (optional, will auto-discover if not set)
|
||||
}
|
||||
|
||||
// NewSocketClient creates a new socket client
|
||||
func NewSocketClient(config SocketConfig) (*SocketClient, error) {
|
||||
path := config.Path
|
||||
if path == "" {
|
||||
var err error
|
||||
path, err = discoverSocketPath()
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("discover socket path: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
conn, err := net.Dial("unix", path)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("connect to socket: %w", err)
|
||||
}
|
||||
|
||||
return &SocketClient{
|
||||
conn: conn,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Close closes the socket connection
|
||||
func (c *SocketClient) Close() error {
|
||||
return c.conn.Close()
|
||||
}
|
||||
|
||||
// Ping sends a ping request to the agent
|
||||
func (c *SocketClient) Ping(ctx context.Context) (*PingResponse, error) {
|
||||
req := &Request{
|
||||
Version: "1.0",
|
||||
Method: "ping",
|
||||
ID: generateRequestID(),
|
||||
}
|
||||
|
||||
resp, err := c.sendRequest(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.Error != nil {
|
||||
return nil, xerrors.Errorf("ping error: %s", resp.Error.Message)
|
||||
}
|
||||
|
||||
var pingResp PingResponse
|
||||
if err := json.Unmarshal(resp.Result, &pingResp); err != nil {
|
||||
return nil, xerrors.Errorf("unmarshal ping response: %w", err)
|
||||
}
|
||||
|
||||
return &pingResp, nil
|
||||
}
|
||||
|
||||
// Health sends a health check request to the agent
|
||||
func (c *SocketClient) Health(ctx context.Context) (*HealthResponse, error) {
|
||||
req := &Request{
|
||||
Version: "1.0",
|
||||
Method: "health",
|
||||
ID: generateRequestID(),
|
||||
}
|
||||
|
||||
resp, err := c.sendRequest(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.Error != nil {
|
||||
return nil, xerrors.Errorf("health error: %s", resp.Error.Message)
|
||||
}
|
||||
|
||||
var healthResp HealthResponse
|
||||
if err := json.Unmarshal(resp.Result, &healthResp); err != nil {
|
||||
return nil, xerrors.Errorf("unmarshal health response: %w", err)
|
||||
}
|
||||
|
||||
return &healthResp, nil
|
||||
}
|
||||
|
||||
// AgentInfo sends an agent info request
|
||||
func (c *SocketClient) AgentInfo(ctx context.Context) (*AgentInfo, error) {
|
||||
req := &Request{
|
||||
Version: "1.0",
|
||||
Method: "agent.info",
|
||||
ID: generateRequestID(),
|
||||
}
|
||||
|
||||
resp, err := c.sendRequest(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.Error != nil {
|
||||
return nil, xerrors.Errorf("agent info error: %s", resp.Error.Message)
|
||||
}
|
||||
|
||||
var agentInfo AgentInfo
|
||||
if err := json.Unmarshal(resp.Result, &agentInfo); err != nil {
|
||||
return nil, xerrors.Errorf("unmarshal agent info response: %w", err)
|
||||
}
|
||||
|
||||
return &agentInfo, nil
|
||||
}
|
||||
|
||||
// ListMethods lists available methods
|
||||
func (c *SocketClient) ListMethods(ctx context.Context) ([]string, error) {
|
||||
req := &Request{
|
||||
Version: "1.0",
|
||||
Method: "methods.list",
|
||||
ID: generateRequestID(),
|
||||
}
|
||||
|
||||
resp, err := c.sendRequest(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.Error != nil {
|
||||
return nil, xerrors.Errorf("list methods error: %s", resp.Error.Message)
|
||||
}
|
||||
|
||||
var methods []string
|
||||
if err := json.Unmarshal(resp.Result, &methods); err != nil {
|
||||
return nil, xerrors.Errorf("unmarshal methods response: %w", err)
|
||||
}
|
||||
|
||||
return methods, nil
|
||||
}
|
||||
|
||||
// sendRequest sends a request and returns the response
|
||||
func (c *SocketClient) sendRequest(_ context.Context, req *Request) (*Response, error) {
|
||||
// Set write deadline
|
||||
if err := c.conn.SetWriteDeadline(time.Now().Add(30 * time.Second)); err != nil {
|
||||
return nil, xerrors.Errorf("set write deadline: %w", err)
|
||||
}
|
||||
|
||||
// Send request
|
||||
if err := json.NewEncoder(c.conn).Encode(req); err != nil {
|
||||
return nil, xerrors.Errorf("send request: %w", err)
|
||||
}
|
||||
|
||||
// Set read deadline
|
||||
if err := c.conn.SetReadDeadline(time.Now().Add(30 * time.Second)); err != nil {
|
||||
return nil, xerrors.Errorf("set read deadline: %w", err)
|
||||
}
|
||||
|
||||
// Read response
|
||||
var resp Response
|
||||
if err := json.NewDecoder(c.conn).Decode(&resp); err != nil {
|
||||
return nil, xerrors.Errorf("read response: %w", err)
|
||||
}
|
||||
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// discoverSocketPath discovers the agent socket path
|
||||
func discoverSocketPath() (string, error) {
|
||||
// Check environment variable first
|
||||
if path := os.Getenv("CODER_AGENT_SOCKET_PATH"); path != "" {
|
||||
return path, nil
|
||||
}
|
||||
|
||||
// Try common socket paths
|
||||
paths := []string{
|
||||
// XDG runtime directory
|
||||
filepath.Join(os.Getenv("XDG_RUNTIME_DIR"), "coder-agent.sock"),
|
||||
// User-specific temp directory
|
||||
filepath.Join(os.TempDir(), fmt.Sprintf("coder-agent-%d.sock", os.Getuid())),
|
||||
// Fallback temp directory
|
||||
filepath.Join(os.TempDir(), "coder-agent.sock"),
|
||||
}
|
||||
|
||||
for _, path := range paths {
|
||||
if path == "" {
|
||||
continue
|
||||
}
|
||||
if _, err := os.Stat(path); err == nil {
|
||||
return path, nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", xerrors.New("agent socket not found")
|
||||
}
|
||||
|
||||
// generateRequestID generates a unique request ID
|
||||
func generateRequestID() string {
|
||||
return fmt.Sprintf("%d", time.Now().UnixNano())
|
||||
}
|
||||
|
||||
// Request represents a socket request
|
||||
type Request struct {
|
||||
Version string `json:"version"`
|
||||
Method string `json:"method"`
|
||||
ID string `json:"id,omitempty"`
|
||||
Params json.RawMessage `json:"params,omitempty"`
|
||||
}
|
||||
|
||||
// Response represents a socket response
|
||||
type Response struct {
|
||||
Version string `json:"version"`
|
||||
ID string `json:"id,omitempty"`
|
||||
Result json.RawMessage `json:"result,omitempty"`
|
||||
Error *Error `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// Error represents a socket error
|
||||
type Error struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data any `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
// PingResponse represents a ping response
|
||||
type PingResponse struct {
|
||||
Message string `json:"message"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
}
|
||||
|
||||
// HealthResponse represents a health check response
|
||||
type HealthResponse struct {
|
||||
Status string `json:"status"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Uptime string `json:"uptime"`
|
||||
}
|
||||
|
||||
// AgentInfo represents agent information
|
||||
type AgentInfo struct {
|
||||
ID string `json:"id"`
|
||||
Version string `json:"version"`
|
||||
Status string `json:"status"`
|
||||
StartedAt time.Time `json:"started_at"`
|
||||
Uptime string `json:"uptime"`
|
||||
}
|
||||
@@ -0,0 +1,110 @@
|
||||
package agentsdk_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog"
|
||||
|
||||
"github.com/coder/coder/v2/agent/agentsocket"
|
||||
"github.com/coder/coder/v2/codersdk/agentsdk"
|
||||
)
|
||||
|
||||
func TestSocketClient_Integration(t *testing.T) {
|
||||
// Create temporary socket path
|
||||
tmpDir := t.TempDir()
|
||||
socketPath := filepath.Join(tmpDir, "test.sock")
|
||||
|
||||
// Set environment variable for socket discovery
|
||||
t.Setenv("CODER_AGENT_SOCKET_PATH", socketPath)
|
||||
|
||||
// Start a real socket server
|
||||
server := startSocketServer(t, socketPath)
|
||||
defer server.Stop()
|
||||
|
||||
// Create client
|
||||
client, err := agentsdk.NewSocketClient(agentsdk.SocketConfig{})
|
||||
require.NoError(t, err)
|
||||
defer client.Close()
|
||||
|
||||
// Test ping
|
||||
ctx := context.Background()
|
||||
pingResp, err := client.Ping(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "pong", pingResp.Message)
|
||||
assert.False(t, pingResp.Timestamp.IsZero())
|
||||
|
||||
// Test health
|
||||
healthResp, err := client.Health(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "ready", healthResp.Status)
|
||||
assert.NotEmpty(t, healthResp.Uptime)
|
||||
|
||||
// Test agent info
|
||||
agentInfo, err := client.AgentInfo(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "test-agent", agentInfo.ID)
|
||||
assert.Equal(t, "1.0.0", agentInfo.Version)
|
||||
assert.Equal(t, "ready", agentInfo.Status)
|
||||
|
||||
// Test list methods
|
||||
methods, err := client.ListMethods(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, methods, "ping")
|
||||
assert.Contains(t, methods, "health")
|
||||
assert.Contains(t, methods, "agent.info")
|
||||
}
|
||||
|
||||
func TestSocketClient_Discovery(t *testing.T) {
|
||||
// Test with explicit path
|
||||
tmpDir := t.TempDir()
|
||||
socketPath := filepath.Join(tmpDir, "test.sock")
|
||||
|
||||
server := startSocketServer(t, socketPath)
|
||||
defer server.Stop()
|
||||
|
||||
client, err := agentsdk.NewSocketClient(agentsdk.SocketConfig{Path: socketPath})
|
||||
require.NoError(t, err)
|
||||
defer client.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
_, err = client.Ping(ctx)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestSocketClient_ErrorHandling(t *testing.T) {
|
||||
// Test with non-existent socket
|
||||
client, err := agentsdk.NewSocketClient(agentsdk.SocketConfig{Path: "/nonexistent/socket"})
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, client)
|
||||
}
|
||||
|
||||
// startSocketServer starts a real socket server for testing
|
||||
func startSocketServer(t *testing.T, path string) *agentsocket.Server {
|
||||
// Create server
|
||||
server := agentsocket.NewServer(agentsocket.Config{
|
||||
Path: path,
|
||||
Logger: slog.Make().Leveled(slog.LevelDebug),
|
||||
})
|
||||
|
||||
// Register default handlers with test data
|
||||
handlerCtx := agentsocket.CreateHandlerContext(
|
||||
"test-agent",
|
||||
"1.0.0",
|
||||
"ready",
|
||||
time.Now().Add(-time.Hour),
|
||||
slog.Make(),
|
||||
)
|
||||
agentsocket.RegisterDefaultHandlers(server, handlerCtx)
|
||||
|
||||
// Start server
|
||||
err := server.Start()
|
||||
require.NoError(t, err)
|
||||
|
||||
return server
|
||||
}
|
||||
Reference in New Issue
Block a user