diff --git a/aibridge/api.go b/aibridge/api.go index 46f39f80ee..809d452fe9 100644 --- a/aibridge/api.go +++ b/aibridge/api.go @@ -62,5 +62,5 @@ func NewMetrics(reg prometheus.Registerer) *metrics.Metrics { } func NewRecorder(logger slog.Logger, tracer trace.Tracer, clientFn func() (Recorder, error)) Recorder { - return recorder.NewRecorder(logger, tracer, clientFn) + return recorder.NewWrappedRecorder(logger, tracer, clientFn) } diff --git a/aibridge/bridge.go b/aibridge/bridge.go index 8adff4d605..f604d0a38a 100644 --- a/aibridge/bridge.go +++ b/aibridge/bridge.go @@ -180,8 +180,8 @@ func newInterceptionProcessor(p provider.Provider, cbs *circuitbreaker.ProviderC // We execute this before CreateInterceptor since the interceptors // read the request body and don't reset them. - client := guessClient(r) - sessionID := guessSessionID(client, r) + client := GuessClient(r) + sessionID := GuessSessionID(client, r) interceptor, err := p.CreateInterceptor(w, r.WithContext(ctx), tracer) if err != nil { @@ -276,7 +276,7 @@ func newInterceptionProcessor(p provider.Provider, cbs *circuitbreaker.ProviderC log.Debug(ctx, "interception ended") } - asyncRecorder.RecordInterceptionEnded(ctx, &recorder.InterceptionRecordEnded{ID: interceptor.ID().String()}) + _ = asyncRecorder.RecordInterceptionEnded(ctx, &recorder.InterceptionRecordEnded{ID: interceptor.ID().String()}) // Ensure all recording have completed before completing request. asyncRecorder.Wait() diff --git a/aibridge/bridge_test.go b/aibridge/bridge_test.go index 666abd31c0..f2657ab80f 100644 --- a/aibridge/bridge_test.go +++ b/aibridge/bridge_test.go @@ -1,4 +1,4 @@ -package aibridge +package aibridge_test import ( "net/http" @@ -7,16 +7,22 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel" "cdr.dev/slog/v3/sloggers/slogtest" + "github.com/coder/coder/v2/aibridge" "github.com/coder/coder/v2/aibridge/config" "github.com/coder/coder/v2/aibridge/internal/testutil" "github.com/coder/coder/v2/aibridge/provider" ) -func TestValidateProvider_Names(t *testing.T) { +var bridgeTestTracer = otel.Tracer("bridge_test") + +func TestValidateProviders(t *testing.T) { t.Parallel() + logger := slogtest.Make(t, nil) + tests := []struct { name string providers []provider.Provider @@ -25,94 +31,69 @@ func TestValidateProvider_Names(t *testing.T) { { name: "all_supported_providers", providers: []provider.Provider{ - NewOpenAIProvider(config.OpenAI{Name: "openai", BaseURL: "https://api.openai.com/v1/"}), - NewAnthropicProvider(config.Anthropic{Name: "anthropic", BaseURL: "https://api.anthropic.com/"}, nil), - NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.individual.githubcopilot.com"}), - NewCopilotProvider(config.Copilot{Name: "copilot-business", BaseURL: "https://api.business.githubcopilot.com"}), - NewCopilotProvider(config.Copilot{Name: "copilot-enterprise", BaseURL: "https://api.enterprise.githubcopilot.com"}), + aibridge.NewOpenAIProvider(config.OpenAI{Name: "openai", BaseURL: "https://api.openai.com/v1/"}), + aibridge.NewAnthropicProvider(config.Anthropic{Name: "anthropic", BaseURL: "https://api.anthropic.com/"}, nil), + aibridge.NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.individual.githubcopilot.com"}), + aibridge.NewCopilotProvider(config.Copilot{Name: "copilot-business", BaseURL: "https://api.business.githubcopilot.com"}), + aibridge.NewCopilotProvider(config.Copilot{Name: "copilot-enterprise", BaseURL: "https://api.enterprise.githubcopilot.com"}), }, }, { name: "default_names_and_base_urls", providers: []provider.Provider{ - NewOpenAIProvider(config.OpenAI{}), - NewAnthropicProvider(config.Anthropic{}, nil), - NewCopilotProvider(config.Copilot{}), + aibridge.NewOpenAIProvider(config.OpenAI{}), + aibridge.NewAnthropicProvider(config.Anthropic{}, nil), + aibridge.NewCopilotProvider(config.Copilot{}), }, }, { name: "multiple_copilot_instances", providers: []provider.Provider{ - NewCopilotProvider(config.Copilot{}), - NewCopilotProvider(config.Copilot{Name: "copilot-business", BaseURL: "https://api.business.githubcopilot.com"}), - NewCopilotProvider(config.Copilot{Name: "copilot-enterprise", BaseURL: "https://api.enterprise.githubcopilot.com"}), + aibridge.NewCopilotProvider(config.Copilot{}), + aibridge.NewCopilotProvider(config.Copilot{Name: "copilot-business", BaseURL: "https://api.business.githubcopilot.com"}), + aibridge.NewCopilotProvider(config.Copilot{Name: "copilot-enterprise", BaseURL: "https://api.enterprise.githubcopilot.com"}), }, }, { name: "name_with_slashes", providers: []provider.Provider{ - NewCopilotProvider(config.Copilot{Name: "copilot/business", BaseURL: "https://api.business.githubcopilot.com"}), + aibridge.NewCopilotProvider(config.Copilot{Name: "copilot/business", BaseURL: "https://api.business.githubcopilot.com"}), }, expectErr: "invalid provider name", }, { name: "name_with_spaces", providers: []provider.Provider{ - NewCopilotProvider(config.Copilot{Name: "copilot business", BaseURL: "https://api.business.githubcopilot.com"}), + aibridge.NewCopilotProvider(config.Copilot{Name: "copilot business", BaseURL: "https://api.business.githubcopilot.com"}), }, expectErr: "invalid provider name", }, { name: "name_with_uppercase", providers: []provider.Provider{ - NewCopilotProvider(config.Copilot{Name: "Copilot", BaseURL: "https://api.business.githubcopilot.com"}), + aibridge.NewCopilotProvider(config.Copilot{Name: "Copilot", BaseURL: "https://api.business.githubcopilot.com"}), }, expectErr: "invalid provider name", }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - err := validateProviders(tc.providers) - if tc.expectErr != "" { - require.Error(t, err) - assert.Contains(t, err.Error(), tc.expectErr) - } else { - require.NoError(t, err) - } - }) - } -} - -func TestValidateProvider_DuplicateNames(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - providers []provider.Provider - expectErr string - }{ { name: "unique_names", providers: []provider.Provider{ - NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.individual.githubcopilot.com"}), - NewCopilotProvider(config.Copilot{Name: "copilot-business", BaseURL: "https://api.business.githubcopilot.com"}), + aibridge.NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.individual.githubcopilot.com"}), + aibridge.NewCopilotProvider(config.Copilot{Name: "copilot-business", BaseURL: "https://api.business.githubcopilot.com"}), }, }, { name: "duplicate_base_url_different_names", providers: []provider.Provider{ - NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.individual.githubcopilot.com"}), - NewCopilotProvider(config.Copilot{Name: "copilot-business", BaseURL: "https://api.individual.githubcopilot.com"}), + aibridge.NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.individual.githubcopilot.com"}), + aibridge.NewCopilotProvider(config.Copilot{Name: "copilot-business", BaseURL: "https://api.individual.githubcopilot.com"}), }, }, { name: "duplicate_name", providers: []provider.Provider{ - NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.individual.githubcopilot.com"}), - NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.business.githubcopilot.com"}), + aibridge.NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.individual.githubcopilot.com"}), + aibridge.NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.business.githubcopilot.com"}), }, expectErr: "duplicate provider name", }, @@ -122,7 +103,7 @@ func TestValidateProvider_DuplicateNames(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - err := validateProviders(tc.providers) + _, err := aibridge.NewRequestBridge(t.Context(), tc.providers, nil, nil, logger, nil, bridgeTestTracer) if tc.expectErr != "" { require.Error(t, err) assert.Contains(t, err.Error(), tc.expectErr) @@ -148,7 +129,7 @@ func TestPassthroughRoutesForProviders(t *testing.T) { name: "openAI_no_base_path", requestPath: "/openai/v1/conversations", provider: func(baseURL string) provider.Provider { - return NewOpenAIProvider(config.OpenAI{BaseURL: baseURL}) + return aibridge.NewOpenAIProvider(config.OpenAI{BaseURL: baseURL}) }, expectPath: "/conversations", }, @@ -157,7 +138,7 @@ func TestPassthroughRoutesForProviders(t *testing.T) { baseURLPath: "/v1", requestPath: "/openai/v1/conversations", provider: func(baseURL string) provider.Provider { - return NewOpenAIProvider(config.OpenAI{BaseURL: baseURL}) + return aibridge.NewOpenAIProvider(config.OpenAI{BaseURL: baseURL}) }, expectPath: "/v1/conversations", }, @@ -165,7 +146,7 @@ func TestPassthroughRoutesForProviders(t *testing.T) { name: "anthropic_no_base_path", requestPath: "/anthropic/v1/models", provider: func(baseURL string) provider.Provider { - return NewAnthropicProvider(config.Anthropic{BaseURL: baseURL}, nil) + return aibridge.NewAnthropicProvider(config.Anthropic{BaseURL: baseURL}, nil) }, expectPath: "/v1/models", }, @@ -174,7 +155,7 @@ func TestPassthroughRoutesForProviders(t *testing.T) { baseURLPath: "/v1", requestPath: "/anthropic/v1/models", provider: func(baseURL string) provider.Provider { - return NewAnthropicProvider(config.Anthropic{BaseURL: baseURL}, nil) + return aibridge.NewAnthropicProvider(config.Anthropic{BaseURL: baseURL}, nil) }, expectPath: "/v1/v1/models", }, @@ -182,7 +163,7 @@ func TestPassthroughRoutesForProviders(t *testing.T) { name: "copilot_no_base_path", requestPath: "/copilot/models", provider: func(baseURL string) provider.Provider { - return NewCopilotProvider(config.Copilot{BaseURL: baseURL}) + return aibridge.NewCopilotProvider(config.Copilot{BaseURL: baseURL}) }, expectPath: "/models", }, @@ -191,7 +172,7 @@ func TestPassthroughRoutesForProviders(t *testing.T) { baseURLPath: "/v1", requestPath: "/copilot/models", provider: func(baseURL string) provider.Provider { - return NewCopilotProvider(config.Copilot{BaseURL: baseURL}) + return aibridge.NewCopilotProvider(config.Copilot{BaseURL: baseURL}) }, expectPath: "/v1/models", }, @@ -210,14 +191,14 @@ func TestPassthroughRoutesForProviders(t *testing.T) { })) t.Cleanup(upstream.Close) - recorder := testutil.MockRecorder{} + rec := testutil.MockRecorder{} prov := tc.provider(upstream.URL + tc.baseURLPath) - bridge, err := NewRequestBridge(t.Context(), []provider.Provider{prov}, &recorder, nil, logger, nil, testTracer) + bridge, err := aibridge.NewRequestBridge(t.Context(), []provider.Provider{prov}, &rec, nil, logger, nil, bridgeTestTracer) require.NoError(t, err) req := httptest.NewRequest("", tc.requestPath, nil) resp := httptest.NewRecorder() - bridge.mux.ServeHTTP(resp, req) + bridge.ServeHTTP(resp, req) assert.Equal(t, http.StatusOK, resp.Code) assert.Contains(t, resp.Body.String(), upstreamRespBody) diff --git a/aibridge/circuitbreaker/circuitbreaker.go b/aibridge/circuitbreaker/circuitbreaker.go index 6c908377fc..0f0880b192 100644 --- a/aibridge/circuitbreaker/circuitbreaker.go +++ b/aibridge/circuitbreaker/circuitbreaker.go @@ -1,8 +1,10 @@ package circuitbreaker import ( + "bufio" "errors" "fmt" + "net" "net/http" "sync" "time" @@ -65,8 +67,8 @@ func (p *ProviderCircuitBreakers) isFailure(statusCode int) bool { return DefaultIsFailure(statusCode) } -// openErrorResponse returns the error response body when the circuit is open. -func (p *ProviderCircuitBreakers) openErrorResponse() []byte { +// openErrBody returns the error response body when the circuit is open. +func (p *ProviderCircuitBreakers) openErrBody() []byte { if p.config.OpenErrorResponse != nil { return p.config.OpenErrorResponse() } @@ -77,7 +79,7 @@ func (p *ProviderCircuitBreakers) openErrorResponse() []byte { func (p *ProviderCircuitBreakers) Get(endpoint, model string) *gobreaker.CircuitBreaker[struct{}] { key := endpoint + ":" + model if v, ok := p.breakers.Load(key); ok { - return v.(*gobreaker.CircuitBreaker[struct{}]) + return v.(*gobreaker.CircuitBreaker[struct{}]) //nolint:forcetypeassert // sync.Map always stores this type } settings := gobreaker.Settings{ @@ -97,11 +99,12 @@ func (p *ProviderCircuitBreakers) Get(endpoint, model string) *gobreaker.Circuit cb := gobreaker.NewCircuitBreaker[struct{}](settings) actual, _ := p.breakers.LoadOrStore(key, cb) - return actual.(*gobreaker.CircuitBreaker[struct{}]) + return actual.(*gobreaker.CircuitBreaker[struct{}]) //nolint:forcetypeassert // sync.Map always stores this type } // statusCapturingWriter wraps http.ResponseWriter to capture the status code. -// It also implements http.Flusher to support streaming responses. +// It implements http.Flusher to support streaming and http.Hijacker to +// satisfy the FullResponseWriter lint rule. type statusCapturingWriter struct { http.ResponseWriter statusCode int @@ -130,6 +133,14 @@ func (w *statusCapturingWriter) Flush() { } } +func (w *statusCapturingWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + h, ok := w.ResponseWriter.(http.Hijacker) + if !ok { + return nil, nil, xerrors.New("upstream ResponseWriter does not support hijacking") + } + return h.Hijack() +} + // Unwrap returns the underlying ResponseWriter for interface checks. func (w *statusCapturingWriter) Unwrap() http.ResponseWriter { return w.ResponseWriter @@ -167,7 +178,7 @@ func (p *ProviderCircuitBreakers) Execute(endpoint, model string, w http.Respons w.Header().Set("Content-Type", "application/json") w.Header().Set("Retry-After", fmt.Sprintf("%d", int64(p.config.Timeout.Seconds()))) w.WriteHeader(http.StatusServiceUnavailable) - _, _ = w.Write(p.openErrorResponse()) + _, _ = w.Write(p.openErrBody()) return ErrCircuitOpen } @@ -187,7 +198,7 @@ func (p *ProviderCircuitBreakers) Provider() string { // OpenErrorResponse returns the error response body when the circuit is open. // This is exposed for handlers to use when responding to rejected requests. func (p *ProviderCircuitBreakers) OpenErrorResponse() []byte { - return p.openErrorResponse() + return p.openErrBody() } // StateToGaugeValue converts gobreaker.State to a gauge value. diff --git a/aibridge/circuitbreaker/circuitbreaker_test.go b/aibridge/circuitbreaker/circuitbreaker_test.go index 7a0957adc1..b80bfa2deb 100644 --- a/aibridge/circuitbreaker/circuitbreaker_test.go +++ b/aibridge/circuitbreaker/circuitbreaker_test.go @@ -1,4 +1,4 @@ -package circuitbreaker +package circuitbreaker_test import ( "errors" @@ -11,6 +11,7 @@ import ( "github.com/sony/gobreaker/v2" "github.com/stretchr/testify/assert" + "github.com/coder/coder/v2/aibridge/circuitbreaker" "github.com/coder/coder/v2/aibridge/config" ) @@ -20,7 +21,7 @@ func TestExecute_PerModelIsolation(t *testing.T) { sonnetCalls := atomic.Int32{} haikuCalls := atomic.Int32{} - cbs := NewProviderCircuitBreakers("test", &config.CircuitBreaker{ + cbs := circuitbreaker.NewProviderCircuitBreakers("test", &config.CircuitBreaker{ FailureThreshold: 1, Interval: time.Minute, Timeout: time.Minute, @@ -48,7 +49,7 @@ func TestExecute_PerModelIsolation(t *testing.T) { rw.WriteHeader(http.StatusOK) return nil }) - assert.True(t, errors.Is(err, ErrCircuitOpen)) + assert.True(t, errors.Is(err, circuitbreaker.ErrCircuitOpen)) assert.Equal(t, int32(1), sonnetCalls.Load()) // No new call assert.Equal(t, http.StatusServiceUnavailable, w.Code) @@ -69,7 +70,7 @@ func TestExecute_PerEndpointIsolation(t *testing.T) { messagesCalls := atomic.Int32{} completionsCalls := atomic.Int32{} - cbs := NewProviderCircuitBreakers("test", &config.CircuitBreaker{ + cbs := circuitbreaker.NewProviderCircuitBreakers("test", &config.CircuitBreaker{ FailureThreshold: 1, Interval: time.Minute, Timeout: time.Minute, @@ -95,7 +96,7 @@ func TestExecute_PerEndpointIsolation(t *testing.T) { rw.WriteHeader(http.StatusOK) return nil }) - assert.True(t, errors.Is(err, ErrCircuitOpen)) + assert.True(t, errors.Is(err, circuitbreaker.ErrCircuitOpen)) assert.Equal(t, int32(1), messagesCalls.Load()) // No new call assert.Equal(t, http.StatusServiceUnavailable, w.Code) @@ -116,7 +117,7 @@ func TestExecute_CustomIsFailure(t *testing.T) { var calls atomic.Int32 // Custom IsFailure that treats 502 as failure - cbs := NewProviderCircuitBreakers("test", &config.CircuitBreaker{ + cbs := circuitbreaker.NewProviderCircuitBreakers("test", &config.CircuitBreaker{ FailureThreshold: 1, Interval: time.Minute, Timeout: time.Minute, @@ -143,7 +144,7 @@ func TestExecute_CustomIsFailure(t *testing.T) { rw.WriteHeader(http.StatusOK) return nil }) - assert.True(t, errors.Is(err, ErrCircuitOpen)) + assert.True(t, errors.Is(err, circuitbreaker.ErrCircuitOpen)) assert.Equal(t, int32(1), calls.Load()) // No new call assert.Equal(t, http.StatusServiceUnavailable, w.Code) } @@ -158,7 +159,7 @@ func TestExecute_OnStateChange(t *testing.T) { to gobreaker.State } - cbs := NewProviderCircuitBreakers("test", &config.CircuitBreaker{ + cbs := circuitbreaker.NewProviderCircuitBreakers("test", &config.CircuitBreaker{ FailureThreshold: 1, Interval: time.Minute, Timeout: time.Minute, @@ -177,10 +178,11 @@ func TestExecute_OnStateChange(t *testing.T) { // Trip circuit w := httptest.NewRecorder() - cbs.Execute(endpoint, model, w, func(rw http.ResponseWriter) error { + err := cbs.Execute(endpoint, model, w, func(rw http.ResponseWriter) error { rw.WriteHeader(http.StatusTooManyRequests) return nil }) + assert.NoError(t, err) // Verify state change callback was called with correct parameters assert.Len(t, stateChanges, 1) @@ -208,14 +210,14 @@ func TestDefaultIsFailure(t *testing.T) { } for _, tt := range tests { - assert.Equal(t, tt.isFailure, DefaultIsFailure(tt.statusCode), "status code %d", tt.statusCode) + assert.Equal(t, tt.isFailure, circuitbreaker.DefaultIsFailure(tt.statusCode), "status code %d", tt.statusCode) } } func TestStateToGaugeValue(t *testing.T) { t.Parallel() - assert.Equal(t, float64(0), StateToGaugeValue(gobreaker.StateClosed)) - assert.Equal(t, float64(0.5), StateToGaugeValue(gobreaker.StateHalfOpen)) - assert.Equal(t, float64(1), StateToGaugeValue(gobreaker.StateOpen)) + assert.Equal(t, float64(0), circuitbreaker.StateToGaugeValue(gobreaker.StateClosed)) + assert.Equal(t, float64(0.5), circuitbreaker.StateToGaugeValue(gobreaker.StateHalfOpen)) + assert.Equal(t, float64(1), circuitbreaker.StateToGaugeValue(gobreaker.StateOpen)) } diff --git a/aibridge/client.go b/aibridge/client.go index a5c84f8497..3e9e277bec 100644 --- a/aibridge/client.go +++ b/aibridge/client.go @@ -24,10 +24,10 @@ const ( ClientUnknown Client = "Unknown" ) -// guessClient attempts to guess the client application from the request headers. +// GuessClient attempts to guess the client application from the request headers. // Not all clients set proper user agent headers, so this is a best-effort approach. // Based on https://github.com/coder/aibridge/issues/20#issuecomment-3769444101. -func guessClient(r *http.Request) Client { +func GuessClient(r *http.Request) Client { userAgent := strings.ToLower(r.UserAgent()) originator := r.Header.Get("originator") diff --git a/aibridge/client_test.go b/aibridge/client_test.go index 5c1d101728..e3fa82866e 100644 --- a/aibridge/client_test.go +++ b/aibridge/client_test.go @@ -1,10 +1,12 @@ -package aibridge +package aibridge_test import ( "net/http" "testing" "github.com/stretchr/testify/require" + + "github.com/coder/coder/v2/aibridge" ) func TestGuessClient(t *testing.T) { @@ -14,93 +16,93 @@ func TestGuessClient(t *testing.T) { name string userAgent string headers map[string]string - wantClient Client + wantClient aibridge.Client }{ { name: "mux", userAgent: "mux/0.19.0-next.2.gcceff159 ai-sdk/openai/3.0.36 ai-sdk/provider-utils/4.0.15 runtime/node.js/22", - wantClient: ClientMux, + wantClient: aibridge.ClientMux, }, { name: "claude_code", userAgent: "claude-cli/2.0.67 (external, cli)", - wantClient: ClientClaudeCode, + wantClient: aibridge.ClientClaudeCode, }, { name: "codex_cli", userAgent: "codex_cli_rs/0.87.0 (Mac OS 26.2.0; arm64) ghostty/1.3.0-main_250877ef", - wantClient: ClientCodex, + wantClient: aibridge.ClientCodex, }, { name: "zed", userAgent: "Zed/0.219.4+stable.119.abc123 (macos; aarch64)", - wantClient: ClientZed, + wantClient: aibridge.ClientZed, }, { name: "github_copilot_vsc", userAgent: "GitHubCopilotChat/0.37.2026011603", - wantClient: ClientCopilotVSC, + wantClient: aibridge.ClientCopilotVSC, }, { name: "github_copilot_cli", userAgent: "copilot/0.0.403 (client/cli linux v24.11.1)", - wantClient: ClientCopilotCLI, + wantClient: aibridge.ClientCopilotCLI, }, { name: "kilo_code_user_agent", userAgent: "kilo-code/5.1.0 (darwin 25.2.0; arm64) node/22.21.1", - wantClient: ClientKilo, + wantClient: aibridge.ClientKilo, }, { name: "kilo_code_originator", headers: map[string]string{"Originator": "kilo-code"}, - wantClient: ClientKilo, + wantClient: aibridge.ClientKilo, }, { name: "roo_code_user_agent", userAgent: "roo-code/3.45.0 (darwin 25.2.0; arm64) node/22.21.1", - wantClient: ClientRoo, + wantClient: aibridge.ClientRoo, }, { name: "roo_code_originator", headers: map[string]string{"Originator": "roo-code"}, - wantClient: ClientRoo, + wantClient: aibridge.ClientRoo, }, { name: "coder_agents", userAgent: "coder-agents/v2.24.0 (linux/amd64)", - wantClient: ClientCoderAgents, + wantClient: aibridge.ClientCoderAgents, }, { name: "coder_agents_dev", userAgent: "coder-agents/v0.0.0-devel (darwin/arm64)", - wantClient: ClientCoderAgents, + wantClient: aibridge.ClientCoderAgents, }, { name: "charm_crush", userAgent: "Charm Crush/0.1.11", - wantClient: ClientCrush, + wantClient: aibridge.ClientCrush, }, { name: "cursor_x_cursor_client_version", userAgent: "connect-es/1.6.1", headers: map[string]string{"X-Cursor-client-version": "0.50.0"}, - wantClient: ClientCursor, + wantClient: aibridge.ClientCursor, }, { name: "cursor_x_cursor_some_other_header", headers: map[string]string{"x-cursor-client-version": "abc123"}, - wantClient: ClientCursor, + wantClient: aibridge.ClientCursor, }, { name: "unknown_client", userAgent: "ccclaude-cli/calude-with-wrong-prefix", - wantClient: ClientUnknown, + wantClient: aibridge.ClientUnknown, }, { name: "empty_user_agent", userAgent: "", - wantClient: ClientUnknown, + wantClient: aibridge.ClientUnknown, }, } @@ -108,7 +110,7 @@ func TestGuessClient(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - req, err := http.NewRequest(http.MethodGet, "", nil) + req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, "", nil) require.NoError(t, err) req.Header.Set("User-Agent", tt.userAgent) @@ -116,7 +118,7 @@ func TestGuessClient(t *testing.T) { req.Header.Set(key, value) } - got := guessClient(req) + got := aibridge.GuessClient(req) require.Equal(t, tt.wantClient, got) }) } diff --git a/aibridge/context/context_test.go b/aibridge/context/context_test.go index 92ead40894..039b3a9a25 100644 --- a/aibridge/context/context_test.go +++ b/aibridge/context/context_test.go @@ -1,4 +1,4 @@ -package context +package context_test import ( "context" @@ -7,6 +7,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + aibcontext "github.com/coder/coder/v2/aibridge/context" "github.com/coder/coder/v2/aibridge/recorder" ) @@ -17,10 +18,10 @@ func TestAsActor(t *testing.T) { metadata := recorder.Metadata{"key": "value"} // When: storing an actor in the context - ctx := AsActor(context.Background(), "actor-123", metadata) + ctx := aibcontext.AsActor(context.Background(), "actor-123", metadata) // Then: the actor should be retrievable with correct ID and metadata - actor := ActorFromContext(ctx) + actor := aibcontext.ActorFromContext(ctx) require.NotNil(t, actor) assert.Equal(t, "actor-123", actor.ID) assert.Equal(t, "value", actor.Metadata["key"]) @@ -33,10 +34,10 @@ func TestActorFromContext(t *testing.T) { t.Parallel() // Given: a context with an actor - ctx := AsActor(context.Background(), "test-id", recorder.Metadata{}) + ctx := aibcontext.AsActor(context.Background(), "test-id", recorder.Metadata{}) // When: extracting the actor from context - actor := ActorFromContext(ctx) + actor := aibcontext.ActorFromContext(ctx) // Then: the actor should be returned with correct ID require.NotNil(t, actor) @@ -50,7 +51,7 @@ func TestActorFromContext(t *testing.T) { ctx := context.Background() // When: extracting the actor from context - actor := ActorFromContext(ctx) + actor := aibcontext.ActorFromContext(ctx) // Then: nil should be returned assert.Nil(t, actor) @@ -64,10 +65,10 @@ func TestActorIDFromContext(t *testing.T) { t.Parallel() // Given: a context with an actor - ctx := AsActor(context.Background(), "test-actor-id", recorder.Metadata{}) + ctx := aibcontext.AsActor(context.Background(), "test-actor-id", recorder.Metadata{}) // When: extracting the actor ID from context - got := ActorIDFromContext(ctx) + got := aibcontext.ActorIDFromContext(ctx) // Then: the actor ID should be returned assert.Equal(t, "test-actor-id", got) @@ -80,7 +81,7 @@ func TestActorIDFromContext(t *testing.T) { ctx := context.Background() // When: extracting the actor ID from context - got := ActorIDFromContext(ctx) + got := aibcontext.ActorIDFromContext(ctx) // Then: an empty string should be returned assert.Empty(t, got) diff --git a/aibridge/fixtures/README.md b/aibridge/fixtures/README.md deleted file mode 100644 index 075eaed0a3..0000000000 --- a/aibridge/fixtures/README.md +++ /dev/null @@ -1,25 +0,0 @@ -These fixtures were created by adding logging middleware to API calls to view the raw requests/responses. - -```go -... -opts = append(opts, option.WithMiddleware(LoggingMiddleware)) -... - -func LoggingMiddleware(req *http.Request, next option.MiddlewareNext) (res *http.Response, err error) { - reqOut, _ := httputil.DumpRequest(req, true) - - // Forward the request to the next handler - res, err = next(req) - fmt.Printf("[req] %s\n", reqOut) - - // Handle stuff after the request - if err != nil { - return res, err - } - - respOut, _ := httputil.DumpResponse(res, true) - fmt.Printf("[resp] %s\n", respOut) - - return res, err -} -``` diff --git a/aibridge/fixtures/fixtures.go b/aibridge/fixtures/fixtures.go index eeb1e6aedd..c731e0fb9c 100644 --- a/aibridge/fixtures/fixtures.go +++ b/aibridge/fixtures/fixtures.go @@ -92,7 +92,7 @@ var ( OaiResponsesBlockingConversation []byte //go:embed openai/responses/blocking/http_error.txtar - OaiResponsesBlockingHttpErr []byte + OaiResponsesBlockingHTTPErr []byte //go:embed openai/responses/blocking/prev_response_id.txtar OaiResponsesBlockingPrevResponseID []byte @@ -139,7 +139,7 @@ var ( OaiResponsesStreamingConversation []byte //go:embed openai/responses/streaming/http_error.txtar - OaiResponsesStreamingHttpErr []byte + OaiResponsesStreamingHTTPErr []byte //go:embed openai/responses/streaming/prev_response_id.txtar OaiResponsesStreamingPrevResponseID []byte diff --git a/aibridge/intercept/actor_headers_test.go b/aibridge/intercept/actor_headers_test.go index 917896e40c..aa2b1a7771 100644 --- a/aibridge/intercept/actor_headers_test.go +++ b/aibridge/intercept/actor_headers_test.go @@ -1,4 +1,4 @@ -package intercept +package intercept_test import ( "testing" @@ -7,14 +7,15 @@ import ( "github.com/stretchr/testify/require" "github.com/coder/coder/v2/aibridge/context" + "github.com/coder/coder/v2/aibridge/intercept" "github.com/coder/coder/v2/aibridge/recorder" ) func TestNilActor(t *testing.T) { t.Parallel() - require.Nil(t, ActorHeadersAsOpenAIOpts(nil)) - require.Nil(t, ActorHeadersAsAnthropicOpts(nil)) + require.Nil(t, intercept.ActorHeadersAsOpenAIOpts(nil)) + require.Nil(t, intercept.ActorHeadersAsAnthropicOpts(nil)) } func TestBasic(t *testing.T) { @@ -28,9 +29,9 @@ func TestBasic(t *testing.T) { // We can't peek inside since these opts require an internal type to apply onto. // All we can do is check the length. // See TestActorHeaders for an integration test. - oaiOpts := ActorHeadersAsOpenAIOpts(actor) + oaiOpts := intercept.ActorHeadersAsOpenAIOpts(actor) require.Len(t, oaiOpts, 1) - antOpts := ActorHeadersAsAnthropicOpts(actor) + antOpts := intercept.ActorHeadersAsAnthropicOpts(actor) require.Len(t, antOpts, 1) } @@ -49,8 +50,8 @@ func TestBasicAndMetadata(t *testing.T) { // We can't peek inside since these opts require an internal type to apply onto. // All we can do is check the length. // See TestActorHeaders for an integration test. - oaiOpts := ActorHeadersAsOpenAIOpts(actor) + oaiOpts := intercept.ActorHeadersAsOpenAIOpts(actor) require.Len(t, oaiOpts, 1+len(actor.Metadata)) - antOpts := ActorHeadersAsAnthropicOpts(actor) + antOpts := intercept.ActorHeadersAsAnthropicOpts(actor) require.Len(t, antOpts, 1+len(actor.Metadata)) } diff --git a/aibridge/intercept/apidump/apidump.go b/aibridge/intercept/apidump/apidump.go index 732e5e37ca..19a2b4d52a 100644 --- a/aibridge/intercept/apidump/apidump.go +++ b/aibridge/intercept/apidump/apidump.go @@ -105,10 +105,11 @@ func (d *dumper) dumpRequest(req *http.Request) error { if err != nil { return xerrors.Errorf("write request header terminator: %w", err) } - buf.Write(prettyBody) - buf.WriteByte('\n') + // bytes.Buffer writes to in-memory storage and never return errors. + _, _ = buf.Write(prettyBody) + _ = buf.WriteByte('\n') - return os.WriteFile(dumpPath, buf.Bytes(), 0o644) + return os.WriteFile(dumpPath, buf.Bytes(), 0o600) } func (d *dumper) dumpResponse(resp *http.Response) error { @@ -129,19 +130,19 @@ func (d *dumper) dumpResponse(resp *http.Response) error { return xerrors.Errorf("write response header terminator: %w", err) } - // Wrap the response body to capture it as it streams - if resp.Body != nil { - resp.Body = &streamingBodyDumper{ - body: resp.Body, - dumpPath: dumpPath, - headerData: headerBuf.Bytes(), - logger: func(err error) { - d.logger.Named("apidump").Warn(context.Background(), "failed to initialize response dump", slog.Error(err)) - }, - } - } else { + if resp.Body == nil { // No body, just write headers - return os.WriteFile(dumpPath, headerBuf.Bytes(), 0o644) + return os.WriteFile(dumpPath, headerBuf.Bytes(), 0o600) + } + + // Wrap the response body to capture it as it streams + resp.Body = &streamingBodyDumper{ + body: resp.Body, + dumpPath: dumpPath, + headerData: headerBuf.Bytes(), + logger: func(err error) { + d.logger.Named("apidump").Warn(context.Background(), "failed to initialize response dump", slog.Error(err)) + }, } return nil @@ -152,7 +153,7 @@ func (d *dumper) dumpResponse(resp *http.Response) error { // for deterministic output. // `sensitive` and `overrides` must both supply keys in canoncialized form. // See [textproto.MIMEHeader]. -func (d *dumper) writeRedactedHeaders(w io.Writer, headers http.Header, sensitive map[string]struct{}, overrides map[string]string) error { +func (*dumper) writeRedactedHeaders(w io.Writer, headers http.Header, sensitive map[string]struct{}, overrides map[string]string) error { // Collect all header keys including overrides. headerKeys := make([]string, 0, len(headers)+len(overrides)) seen := make(map[string]struct{}, len(headers)+len(overrides)) diff --git a/aibridge/intercept/apidump/apidump_test.go b/aibridge/intercept/apidump/apidump_test.go index 5fb8aa2e6c..ac85035afe 100644 --- a/aibridge/intercept/apidump/apidump_test.go +++ b/aibridge/intercept/apidump/apidump_test.go @@ -1,4 +1,4 @@ -package apidump +package apidump //nolint:testpackage // tests unexported internals import ( "bytes" @@ -39,7 +39,7 @@ func TestBridgedMiddleware_RedactsSensitiveRequestHeaders(t *testing.T) { middleware := NewBridgeMiddleware(tmpDir, "openai", "gpt-4", interceptionID, logger, clk) require.NotNil(t, middleware) - req, err := http.NewRequest(http.MethodPost, "https://api.openai.com/v1/chat/completions", bytes.NewReader([]byte(`{"test": true}`))) + req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, "https://api.openai.com/v1/chat/completions", bytes.NewReader([]byte(`{"test": true}`))) require.NoError(t, err) // Add sensitive headers that should be redacted @@ -52,7 +52,7 @@ func TestBridgedMiddleware_RedactsSensitiveRequestHeaders(t *testing.T) { req.Header.Set("User-Agent", "test-client") // Call middleware with a mock next function - _, err = middleware(req, func(r *http.Request) (*http.Response, error) { + resp, err := middleware(req, func(r *http.Request) (*http.Response, error) { return &http.Response{ StatusCode: http.StatusOK, Status: "200 OK", @@ -62,6 +62,7 @@ func TestBridgedMiddleware_RedactsSensitiveRequestHeaders(t *testing.T) { }, nil }) require.NoError(t, err) + defer resp.Body.Close() // Read the request dump file modelDir := filepath.Join(tmpDir, "openai", "gpt-4") @@ -96,7 +97,7 @@ func TestBridgedMiddleware_RedactsSensitiveResponseHeaders(t *testing.T) { middleware := NewBridgeMiddleware(tmpDir, "openai", "gpt-4", interceptionID, logger, clk) require.NotNil(t, middleware) - req, err := http.NewRequest(http.MethodPost, "https://api.openai.com/v1/chat/completions", bytes.NewReader([]byte(`{}`))) + req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, "https://api.openai.com/v1/chat/completions", bytes.NewReader([]byte(`{}`))) require.NoError(t, err) // Call middleware with a response containing sensitive headers @@ -166,11 +167,11 @@ func TestBridgedMiddleware_PreservesRequestBody(t *testing.T) { require.NotNil(t, middleware) originalBody := `{"messages": [{"role": "user", "content": "hello"}]}` - req, err := http.NewRequest(http.MethodPost, "https://api.openai.com/v1/chat/completions", bytes.NewReader([]byte(originalBody))) + req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, "https://api.openai.com/v1/chat/completions", bytes.NewReader([]byte(originalBody))) require.NoError(t, err) var capturedBody []byte - _, err = middleware(req, func(r *http.Request) (*http.Response, error) { + resp2, err := middleware(req, func(r *http.Request) (*http.Response, error) { // Read the body in the next handler to verify it's still available capturedBody, _ = io.ReadAll(r.Body) return &http.Response{ @@ -182,6 +183,7 @@ func TestBridgedMiddleware_PreservesRequestBody(t *testing.T) { }, nil }) require.NoError(t, err) + defer resp2.Body.Close() // Verify the body was preserved for the next handler require.Equal(t, originalBody, string(capturedBody)) @@ -199,10 +201,10 @@ func TestBridgedMiddleware_ModelWithSlash(t *testing.T) { middleware := NewBridgeMiddleware(tmpDir, "google", "gemini/1.5-pro", interceptionID, logger, clk) require.NotNil(t, middleware) - req, err := http.NewRequest(http.MethodPost, "https://api.google.com/v1/chat", bytes.NewReader([]byte(`{}`))) + req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, "https://api.google.com/v1/chat", bytes.NewReader([]byte(`{}`))) require.NoError(t, err) - _, err = middleware(req, func(r *http.Request) (*http.Response, error) { + resp3, err := middleware(req, func(r *http.Request) (*http.Response, error) { return &http.Response{ StatusCode: http.StatusOK, Status: "200 OK", @@ -212,6 +214,7 @@ func TestBridgedMiddleware_ModelWithSlash(t *testing.T) { }, nil }) require.NoError(t, err) + defer resp3.Body.Close() // Verify files are created with sanitized model name modelDir := filepath.Join(tmpDir, "google", "gemini-1.5-pro") @@ -278,7 +281,7 @@ func TestBridgedMiddleware_AllSensitiveRequestHeaders(t *testing.T) { middleware := NewBridgeMiddleware(tmpDir, "openai", "gpt-4", interceptionID, logger, clk) require.NotNil(t, middleware) - req, err := http.NewRequest(http.MethodPost, "https://api.openai.com/v1/chat/completions", bytes.NewReader([]byte(`{}`))) + req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, "https://api.openai.com/v1/chat/completions", bytes.NewReader([]byte(`{}`))) require.NoError(t, err) // Set all sensitive headers @@ -290,7 +293,7 @@ func TestBridgedMiddleware_AllSensitiveRequestHeaders(t *testing.T) { req.Header.Set("Proxy-Authorization", "Basic proxy-creds") req.Header.Set("X-Amz-Security-Token", "aws-security-token") - _, err = middleware(req, func(r *http.Request) (*http.Response, error) { + resp4, err := middleware(req, func(r *http.Request) (*http.Response, error) { return &http.Response{ StatusCode: http.StatusOK, Status: "200 OK", @@ -300,6 +303,7 @@ func TestBridgedMiddleware_AllSensitiveRequestHeaders(t *testing.T) { }, nil }) require.NoError(t, err) + defer resp4.Body.Close() modelDir := filepath.Join(tmpDir, "openai", "gpt-4") reqDumpPath := findDumpFile(t, modelDir, SuffixRequest) @@ -355,10 +359,10 @@ func TestPassthroughMiddleware(t *testing.T) { rt := NewPassthroughMiddleware(inner, tmpDir, "openai", logger, clk) - req, err := http.NewRequest(http.MethodGet, "https://api.openai.com/v1/models", nil) + req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, "https://api.openai.com/v1/models", nil) require.NoError(t, err) - resp, err := rt.RoundTrip(req) + resp, err := rt.RoundTrip(req) //nolint:bodyclose // resp is nil on error require.ErrorIs(t, err, innerErr) require.Nil(t, resp) }) @@ -399,7 +403,7 @@ func TestPassthroughMiddleware(t *testing.T) { rt := NewPassthroughMiddleware(inner, tmpDir, "openai", logger, clk) - req, err := http.NewRequest(http.MethodPost, "/v1/models", bytes.NewReader([]byte(req1Body))) + req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, "/v1/models", bytes.NewReader([]byte(req1Body))) require.NoError(t, err) req.Header.Set("Authorization", "Bearer sk-secret-key-12345") resp, err := rt.RoundTrip(req) @@ -409,7 +413,7 @@ func TestPassthroughMiddleware(t *testing.T) { require.NoError(t, resp.Body.Close()) // Second request should create new req/resp files - req2, err := http.NewRequest(http.MethodPost, "/v1/conversations", bytes.NewReader([]byte(req2Body))) + req2, err := http.NewRequestWithContext(t.Context(), http.MethodPost, "/v1/conversations", bytes.NewReader([]byte(req2Body))) require.NoError(t, err) resp2, err := rt.RoundTrip(req2) require.NoError(t, err) diff --git a/aibridge/intercept/apidump/headers_test.go b/aibridge/intercept/apidump/headers_test.go index e9046541a3..cef10c0614 100644 --- a/aibridge/intercept/apidump/headers_test.go +++ b/aibridge/intercept/apidump/headers_test.go @@ -1,4 +1,4 @@ -package apidump +package apidump //nolint:testpackage // tests unexported internals import ( "bytes" diff --git a/aibridge/intercept/apidump/streaming.go b/aibridge/intercept/apidump/streaming.go index e2db42ac12..ef9805d86d 100644 --- a/aibridge/intercept/apidump/streaming.go +++ b/aibridge/intercept/apidump/streaming.go @@ -37,7 +37,7 @@ func (s *streamingBodyDumper) init() { // Write headers first. if _, err := s.file.Write(s.headerData); err != nil { s.initErr = xerrors.Errorf("write headers: %w", err) - s.file.Close() + _ = s.file.Close() // best-effort cleanup on header write failure s.file = nil } }) diff --git a/aibridge/intercept/apidump/streaming_test.go b/aibridge/intercept/apidump/streaming_test.go index 2a39c1b81a..7bdac2a96c 100644 --- a/aibridge/intercept/apidump/streaming_test.go +++ b/aibridge/intercept/apidump/streaming_test.go @@ -1,4 +1,4 @@ -package apidump +package apidump //nolint:testpackage // shares test helpers with apidump_test.go import ( "bytes" @@ -28,7 +28,7 @@ func TestMiddleware_StreamingResponse(t *testing.T) { middleware := NewBridgeMiddleware(tmpDir, "openai", "gpt-4", interceptionID, logger, clk) require.NotNil(t, middleware) - req, err := http.NewRequest(http.MethodPost, "https://api.openai.com/v1/chat/completions", bytes.NewReader([]byte(`{}`))) + req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, "https://api.openai.com/v1/chat/completions", bytes.NewReader([]byte(`{}`))) require.NoError(t, err) // Simulate a streaming response with multiple chunks @@ -42,10 +42,12 @@ func TestMiddleware_StreamingResponse(t *testing.T) { // Create a pipe to simulate streaming pr, pw := io.Pipe() go func() { + defer pw.Close() //nolint:revive // error handled via pipe read side for _, chunk := range chunks { - pw.Write([]byte(chunk)) + if _, err := pw.Write([]byte(chunk)); err != nil { + return + } } - pw.Close() }() resp, err := middleware(req, func(r *http.Request) (*http.Response, error) { @@ -65,7 +67,7 @@ func TestMiddleware_StreamingResponse(t *testing.T) { for { n, err := resp.Body.Read(buf) if n > 0 { - receivedData.Write(buf[:n]) + _, _ = receivedData.Write(buf[:n]) // bytes.Buffer.Write never fails } if err == io.EOF { break @@ -104,7 +106,7 @@ func TestMiddleware_PreservesResponseBody(t *testing.T) { middleware := NewBridgeMiddleware(tmpDir, "openai", "gpt-4", interceptionID, logger, clk) require.NotNil(t, middleware) - req, err := http.NewRequest(http.MethodPost, "https://api.openai.com/v1/chat/completions", bytes.NewReader([]byte(`{}`))) + req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, "https://api.openai.com/v1/chat/completions", bytes.NewReader([]byte(`{}`))) require.NoError(t, err) originalRespBody := `{"choices": [{"message": {"content": "hi"}}]}` @@ -118,6 +120,7 @@ func TestMiddleware_PreservesResponseBody(t *testing.T) { }, nil }) require.NoError(t, err) + defer resp.Body.Close() // Verify the response body is still readable after middleware capturedBody, err := io.ReadAll(resp.Body) diff --git a/aibridge/intercept/chatcompletions/base.go b/aibridge/intercept/chatcompletions/base.go index f27b070b3f..40922fae3b 100644 --- a/aibridge/intercept/chatcompletions/base.go +++ b/aibridge/intercept/chatcompletions/base.go @@ -79,9 +79,9 @@ func (i *interceptionBase) Credential() intercept.CredentialInfo { return i.credential } -func (i *interceptionBase) Setup(logger slog.Logger, recorder recorder.Recorder, mcpProxy mcp.ServerProxier) { +func (i *interceptionBase) Setup(logger slog.Logger, rec recorder.Recorder, mcpProxy mcp.ServerProxier) { i.logger = logger - i.recorder = recorder + i.recorder = rec i.mcpProxy = mcpProxy } @@ -98,13 +98,13 @@ func (i *interceptionBase) CorrelatingToolCallID() *string { return &msg.OfTool.ToolCallID } -func (s *interceptionBase) baseTraceAttributes(r *http.Request, streaming bool) []attribute.KeyValue { +func (i *interceptionBase) baseTraceAttributes(r *http.Request, streaming bool) []attribute.KeyValue { return []attribute.KeyValue{ attribute.String(tracing.RequestPath, r.URL.Path), - attribute.String(tracing.InterceptionID, s.id.String()), + attribute.String(tracing.InterceptionID, i.id.String()), attribute.String(tracing.InitiatorID, aibcontext.ActorIDFromContext(r.Context())), - attribute.String(tracing.Provider, s.providerName), - attribute.String(tracing.Model, s.Model()), + attribute.String(tracing.Provider, i.providerName), + attribute.String(tracing.Model, i.Model()), attribute.Bool(tracing.Streaming, streaming), } } @@ -114,10 +114,10 @@ func (i *interceptionBase) Model() string { return "coder-aibridge-unknown" } - return string(i.req.Model) + return i.req.Model } -func (i *interceptionBase) newErrorResponse(err error) map[string]any { +func (*interceptionBase) newErrorResponse(err error) map[string]any { return map[string]any{ "error": true, "message": err.Error(), @@ -172,7 +172,7 @@ func (i *interceptionBase) unmarshalArgs(in string) (args recorder.ToolArgs) { } // writeUpstreamError marshals and writes a given error. -func (i *interceptionBase) writeUpstreamError(w http.ResponseWriter, oaiErr *errorResponse) { +func (i *interceptionBase) writeUpstreamError(w http.ResponseWriter, oaiErr *chatCompletionResponseError) { if oaiErr == nil { return } @@ -182,7 +182,7 @@ func (i *interceptionBase) writeUpstreamError(w http.ResponseWriter, oaiErr *err out, err := json.Marshal(oaiErr) if err != nil { - i.logger.Warn(context.Background(), "failed to marshal upstream error", slog.Error(err), slog.F("error_payload", slog.F("%+v", oaiErr))) + i.logger.Warn(context.Background(), "failed to marshal upstream error", slog.Error(err), slog.F("error_payload", oaiErr)) // Response has to match expected format. _, _ = w.Write([]byte(`{ "error": { @@ -227,13 +227,13 @@ func calculateActualInputTokenUsage(in openai.CompletionUsage) int64 { in.PromptTokensDetails.CachedTokens /* The aggregated number of text input tokens that has been cached from previous requests. */ } -func getErrorResponse(err error) *errorResponse { +func getErrorResponse(err error) *chatCompletionResponseError { var apiErr *openai.Error if !errors.As(err, &apiErr) { return nil } - return &errorResponse{ + return &chatCompletionResponseError{ ErrorObject: &shared.ErrorObject{ Code: apiErr.Code, Message: apiErr.Message, @@ -243,15 +243,15 @@ func getErrorResponse(err error) *errorResponse { } } -var _ error = &errorResponse{} +var _ error = &chatCompletionResponseError{} -type errorResponse struct { +type chatCompletionResponseError struct { ErrorObject *shared.ErrorObject `json:"error"` StatusCode int `json:"-"` } -func newErrorResponse(msg error) *errorResponse { - return &errorResponse{ +func newErrorResponse(msg error) *chatCompletionResponseError { + return &chatCompletionResponseError{ ErrorObject: &shared.ErrorObject{ Code: "error", Message: msg.Error(), @@ -260,7 +260,7 @@ func newErrorResponse(msg error) *errorResponse { } } -func (a *errorResponse) Error() string { +func (a *chatCompletionResponseError) Error() string { if a.ErrorObject == nil { return "" } diff --git a/aibridge/intercept/chatcompletions/base_test.go b/aibridge/intercept/chatcompletions/base_test.go index 3bfbe82998..67104b9085 100644 --- a/aibridge/intercept/chatcompletions/base_test.go +++ b/aibridge/intercept/chatcompletions/base_test.go @@ -1,4 +1,4 @@ -package chatcompletions +package chatcompletions //nolint:testpackage // tests unexported internals import ( "testing" diff --git a/aibridge/intercept/chatcompletions/blocking.go b/aibridge/intercept/chatcompletions/blocking.go index 98574c277e..59c8bbb731 100644 --- a/aibridge/intercept/chatcompletions/blocking.go +++ b/aibridge/intercept/chatcompletions/blocking.go @@ -50,16 +50,16 @@ func NewBlockingInterceptor( }} } -func (s *BlockingInterception) Setup(logger slog.Logger, recorder recorder.Recorder, mcpProxy mcp.ServerProxier) { - s.interceptionBase.Setup(logger.Named("blocking"), recorder, mcpProxy) +func (i *BlockingInterception) Setup(logger slog.Logger, rec recorder.Recorder, mcpProxy mcp.ServerProxier) { + i.interceptionBase.Setup(logger.Named("blocking"), rec, mcpProxy) } -func (s *BlockingInterception) Streaming() bool { +func (*BlockingInterception) Streaming() bool { return false } -func (s *BlockingInterception) TraceAttributes(r *http.Request) []attribute.KeyValue { - return s.interceptionBase.baseTraceAttributes(r, false) +func (i *BlockingInterception) TraceAttributes(r *http.Request) []attribute.KeyValue { + return i.interceptionBase.baseTraceAttributes(r, false) } func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Request) (outErr error) { diff --git a/aibridge/intercept/chatcompletions/paramswrap.go b/aibridge/intercept/chatcompletions/paramswrap.go index 58851459aa..8b9efbbf4f 100644 --- a/aibridge/intercept/chatcompletions/paramswrap.go +++ b/aibridge/intercept/chatcompletions/paramswrap.go @@ -52,7 +52,7 @@ func (c *ChatCompletionNewParamsWrapper) lastUserPrompt() (*string, error) { // We only care if the last message was issued by a user. msg := c.Messages[len(c.Messages)-1] if msg.OfUser == nil { - return nil, nil + return nil, nil //nolint:nilnil // no user prompt found is not an error } if msg.OfUser.Content.OfString.String() != "" { @@ -69,5 +69,5 @@ func (c *ChatCompletionNewParamsWrapper) lastUserPrompt() (*string, error) { } } - return nil, nil + return nil, nil //nolint:nilnil // no text content found is not an error } diff --git a/aibridge/intercept/chatcompletions/paramswrap_test.go b/aibridge/intercept/chatcompletions/paramswrap_test.go index 230664e463..1e7c61f3b8 100644 --- a/aibridge/intercept/chatcompletions/paramswrap_test.go +++ b/aibridge/intercept/chatcompletions/paramswrap_test.go @@ -1,4 +1,4 @@ -package chatcompletions +package chatcompletions //nolint:testpackage // tests unexported internals import ( "fmt" @@ -114,6 +114,8 @@ func TestOpenAILastUserPrompt(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + t.Parallel() + result, err := tt.wrapper.lastUserPrompt() if tt.expectError { @@ -144,7 +146,7 @@ func generatePayload(messageCount int) []byte { } // Use realistic message content size content := fmt.Sprintf("This is message number %d with some realistic content that might appear in a conversation.", i+1) - messages = append(messages, fmt.Sprintf(`{"role": "%s", "content": "%s"}`, role, content)) + messages = append(messages, fmt.Sprintf(`{"role": %q, "content": %q}`, role, content)) } return []byte(fmt.Sprintf(`{ diff --git a/aibridge/intercept/chatcompletions/streaming.go b/aibridge/intercept/chatcompletions/streaming.go index 9534780c33..7ce31f0ee5 100644 --- a/aibridge/intercept/chatcompletions/streaming.go +++ b/aibridge/intercept/chatcompletions/streaming.go @@ -54,16 +54,16 @@ func NewStreamingInterceptor( }} } -func (i *StreamingInterception) Setup(logger slog.Logger, recorder recorder.Recorder, mcpProxy mcp.ServerProxier) { - i.interceptionBase.Setup(logger.Named("streaming"), recorder, mcpProxy) +func (i *StreamingInterception) Setup(logger slog.Logger, rec recorder.Recorder, mcpProxy mcp.ServerProxier) { + i.interceptionBase.Setup(logger.Named("streaming"), rec, mcpProxy) } -func (i *StreamingInterception) Streaming() bool { +func (*StreamingInterception) Streaming() bool { return true } -func (s *StreamingInterception) TraceAttributes(r *http.Request) []attribute.KeyValue { - return s.interceptionBase.baseTraceAttributes(r, true) +func (i *StreamingInterception) TraceAttributes(r *http.Request) []attribute.KeyValue { + return i.interceptionBase.baseTraceAttributes(r, true) } // ProcessRequest handles a request to /v1/chat/completions. @@ -189,16 +189,14 @@ func (i *StreamingInterception) ProcessRequest(w http.ResponseWriter, r *http.Re }) toolCall = nil - } else { + } else if stream.Err() == nil { // When the provider responds with only tool calls (no text content), // no chunks are relayed to the client, so the stream is not yet // initiated. Initiate it here so the SSE headers are sent and the // ping ticker is started, preventing client timeout during tool invocation. // Only initiate if no stream error, if there's an error, we'll return // an HTTP error response instead of starting an SSE stream. - if stream.Err() == nil { - events.InitiateStream(w) - } + events.InitiateStream(w) } } @@ -231,43 +229,43 @@ func (i *StreamingInterception) ProcessRequest(w http.ResponseWriter, r *http.Re }) } - if events.IsStreaming() { - // Check if the stream encountered any errors. - if streamErr := stream.Err(); streamErr != nil { - if eventstream.IsUnrecoverableError(streamErr) { - logger.Debug(ctx, "stream terminated", slog.Error(streamErr)) - // We can't reflect an error back if there's a connection error or the request context was canceled. - } else if oaiErr := getErrorResponse(streamErr); oaiErr != nil { - logger.Warn(ctx, "openai stream error", slog.Error(streamErr)) - interceptionErr = oaiErr - } else { - logger.Warn(ctx, "unknown error", slog.Error(streamErr)) - // Unfortunately, the OpenAI SDK does not support parsing errors received in the stream - // into known types (i.e. [shared.OverloadedError]). - // See https://github.com/openai/openai-go/blob/v2.7.0/packages/ssestream/ssestream.go#L171 - // All it does is wrap the payload in an error - which is all we can return, currently. - interceptionErr = newErrorResponse(xerrors.Errorf("unknown stream error: %w", streamErr)) - } - } else if lastErr != nil { - // Otherwise check if any logical errors occurred during processing. - logger.Warn(ctx, "stream failed", slog.Error(lastErr)) - interceptionErr = newErrorResponse(xerrors.Errorf("processing error: %w", lastErr)) - } - - if interceptionErr != nil { - payload, err := i.marshalErr(interceptionErr) - if err != nil { - logger.Warn(ctx, "failed to marshal error", slog.Error(err), slog.F("error_payload", slog.F("%+v", interceptionErr))) - } else if err := events.Send(streamCtx, payload); err != nil { - logger.Warn(ctx, "failed to relay error", slog.Error(err), slog.F("payload", payload)) - } - } - } else { + if !events.IsStreaming() { // response/downstream Stream has not started yet; write error response and exit. i.writeUpstreamError(w, getErrorResponse(stream.Err())) return stream.Err() } + // Check if the stream encountered any errors. + if streamErr := stream.Err(); streamErr != nil { + if eventstream.IsUnrecoverableError(streamErr) { + logger.Debug(ctx, "stream terminated", slog.Error(streamErr)) + // We can't reflect an error back if there's a connection error or the request context was canceled. + } else if oaiErr := getErrorResponse(streamErr); oaiErr != nil { + logger.Warn(ctx, "openai stream error", slog.Error(streamErr)) + interceptionErr = oaiErr + } else { + logger.Warn(ctx, "unknown stream error encountered", slog.Error(streamErr)) + // Unfortunately, the OpenAI SDK does not support parsing errors received in the stream + // into known types (i.e. [shared.OverloadedError]). + // See https://github.com/openai/openai-go/blob/v2.7.0/packages/ssestream/ssestream.go#L171 + // All it does is wrap the payload in an error - which is all we can return, currently. + interceptionErr = newErrorResponse(xerrors.Errorf("unknown stream error: %w", streamErr)) + } + } else if lastErr != nil { + // Otherwise check if any logical errors occurred during processing. + logger.Warn(ctx, "stream processing failed", slog.Error(lastErr)) + interceptionErr = newErrorResponse(xerrors.Errorf("processing error: %w", lastErr)) + } + + if interceptionErr != nil { + payload, err := i.marshalErr(interceptionErr) + if err != nil { + logger.Warn(ctx, "failed to marshal error", slog.Error(err), slog.F("error_payload", interceptionErr.Error())) + } else if err := events.Send(streamCtx, payload); err != nil { + logger.Warn(ctx, "failed to relay error", slog.Error(err), slog.F("payload", payload)) + } + } + // No tool call, nothing more to do. if toolCall == nil { break @@ -390,11 +388,12 @@ func (i *StreamingInterception) marshalErr(err error) ([]byte, error) { return i.encodeForStream(data), nil } -func (i *StreamingInterception) encodeForStream(payload []byte) []byte { +func (*StreamingInterception) encodeForStream(payload []byte) []byte { + // bytes.Buffer writes to in-memory storage and never return errors. var buf bytes.Buffer - buf.WriteString("data: ") - buf.Write(payload) - buf.WriteString("\n\n") + _, _ = buf.WriteString("data: ") + _, _ = buf.Write(payload) + _, _ = buf.WriteString("\n\n") return buf.Bytes() } diff --git a/aibridge/intercept/chatcompletions/streaming_test.go b/aibridge/intercept/chatcompletions/streaming_test.go index 83c081097b..640ad197c5 100644 --- a/aibridge/intercept/chatcompletions/streaming_test.go +++ b/aibridge/intercept/chatcompletions/streaming_test.go @@ -1,4 +1,4 @@ -package chatcompletions +package chatcompletions_test import ( "net/http" @@ -16,6 +16,7 @@ import ( "cdr.dev/slog/v3/sloggers/slogtest" "github.com/coder/coder/v2/aibridge/config" "github.com/coder/coder/v2/aibridge/intercept" + "github.com/coder/coder/v2/aibridge/intercept/chatcompletions" "github.com/coder/coder/v2/aibridge/internal/testutil" ) @@ -73,7 +74,7 @@ func TestStreamingInterception_RelaysUpstreamErrorToClient(t *testing.T) { Key: "test-key", } - req := &ChatCompletionNewParamsWrapper{ + req := &chatcompletions.ChatCompletionNewParamsWrapper{ ChatCompletionNewParams: openai.ChatCompletionNewParams{ Model: "gpt-4", Messages: []openai.ChatCompletionMessageParamUnion{ @@ -88,7 +89,7 @@ func TestStreamingInterception_RelaysUpstreamErrorToClient(t *testing.T) { httpReq := httptest.NewRequest(http.MethodPost, "/chat/completions", nil) tracer := otel.Tracer("test") - interceptor := NewStreamingInterceptor(uuid.New(), req, config.ProviderOpenAI, cfg, httpReq.Header, "Authorization", tracer, intercept.CredentialInfo{}) + interceptor := chatcompletions.NewStreamingInterceptor(uuid.New(), req, config.ProviderOpenAI, cfg, httpReq.Header, "Authorization", tracer, intercept.CredentialInfo{}) logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) interceptor.Setup(logger, &testutil.MockRecorder{}, nil) diff --git a/aibridge/intercept/client_headers_test.go b/aibridge/intercept/client_headers_test.go index ecd2f018aa..f811fbecb0 100644 --- a/aibridge/intercept/client_headers_test.go +++ b/aibridge/intercept/client_headers_test.go @@ -1,4 +1,4 @@ -package intercept +package intercept_test import ( "net/http" @@ -6,6 +6,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/coder/coder/v2/aibridge/intercept" ) func TestPrepareClientHeaders(t *testing.T) { @@ -14,7 +16,7 @@ func TestPrepareClientHeaders(t *testing.T) { t.Run("nil input returns empty header", func(t *testing.T) { t.Parallel() - result := PrepareClientHeaders(nil) + result := intercept.PrepareClientHeaders(nil) require.Empty(t, result) }) @@ -29,7 +31,7 @@ func TestPrepareClientHeaders(t *testing.T) { "X-Custom": {"preserved"}, } - result := PrepareClientHeaders(input) + result := intercept.PrepareClientHeaders(input) assert.Empty(t, result.Get("Connection")) assert.Empty(t, result.Get("Keep-Alive")) @@ -48,7 +50,7 @@ func TestPrepareClientHeaders(t *testing.T) { "X-Custom": {"preserved"}, } - result := PrepareClientHeaders(input) + result := intercept.PrepareClientHeaders(input) assert.Empty(t, result.Get("Host")) assert.Empty(t, result.Get("Accept-Encoding")) @@ -65,7 +67,7 @@ func TestPrepareClientHeaders(t *testing.T) { "X-Custom": {"preserved"}, } - result := PrepareClientHeaders(input) + result := intercept.PrepareClientHeaders(input) assert.Empty(t, result.Get("Authorization")) assert.Empty(t, result.Get("X-Api-Key")) @@ -79,7 +81,7 @@ func TestPrepareClientHeaders(t *testing.T) { "X-Custom": {"value-1", "value-2"}, } - result := PrepareClientHeaders(input) + result := intercept.PrepareClientHeaders(input) require.Equal(t, []string{"value-1", "value-2"}, result["X-Custom"]) }) @@ -93,7 +95,7 @@ func TestPrepareClientHeaders(t *testing.T) { } originalCopy := input.Clone() - _ = PrepareClientHeaders(input) + _ = intercept.PrepareClientHeaders(input) require.Equal(t, originalCopy, input) }) @@ -113,7 +115,7 @@ func TestBuildUpstreamHeaders(t *testing.T) { "User-Agent": {"claude-code/1.0"}, } - result := BuildUpstreamHeaders(sdkHeader, clientHeaders, "Authorization") + result := intercept.BuildUpstreamHeaders(sdkHeader, clientHeaders, "Authorization") assert.Equal(t, "Bearer sk-provider-key", result.Get("Authorization")) assert.Equal(t, "claude-code/1.0", result.Get("User-Agent")) @@ -131,7 +133,7 @@ func TestBuildUpstreamHeaders(t *testing.T) { "Anthropic-Beta": {"prompt-caching-2024-07-31"}, } - result := BuildUpstreamHeaders(sdkHeader, clientHeaders, "X-Api-Key") + result := intercept.BuildUpstreamHeaders(sdkHeader, clientHeaders, "X-Api-Key") assert.Equal(t, "sk-ant-provider-key", result.Get("X-Api-Key")) assert.Empty(t, result.Get("Authorization")) @@ -151,7 +153,7 @@ func TestBuildUpstreamHeaders(t *testing.T) { "User-Agent": {"claude-code/1.0"}, } - result := BuildUpstreamHeaders(sdkHeader, clientHeaders, "Authorization") + result := intercept.BuildUpstreamHeaders(sdkHeader, clientHeaders, "Authorization") assert.Equal(t, "Bearer sk-key", result.Get("Authorization")) assert.Equal(t, "user-123", result.Get("X-Ai-Bridge-Actor-Id")) @@ -174,7 +176,7 @@ func TestBuildUpstreamHeaders(t *testing.T) { "User-Agent": {"claude-code/1.0"}, } - result := BuildUpstreamHeaders(sdkHeader, clientHeaders, "Authorization") + result := intercept.BuildUpstreamHeaders(sdkHeader, clientHeaders, "Authorization") assert.Empty(t, result.Get("Connection")) assert.Empty(t, result.Get("Host")) @@ -192,7 +194,7 @@ func TestBuildUpstreamHeaders(t *testing.T) { "User-Agent": {"claude-code/1.0"}, } - result := BuildUpstreamHeaders(sdkHeader, clientHeaders, "Authorization") + result := intercept.BuildUpstreamHeaders(sdkHeader, clientHeaders, "Authorization") assert.Empty(t, result.Get("Authorization")) assert.Equal(t, "claude-code/1.0", result.Get("User-Agent")) @@ -211,7 +213,7 @@ func TestBuildUpstreamHeaders(t *testing.T) { sdkCopy := sdkHeader.Clone() clientCopy := clientHeaders.Clone() - _ = BuildUpstreamHeaders(sdkHeader, clientHeaders, "Authorization") + _ = intercept.BuildUpstreamHeaders(sdkHeader, clientHeaders, "Authorization") require.Equal(t, sdkCopy, sdkHeader) require.Equal(t, clientCopy, clientHeaders) diff --git a/aibridge/intercept/eventstream/eventstream.go b/aibridge/intercept/eventstream/eventstream.go index 562e385c2c..9baeadede5 100644 --- a/aibridge/intercept/eventstream/eventstream.go +++ b/aibridge/intercept/eventstream/eventstream.go @@ -32,7 +32,6 @@ type EventStream struct { initiated atomic.Bool initiateOnce sync.Once - closeOnce sync.Once shutdownOnce sync.Once eventsCh chan event @@ -133,7 +132,7 @@ func (s *EventStream) Start(w http.ResponseWriter, r *http.Request) { return } if err := flush(w); err != nil { - s.logger.Warn(ctx, "failed to flush", slog.Error(err)) + s.logger.Warn(ctx, "failed to flush event stream", slog.Error(err)) return } @@ -240,8 +239,7 @@ func flush(w http.ResponseWriter) (err error) { } defer func() { - if r := recover(); r != nil { - // Likely a broken connection, don't spam the logs. + if r := recover(); r != nil { //nolint:revive,staticcheck // Intentionally swallowed; likely a broken connection. } }() diff --git a/aibridge/intercept/interceptor.go b/aibridge/intercept/interceptor.go index 270d1aa866..33cbc51dff 100644 --- a/aibridge/intercept/interceptor.go +++ b/aibridge/intercept/interceptor.go @@ -17,7 +17,7 @@ type Interceptor interface { ID() uuid.UUID // Setup injects some required dependencies. This MUST be called before using the interceptor // to process requests. - Setup(logger slog.Logger, recorder recorder.Recorder, mcpProxy mcp.ServerProxier) + Setup(logger slog.Logger, rec recorder.Recorder, mcpProxy mcp.ServerProxier) // Model returns the model in use for this [Interceptor]. Model() string // ProcessRequest handles the HTTP request. diff --git a/aibridge/intercept/messages/base.go b/aibridge/intercept/messages/base.go index 3a62db322e..3cdc584d33 100644 --- a/aibridge/intercept/messages/base.go +++ b/aibridge/intercept/messages/base.go @@ -65,7 +65,7 @@ var bedrockSupportedBetaFlags = map[string]bool{ type interceptionBase struct { id uuid.UUID providerName string - reqPayload MessagesRequestPayload + reqPayload RequestPayload cfg aibconfig.Anthropic bedrockCfg *aibconfig.AWSBedrock @@ -90,9 +90,9 @@ func (i *interceptionBase) Credential() intercept.CredentialInfo { return i.credential } -func (i *interceptionBase) Setup(logger slog.Logger, recorder recorder.Recorder, mcpProxy mcp.ServerProxier) { +func (i *interceptionBase) Setup(logger slog.Logger, rec recorder.Recorder, mcpProxy mcp.ServerProxier) { i.logger = logger - i.recorder = recorder + i.recorder = rec i.mcpProxy = mcpProxy } @@ -116,15 +116,15 @@ func (i *interceptionBase) Model() string { return i.reqPayload.model() } -func (s *interceptionBase) baseTraceAttributes(r *http.Request, streaming bool) []attribute.KeyValue { +func (i *interceptionBase) baseTraceAttributes(r *http.Request, streaming bool) []attribute.KeyValue { return []attribute.KeyValue{ attribute.String(tracing.RequestPath, r.URL.Path), - attribute.String(tracing.InterceptionID, s.id.String()), + attribute.String(tracing.InterceptionID, i.id.String()), attribute.String(tracing.InitiatorID, aibcontext.ActorIDFromContext(r.Context())), - attribute.String(tracing.Provider, s.providerName), - attribute.String(tracing.Model, s.Model()), + attribute.String(tracing.Provider, i.providerName), + attribute.String(tracing.Model, i.Model()), attribute.Bool(tracing.Streaming, streaming), - attribute.Bool(tracing.IsBedrock, s.bedrockCfg != nil), + attribute.Bool(tracing.IsBedrock, i.bedrockCfg != nil), } } @@ -174,24 +174,22 @@ func (i *interceptionBase) disableParallelToolCalls() { } // extractModelThoughts returns any thinking blocks that were returned in the response. -func (i *interceptionBase) extractModelThoughts(msg *anthropic.Message) []*recorder.ModelThoughtRecord { +func (*interceptionBase) extractModelThoughts(msg *anthropic.Message) []*recorder.ModelThoughtRecord { if msg == nil { return nil } var thoughtRecords []*recorder.ModelThoughtRecord for _, block := range msg.Content { - switch variant := block.AsAny().(type) { - case anthropic.ThinkingBlock: - if variant.Thinking == "" { - continue - } - thoughtRecords = append(thoughtRecords, &recorder.ModelThoughtRecord{ - Content: variant.Thinking, - Metadata: recorder.Metadata{"source": recorder.ThoughtSourceThinking}, - }) - } // anthropic.RedactedThinkingBlock also exists, but there's nothing useful we can capture. + variant, ok := block.AsAny().(anthropic.ThinkingBlock) + if !ok || variant.Thinking == "" { + continue + } + thoughtRecords = append(thoughtRecords, &recorder.ModelThoughtRecord{ + Content: variant.Thinking, + Metadata: recorder.Metadata{"source": recorder.ThoughtSourceThinking}, + }) } return thoughtRecords } @@ -264,7 +262,7 @@ func (i *interceptionBase) withBody() option.RequestOption { return option.WithRequestBody("application/json", []byte(i.reqPayload)) } -func (i *interceptionBase) withAWSBedrockOptions(ctx context.Context, cfg *aibconfig.AWSBedrock) ([]option.RequestOption, error) { +func (*interceptionBase) withAWSBedrockOptions(ctx context.Context, cfg *aibconfig.AWSBedrock) ([]option.RequestOption, error) { if cfg == nil { return nil, xerrors.New("nil config given") } @@ -405,7 +403,7 @@ func filterBedrockBetaFlags(headers http.Header, model string) { } // writeUpstreamError marshals and writes a given error. -func (i *interceptionBase) writeUpstreamError(w http.ResponseWriter, antErr *ErrorResponse) { +func (i *interceptionBase) writeUpstreamError(w http.ResponseWriter, antErr *messagesResponseError) { if antErr == nil { return } @@ -415,7 +413,7 @@ func (i *interceptionBase) writeUpstreamError(w http.ResponseWriter, antErr *Err out, err := json.Marshal(antErr) if err != nil { - i.logger.Warn(context.Background(), "failed to marshal upstream error", slog.Error(err), slog.F("error_payload", slog.F("%+v", antErr))) + i.logger.Warn(context.Background(), "failed to marshal upstream error", slog.Error(err), slog.F("error_payload", antErr)) // Response has to match expected format. // See https://docs.claude.com/en/api/errors#error-shapes. _, _ = w.Write([]byte(fmt.Sprintf(`{ @@ -487,7 +485,7 @@ func accumulateUsage(dest, src any) { } } -func getErrorResponse(err error) *ErrorResponse { +func getErrorResponse(err error) *messagesResponseError { var apierr *anthropic.Error if !errors.As(err, &apierr) { return nil @@ -505,7 +503,7 @@ func getErrorResponse(err error) *ErrorResponse { typ = string(detail.Type) } - return &ErrorResponse{ + return &messagesResponseError{ ErrorResponse: &anthropic.ErrorResponse{ Error: anthropic.ErrorObjectUnion{ Message: msg, @@ -517,16 +515,16 @@ func getErrorResponse(err error) *ErrorResponse { } } -var _ error = &ErrorResponse{} +var _ error = &messagesResponseError{} -type ErrorResponse struct { +type messagesResponseError struct { *anthropic.ErrorResponse StatusCode int `json:"-"` } -func newErrorResponse(msg error) *ErrorResponse { - return &ErrorResponse{ +func newErrorResponse(msg error) *messagesResponseError { + return &messagesResponseError{ ErrorResponse: &shared.ErrorResponse{ Error: shared.ErrorObjectUnion{ Message: msg.Error(), @@ -536,7 +534,7 @@ func newErrorResponse(msg error) *ErrorResponse { } } -func (a *ErrorResponse) Error() string { +func (a *messagesResponseError) Error() string { if a.ErrorResponse == nil { return "" } diff --git a/aibridge/intercept/messages/base_test.go b/aibridge/intercept/messages/base_test.go index be775966eb..de6444a2b4 100644 --- a/aibridge/intercept/messages/base_test.go +++ b/aibridge/intercept/messages/base_test.go @@ -1,4 +1,4 @@ -package messages +package messages //nolint:testpackage // tests unexported internals import ( "context" @@ -197,6 +197,8 @@ func TestAWSBedrockValidation(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + t.Parallel() + base := &interceptionBase{} opts, err := base.withAWSBedrockOptions(context.Background(), tt.cfg) @@ -212,7 +214,10 @@ func TestAWSBedrockValidation(t *testing.T) { } func TestAccumulateUsage(t *testing.T) { + t.Parallel() + t.Run("Usage to Usage", func(t *testing.T) { + t.Parallel() dest := &anthropic.Usage{ InputTokens: 10, OutputTokens: 20, @@ -253,6 +258,8 @@ func TestAccumulateUsage(t *testing.T) { }) t.Run("MessageDeltaUsage to MessageDeltaUsage", func(t *testing.T) { + t.Parallel() + dest := &anthropic.MessageDeltaUsage{ InputTokens: 10, OutputTokens: 20, @@ -283,6 +290,8 @@ func TestAccumulateUsage(t *testing.T) { }) t.Run("Usage to MessageDeltaUsage", func(t *testing.T) { + t.Parallel() + dest := &anthropic.MessageDeltaUsage{ InputTokens: 10, OutputTokens: 20, @@ -317,6 +326,8 @@ func TestAccumulateUsage(t *testing.T) { }) t.Run("MessageDeltaUsage to Usage", func(t *testing.T) { + t.Parallel() + dest := &anthropic.Usage{ InputTokens: 10, OutputTokens: 20, @@ -354,6 +365,8 @@ func TestAccumulateUsage(t *testing.T) { }) t.Run("Nil or unsupported types", func(t *testing.T) { + t.Parallel() + // Test with nil dest var nilUsage *anthropic.Usage source := anthropic.Usage{InputTokens: 10} @@ -763,10 +776,10 @@ func TestAugmentRequestForBedrock_AdaptiveThinking(t *testing.T) { } } -func mustMessagesPayload(t *testing.T, requestBody string) MessagesRequestPayload { +func mustMessagesPayload(t *testing.T, requestBody string) RequestPayload { t.Helper() - payload, err := NewMessagesRequestPayload([]byte(requestBody)) + payload, err := NewRequestPayload([]byte(requestBody)) require.NoError(t, err) return payload @@ -777,11 +790,11 @@ type mockServerProxier struct { tools []*mcp.Tool } -func (m *mockServerProxier) Init(context.Context) error { +func (*mockServerProxier) Init(context.Context) error { return nil } -func (m *mockServerProxier) Shutdown(context.Context) error { +func (*mockServerProxier) Shutdown(context.Context) error { return nil } @@ -798,8 +811,8 @@ func (m *mockServerProxier) GetTool(id string) *mcp.Tool { return nil } -func (m *mockServerProxier) CallTool(context.Context, string, any) (*mcpgo.CallToolResult, error) { - return nil, nil +func (*mockServerProxier) CallTool(context.Context, string, any) (*mcpgo.CallToolResult, error) { + return nil, nil //nolint:nilnil // mock: no-op implementation } func TestFilterBedrockBetaFlags(t *testing.T) { diff --git a/aibridge/intercept/messages/blocking.go b/aibridge/intercept/messages/blocking.go index 667ae00d46..610f934578 100644 --- a/aibridge/intercept/messages/blocking.go +++ b/aibridge/intercept/messages/blocking.go @@ -31,7 +31,7 @@ type BlockingInterception struct { func NewBlockingInterceptor( id uuid.UUID, - reqPayload MessagesRequestPayload, + reqPayload RequestPayload, providerName string, cfg config.Anthropic, bedrockCfg *config.AWSBedrock, @@ -53,15 +53,15 @@ func NewBlockingInterceptor( }} } -func (i *BlockingInterception) Setup(logger slog.Logger, recorder recorder.Recorder, mcpProxy mcp.ServerProxier) { - i.interceptionBase.Setup(logger.Named("blocking"), recorder, mcpProxy) +func (i *BlockingInterception) Setup(logger slog.Logger, rec recorder.Recorder, mcpProxy mcp.ServerProxier) { + i.interceptionBase.Setup(logger.Named("blocking"), rec, mcpProxy) } func (i *BlockingInterception) TraceAttributes(r *http.Request) []attribute.KeyValue { return i.interceptionBase.baseTraceAttributes(r, false) } -func (s *BlockingInterception) Streaming() bool { +func (*BlockingInterception) Streaming() bool { return false } diff --git a/aibridge/intercept/messages/reqpayload.go b/aibridge/intercept/messages/reqpayload.go index fa5142eadd..dfe52fc80c 100644 --- a/aibridge/intercept/messages/reqpayload.go +++ b/aibridge/intercept/messages/reqpayload.go @@ -82,12 +82,12 @@ var ( } ) -// MessagesRequestPayload is raw JSON bytes of an Anthropic Messages API request. +// RequestPayload is raw JSON bytes of an Anthropic Messages API request. // Methods provide package-specific reads and rewrites while preserving the // original body for upstream pass-through. -type MessagesRequestPayload []byte +type RequestPayload []byte -func NewMessagesRequestPayload(raw []byte) (MessagesRequestPayload, error) { +func NewRequestPayload(raw []byte) (RequestPayload, error) { if len(bytes.TrimSpace(raw)) == 0 { return nil, xerrors.New("messages empty request body") } @@ -95,10 +95,10 @@ func NewMessagesRequestPayload(raw []byte) (MessagesRequestPayload, error) { return nil, xerrors.New("messages invalid JSON request body") } - return MessagesRequestPayload(raw), nil + return RequestPayload(raw), nil } -func (p MessagesRequestPayload) Stream() bool { +func (p RequestPayload) Stream() bool { v := gjson.GetBytes(p, messagesReqPathStream) if !v.IsBool() { return false @@ -106,11 +106,11 @@ func (p MessagesRequestPayload) Stream() bool { return v.Bool() } -func (p MessagesRequestPayload) model() string { +func (p RequestPayload) model() string { return gjson.GetBytes(p, messagesReqPathModel).Str } -func (p MessagesRequestPayload) correlatingToolCallID() *string { +func (p RequestPayload) correlatingToolCallID() *string { messages := gjson.GetBytes(p, messagesReqPathMessages) if !messages.IsArray() { return nil @@ -147,7 +147,7 @@ func (p MessagesRequestPayload) correlatingToolCallID() *string { // lastUserPrompt returns the prompt text from the last user message. If no prompt // is found, it returns empty string, false, nil. Unexpected shapes are treated as // unsupported and do not fail the request path. -func (p MessagesRequestPayload) lastUserPrompt() (string, bool, error) { +func (p RequestPayload) lastUserPrompt() (string, bool, error) { messages := gjson.GetBytes(p, messagesReqPathMessages) if !messages.Exists() || messages.Type == gjson.Null { return "", false, nil @@ -195,7 +195,7 @@ func (p MessagesRequestPayload) lastUserPrompt() (string, bool, error) { return "", false, nil } -func (p MessagesRequestPayload) injectTools(injected []anthropic.ToolUnionParam) (MessagesRequestPayload, error) { +func (p RequestPayload) injectTools(injected []anthropic.ToolUnionParam) (RequestPayload, error) { if len(injected) == 0 { return p, nil } @@ -221,7 +221,7 @@ func (p MessagesRequestPayload) injectTools(injected []anthropic.ToolUnionParam) return p.set(messagesReqPathTools, allTools) } -func (p MessagesRequestPayload) disableParallelToolCalls() (MessagesRequestPayload, error) { +func (p RequestPayload) disableParallelToolCalls() (RequestPayload, error) { toolChoice := gjson.GetBytes(p, messagesReqPathToolChoice) // If no tool_choice was defined, assume auto. @@ -258,7 +258,7 @@ func (p MessagesRequestPayload) disableParallelToolCalls() (MessagesRequestPaylo } } -func (p MessagesRequestPayload) appendedMessages(newMessages []anthropic.MessageParam) (MessagesRequestPayload, error) { +func (p RequestPayload) appendedMessages(newMessages []anthropic.MessageParam) (RequestPayload, error) { if len(newMessages) == 0 { return p, nil } @@ -285,11 +285,11 @@ func (p MessagesRequestPayload) appendedMessages(newMessages []anthropic.Message return p.set(messagesReqPathMessages, allMessages) } -func (p MessagesRequestPayload) withModel(model string) (MessagesRequestPayload, error) { +func (p RequestPayload) withModel(model string) (RequestPayload, error) { return p.set(messagesReqPathModel, model) } -func (p MessagesRequestPayload) messages() ([]json.RawMessage, error) { +func (p RequestPayload) messages() ([]json.RawMessage, error) { messages := gjson.GetBytes(p, messagesReqPathMessages) if !messages.Exists() || messages.Type == gjson.Null { return nil, nil @@ -301,7 +301,7 @@ func (p MessagesRequestPayload) messages() ([]json.RawMessage, error) { return p.resultToRawMessage(messages.Array()), nil } -func (p MessagesRequestPayload) tools() ([]json.RawMessage, error) { +func (p RequestPayload) tools() ([]json.RawMessage, error) { tools := gjson.GetBytes(p, messagesReqPathTools) if !tools.Exists() || tools.Type == gjson.Null { return nil, nil @@ -313,7 +313,7 @@ func (p MessagesRequestPayload) tools() ([]json.RawMessage, error) { return p.resultToRawMessage(tools.Array()), nil } -func (p MessagesRequestPayload) resultToRawMessage(items []gjson.Result) []json.RawMessage { +func (RequestPayload) resultToRawMessage(items []gjson.Result) []json.RawMessage { // gjson.Result conversion to json.RawMessage is needed because // gjson.Result does not implement json.Marshaler — would // serialize its struct fields instead of the raw JSON it represents. @@ -326,7 +326,7 @@ func (p MessagesRequestPayload) resultToRawMessage(items []gjson.Result) []json. // convertAdaptiveThinkingForBedrock converts thinking.type "adaptive" to "enabled" with a calculated budget_tokens // conversion is needed for Bedrock models that does not support the "adaptive" thinking.type -func (p MessagesRequestPayload) convertAdaptiveThinkingForBedrock() (MessagesRequestPayload, error) { +func (p RequestPayload) convertAdaptiveThinkingForBedrock() (RequestPayload, error) { thinkingType := gjson.GetBytes(p, messagesReqPathThinkingType) if thinkingType.String() != constAdaptive { return p, nil @@ -377,7 +377,7 @@ func (p MessagesRequestPayload) convertAdaptiveThinkingForBedrock() (MessagesReq // removed when the corresponding flag is absent from the Anthropic-Beta header. // Model-specific beta flags must already be filtered from the header before // calling this method (see filterBedrockBetaFlags). -func (p MessagesRequestPayload) removeUnsupportedBedrockFields(headers http.Header) (MessagesRequestPayload, error) { +func (p RequestPayload) removeUnsupportedBedrockFields(headers http.Header) (RequestPayload, error) { var payloadMap map[string]any if err := json.Unmarshal(p, &payloadMap); err != nil { return p, xerrors.Errorf("failed to unmarshal request payload when removing unsupported Bedrock fields: %w", err) @@ -400,13 +400,13 @@ func (p MessagesRequestPayload) removeUnsupportedBedrockFields(headers http.Head if err != nil { return p, xerrors.Errorf("failed to marshal request payload when removing unsupported Bedrock fields: %w", err) } - return MessagesRequestPayload(result), nil + return RequestPayload(result), nil } -func (p MessagesRequestPayload) set(path string, value any) (MessagesRequestPayload, error) { +func (p RequestPayload) set(path string, value any) (RequestPayload, error) { out, err := sjson.SetBytes(p, path, value) if err != nil { return p, xerrors.Errorf("set %s: %w", path, err) } - return MessagesRequestPayload(out), nil + return RequestPayload(out), nil } diff --git a/aibridge/intercept/messages/reqpayload_test.go b/aibridge/intercept/messages/reqpayload_test.go index 0f16ce5463..d7cf8ba9b1 100644 --- a/aibridge/intercept/messages/reqpayload_test.go +++ b/aibridge/intercept/messages/reqpayload_test.go @@ -1,4 +1,4 @@ -package messages +package messages //nolint:testpackage // tests unexported internals import ( "testing" @@ -11,7 +11,7 @@ import ( "github.com/coder/coder/v2/aibridge/utils" ) -func TestNewMessagesRequestPayload(t *testing.T) { +func TestNewRequestPayload(t *testing.T) { t.Parallel() testCases := []struct { @@ -42,7 +42,7 @@ func TestNewMessagesRequestPayload(t *testing.T) { t.Run(testCase.name, func(t *testing.T) { t.Parallel() - payload, err := NewMessagesRequestPayload(testCase.requestBody) + payload, err := NewRequestPayload(testCase.requestBody) if testCase.expectError { require.Error(t, err) require.Nil(t, payload) @@ -50,12 +50,12 @@ func TestNewMessagesRequestPayload(t *testing.T) { } require.NoError(t, err) - require.Equal(t, MessagesRequestPayload(testCase.requestBody), payload) + require.Equal(t, RequestPayload(testCase.requestBody), payload) }) } } -func TestMessagesRequestPayloadStream(t *testing.T) { +func TestRequestPayloadStream(t *testing.T) { t.Parallel() testCases := []struct { @@ -97,7 +97,7 @@ func TestMessagesRequestPayloadStream(t *testing.T) { } } -func TestMessagesRequestPayloadModel(t *testing.T) { +func TestRequestPayloadModel(t *testing.T) { t.Parallel() testCases := []struct { @@ -132,7 +132,7 @@ func TestMessagesRequestPayloadModel(t *testing.T) { } } -func TestMessagesRequestPayloadLastUserPrompt(t *testing.T) { +func TestRequestPayloadLastUserPrompt(t *testing.T) { t.Parallel() testCases := []struct { @@ -229,7 +229,7 @@ func TestMessagesRequestPayloadLastUserPrompt(t *testing.T) { } } -func TestMessagesRequestPayloadCorrelatingToolCallID(t *testing.T) { +func TestRequestPayloadCorrelatingToolCallID(t *testing.T) { t.Parallel() testCases := []struct { @@ -266,7 +266,7 @@ func TestMessagesRequestPayloadCorrelatingToolCallID(t *testing.T) { } } -func TestMessagesRequestPayloadInjectTools(t *testing.T) { +func TestRequestPayloadInjectTools(t *testing.T) { t.Parallel() payload := mustMessagesPayload(t, `{"model":"claude-opus-4-5","max_tokens":1024,"messages":[{"role":"user","content":"hello"}],"tools":[{"name":"existing_tool","type":"custom","input_schema":{"type":"object","properties":{}},"cache_control":{"type":"ephemeral"}}]}`) @@ -291,7 +291,7 @@ func TestMessagesRequestPayloadInjectTools(t *testing.T) { require.Equal(t, "ephemeral", toolItems[1].Get("cache_control.type").String()) } -func TestMessagesRequestPayloadConvertAdaptiveThinkingForBedrock(t *testing.T) { +func TestRequestPayloadConvertAdaptiveThinkingForBedrock(t *testing.T) { t.Parallel() testCases := []struct { @@ -361,7 +361,7 @@ func TestMessagesRequestPayloadConvertAdaptiveThinkingForBedrock(t *testing.T) { } } -func TestMessagesRequestPayloadDisableParallelToolCalls(t *testing.T) { +func TestRequestPayloadDisableParallelToolCalls(t *testing.T) { t.Parallel() testCases := []struct { @@ -451,7 +451,7 @@ func TestMessagesRequestPayloadDisableParallelToolCalls(t *testing.T) { } } -func TestMessagesRequestPayloadAppendedMessages(t *testing.T) { +func TestRequestPayloadAppendedMessages(t *testing.T) { t.Parallel() payload := mustMessagesPayload(t, `{"model":"claude-opus-4-5","max_tokens":1024,"messages":[{"role":"user","content":"hello"}]}`) diff --git a/aibridge/intercept/messages/streaming.go b/aibridge/intercept/messages/streaming.go index b52b12a475..a6bdc52bee 100644 --- a/aibridge/intercept/messages/streaming.go +++ b/aibridge/intercept/messages/streaming.go @@ -36,7 +36,7 @@ type StreamingInterception struct { func NewStreamingInterceptor( id uuid.UUID, - reqPayload MessagesRequestPayload, + reqPayload RequestPayload, providerName string, cfg config.Anthropic, bedrockCfg *config.AWSBedrock, @@ -58,16 +58,16 @@ func NewStreamingInterceptor( }} } -func (s *StreamingInterception) Setup(logger slog.Logger, recorder recorder.Recorder, mcpProxy mcp.ServerProxier) { - s.interceptionBase.Setup(logger.Named("streaming"), recorder, mcpProxy) +func (i *StreamingInterception) Setup(logger slog.Logger, rec recorder.Recorder, mcpProxy mcp.ServerProxier) { + i.interceptionBase.Setup(logger.Named("streaming"), rec, mcpProxy) } -func (s *StreamingInterception) Streaming() bool { +func (*StreamingInterception) Streaming() bool { return true } -func (s *StreamingInterception) TraceAttributes(r *http.Request) []attribute.KeyValue { - return s.interceptionBase.baseTraceAttributes(r, true) +func (i *StreamingInterception) TraceAttributes(r *http.Request) []attribute.KeyValue { + return i.interceptionBase.baseTraceAttributes(r, true) } // ProcessRequest handles a request to /v1/messages. @@ -156,7 +156,7 @@ newStream: for { // TODO add outer loop span (https://github.com/coder/aibridge/issues/67) if err := streamCtx.Err(); err != nil { - lastErr = xerrors.Errorf("stream exit: %w", err) + interceptionErr = xerrors.Errorf("stream exit: %w", err) break } @@ -178,8 +178,7 @@ newStream: // Tool-related handling. switch event.Type { case string(constant.ValueOf[constant.ContentBlockStart]()): - switch block := event.AsContentBlockStart().ContentBlock.AsAny().(type) { - case anthropic.ToolUseBlock: + if block, ok := event.AsContentBlockStart().ContentBlock.AsAny().(anthropic.ToolUseBlock); ok { lastToolName = block.Name if i.mcpProxy != nil && i.mcpProxy.GetTool(block.Name) != nil { @@ -306,8 +305,7 @@ newStream: foundTools int ) for _, block := range message.Content { - switch variant := block.AsAny().(type) { - case anthropic.ToolUseBlock: + if variant, ok := block.AsAny().(anthropic.ToolUseBlock); ok { foundTools++ if variant.Name == name { input = variant.Input @@ -430,24 +428,23 @@ newStream: // Causes a new stream to be run with updated messages. isFirst = false continue newStream - } else { - // Find all the non-injected tools and track their uses. - for _, block := range message.Content { - switch variant := block.AsAny().(type) { - case anthropic.ToolUseBlock: - if i.mcpProxy != nil && i.mcpProxy.GetTool(variant.Name) != nil { - continue - } + } - _ = i.recorder.RecordToolUsage(streamCtx, &recorder.ToolUsageRecord{ - InterceptionID: i.ID().String(), - MsgID: message.ID, - ToolCallID: variant.ID, - Tool: variant.Name, - Args: variant.Input, - Injected: false, - }) + // Find all the non-injected tools and track their uses. + for _, block := range message.Content { + if variant, ok := block.AsAny().(anthropic.ToolUseBlock); ok { + if i.mcpProxy != nil && i.mcpProxy.GetTool(variant.Name) != nil { + continue } + + _ = i.recorder.RecordToolUsage(streamCtx, &recorder.ToolUsageRecord{ + InterceptionID: i.ID().String(), + MsgID: message.ID, + ToolCallID: variant.ID, + Tool: variant.Name, + Args: variant.Input, + Injected: false, + }) } } } @@ -463,11 +460,10 @@ newStream: if eventstream.IsUnrecoverableError(err) { logger.Debug(ctx, "processing terminated", slog.Error(err)) break // Stop processing if client disconnected or context canceled. - } else { - logger.Warn(ctx, "failed to relay event", slog.Error(err)) - lastErr = xerrors.Errorf("relay event: %w", err) - break } + logger.Warn(ctx, "failed to relay event", slog.Error(err)) + lastErr = xerrors.Errorf("relay event: %w", err) + break } } @@ -477,8 +473,8 @@ newStream: MsgID: message.ID, Prompt: prompt, }) - prompt = "" - promptFound = false + prompt = "" //nolint:ineffassign // reset to prevent double-recording across newStream iterations + promptFound = false //nolint:ineffassign // reset to prevent double-recording across newStream iterations } if events.IsStreaming() { @@ -491,7 +487,7 @@ newStream: logger.Warn(ctx, "anthropic stream error", slog.Error(streamErr)) interceptionErr = antErr } else { - logger.Warn(ctx, "unknown error", slog.Error(streamErr)) + logger.Warn(ctx, "unknown stream error encountered", slog.Error(streamErr)) // Unfortunately, the Anthropic SDK does not support parsing errors received in the stream // into known types (i.e. [shared.OverloadedError]). // See https://github.com/anthropics/anthropic-sdk-go/blob/v1.12.0/packages/ssestream/ssestream.go#L172-L174 @@ -500,14 +496,14 @@ newStream: } } else if lastErr != nil { // Otherwise check if any logical errors occurred during processing. - logger.Warn(ctx, "stream failed", slog.Error(lastErr)) + logger.Warn(ctx, "stream processing failed", slog.Error(lastErr)) interceptionErr = newErrorResponse(xerrors.Errorf("processing error: %w", lastErr)) } if interceptionErr != nil { payload, err := i.marshal(interceptionErr) if err != nil { - logger.Warn(ctx, "failed to marshal error", slog.Error(err), slog.F("error_payload", slog.F("%+v", interceptionErr))) + logger.Warn(ctx, "failed to marshal error", slog.Error(err), slog.F("error_payload", interceptionErr.Error())) } else if err := events.Send(streamCtx, payload); err != nil { logger.Warn(ctx, "failed to relay error", slog.Error(err), slog.F("payload", payload)) } @@ -518,11 +514,11 @@ newStream: } shutdownCtx, shutdownCancel := context.WithTimeout(ctx, time.Second*30) - defer shutdownCancel() // Give the events stream 30 seconds (TODO: configurable) to gracefully shutdown. if err := events.Shutdown(shutdownCtx); err != nil { logger.Warn(ctx, "event stream shutdown", slog.Error(err)) } + shutdownCancel() // Cancel the stream context, we're now done. if interceptionErr != nil { @@ -537,8 +533,8 @@ newStream: return interceptionErr } -func (s *StreamingInterception) marshalEvent(event anthropic.MessageStreamEventUnion) ([]byte, error) { - sj, err := sjson.Set(event.RawJSON(), "message.id", s.ID().String()) +func (i *StreamingInterception) marshalEvent(event anthropic.MessageStreamEventUnion) ([]byte, error) { + sj, err := sjson.Set(event.RawJSON(), "message.id", i.ID().String()) if err != nil { return nil, xerrors.Errorf("marshal event id failed: %w", err) } @@ -548,10 +544,10 @@ func (s *StreamingInterception) marshalEvent(event anthropic.MessageStreamEventU return nil, xerrors.Errorf("marshal event usage failed: %w", err) } - return s.encodeForStream([]byte(sj), event.Type), nil + return i.encodeForStream([]byte(sj), event.Type), nil } -func (s *StreamingInterception) marshal(payload any) ([]byte, error) { +func (i *StreamingInterception) marshal(payload any) ([]byte, error) { data, err := json.Marshal(payload) if err != nil { return nil, xerrors.Errorf("marshal payload: %w", err) @@ -567,29 +563,30 @@ func (s *StreamingInterception) marshal(payload any) ([]byte, error) { return nil, xerrors.Errorf("could not determine type from payload %q", data) } - return s.encodeForStream(data, eventType), nil + return i.encodeForStream(data, eventType), nil } // https://docs.anthropic.com/en/docs/build-with-claude/streaming#basic-streaming-request -func (s *StreamingInterception) pingPayload() []byte { - return s.encodeForStream([]byte(`{"type": "ping"}`), "ping") +func (i *StreamingInterception) pingPayload() []byte { + return i.encodeForStream([]byte(`{"type": "ping"}`), "ping") } -func (s *StreamingInterception) encodeForStream(payload []byte, typ string) []byte { +func (*StreamingInterception) encodeForStream(payload []byte, typ string) []byte { + // bytes.Buffer writes to in-memory storage and never return errors. var buf bytes.Buffer - buf.WriteString("event: ") - buf.WriteString(typ) - buf.WriteString("\n") - buf.WriteString("data: ") - buf.Write(payload) - buf.WriteString("\n\n") + _, _ = buf.WriteString("event: ") + _, _ = buf.WriteString(typ) + _, _ = buf.WriteString("\n") + _, _ = buf.WriteString("data: ") + _, _ = buf.Write(payload) + _, _ = buf.WriteString("\n\n") return buf.Bytes() } // newStream traces svc.NewStreaming() call. -func (s *StreamingInterception) newStream(ctx context.Context, svc anthropic.MessageService) *ssestream.Stream[anthropic.MessageStreamEventUnion] { - _, span := s.tracer.Start(ctx, "Intercept.ProcessRequest.Upstream", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...)) +func (i *StreamingInterception) newStream(ctx context.Context, svc anthropic.MessageService) *ssestream.Stream[anthropic.MessageStreamEventUnion] { + _, span := i.tracer.Start(ctx, "Intercept.ProcessRequest.Upstream", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...)) defer span.End() - return svc.NewStreaming(ctx, anthropic.MessageNewParams{}, s.withBody()) + return svc.NewStreaming(ctx, anthropic.MessageNewParams{}, i.withBody()) } diff --git a/aibridge/intercept/responses/base.go b/aibridge/intercept/responses/base.go index ad8465cb6a..9affc7d3ea 100644 --- a/aibridge/intercept/responses/base.go +++ b/aibridge/intercept/responses/base.go @@ -42,7 +42,7 @@ type responsesInterceptionBase struct { // clientHeaders are the original HTTP headers from the client request. clientHeaders http.Header authHeaderName string - reqPayload ResponsesRequestPayload + reqPayload RequestPayload cfg config.OpenAI recorder recorder.Recorder @@ -89,9 +89,9 @@ func (i *responsesInterceptionBase) Credential() intercept.CredentialInfo { return i.credential } -func (i *responsesInterceptionBase) Setup(logger slog.Logger, recorder recorder.Recorder, mcpProxy mcp.ServerProxier) { +func (i *responsesInterceptionBase) Setup(logger slog.Logger, rec recorder.Recorder, mcpProxy mcp.ServerProxier) { i.logger = logger.With(slog.F("model", i.Model())) - i.recorder = recorder + i.recorder = rec i.mcpProxy = mcpProxy } @@ -127,7 +127,13 @@ func (i *responsesInterceptionBase) validateRequest(ctx context.Context, w http. // sendCustomErr sends custom responses.Error error to the client // it should only be called before any data is sent back to the client func (i *responsesInterceptionBase) sendCustomErr(ctx context.Context, w http.ResponseWriter, code int, err error) { - respErr := responses.Error{ + // Same JSON shape as responses.Error but using a plain struct because + // responses.Error embeds *http.Request whose GetBody func field + // is not JSON-marshalable (SA1026). + respErr := struct { + Code string `json:"code"` + Message string `json:"message"` + }{ Code: strconv.Itoa(code), Message: err.Error(), } @@ -222,15 +228,15 @@ func (i *responsesInterceptionBase) recordNonInjectedToolUsage(ctx context.Conte func (i *responsesInterceptionBase) parseFunctionCallJSONArgs(ctx context.Context, raw string) recorder.ToolArgs { trimmed := strings.TrimSpace(raw) - if trimmed != "" { - var args recorder.ToolArgs - if err := json.Unmarshal([]byte(trimmed), &args); err != nil { - i.logger.Warn(ctx, "failed to unmarshal tool args", slog.Error(err)) - } else { - return args - } + if trimmed == "" { + return trimmed } - return trimmed + var args recorder.ToolArgs + if err := json.Unmarshal([]byte(trimmed), &args); err != nil { + i.logger.Warn(ctx, "failed to unmarshal tool args", slog.Error(err)) + return trimmed + } + return args } func (i *responsesInterceptionBase) recordTokenUsage(ctx context.Context, response *responses.Response) { @@ -264,7 +270,7 @@ func (i *responsesInterceptionBase) recordTokenUsage(ctx context.Context, respon // extractModelThoughts extracts model thoughts from response output items. // It captures both reasoning summary items and commentary messages (message // output items with "phase": "commentary") as model thoughts. -func (i *responsesInterceptionBase) extractModelThoughts(response *responses.Response) []*recorder.ModelThoughtRecord { +func (*responsesInterceptionBase) extractModelThoughts(response *responses.Response) []*recorder.ModelThoughtRecord { if response == nil { return nil } diff --git a/aibridge/intercept/responses/base_test.go b/aibridge/intercept/responses/base_test.go index 27d5df6493..bf1fa198c8 100644 --- a/aibridge/intercept/responses/base_test.go +++ b/aibridge/intercept/responses/base_test.go @@ -1,4 +1,4 @@ -package responses +package responses //nolint:testpackage // tests unexported internals import ( "net/http" @@ -359,13 +359,16 @@ func (mrw *mockResponseWriter) WriteHeader(statusCode int) { } func TestResponseCopierDoesntSendIfNoResponseReceived(t *testing.T) { + t.Parallel() + mrw := mockResponseWriter{} respCopy := responseCopier{} body := "test_body" - respCopy.buff.Write([]byte(body)) + _, _ = respCopy.buff.Write([]byte(body)) // bytes.Buffer.Write never fails - respCopy.forwardResp(&mrw) + err := respCopy.forwardResp(&mrw) + require.NoError(t, err) require.False(t, mrw.headerCalled) require.False(t, mrw.writeCalled) require.False(t, mrw.writeHeaderCalled) @@ -373,7 +376,8 @@ func TestResponseCopierDoesntSendIfNoResponseReceived(t *testing.T) { // after response is received data is forwarded respCopy.responseReceived.Store(true) - respCopy.forwardResp(&mrw) + err = respCopy.forwardResp(&mrw) + require.NoError(t, err) require.True(t, mrw.headerCalled) require.True(t, mrw.writeCalled) require.True(t, mrw.writeHeaderCalled) diff --git a/aibridge/intercept/responses/blocking.go b/aibridge/intercept/responses/blocking.go index 6a3ea74017..ce98219fc3 100644 --- a/aibridge/intercept/responses/blocking.go +++ b/aibridge/intercept/responses/blocking.go @@ -28,7 +28,7 @@ type BlockingResponsesInterceptor struct { func NewBlockingInterceptor( id uuid.UUID, - reqPayload ResponsesRequestPayload, + reqPayload RequestPayload, providerName string, cfg config.OpenAI, clientHeaders http.Header, @@ -50,11 +50,11 @@ func NewBlockingInterceptor( } } -func (i *BlockingResponsesInterceptor) Setup(logger slog.Logger, recorder recorder.Recorder, mcpProxy mcp.ServerProxier) { - i.responsesInterceptionBase.Setup(logger.Named("blocking"), recorder, mcpProxy) +func (i *BlockingResponsesInterceptor) Setup(logger slog.Logger, rec recorder.Recorder, mcpProxy mcp.ServerProxier) { + i.responsesInterceptionBase.Setup(logger.Named("blocking"), rec, mcpProxy) } -func (i *BlockingResponsesInterceptor) Streaming() bool { +func (*BlockingResponsesInterceptor) Streaming() bool { return false } diff --git a/aibridge/intercept/responses/reqpayload.go b/aibridge/intercept/responses/reqpayload.go index 020863552e..600402d0ec 100644 --- a/aibridge/intercept/responses/reqpayload.go +++ b/aibridge/intercept/responses/reqpayload.go @@ -37,13 +37,13 @@ var ( reqPathType = string(constant.ValueOf[constant.Type]()) ) -// ResponsesRequestPayload is raw JSON bytes of a Responses API request. +// RequestPayload is raw JSON bytes of a Responses API request. // Methods provide package-specific reads and rewrites while preserving the // original body for upstream pass-through. // Note: No changes are made on schema error. -type ResponsesRequestPayload []byte +type RequestPayload []byte -func NewResponsesRequestPayload(raw []byte) (ResponsesRequestPayload, error) { +func NewRequestPayload(raw []byte) (RequestPayload, error) { if len(bytes.TrimSpace(raw)) == 0 { return nil, xerrors.New("empty request body") } @@ -51,22 +51,22 @@ func NewResponsesRequestPayload(raw []byte) (ResponsesRequestPayload, error) { return nil, xerrors.New("invalid JSON payload") } - return ResponsesRequestPayload(raw), nil + return RequestPayload(raw), nil } -func (p ResponsesRequestPayload) Stream() bool { +func (p RequestPayload) Stream() bool { return gjson.GetBytes(p, reqPathStream).Bool() } -func (p ResponsesRequestPayload) model() string { +func (p RequestPayload) model() string { return gjson.GetBytes(p, reqPathModel).String() } -func (p ResponsesRequestPayload) background() bool { +func (p RequestPayload) background() bool { return gjson.GetBytes(p, reqPathBackground).Bool() } -func (p ResponsesRequestPayload) correlatingToolCallID() *string { +func (p RequestPayload) correlatingToolCallID() *string { items := gjson.GetBytes(p, reqPathInput) if !items.IsArray() { return nil @@ -94,7 +94,7 @@ func (p ResponsesRequestPayload) correlatingToolCallID() *string { // item, or the string input value if present. If no prompt is found, it returns // empty string, false, nil. Unexpected shapes are treated as unsupported and do // not fail the request path. -func (p ResponsesRequestPayload) lastUserPrompt(ctx context.Context, logger slog.Logger) (string, bool, error) { +func (p RequestPayload) lastUserPrompt(ctx context.Context, logger slog.Logger) (string, bool, error) { inputItems := gjson.GetBytes(p, reqPathInput) if !inputItems.Exists() || inputItems.Type == gjson.Null { return "", false, nil @@ -155,10 +155,10 @@ func (p ResponsesRequestPayload) lastUserPrompt(ctx context.Context, logger slog } if promptExists { - sb.WriteByte('\n') + _ = sb.WriteByte('\n') // strings.Builder.WriteByte never fails } promptExists = true - sb.WriteString(text.Str) + _, _ = sb.WriteString(text.Str) // strings.Builder.WriteString never fails } if !promptExists { @@ -168,7 +168,7 @@ func (p ResponsesRequestPayload) lastUserPrompt(ctx context.Context, logger slog return sb.String(), true, nil } -func (p ResponsesRequestPayload) injectTools(injected []responses.ToolUnionParam) (ResponsesRequestPayload, error) { +func (p RequestPayload) injectTools(injected []responses.ToolUnionParam) (RequestPayload, error) { if len(injected) == 0 { return p, nil } @@ -189,11 +189,11 @@ func (p ResponsesRequestPayload) injectTools(injected []responses.ToolUnionParam return p.set(reqPathTools, allTools) } -func (p ResponsesRequestPayload) disableParallelToolCalls() (ResponsesRequestPayload, error) { +func (p RequestPayload) disableParallelToolCalls() (RequestPayload, error) { return p.set(reqPathParallelToolCalls, false) } -func (p ResponsesRequestPayload) appendInputItems(items []responses.ResponseInputItemUnionParam) (ResponsesRequestPayload, error) { +func (p RequestPayload) appendInputItems(items []responses.ResponseInputItemUnionParam) (RequestPayload, error) { if len(items) == 0 { return p, nil } @@ -212,7 +212,7 @@ func (p ResponsesRequestPayload) appendInputItems(items []responses.ResponseInpu return p.set(reqPathInput, allInput) } -func (p ResponsesRequestPayload) inputItems() ([]any, error) { +func (p RequestPayload) inputItems() ([]any, error) { input := gjson.GetBytes(p, reqPathInput) if !input.Exists() || input.Type == gjson.Null { return []any{}, nil @@ -235,7 +235,7 @@ func (p ResponsesRequestPayload) inputItems() ([]any, error) { return existing, nil } -func (p ResponsesRequestPayload) toolItems() ([]json.RawMessage, error) { +func (p RequestPayload) toolItems() ([]json.RawMessage, error) { tools := gjson.GetBytes(p, reqPathTools) if !tools.Exists() { return nil, nil @@ -253,7 +253,7 @@ func (p ResponsesRequestPayload) toolItems() ([]json.RawMessage, error) { return existing, nil } -func (p ResponsesRequestPayload) set(path string, value any) (ResponsesRequestPayload, error) { +func (p RequestPayload) set(path string, value any) (RequestPayload, error) { updated, err := sjson.SetBytes(p, path, value) if err != nil { return p, xerrors.Errorf("failed to set value at path %s: %w", path, err) diff --git a/aibridge/intercept/responses/reqpayload_test.go b/aibridge/intercept/responses/reqpayload_test.go index 75ad7ecd5f..15f84183d3 100644 --- a/aibridge/intercept/responses/reqpayload_test.go +++ b/aibridge/intercept/responses/reqpayload_test.go @@ -1,4 +1,4 @@ -package responses +package responses //nolint:testpackage // tests unexported internals import ( "encoding/json" @@ -16,7 +16,7 @@ import ( "github.com/coder/coder/v2/aibridge/utils" ) -func TestNewResponsesRequestPayload(t *testing.T) { +func TestNewRequestPayload(t *testing.T) { t.Parallel() payloadWithWrongTypes := []byte(`{"model":123,"stream":"yes","input":42,"background":"nope"}`) @@ -42,7 +42,7 @@ func TestNewResponsesRequestPayload(t *testing.T) { err: "invalid JSON payload", }, { - // ResponsesRequestPayload just checks for JSON validity, + // RequestPayload just checks for JSON validity, // schema errors are not surfaced here and // the original body is preserved for upstream handling // similar to how reverse proxy would behave. @@ -59,7 +59,7 @@ func TestNewResponsesRequestPayload(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - payload, err := NewResponsesRequestPayload(tc.raw) + payload, err := NewRequestPayload(tc.raw) if tc.err != "" { require.ErrorContains(t, err, tc.err) @@ -518,10 +518,10 @@ func injectedFunctionTool(name string) responses.ToolUnionParam { } } -func mustPayload(t *testing.T, raw []byte) ResponsesRequestPayload { +func mustPayload(t *testing.T, raw []byte) RequestPayload { t.Helper() - payload, err := NewResponsesRequestPayload(raw) + payload, err := NewRequestPayload(raw) require.NoError(t, err) return payload } diff --git a/aibridge/intercept/responses/streaming.go b/aibridge/intercept/responses/streaming.go index 11d67fadb9..e3f77b21e8 100644 --- a/aibridge/intercept/responses/streaming.go +++ b/aibridge/intercept/responses/streaming.go @@ -35,7 +35,7 @@ type StreamingResponsesInterceptor struct { func NewStreamingInterceptor( id uuid.UUID, - reqPayload ResponsesRequestPayload, + reqPayload RequestPayload, providerName string, cfg config.OpenAI, clientHeaders http.Header, @@ -57,11 +57,11 @@ func NewStreamingInterceptor( } } -func (i *StreamingResponsesInterceptor) Setup(logger slog.Logger, recorder recorder.Recorder, mcpProxy mcp.ServerProxier) { - i.responsesInterceptionBase.Setup(logger.Named("streaming"), recorder, mcpProxy) +func (i *StreamingResponsesInterceptor) Setup(logger slog.Logger, rec recorder.Recorder, mcpProxy mcp.ServerProxier) { + i.responsesInterceptionBase.Setup(logger.Named("streaming"), rec, mcpProxy) } -func (i *StreamingResponsesInterceptor) Streaming() bool { +func (*StreamingResponsesInterceptor) Streaming() bool { return true } diff --git a/aibridge/internal/integrationtest/apidump_test.go b/aibridge/internal/integrationtest/apidump_test.go index bd88650d8e..8ec62297b0 100644 --- a/aibridge/internal/integrationtest/apidump_test.go +++ b/aibridge/internal/integrationtest/apidump_test.go @@ -1,4 +1,4 @@ -package integrationtest +package integrationtest //nolint:testpackage // tests unexported internals import ( "bufio" @@ -11,7 +11,6 @@ import ( "path/filepath" "strings" "testing" - "time" "github.com/stretchr/testify/require" @@ -19,6 +18,7 @@ import ( "github.com/coder/coder/v2/aibridge/config" "github.com/coder/coder/v2/aibridge/fixtures" "github.com/coder/coder/v2/aibridge/intercept/apidump" + "github.com/coder/coder/v2/aibridge/internal/testutil" "github.com/coder/coder/v2/aibridge/provider" ) @@ -114,23 +114,25 @@ func TestAPIDump(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong) t.Cleanup(cancel) // Setup mock upstream server. fix := fixtures.Parse(t, tc.fixture) - srv := newMockUpstream(t, ctx, newFixtureResponse(fix)) + srv := newMockUpstream(ctx, t, newFixtureResponse(fix)) // Create temp dir for API dumps. dumpDir := t.TempDir() - bridgeServer := newBridgeTestServer(t, ctx, srv.URL, + bridgeServer := newBridgeTestServer(ctx, t, srv.URL, withCustomProvider(tc.providerFunc(srv.URL, dumpDir)), ) - resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, fix.Request(), tc.headers) + resp, err := bridgeServer.makeRequest(t, http.MethodPost, tc.path, fix.Request(), tc.headers) + require.NoError(t, err) + defer resp.Body.Close() require.Equal(t, http.StatusOK, resp.StatusCode) - _, err := io.ReadAll(resp.Body) + _, err = io.ReadAll(resp.Body) require.NoError(t, err) // Verify dump files were created. @@ -187,6 +189,7 @@ func TestAPIDump(t *testing.T) { // Parse the dumped HTTP response. dumpResp, err := http.ReadResponse(bufio.NewReader(bytes.NewReader(respDumpData)), nil) require.NoError(t, err) + defer dumpResp.Body.Close() require.Equal(t, http.StatusOK, dumpResp.StatusCode) dumpRespBody, err := io.ReadAll(dumpResp.Body) require.NoError(t, err) @@ -241,7 +244,7 @@ func TestAPIDumpPassthrough(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong) t.Cleanup(cancel) upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -252,16 +255,18 @@ func TestAPIDumpPassthrough(t *testing.T) { dumpDir := t.TempDir() - bridgeServer := newBridgeTestServer(t, ctx, upstream.URL, + bridgeServer := newBridgeTestServer(ctx, t, upstream.URL, withCustomProvider(tc.providerFunc(upstream.URL, dumpDir)), ) - bridgeServer.makeRequest(t, http.MethodGet, tc.requestPath, nil) + resp, err := bridgeServer.makeRequest(t, http.MethodGet, tc.requestPath, nil) + require.NoError(t, err) + defer resp.Body.Close() // Find dump files in the passthrough directory. passthroughDir := filepath.Join(dumpDir, tc.name, "passthrough") var reqDumpFile, respDumpFile string - err := filepath.Walk(passthroughDir, func(path string, info os.FileInfo, err error) error { + err = filepath.Walk(passthroughDir, func(path string, info os.FileInfo, err error) error { if err != nil { return err } @@ -299,6 +304,7 @@ func TestAPIDumpPassthrough(t *testing.T) { require.NoError(t, err) dumpResp, err := http.ReadResponse(bufio.NewReader(bytes.NewReader(respDumpData)), nil) require.NoError(t, err) + defer dumpResp.Body.Close() require.Equal(t, http.StatusOK, dumpResp.StatusCode) dumpRespBody, err := io.ReadAll(dumpResp.Body) require.NoError(t, err) diff --git a/aibridge/internal/integrationtest/bridge_test.go b/aibridge/internal/integrationtest/bridge_test.go index 7247af6093..80ebf49915 100644 --- a/aibridge/internal/integrationtest/bridge_test.go +++ b/aibridge/internal/integrationtest/bridge_test.go @@ -1,4 +1,4 @@ -package integrationtest +package integrationtest //nolint:testpackage // tests unexported internals import ( "bytes" @@ -10,7 +10,6 @@ import ( "slices" "strings" "testing" - "time" "github.com/anthropics/anthropic-sdk-go" "github.com/anthropics/anthropic-sdk-go/packages/ssestream" @@ -29,6 +28,7 @@ import ( "github.com/coder/coder/v2/aibridge/config" "github.com/coder/coder/v2/aibridge/fixtures" "github.com/coder/coder/v2/aibridge/intercept" + "github.com/coder/coder/v2/aibridge/internal/testutil" "github.com/coder/coder/v2/aibridge/mcp" "github.com/coder/coder/v2/aibridge/provider" "github.com/coder/coder/v2/aibridge/recorder" @@ -78,18 +78,20 @@ func TestAnthropicMessages(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong) t.Cleanup(cancel) fix := fixtures.Parse(t, fixtures.AntSingleBuiltinTool) - upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) + upstream := newMockUpstream(ctx, t, newFixtureResponse(fix)) - bridgeServer := newBridgeTestServer(t, ctx, upstream.URL) + bridgeServer := newBridgeTestServer(ctx, t, upstream.URL) // Make API call to aibridge for Anthropic /v1/messages reqBody, err := sjson.SetBytes(fix.Request(), "stream", tc.streaming) require.NoError(t, err) - resp := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody) + resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody) + require.NoError(t, err) + defer resp.Body.Close() require.Equal(t, http.StatusOK, resp.StatusCode) // Response-specific checks. @@ -210,17 +212,19 @@ func TestAnthropicMessagesModelThoughts(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong) t.Cleanup(cancel) fix := fixtures.Parse(t, tc.fixture) - upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) + upstream := newMockUpstream(ctx, t, newFixtureResponse(fix)) - bridgeServer := newBridgeTestServer(t, ctx, upstream.URL) + bridgeServer := newBridgeTestServer(ctx, t, upstream.URL) reqBody, err := sjson.SetBytes(fix.Request(), "stream", tc.streaming) require.NoError(t, err) - resp := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody) + resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody) + require.NoError(t, err) + defer resp.Body.Close() require.Equal(t, http.StatusOK, resp.StatusCode) if tc.streaming { @@ -242,7 +246,7 @@ func TestAWSBedrockIntegration(t *testing.T) { t.Run("invalid config", func(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong) t.Cleanup(cancel) // Invalid bedrock config - missing region & base url @@ -254,11 +258,13 @@ func TestAWSBedrockIntegration(t *testing.T) { SmallFastModel: "test-haiku", } - bridgeServer := newBridgeTestServer(t, ctx, "http://unused", + bridgeServer := newBridgeTestServer(ctx, t, "http://unused", withCustomProvider(provider.NewAnthropic(anthropicCfg("http://unused", apiKey), bedrockCfg)), ) - resp := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, fixtures.Request(t, fixtures.AntSingleBuiltinTool)) + resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, fixtures.Request(t, fixtures.AntSingleBuiltinTool)) + require.NoError(t, err) + defer resp.Body.Close() require.Equal(t, http.StatusInternalServerError, resp.StatusCode) body, err := io.ReadAll(resp.Body) @@ -272,11 +278,11 @@ func TestAWSBedrockIntegration(t *testing.T) { t.Run(fmt.Sprintf("%s/streaming=%v", t.Name(), streaming), func(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong) t.Cleanup(cancel) fix := fixtures.Parse(t, fixtures.AntSingleBuiltinTool) - upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) + upstream := newMockUpstream(ctx, t, newFixtureResponse(fix)) // We define region here to validate that with Region & BaseURL defined, the latter takes precedence. bedrockCfg := &config.AWSBedrock{ @@ -288,7 +294,7 @@ func TestAWSBedrockIntegration(t *testing.T) { BaseURL: upstream.URL, // Use the mock server. } - bridgeServer := newBridgeTestServer(t, ctx, upstream.URL, + bridgeServer := newBridgeTestServer(ctx, t, upstream.URL, withCustomProvider(provider.NewAnthropic(anthropicCfg(upstream.URL, apiKey), bedrockCfg)), ) @@ -296,7 +302,9 @@ func TestAWSBedrockIntegration(t *testing.T) { // We override the AWS Bedrock client to route requests through our mock server. reqBody, err := sjson.SetBytes(fix.Request(), "stream", streaming) require.NoError(t, err) - resp := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody) + resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody) + require.NoError(t, err) + defer resp.Body.Close() // For streaming responses, consume the body to allow the stream to complete. if streaming { @@ -396,11 +404,11 @@ func TestAWSBedrockIntegration(t *testing.T) { t.Run(fmt.Sprintf("%s/streaming=%v", tc.name, streaming), func(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong) t.Cleanup(cancel) fix := fixtures.Parse(t, fixtures.AntSimpleBedrock) - upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) + upstream := newMockUpstream(ctx, t, newFixtureResponse(fix)) bCfg := &config.AWSBedrock{ Region: "us-west-2", @@ -411,7 +419,7 @@ func TestAWSBedrockIntegration(t *testing.T) { BaseURL: upstream.URL, } - bridgeServer := newBridgeTestServer(t, ctx, upstream.URL, + bridgeServer := newBridgeTestServer(ctx, t, upstream.URL, withCustomProvider(provider.NewAnthropic(anthropicCfg(upstream.URL, apiKey), bCfg)), ) @@ -419,9 +427,11 @@ func TestAWSBedrockIntegration(t *testing.T) { require.NoError(t, err) // Send with Anthropic-Beta header containing flags that should be filtered. - resp := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody, http.Header{ + resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody, http.Header{ "Anthropic-Beta": {"interleaved-thinking-2025-05-14,effort-2025-11-24,context-management-2025-06-27,prompt-caching-scope-2026-01-05"}, }) + require.NoError(t, err) + defer resp.Body.Close() require.Equal(t, http.StatusOK, resp.StatusCode) _, err = io.ReadAll(resp.Body) require.NoError(t, err) @@ -491,18 +501,20 @@ func TestOpenAIChatCompletions(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong) t.Cleanup(cancel) fix := fixtures.Parse(t, fixtures.OaiChatSingleBuiltinTool) - upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) + upstream := newMockUpstream(ctx, t, newFixtureResponse(fix)) - bridgeServer := newBridgeTestServer(t, ctx, upstream.URL) + bridgeServer := newBridgeTestServer(ctx, t, upstream.URL) // Make API call to aibridge for OpenAI /v1/chat/completions reqBody, err := sjson.SetBytes(fix.Request(), "stream", tc.streaming) require.NoError(t, err) - resp := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIChatCompletions, reqBody) + resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIChatCompletions, reqBody) + require.NoError(t, err) + defer resp.Body.Close() require.Equal(t, http.StatusOK, resp.StatusCode) // Response-specific checks. @@ -565,25 +577,27 @@ func TestOpenAIChatCompletions(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong) t.Cleanup(cancel) // Setup mock server for multi-turn interaction. // First request → tool call response, second → tool response. fix := fixtures.Parse(t, tc.fixture) - upstream := newMockUpstream(t, ctx, newFixtureResponse(fix), newFixtureToolResponse(fix)) + upstream := newMockUpstream(ctx, t, newFixtureResponse(fix), newFixtureToolResponse(fix)) // Setup MCP proxies with the tool from the fixture mockMCP := setupMCPForTest(t, defaultTracer) - bridgeServer := newBridgeTestServer(t, ctx, upstream.URL, + bridgeServer := newBridgeTestServer(ctx, t, upstream.URL, withMCP(mockMCP), ) // Add the stream param to the request. reqBody, err := sjson.SetBytes(fix.Request(), "stream", true) require.NoError(t, err) - resp := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIChatCompletions, reqBody) + resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIChatCompletions, reqBody) + require.NoError(t, err) + defer resp.Body.Close() require.Equal(t, http.StatusOK, resp.StatusCode) // Verify SSE headers are sent correctly @@ -756,18 +770,20 @@ func TestSimple(t *testing.T) { t.Run(fmt.Sprintf("streaming=%v", streaming), func(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong) t.Cleanup(cancel) fix := fixtures.Parse(t, tc.fixture) - upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) + upstream := newMockUpstream(ctx, t, newFixtureResponse(fix)) - bridgeServer := newBridgeTestServer(t, ctx, upstream.URL+tc.basePath) + bridgeServer := newBridgeTestServer(ctx, t, upstream.URL+tc.basePath) // When: calling the "API server" with the fixture's request body. reqBody, err := sjson.SetBytes(fix.Request(), "stream", streaming) require.NoError(t, err) - resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, reqBody, http.Header{"User-Agent": {tc.userAgent}}) + resp, err := bridgeServer.makeRequest(t, http.MethodPost, tc.path, reqBody, http.Header{"User-Agent": {tc.userAgent}}) + require.NoError(t, err) + defer resp.Body.Close() require.Equal(t, http.StatusOK, resp.StatusCode) // Then: I expect the upstream request to have the correct path. @@ -861,12 +877,12 @@ func TestSessionIDTracking(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong) t.Cleanup(cancel) fix := fixtures.Parse(t, tc.fixture) - upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) - bridgeServer := newBridgeTestServer(t, ctx, upstream.URL, withProvider(config.ProviderAnthropic)) + upstream := newMockUpstream(ctx, t, newFixtureResponse(fix)) + bridgeServer := newBridgeTestServer(ctx, t, upstream.URL, withProvider(config.ProviderAnthropic)) reqBody := fix.Request() if tc.metadataSessionID != "" { @@ -875,11 +891,13 @@ func TestSessionIDTracking(t *testing.T) { require.NoError(t, err) } - resp := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody, tc.header) + resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody, tc.header) + require.NoError(t, err) + defer resp.Body.Close() require.Equal(t, http.StatusOK, resp.StatusCode) // Drain the body to let the stream complete. - _, err := io.ReadAll(resp.Body) + _, err = io.ReadAll(resp.Body) require.NoError(t, err) interceptions := bridgeServer.Recorder.RecordedInterceptions() @@ -948,10 +966,12 @@ func TestFallthrough(t *testing.T) { t.Parallel() fix := fixtures.Parse(t, tc.fixture) - upstream := newMockUpstream(t, t.Context(), newFixtureResponse(fix)) - bridgeServer := newBridgeTestServer(t, t.Context(), upstream.URL+tc.basePath) + upstream := newMockUpstream(t.Context(), t, newFixtureResponse(fix)) + bridgeServer := newBridgeTestServer(t.Context(), t, upstream.URL+tc.basePath) - resp := bridgeServer.makeRequest(t, http.MethodGet, tc.requestPath, nil) + resp, err := bridgeServer.makeRequest(t, http.MethodGet, tc.requestPath, nil) + require.NoError(t, err) + defer resp.Body.Close() require.Equal(t, http.StatusOK, resp.StatusCode) @@ -984,6 +1004,7 @@ func TestAnthropicInjectedTools(t *testing.T) { // Build the requirements & make the assertions which are common to all providers. bridgeServer, mockMCP, resp := setupInjectedToolTest(t, fixtures.AntSingleInjectedTool, streaming, defaultTracer, pathAnthropicMessages, anthropicToolResultValidator(t)) + defer resp.Body.Close() // Ensure expected tool was invoked with expected input. toolUsages := bridgeServer.Recorder.RecordedToolUsages() @@ -1067,6 +1088,7 @@ func TestOpenAIInjectedTools(t *testing.T) { // Build the requirements & make the assertions which are common to all providers. bridgeServer, mockMCP, resp := setupInjectedToolTest(t, fixtures.OaiChatSingleInjectedTool, streaming, defaultTracer, pathOpenAIChatCompletions, openaiChatToolResultValidator(t)) + defer resp.Body.Close() // Ensure expected tool was invoked with expected input. toolUsages := bridgeServer.Recorder.RecordedToolUsages() @@ -1234,6 +1256,8 @@ func TestErrorHandling(t *testing.T) { // Tests that errors which occur *before* a streaming response begins, or in non-streaming requests, are handled as expected. t.Run("non-stream error", func(t *testing.T) { + t.Parallel() + cases := []struct { name string fixture []byte @@ -1276,21 +1300,23 @@ func TestErrorHandling(t *testing.T) { t.Run(fmt.Sprintf("streaming=%v", streaming), func(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong) t.Cleanup(cancel) // Setup mock server. Error fixtures contain raw HTTP // responses that may cause the bridge to retry. fix := fixtures.Parse(t, tc.fixture) - upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) + upstream := newMockUpstream(ctx, t, newFixtureResponse(fix)) - bridgeServer := newBridgeTestServer(t, ctx, upstream.URL) + bridgeServer := newBridgeTestServer(ctx, t, upstream.URL) // Add the stream param to the request. reqBody, err := sjson.SetBytes(fix.Request(), "stream", streaming) require.NoError(t, err) - resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, reqBody) + resp, err := bridgeServer.makeRequest(t, http.MethodPost, tc.path, reqBody) + require.NoError(t, err) + defer resp.Body.Close() tc.responseHandlerFn(resp) bridgeServer.Recorder.VerifyAllInterceptionsEnded(t) @@ -1347,17 +1373,19 @@ func TestErrorHandling(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong) t.Cleanup(cancel) // Setup mock server. fix := fixtures.Parse(t, tc.fixture) - upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) + upstream := newMockUpstream(ctx, t, newFixtureResponse(fix)) upstream.StatusCode = http.StatusInternalServerError - bridgeServer := newBridgeTestServer(t, ctx, upstream.URL) + bridgeServer := newBridgeTestServer(ctx, t, upstream.URL) - resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, fix.Request()) + resp, err := bridgeServer.makeRequest(t, http.MethodPost, tc.path, fix.Request()) + require.NoError(t, err) + defer resp.Body.Close() tc.responseHandlerFn(resp) bridgeServer.Recorder.VerifyAllInterceptionsEnded(t) @@ -1394,7 +1422,7 @@ func TestStableRequestEncoding(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong) t.Cleanup(cancel) // Setup MCP tools. @@ -1408,15 +1436,17 @@ func TestStableRequestEncoding(t *testing.T) { for i := range count { responses[i] = newFixtureResponse(fix) } - upstream := newMockUpstream(t, ctx, responses...) + upstream := newMockUpstream(ctx, t, responses...) - bridgeServer := newBridgeTestServer(t, ctx, upstream.URL, + bridgeServer := newBridgeTestServer(ctx, t, upstream.URL, withMCP(mockMCP), ) // Make multiple requests and verify they all have identical payloads. for range count { - resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, fix.Request()) + resp, err := bridgeServer.makeRequest(t, http.MethodPost, tc.path, fix.Request()) + require.NoError(t, err) + defer resp.Body.Close() require.Equal(t, http.StatusOK, resp.StatusCode) } @@ -1657,7 +1687,7 @@ func TestAnthropicToolChoiceParallelDisabled(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong) t.Cleanup(cancel) // Setup MCP tools conditionally. @@ -1669,9 +1699,9 @@ func TestAnthropicToolChoiceParallelDisabled(t *testing.T) { } fix := fixtures.Parse(t, tc.fixture) - upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) + upstream := newMockUpstream(ctx, t, newFixtureResponse(fix)) - bridgeServer := newBridgeTestServer(t, ctx, upstream.URL, + bridgeServer := newBridgeTestServer(ctx, t, upstream.URL, withMCP(mockMCP), ) @@ -1679,7 +1709,9 @@ func TestAnthropicToolChoiceParallelDisabled(t *testing.T) { reqBody, err := sjson.SetBytes(fix.Request(), "tool_choice", tc.toolChoice) require.NoError(t, err) - resp := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody) + resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody) + require.NoError(t, err) + defer resp.Body.Close() require.Equal(t, http.StatusOK, resp.StatusCode) // Verify tool_choice in the upstream request. @@ -1819,17 +1851,17 @@ func TestChatCompletionsParallelToolCallsDisabled(t *testing.T) { t.Run(fmt.Sprintf("%s/streaming=%v", tc.name, streaming), func(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong) t.Cleanup(cancel) fix := fixtures.Parse(t, tc.fixture) - upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) + upstream := newMockUpstream(ctx, t, newFixtureResponse(fix)) var opts []bridgeOption if tc.withInjectedTools { opts = append(opts, withMCP(setupMCPForTest(t, defaultTracer))) } - bridgeServer := newBridgeTestServer(t, ctx, upstream.URL, opts...) + bridgeServer := newBridgeTestServer(ctx, t, upstream.URL, opts...) var ( reqBody = fix.Request() @@ -1842,7 +1874,9 @@ func TestChatCompletionsParallelToolCallsDisabled(t *testing.T) { reqBody, err = sjson.SetBytes(reqBody, "stream", streaming) require.NoError(t, err) - resp := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIChatCompletions, reqBody) + resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIChatCompletions, reqBody) + require.NoError(t, err) + defer resp.Body.Close() _, err = io.ReadAll(resp.Body) require.NoError(t, err) @@ -1872,13 +1906,13 @@ func TestThinkingAdaptiveIsPreserved(t *testing.T) { t.Run(fmt.Sprintf("streaming=%v", streaming), func(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong) t.Cleanup(cancel) // Create a mock server that captures the request body sent upstream. - upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) + upstream := newMockUpstream(ctx, t, newFixtureResponse(fix)) - bridgeServer := newBridgeTestServer(t, ctx, upstream.URL) + bridgeServer := newBridgeTestServer(ctx, t, upstream.URL) // Inject adaptive thinking into the fixture request. reqBody, err := sjson.SetBytes(fix.Request(), "thinking", map[string]string{"type": "adaptive"}) @@ -1886,7 +1920,9 @@ func TestThinkingAdaptiveIsPreserved(t *testing.T) { reqBody, err = sjson.SetBytes(reqBody, "stream", streaming) require.NoError(t, err) - resp := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody) + resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody) + require.NoError(t, err) + defer resp.Body.Close() require.Equal(t, http.StatusOK, resp.StatusCode) _, err = io.ReadAll(resp.Body) require.NoError(t, err) @@ -1935,11 +1971,11 @@ func TestEnvironmentDoNotLeak(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // NOTE: Cannot use t.Parallel() here because t.Setenv requires sequential execution. - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong) t.Cleanup(cancel) fix := fixtures.Parse(t, tc.fixture) - upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) + upstream := newMockUpstream(ctx, t, newFixtureResponse(fix)) // Set environment variables that the SDK would automatically read. // These should NOT leak into upstream requests. @@ -1947,9 +1983,11 @@ func TestEnvironmentDoNotLeak(t *testing.T) { t.Setenv(key, val) } - bridgeServer := newBridgeTestServer(t, ctx, upstream.URL) + bridgeServer := newBridgeTestServer(ctx, t, upstream.URL) - resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, fix.Request()) + resp, err := bridgeServer.makeRequest(t, http.MethodPost, tc.path, fix.Request()) + require.NoError(t, err) + defer resp.Body.Close() require.Equal(t, http.StatusOK, resp.StatusCode) // Verify that environment values did not leak. @@ -2045,14 +2083,14 @@ func TestActorHeaders(t *testing.T) { t.Run(fmt.Sprintf("%s/streaming=%v/send-headers=%v", tc.name, tc.streaming, send), func(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong) t.Cleanup(cancel) fix := fixtures.Parse(t, tc.fixture) - upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) + upstream := newMockUpstream(ctx, t, newFixtureResponse(fix)) metadataKey := "Username" - bridgeServer := newBridgeTestServer(t, ctx, upstream.URL, + bridgeServer := newBridgeTestServer(ctx, t, upstream.URL, withCustomProvider(tc.createProviderFn(upstream.URL, apiKey, send)), withActor(defaultActorID, recorder.Metadata{ metadataKey: actorUsername, @@ -2063,7 +2101,9 @@ func TestActorHeaders(t *testing.T) { reqBody, err := sjson.SetBytes(fix.Request(), "stream", tc.streaming) require.NoError(t, err) - resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, reqBody) + resp, err := bridgeServer.makeRequest(t, http.MethodPost, tc.path, reqBody) + require.NoError(t, err) + defer resp.Body.Close() // Drain the body so streaming responses complete without // a "connection reset" error in the mock upstream. _, err = io.ReadAll(resp.Body) diff --git a/aibridge/internal/integrationtest/circuit_breaker_test.go b/aibridge/internal/integrationtest/circuit_breaker_test.go index 4a367ea15e..3ace843275 100644 --- a/aibridge/internal/integrationtest/circuit_breaker_test.go +++ b/aibridge/internal/integrationtest/circuit_breaker_test.go @@ -1,4 +1,4 @@ -package integrationtest +package integrationtest //nolint:testpackage // tests unexported internals import ( "fmt" @@ -28,11 +28,11 @@ const ( ) func anthropicSuccessResponse(model string) string { - return fmt.Sprintf(`{"id":"msg_01","type":"message","role":"assistant","content":[{"type":"text","text":"Hello!"}],"model":"%s","stop_reason":"end_turn","usage":{"input_tokens":10,"output_tokens":5}}`, model) + return fmt.Sprintf(`{"id":"msg_01","type":"message","role":"assistant","content":[{"type":"text","text":"Hello!"}],"model":%q,"stop_reason":"end_turn","usage":{"input_tokens":10,"output_tokens":5}}`, model) } func openAISuccessResponse(model string) string { - return fmt.Sprintf(`{"id":"chatcmpl-123","object":"chat.completion","created":1677652288,"model":"%s","choices":[{"index":0,"message":{"role":"assistant","content":"Hello!"},"finish_reason":"stop"}],"usage":{"prompt_tokens":9,"completion_tokens":12,"total_tokens":21}}`, model) + return fmt.Sprintf(`{"id":"chatcmpl-123","object":"chat.completion","created":1677652288,"model":%q,"choices":[{"index":0,"message":{"role":"assistant","content":"Hello!"},"finish_reason":"stop"}],"usage":{"prompt_tokens":9,"completion_tokens":12,"total_tokens":21}}`, model) } // TestCircuitBreaker_FullRecoveryCycle tests the complete circuit breaker lifecycle: @@ -130,31 +130,35 @@ func TestCircuitBreaker_FullRecoveryCycle(t *testing.T) { } ctx := t.Context() - bridgeServer := newBridgeTestServer(t, ctx, mockUpstream.URL, + bridgeServer := newBridgeTestServer(ctx, t, mockUpstream.URL, withCustomProvider(tc.createProvider(mockUpstream.URL, cbConfig)), withMetrics(m), withActor("test-user-id", nil), ) - doRequest := func() *http.Response { - resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, []byte(tc.requestBody), tc.headers) - _, err := io.ReadAll(resp.Body) + doRequest := func() int { + resp, err := bridgeServer.makeRequest(t, http.MethodPost, tc.path, []byte(tc.requestBody), tc.headers) require.NoError(t, err) - return resp + _, err = io.ReadAll(resp.Body) + require.NoError(t, err) + require.NoError(t, resp.Body.Close()) + return resp.StatusCode } // Phase 1: Trip the circuit breaker // First FailureThreshold requests hit upstream, get 429 for i := uint32(0); i < cbConfig.FailureThreshold; i++ { - resp := doRequest() - assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode) + status := doRequest() + assert.Equal(t, http.StatusTooManyRequests, status) } + //nolint:gosec // G115: test constant, no overflow risk assert.Equal(t, int32(cbConfig.FailureThreshold), upstreamCalls.Load()) // Phase 2: Verify circuit is open // Request should be blocked by circuit breaker (no upstream call) - resp := doRequest() - assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode) + status := doRequest() + assert.Equal(t, http.StatusServiceUnavailable, status) + //nolint:gosec // G115: test constant, no overflow risk assert.Equal(t, int32(cbConfig.FailureThreshold), upstreamCalls.Load(), "No new upstream call when circuit is open") // Verify metrics show circuit is open @@ -175,8 +179,8 @@ func TestCircuitBreaker_FullRecoveryCycle(t *testing.T) { // Phase 4: Recovery - request in half-open state should succeed and close circuit upstreamCallsBefore := upstreamCalls.Load() - resp = doRequest() - assert.Equal(t, http.StatusOK, resp.StatusCode, "Request should succeed in half-open state") + status = doRequest() + assert.Equal(t, http.StatusOK, status, "Request should succeed in half-open state") assert.Equal(t, upstreamCallsBefore+1, upstreamCalls.Load(), "Request should reach upstream in half-open state") // Verify circuit is now closed @@ -186,8 +190,8 @@ func TestCircuitBreaker_FullRecoveryCycle(t *testing.T) { // Phase 5: Verify circuit is fully functional again // Multiple requests should all succeed and reach upstream for i := 0; i < 3; i++ { - resp = doRequest() - assert.Equal(t, http.StatusOK, resp.StatusCode, "Request should succeed after circuit closes") + status = doRequest() + assert.Equal(t, http.StatusOK, status, "Request should succeed after circuit closes") } // All requests should have reached upstream @@ -283,28 +287,30 @@ func TestCircuitBreaker_HalfOpenFailure(t *testing.T) { } ctx := t.Context() - bridgeServer := newBridgeTestServer(t, ctx, mockUpstream.URL, + bridgeServer := newBridgeTestServer(ctx, t, mockUpstream.URL, withCustomProvider(tc.createProvider(mockUpstream.URL, cbConfig)), withMetrics(m), withActor("test-user-id", nil), ) - doRequest := func() *http.Response { - resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, []byte(tc.requestBody), tc.headers) - _, err := io.ReadAll(resp.Body) + doRequest := func() int { + resp, err := bridgeServer.makeRequest(t, http.MethodPost, tc.path, []byte(tc.requestBody), tc.headers) require.NoError(t, err) - return resp + _, err = io.ReadAll(resp.Body) + require.NoError(t, err) + require.NoError(t, resp.Body.Close()) + return resp.StatusCode } // Phase 1: Trip the circuit for i := uint32(0); i < cbConfig.FailureThreshold; i++ { - resp := doRequest() - assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode) + status := doRequest() + assert.Equal(t, http.StatusTooManyRequests, status) } // Verify circuit is open - resp := doRequest() - assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode) + status := doRequest() + assert.Equal(t, http.StatusServiceUnavailable, status) trips := promtest.ToFloat64(m.CircuitBreakerTrips.WithLabelValues(tc.expectProvider, tc.expectEndpoint, tc.expectModel)) assert.Equal(t, 1.0, trips, "CircuitBreakerTrips should be 1") @@ -314,13 +320,13 @@ func TestCircuitBreaker_HalfOpenFailure(t *testing.T) { // Phase 3: Request in half-open state fails, circuit should re-open upstreamCallsBefore := upstreamCalls.Load() - resp = doRequest() - assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode, "Request should fail in half-open state") + status = doRequest() + assert.Equal(t, http.StatusTooManyRequests, status, "Request should fail in half-open state") assert.Equal(t, upstreamCallsBefore+1, upstreamCalls.Load(), "Request should reach upstream in half-open state") // Circuit should be open again - next request should be rejected immediately - resp = doRequest() - assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode, "Circuit should be open again after half-open failure") + status = doRequest() + assert.Equal(t, http.StatusServiceUnavailable, status, "Circuit should be open again after half-open failure") assert.Equal(t, upstreamCallsBefore+1, upstreamCalls.Load(), "Request should NOT reach upstream when circuit re-opens") // Verify metrics: trips should be 2 now (tripped twice) @@ -429,28 +435,30 @@ func TestCircuitBreaker_HalfOpenMaxRequests(t *testing.T) { } ctx := t.Context() - bridgeServer := newBridgeTestServer(t, ctx, mockUpstream.URL, + bridgeServer := newBridgeTestServer(ctx, t, mockUpstream.URL, withCustomProvider(tc.createProvider(mockUpstream.URL, cbConfig)), withMetrics(m), withActor("test-user-id", nil), ) - doRequest := func() *http.Response { - resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, []byte(tc.requestBody), tc.headers) - _, err := io.ReadAll(resp.Body) + doRequest := func() int { + resp, err := bridgeServer.makeRequest(t, http.MethodPost, tc.path, []byte(tc.requestBody), tc.headers) require.NoError(t, err) - return resp + _, err = io.ReadAll(resp.Body) + require.NoError(t, err) + require.NoError(t, resp.Body.Close()) + return resp.StatusCode } // Phase 1: Trip the circuit for i := uint32(0); i < cbConfig.FailureThreshold; i++ { - resp := doRequest() - assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode) + status := doRequest() + assert.Equal(t, http.StatusTooManyRequests, status) } // Verify circuit is open - resp := doRequest() - assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode) + status := doRequest() + assert.Equal(t, http.StatusServiceUnavailable, status) // Phase 2: Wait for half-open state and switch upstream to success time.Sleep(cbConfig.Timeout + 10*time.Millisecond) @@ -466,8 +474,8 @@ func TestCircuitBreaker_HalfOpenMaxRequests(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - resp := doRequest() - responses <- resp.StatusCode + status := doRequest() + responses <- status }() } @@ -544,7 +552,7 @@ func TestCircuitBreaker_PerModelIsolation(t *testing.T) { MaxRequests: 1, } ctx := t.Context() - bridgeServer := newBridgeTestServer(t, ctx, mockUpstream.URL, + bridgeServer := newBridgeTestServer(ctx, t, mockUpstream.URL, withCustomProvider(provider.NewAnthropic(config.Anthropic{ BaseURL: mockUpstream.URL, Key: "test-key", @@ -554,27 +562,31 @@ func TestCircuitBreaker_PerModelIsolation(t *testing.T) { withActor("test-user-id", nil), ) - doRequest := func(model string) *http.Response { - body := fmt.Sprintf(`{"model":"%s","max_tokens":1024,"messages":[{"role":"user","content":"hi"}]}`, model) - resp := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, []byte(body), http.Header{ + doRequest := func(model string) int { + body := fmt.Sprintf(`{"model":%q,"max_tokens":1024,"messages":[{"role":"user","content":"hi"}]}`, model) + resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, []byte(body), http.Header{ "x-api-key": {"test"}, "anthropic-version": {"2023-06-01"}, }) - _, err := io.ReadAll(resp.Body) require.NoError(t, err) - return resp + _, err = io.ReadAll(resp.Body) + require.NoError(t, err) + require.NoError(t, resp.Body.Close()) + return resp.StatusCode } // Phase 1: Trip the circuit for sonnet model for i := uint32(0); i < cbConfig.FailureThreshold; i++ { - resp := doRequest("claude-sonnet-4-20250514") - assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode) + status := doRequest("claude-sonnet-4-20250514") + assert.Equal(t, http.StatusTooManyRequests, status) } + //nolint:gosec // G115: test constant, no overflow risk assert.Equal(t, int32(cbConfig.FailureThreshold), sonnetCalls.Load()) // Verify sonnet circuit is open - resp := doRequest("claude-sonnet-4-20250514") - assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode, "Sonnet circuit should be open") + status := doRequest("claude-sonnet-4-20250514") + assert.Equal(t, http.StatusServiceUnavailable, status, "Sonnet circuit should be open") + //nolint:gosec // G115: test constant, no overflow risk assert.Equal(t, int32(cbConfig.FailureThreshold), sonnetCalls.Load(), "No new sonnet calls when circuit is open") // Verify sonnet metrics show circuit is open @@ -585,14 +597,14 @@ func TestCircuitBreaker_PerModelIsolation(t *testing.T) { assert.Equal(t, 1.0, sonnetState, "Sonnet CircuitBreakerState should be 1 (open)") // Phase 2: Haiku model should still work (independent circuit) - resp = doRequest("claude-3-5-haiku-20241022") - assert.Equal(t, http.StatusOK, resp.StatusCode, "Haiku should succeed while sonnet circuit is open") + status = doRequest("claude-3-5-haiku-20241022") + assert.Equal(t, http.StatusOK, status, "Haiku should succeed while sonnet circuit is open") assert.Equal(t, int32(1), haikuCalls.Load(), "Haiku call should reach upstream") // Make multiple haiku requests - all should succeed for i := 0; i < 3; i++ { - resp = doRequest("claude-3-5-haiku-20241022") - assert.Equal(t, http.StatusOK, resp.StatusCode, "Haiku should continue to succeed") + status = doRequest("claude-3-5-haiku-20241022") + assert.Equal(t, http.StatusOK, status, "Haiku should continue to succeed") } assert.Equal(t, int32(4), haikuCalls.Load(), "All haiku calls should reach upstream") @@ -607,8 +619,8 @@ func TestCircuitBreaker_PerModelIsolation(t *testing.T) { time.Sleep(cbConfig.Timeout + 10*time.Millisecond) sonnetShouldFail.Store(false) - resp = doRequest("claude-sonnet-4-20250514") - assert.Equal(t, http.StatusOK, resp.StatusCode, "Sonnet should recover after timeout") + status = doRequest("claude-sonnet-4-20250514") + assert.Equal(t, http.StatusOK, status, "Sonnet should recover after timeout") // Verify sonnet circuit is now closed sonnetState = promtest.ToFloat64(m.CircuitBreakerState.WithLabelValues(config.ProviderAnthropic, "/v1/messages", "claude-sonnet-4-20250514")) diff --git a/aibridge/internal/integrationtest/metrics_test.go b/aibridge/internal/integrationtest/metrics_test.go index 35ac9f9a32..c3d61ad715 100644 --- a/aibridge/internal/integrationtest/metrics_test.go +++ b/aibridge/internal/integrationtest/metrics_test.go @@ -1,4 +1,4 @@ -package integrationtest +package integrationtest //nolint:testpackage // tests unexported internals import ( "bytes" @@ -7,7 +7,6 @@ import ( "net/http" "net/http/httptest" "testing" - "time" "github.com/prometheus/client_golang/prometheus" promtest "github.com/prometheus/client_golang/prometheus/testutil" @@ -17,6 +16,7 @@ import ( "github.com/coder/coder/v2/aibridge" "github.com/coder/coder/v2/aibridge/config" "github.com/coder/coder/v2/aibridge/fixtures" + "github.com/coder/coder/v2/aibridge/internal/testutil" "github.com/coder/coder/v2/aibridge/metrics" ) @@ -104,7 +104,7 @@ func TestMetrics_Interception(t *testing.T) { }, { name: "oai_responses_blocking_error", - fixture: fixtures.OaiResponsesBlockingHttpErr, + fixture: fixtures.OaiResponsesBlockingHTTPErr, path: pathOpenAIResponses, headers: http.Header{"User-Agent": []string{"codex/1.0.0"}}, expectStatus: metrics.InterceptionCountStatusFailed, @@ -127,7 +127,7 @@ func TestMetrics_Interception(t *testing.T) { }, { name: "oai_responses_streaming_error", - fixture: fixtures.OaiResponsesStreamingHttpErr, + fixture: fixtures.OaiResponsesStreamingHTTPErr, path: pathOpenAIResponses, headers: http.Header{"Originator": []string{"roo-code"}}, expectStatus: metrics.InterceptionCountStatusFailed, @@ -143,20 +143,22 @@ func TestMetrics_Interception(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong) t.Cleanup(cancel) fix := fixtures.Parse(t, tc.fixture) - upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) + upstream := newMockUpstream(ctx, t, newFixtureResponse(fix)) upstream.AllowOverflow = tc.allowOverflow m := aibridge.NewMetrics(prometheus.NewRegistry()) - bridgeServer := newBridgeTestServer(t, ctx, upstream.URL, + bridgeServer := newBridgeTestServer(ctx, t, upstream.URL, withMetrics(m), ) - resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, fix.Request(), tc.headers) - _, err := io.ReadAll(resp.Body) + resp, err := bridgeServer.makeRequest(t, http.MethodPost, tc.path, fix.Request(), tc.headers) + require.NoError(t, err) + defer resp.Body.Close() + _, err = io.ReadAll(resp.Body) require.NoError(t, err) count := promtest.ToFloat64(m.InterceptionCount.WithLabelValues( @@ -173,7 +175,7 @@ func TestMetrics_InterceptionsInflight(t *testing.T) { fix := fixtures.Parse(t, fixtures.AntSimple) - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong) t.Cleanup(cancel) blockCh := make(chan struct{}) @@ -185,7 +187,7 @@ func TestMetrics_InterceptionsInflight(t *testing.T) { t.Cleanup(srv.Close) m := aibridge.NewMetrics(prometheus.NewRegistry()) - bridgeServer := newBridgeTestServer(t, ctx, srv.URL, + bridgeServer := newBridgeTestServer(ctx, t, srv.URL, withMetrics(m), ) @@ -208,7 +210,7 @@ func TestMetrics_InterceptionsInflight(t *testing.T) { return promtest.ToFloat64( m.InterceptionsInflight.WithLabelValues(config.ProviderAnthropic, "claude-sonnet-4-0", "/v1/messages"), ) == 1 - }, time.Second*10, time.Millisecond*50) + }, testutil.WaitMedium, testutil.IntervalFast) // Unblock request, await completion. close(blockCh) @@ -223,7 +225,7 @@ func TestMetrics_InterceptionsInflight(t *testing.T) { return promtest.ToFloat64( m.InterceptionsInflight.WithLabelValues(config.ProviderAnthropic, "claude-sonnet-4-0", "/v1/messages"), ) == 0 - }, time.Second*10, time.Millisecond*50) + }, testutil.WaitMedium, testutil.IntervalFast) } func TestMetrics_PassthroughCount(t *testing.T) { @@ -233,11 +235,13 @@ func TestMetrics_PassthroughCount(t *testing.T) { t.Cleanup(upstream.Close) m := aibridge.NewMetrics(prometheus.NewRegistry()) - bridgeServer := newBridgeTestServer(t, t.Context(), upstream.URL, + bridgeServer := newBridgeTestServer(t.Context(), t, upstream.URL, withMetrics(m), ) - resp := bridgeServer.makeRequest(t, http.MethodGet, "/openai/v1/models", nil) + resp, err := bridgeServer.makeRequest(t, http.MethodGet, "/openai/v1/models", nil) + require.NoError(t, err) + defer resp.Body.Close() require.Equal(t, http.StatusOK, resp.StatusCode) count := promtest.ToFloat64(m.PassthroughCount.WithLabelValues( @@ -248,20 +252,22 @@ func TestMetrics_PassthroughCount(t *testing.T) { func TestMetrics_PromptCount(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong) t.Cleanup(cancel) fix := fixtures.Parse(t, fixtures.OaiChatSimple) - upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) + upstream := newMockUpstream(ctx, t, newFixtureResponse(fix)) m := aibridge.NewMetrics(prometheus.NewRegistry()) - bridgeServer := newBridgeTestServer(t, ctx, upstream.URL, + bridgeServer := newBridgeTestServer(ctx, t, upstream.URL, withMetrics(m), ) - resp := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIChatCompletions, fix.Request(), http.Header{"User-Agent": []string{"claude-code/1.0.0"}}) + resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIChatCompletions, fix.Request(), http.Header{"User-Agent": []string{"claude-code/1.0.0"}}) + require.NoError(t, err) + defer resp.Body.Close() require.Equal(t, http.StatusOK, resp.StatusCode) - _, err := io.ReadAll(resp.Body) + _, err = io.ReadAll(resp.Body) require.NoError(t, err) prompts := promtest.ToFloat64(m.PromptCount.WithLabelValues( @@ -336,14 +342,14 @@ func TestMetrics_TokenUseCount(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong) t.Cleanup(cancel) fix := fixtures.Parse(t, tc.fixture) - upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) + upstream := newMockUpstream(ctx, t, newFixtureResponse(fix)) m := aibridge.NewMetrics(prometheus.NewRegistry()) - bridgeServer := newBridgeTestServer(t, ctx, upstream.URL, + bridgeServer := newBridgeTestServer(ctx, t, upstream.URL, withMetrics(m), ) @@ -353,7 +359,9 @@ func TestMetrics_TokenUseCount(t *testing.T) { reqBody, err = sjson.SetBytes(reqBody, "stream", true) require.NoError(t, err) } - resp := bridgeServer.makeRequest(t, http.MethodPost, tc.reqPath, reqBody, nil) + resp, err := bridgeServer.makeRequest(t, http.MethodPost, tc.reqPath, reqBody, nil) + require.NoError(t, err) + defer resp.Body.Close() require.Equal(t, http.StatusOK, resp.StatusCode) _, _ = io.ReadAll(resp.Body) @@ -361,7 +369,7 @@ func TestMetrics_TokenUseCount(t *testing.T) { require.Eventually(t, func() bool { return promtest.ToFloat64(m.TokenUseCount.WithLabelValues( tc.expectProvider, tc.expectModel, "input", defaultActorID, string(aibridge.ClientUnknown))) > 0 - }, time.Second*10, time.Millisecond*50) + }, testutil.WaitMedium, testutil.IntervalFast) for label, expected := range tc.expectedLabels { require.Equal(t, expected, promtest.ToFloat64(m.TokenUseCount.WithLabelValues( @@ -375,20 +383,22 @@ func TestMetrics_TokenUseCount(t *testing.T) { func TestMetrics_NonInjectedToolUseCount(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong) t.Cleanup(cancel) fix := fixtures.Parse(t, fixtures.OaiChatSingleBuiltinTool) - upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) + upstream := newMockUpstream(ctx, t, newFixtureResponse(fix)) m := aibridge.NewMetrics(prometheus.NewRegistry()) - bridgeServer := newBridgeTestServer(t, ctx, upstream.URL, + bridgeServer := newBridgeTestServer(ctx, t, upstream.URL, withMetrics(m), ) - resp := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIChatCompletions, fix.Request()) + resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIChatCompletions, fix.Request()) + require.NoError(t, err) + defer resp.Body.Close() require.Equal(t, http.StatusOK, resp.StatusCode) - _, err := io.ReadAll(resp.Body) + _, err = io.ReadAll(resp.Body) require.NoError(t, err) count := promtest.ToFloat64(m.NonInjectedToolUseCount.WithLabelValues( @@ -399,32 +409,34 @@ func TestMetrics_NonInjectedToolUseCount(t *testing.T) { func TestMetrics_InjectedToolUseCount(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong) t.Cleanup(cancel) // First request returns the tool invocation, the second returns the mocked response to the tool result. fix := fixtures.Parse(t, fixtures.AntSingleInjectedTool) - upstream := newMockUpstream(t, ctx, newFixtureResponse(fix), newFixtureToolResponse(fix)) + upstream := newMockUpstream(ctx, t, newFixtureResponse(fix), newFixtureToolResponse(fix)) m := aibridge.NewMetrics(prometheus.NewRegistry()) // Setup mocked MCP server & tools. mockMCP := setupMCPForTest(t, defaultTracer) - bridgeServer := newBridgeTestServer(t, ctx, upstream.URL, + bridgeServer := newBridgeTestServer(ctx, t, upstream.URL, withMetrics(m), withMCP(mockMCP), ) - resp := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, fix.Request()) + resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, fix.Request()) + require.NoError(t, err) + defer resp.Body.Close() require.Equal(t, http.StatusOK, resp.StatusCode) - _, err := io.ReadAll(resp.Body) + _, err = io.ReadAll(resp.Body) require.NoError(t, err) // Wait until full roundtrip has completed. require.Eventually(t, func() bool { return upstream.Calls.Load() == 2 - }, time.Second*10, time.Millisecond*50) + }, testutil.WaitMedium, testutil.IntervalFast) recorder := bridgeServer.Recorder require.Len(t, recorder.ToolUsages(), 1) diff --git a/aibridge/internal/integrationtest/mockmcp.go b/aibridge/internal/integrationtest/mockmcp.go index faaea788f2..ffbd4fad19 100644 --- a/aibridge/internal/integrationtest/mockmcp.go +++ b/aibridge/internal/integrationtest/mockmcp.go @@ -7,7 +7,6 @@ import ( "net/http/httptest" "sync" "testing" - "time" "github.com/mark3labs/mcp-go/client/transport" mcplib "github.com/mark3labs/mcp-go/mcp" @@ -19,6 +18,7 @@ import ( "cdr.dev/slog/v3" "cdr.dev/slog/v3/sloggers/slogtest" + "github.com/coder/coder/v2/aibridge/internal/testutil" "github.com/coder/coder/v2/aibridge/mcp" ) @@ -68,12 +68,12 @@ func setupMCPForTestWithName(t *testing.T, name string, tracer trace.Tracer) *mo mgr := mcp.NewServerProxyManager(map[string]mcp.ServerProxier{proxy.Name(): proxy}, tracer) t.Cleanup(func() { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) defer cancel() require.NoError(t, mgr.Shutdown(ctx)) }) - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong) t.Cleanup(cancel) require.NoError(t, mgr.Init(ctx)) require.NotEmpty(t, mgr.ListTools(), "mock MCP server should expose tools after init") @@ -141,7 +141,7 @@ func createMockMCPSrv(t *testing.T) (http.Handler, *callAccumulator) { tool := mcplib.NewTool(name, mcplib.WithDescription(fmt.Sprintf("Mock of the %s tool", name)), ) - s.AddTool(tool, func(ctx context.Context, request mcplib.CallToolRequest) (*mcplib.CallToolResult, error) { + s.AddTool(tool, func(_ context.Context, request mcplib.CallToolRequest) (*mcplib.CallToolResult, error) { acc.addCall(request.Params.Name, request.Params.Arguments) if errMsg, ok := acc.getToolError(request.Params.Name); ok { return nil, xerrors.New(errMsg) diff --git a/aibridge/internal/integrationtest/mockupstream.go b/aibridge/internal/integrationtest/mockupstream.go index 909f8ac1f8..4d8668e776 100644 --- a/aibridge/internal/integrationtest/mockupstream.go +++ b/aibridge/internal/integrationtest/mockupstream.go @@ -111,9 +111,9 @@ func (ms *mockUpstream) receivedRequests() []receivedRequest { // The test fails if the number of requests doesn't match the number of // responses (when AllowOverflow is not set, default). // -// srv := newMockUpstream(t, ctx, newFixtureResponse(fix)) // simple -// srv := newMockUpstream(t, ctx, newFixtureResponse(fix), newFixtureToolResponse(fix)) // multi-turn -func newMockUpstream(t *testing.T, ctx context.Context, responses ...upstreamResponse) *mockUpstream { +// srv := newMockUpstream(ctx, t, newFixtureResponse(fix)) // simple +// srv := newMockUpstream(ctx, t, newFixtureResponse(fix), newFixtureToolResponse(fix)) // multi-turn +func newMockUpstream(ctx context.Context, t *testing.T, responses ...upstreamResponse) *mockUpstream { t.Helper() require.NotEmpty(t, responses, "at least one upstreamResponse required") diff --git a/aibridge/internal/integrationtest/responses_test.go b/aibridge/internal/integrationtest/responses_test.go index 00f52efc46..8498588173 100644 --- a/aibridge/internal/integrationtest/responses_test.go +++ b/aibridge/internal/integrationtest/responses_test.go @@ -1,4 +1,4 @@ -package integrationtest +package integrationtest //nolint:testpackage // tests unexported internals import ( "context" @@ -22,6 +22,7 @@ import ( "github.com/coder/coder/v2/aibridge" "github.com/coder/coder/v2/aibridge/config" "github.com/coder/coder/v2/aibridge/fixtures" + "github.com/coder/coder/v2/aibridge/internal/testutil" "github.com/coder/coder/v2/aibridge/provider" "github.com/coder/coder/v2/aibridge/recorder" "github.com/coder/coder/v2/aibridge/utils" @@ -335,15 +336,17 @@ func TestResponsesOutputMatchesUpstream(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong) t.Cleanup(cancel) fix := fixtures.Parse(t, tc.fixture) - upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) + upstream := newMockUpstream(ctx, t, newFixtureResponse(fix)) - bridgeServer := newBridgeTestServer(t, ctx, upstream.URL) + bridgeServer := newBridgeTestServer(ctx, t, upstream.URL) - resp := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIResponses, fix.Request(), http.Header{"User-Agent": {tc.userAgent}}) + resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIResponses, fix.Request(), http.Header{"User-Agent": {tc.userAgent}}) + require.NoError(t, err) + defer resp.Body.Close() require.Equal(t, http.StatusOK, resp.StatusCode) got, err := io.ReadAll(resp.Body) @@ -416,7 +419,7 @@ func TestResponsesBackgroundModeForbidden(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong) t.Cleanup(cancel) // request with Background mode should be rejected before it reaches upstream @@ -426,11 +429,13 @@ func TestResponsesBackgroundModeForbidden(t *testing.T) { })) t.Cleanup(upstream.Close) - bridgeServer := newBridgeTestServer(t, ctx, upstream.URL) + bridgeServer := newBridgeTestServer(ctx, t, upstream.URL) // Create a request with background mode enabled reqBytes := responsesRequestBytes(t, tc.streaming, keyVal{"background", true}) - resp := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIResponses, reqBytes) + resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIResponses, reqBytes) + require.NoError(t, err) + defer resp.Body.Close() require.Equal(t, "application/json", resp.Header.Get("Content-Type")) require.Equal(t, http.StatusNotImplemented, resp.StatusCode) @@ -547,17 +552,17 @@ func TestResponsesParallelToolsOverwritten(t *testing.T) { t.Run(fmt.Sprintf("%s/streaming=%v", tc.name, streaming), func(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong) t.Cleanup(cancel) fix := fixtures.Parse(t, tc.fixture[i]) - upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) + upstream := newMockUpstream(ctx, t, newFixtureResponse(fix)) var opts []bridgeOption if tc.withInjectedTools { opts = append(opts, withMCP(setupMCPForTest(t, defaultTracer))) } - bridgeServer := newBridgeTestServer(t, ctx, upstream.URL, opts...) + bridgeServer := newBridgeTestServer(ctx, t, upstream.URL, opts...) var ( reqBody = fix.Request() @@ -568,7 +573,9 @@ func TestResponsesParallelToolsOverwritten(t *testing.T) { require.NoError(t, err) } - resp := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIResponses, reqBody) + resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIResponses, reqBody) + require.NoError(t, err) + defer resp.Body.Close() _, err = io.ReadAll(resp.Body) require.NoError(t, err) @@ -631,14 +638,16 @@ func TestClientAndConnectionError(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong) t.Cleanup(cancel) // tc.addr may be an intentionally invalid URL; use withCustomProvider. - bridgeServer := newBridgeTestServer(t, ctx, tc.addr, withCustomProvider(provider.NewOpenAI(openAICfg(tc.addr, apiKey)))) + bridgeServer := newBridgeTestServer(ctx, t, tc.addr, withCustomProvider(provider.NewOpenAI(openAICfg(tc.addr, apiKey)))) reqBytes := responsesRequestBytes(t, tc.streaming) - resp := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIResponses, reqBytes) + resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIResponses, reqBytes) + require.NoError(t, err) + defer resp.Body.Close() require.Equal(t, "application/json", resp.Header.Get("Content-Type")) require.Equal(t, http.StatusInternalServerError, resp.StatusCode) @@ -701,7 +710,7 @@ func TestUpstreamError(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong) t.Cleanup(cancel) upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -712,10 +721,12 @@ func TestUpstreamError(t *testing.T) { })) t.Cleanup(upstream.Close) - bridgeServer := newBridgeTestServer(t, ctx, upstream.URL) + bridgeServer := newBridgeTestServer(ctx, t, upstream.URL) reqBytes := responsesRequestBytes(t, tc.streaming) - resp := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIResponses, reqBytes) + resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIResponses, reqBytes) + require.NoError(t, err) + defer resp.Body.Close() require.Equal(t, tc.statusCode, resp.StatusCode) require.Equal(t, tc.contentType, resp.Header.Get("Content-Type")) @@ -880,13 +891,13 @@ func TestResponsesInjectedTool(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong) t.Cleanup(cancel) // Setup mock server for multi-turn interaction. // First request → tool call response, second → tool response. fix := fixtures.Parse(t, tc.fixture) - upstream := newMockUpstream(t, ctx, newFixtureResponse(fix), newFixtureToolResponse(fix)) + upstream := newMockUpstream(ctx, t, newFixtureResponse(fix), newFixtureToolResponse(fix)) // Setup MCP server proxies (with mock tools). mockMCP := setupMCPForTest(t, defaultTracer) @@ -894,9 +905,11 @@ func TestResponsesInjectedTool(t *testing.T) { mockMCP.setToolError(tc.mcpToolName, tc.expectToolError) } - bridgeServer := newBridgeTestServer(t, ctx, upstream.URL, withMCP(mockMCP)) + bridgeServer := newBridgeTestServer(ctx, t, upstream.URL, withMCP(mockMCP)) - resp := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIResponses, fix.Request()) + resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIResponses, fix.Request()) + require.NoError(t, err) + defer resp.Body.Close() require.Equal(t, http.StatusOK, resp.StatusCode) body, err := io.ReadAll(resp.Body) @@ -905,7 +918,7 @@ func TestResponsesInjectedTool(t *testing.T) { // Wait for both requests to be made (inner agentic loop). require.Eventually(t, func() bool { return upstream.Calls.Load() == 2 - }, time.Second*10, time.Millisecond*50) + }, testutil.WaitMedium, testutil.IntervalFast) // Verify the injected tool was invoked via MCP. invocations := mockMCP.getCallsByTool(tc.mcpToolName) @@ -1025,18 +1038,20 @@ func TestResponsesModelThoughts(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong) t.Cleanup(cancel) fix := fixtures.Parse(t, tc.fixture) - upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) + upstream := newMockUpstream(ctx, t, newFixtureResponse(fix)) - bridgeServer := newBridgeTestServer(t, ctx, upstream.URL) + bridgeServer := newBridgeTestServer(ctx, t, upstream.URL) - resp := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIResponses, fix.Request()) + resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIResponses, fix.Request()) + require.NoError(t, err) + defer resp.Body.Close() require.Equal(t, http.StatusOK, resp.StatusCode) - _, err := io.ReadAll(resp.Body) + _, err = io.ReadAll(resp.Body) require.NoError(t, err) bridgeServer.Recorder.VerifyModelThoughtsRecorded(t, tc.expectedThoughts) diff --git a/aibridge/internal/integrationtest/setupbridge.go b/aibridge/internal/integrationtest/setupbridge.go index 581a5e4b42..a77ac6f61a 100644 --- a/aibridge/internal/integrationtest/setupbridge.go +++ b/aibridge/internal/integrationtest/setupbridge.go @@ -7,7 +7,6 @@ import ( "net/http" "net/http/httptest" "testing" - "time" "github.com/stretchr/testify/require" "github.com/tidwall/sjson" @@ -63,11 +62,13 @@ type bridgeTestServer struct { // makeRequest builds and executes an HTTP request against this server. // Optional headers are applied after the default Content-Type. -func (s *bridgeTestServer) makeRequest(t *testing.T, method string, path string, body []byte, header ...http.Header) *http.Response { +func (s *bridgeTestServer) makeRequest(t *testing.T, method string, path string, body []byte, header ...http.Header) (*http.Response, error) { t.Helper() req, err := http.NewRequestWithContext(t.Context(), method, s.URL+path, bytes.NewReader(body)) - require.NoError(t, err) + if err != nil { + return nil, err + } req.Header.Set("Content-Type", "application/json") for _, h := range header { for k, vals := range h { @@ -76,10 +77,7 @@ func (s *bridgeTestServer) makeRequest(t *testing.T, method string, path string, } } } - resp, err := http.DefaultClient.Do(req) - require.NoError(t, err) - t.Cleanup(func() { _ = resp.Body.Close() }) - return resp + return http.DefaultClient.Do(req) } type bridgeOption func(*bridgeConfig) @@ -133,8 +131,8 @@ func withActor(id string, md recorder.Metadata) bridgeOption { // - defaultTracer (unless withTracer) // - defaultActorID (unless withActor) func newBridgeTestServer( - t *testing.T, ctx context.Context, + t *testing.T, upstreamURL string, opts ...bridgeOption, ) *bridgeTestServer { @@ -209,7 +207,7 @@ func setupInjectedToolTest( ) (*bridgeTestServer, *mockMCP, *http.Response) { t.Helper() - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong) t.Cleanup(cancel) fix := fixtures.Parse(t, fixture) @@ -220,7 +218,7 @@ func setupInjectedToolTest( firstResp := newFixtureResponse(fix) toolResp := newFixtureToolResponse(fix) toolResp.OnRequest = toolRequestValidatorFn - upstream := newMockUpstream(t, ctx, firstResp, toolResp) + upstream := newMockUpstream(ctx, t, firstResp, toolResp) mockMCP := setupMCPForTest(t, tracer) @@ -230,19 +228,20 @@ func setupInjectedToolTest( withActor(defaultActorID, nil), } allOpts = append(allOpts, opts...) - bridgeServer := newBridgeTestServer(t, ctx, upstream.URL, allOpts...) + bridgeServer := newBridgeTestServer(ctx, t, upstream.URL, allOpts...) // Add the stream param to the request. reqBody, err := sjson.SetBytes(fix.Request(), "stream", streaming) require.NoError(t, err) - resp := bridgeServer.makeRequest(t, http.MethodPost, path, reqBody) + resp, err := bridgeServer.makeRequest(t, http.MethodPost, path, reqBody) + require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) // Wait both requests (initial + tool call result) require.Eventually(t, func() bool { return upstream.Calls.Load() == 2 - }, time.Second*10, time.Millisecond*50) + }, testutil.WaitMedium, testutil.IntervalFast) return bridgeServer, mockMCP, resp } diff --git a/aibridge/internal/integrationtest/trace_test.go b/aibridge/internal/integrationtest/trace_test.go index 28c266e9dd..f3e835ca8a 100644 --- a/aibridge/internal/integrationtest/trace_test.go +++ b/aibridge/internal/integrationtest/trace_test.go @@ -1,4 +1,4 @@ -package integrationtest +package integrationtest //nolint:testpackage // tests unexported internals import ( "context" @@ -6,7 +6,6 @@ import ( "slices" "strings" "testing" - "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -20,6 +19,7 @@ import ( "github.com/coder/coder/v2/aibridge/config" "github.com/coder/coder/v2/aibridge/fixtures" + "github.com/coder/coder/v2/aibridge/internal/testutil" "github.com/coder/coder/v2/aibridge/tracing" ) @@ -43,6 +43,8 @@ func setupTracer(t *testing.T) (*tracetest.SpanRecorder, oteltrace.Tracer) { } func TestTraceAnthropic(t *testing.T) { + t.Parallel() + expectNonStreaming := []expectTrace{ {"Intercept", 1, codes.Unset}, {"Intercept.CreateInterceptor", 1, codes.Unset}, @@ -137,13 +139,15 @@ func TestTraceAnthropic(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + t.Parallel() + + ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong) t.Cleanup(cancel) sr, tracer := setupTracer(t) fix := fixtures.Parse(t, tc.fixture) - upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) + upstream := newMockUpstream(ctx, t, newFixtureResponse(fix)) opts := []bridgeOption{ withTracer(tracer), @@ -151,11 +155,13 @@ func TestTraceAnthropic(t *testing.T) { if tc.bedrock { opts = append(opts, withProvider(providerBedrock)) } - bridgeServer := newBridgeTestServer(t, ctx, upstream.URL, opts...) + bridgeServer := newBridgeTestServer(ctx, t, upstream.URL, opts...) reqBody, err := sjson.SetBytes(fix.Request(), "stream", tc.streaming) require.NoError(t, err) - resp := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody) + resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody) + require.NoError(t, err) + defer resp.Body.Close() require.Equal(t, http.StatusOK, resp.StatusCode) bridgeServer.Close() @@ -189,6 +195,8 @@ func TestTraceAnthropic(t *testing.T) { } func TestTraceAnthropicErr(t *testing.T) { + t.Parallel() + expectNonStream := []expectTrace{ {"Intercept", 1, codes.Error}, {"Intercept.CreateInterceptor", 1, codes.Unset}, @@ -247,13 +255,15 @@ func TestTraceAnthropicErr(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + t.Parallel() + + ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong) t.Cleanup(cancel) sr, tracer := setupTracer(t) fix := fixtures.Parse(t, tc.fixture) - upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) + upstream := newMockUpstream(ctx, t, newFixtureResponse(fix)) opts := []bridgeOption{ withTracer(tracer), @@ -261,11 +271,13 @@ func TestTraceAnthropicErr(t *testing.T) { if tc.bedrock { opts = append(opts, withProvider(providerBedrock)) } - bridgeServer := newBridgeTestServer(t, ctx, upstream.URL, opts...) + bridgeServer := newBridgeTestServer(ctx, t, upstream.URL, opts...) reqBody, err := sjson.SetBytes(fix.Request(), "stream", tc.streaming) require.NoError(t, err) - resp := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody) + resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody) + require.NoError(t, err) + defer resp.Body.Close() if tc.streaming { require.Equal(t, http.StatusOK, resp.StatusCode) } else { @@ -385,10 +397,11 @@ func TestInjectedToolsTrace(t *testing.T) { validatorFn = openaiChatToolResultValidator(t) } - bridgeServer, mockMCP, _ := setupInjectedToolTest( + bridgeServer, mockMCP, resp := setupInjectedToolTest( t, tc.fixture, tc.streaming, tracer, tc.path, validatorFn, tc.opts..., ) + defer resp.Body.Close() require.Len(t, bridgeServer.Recorder.RecordedInterceptions(), 1) intcID := bridgeServer.Recorder.RecordedInterceptions()[0].ID @@ -417,6 +430,8 @@ func TestInjectedToolsTrace(t *testing.T) { } func TestTraceOpenAI(t *testing.T) { + t.Parallel() + cases := []struct { name string fixture []byte @@ -529,20 +544,24 @@ func TestTraceOpenAI(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + t.Parallel() + + ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong) t.Cleanup(cancel) sr, tracer := setupTracer(t) fix := fixtures.Parse(t, tc.fixture) - upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) - bridgeServer := newBridgeTestServer(t, ctx, upstream.URL, + upstream := newMockUpstream(ctx, t, newFixtureResponse(fix)) + bridgeServer := newBridgeTestServer(ctx, t, upstream.URL, withTracer(tracer), ) reqBody, err := sjson.SetBytes(fix.Request(), "stream", tc.streaming) require.NoError(t, err) - resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, reqBody) + resp, err := bridgeServer.makeRequest(t, http.MethodPost, tc.path, reqBody) + require.NoError(t, err) + defer resp.Body.Close() require.Equal(t, http.StatusOK, resp.StatusCode) bridgeServer.Close() @@ -569,6 +588,8 @@ func TestTraceOpenAI(t *testing.T) { } func TestTraceOpenAIErr(t *testing.T) { + t.Parallel() + cases := []struct { name string fixture []byte @@ -647,7 +668,7 @@ func TestTraceOpenAIErr(t *testing.T) { }, { name: "trace_openai_responses_streaming_http_error", - fixture: fixtures.OaiResponsesStreamingHttpErr, + fixture: fixtures.OaiResponsesStreamingHTTPErr, streaming: true, allowOverflow: true, // 429 error causes retries @@ -664,7 +685,7 @@ func TestTraceOpenAIErr(t *testing.T) { }, { name: "trace_openai_responses_blocking_http_error", - fixture: fixtures.OaiResponsesBlockingHttpErr, + fixture: fixtures.OaiResponsesBlockingHTTPErr, streaming: false, path: pathOpenAIResponses, @@ -682,22 +703,26 @@ func TestTraceOpenAIErr(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + t.Parallel() + + ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong) t.Cleanup(cancel) sr, tracer := setupTracer(t) fix := fixtures.Parse(t, tc.fixture) - mockAPI := newMockUpstream(t, ctx, newFixtureResponse(fix)) + mockAPI := newMockUpstream(ctx, t, newFixtureResponse(fix)) mockAPI.AllowOverflow = tc.allowOverflow - bridgeServer := newBridgeTestServer(t, ctx, mockAPI.URL, + bridgeServer := newBridgeTestServer(ctx, t, mockAPI.URL, withTracer(tracer), ) reqBody, err := sjson.SetBytes(fix.Request(), "stream", tc.streaming) require.NoError(t, err) - resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, reqBody) + resp, err := bridgeServer.makeRequest(t, http.MethodPost, tc.path, reqBody) + require.NoError(t, err) + defer resp.Body.Close() require.Equal(t, tc.expectCode, resp.StatusCode) bridgeServer.Close() @@ -729,15 +754,17 @@ func TestTracePassthrough(t *testing.T) { fix := fixtures.Parse(t, fixtures.OaiChatFallthrough) - upstream := newMockUpstream(t, t.Context(), newFixtureResponse(fix)) + upstream := newMockUpstream(t.Context(), t, newFixtureResponse(fix)) sr, tracer := setupTracer(t) - bridgeServer := newBridgeTestServer(t, t.Context(), upstream.URL, + bridgeServer := newBridgeTestServer(t.Context(), t, upstream.URL, withTracer(tracer), ) - resp := bridgeServer.makeRequest(t, http.MethodGet, "/openai/v1/models", nil) + resp, err := bridgeServer.makeRequest(t, http.MethodGet, "/openai/v1/models", nil) + require.NoError(t, err) + defer resp.Body.Close() require.Equal(t, http.StatusOK, resp.StatusCode) bridgeServer.Close() @@ -755,6 +782,8 @@ func TestTracePassthrough(t *testing.T) { } func TestNewServerProxyManagerTraces(t *testing.T) { + t.Parallel() + sr, tracer := setupTracer(t) serverName := "serverName" diff --git a/aibridge/internal/testutil/mock_recorder.go b/aibridge/internal/testutil/mock_recorder.go index fde8c27346..52a86c847d 100644 --- a/aibridge/internal/testutil/mock_recorder.go +++ b/aibridge/internal/testutil/mock_recorder.go @@ -26,14 +26,14 @@ type MockRecorder struct { interceptionsEnd map[string]*recorder.InterceptionRecordEnded } -func (m *MockRecorder) RecordInterception(ctx context.Context, req *recorder.InterceptionRecord) error { +func (m *MockRecorder) RecordInterception(_ context.Context, req *recorder.InterceptionRecord) error { m.mu.Lock() defer m.mu.Unlock() m.interceptions = append(m.interceptions, req) return nil } -func (m *MockRecorder) RecordInterceptionEnded(ctx context.Context, req *recorder.InterceptionRecordEnded) error { +func (m *MockRecorder) RecordInterceptionEnded(_ context.Context, req *recorder.InterceptionRecordEnded) error { m.mu.Lock() defer m.mu.Unlock() if m.interceptionsEnd == nil { @@ -46,28 +46,28 @@ func (m *MockRecorder) RecordInterceptionEnded(ctx context.Context, req *recorde return nil } -func (m *MockRecorder) RecordPromptUsage(ctx context.Context, req *recorder.PromptUsageRecord) error { +func (m *MockRecorder) RecordPromptUsage(_ context.Context, req *recorder.PromptUsageRecord) error { m.mu.Lock() defer m.mu.Unlock() m.userPrompts = append(m.userPrompts, req) return nil } -func (m *MockRecorder) RecordTokenUsage(ctx context.Context, req *recorder.TokenUsageRecord) error { +func (m *MockRecorder) RecordTokenUsage(_ context.Context, req *recorder.TokenUsageRecord) error { m.mu.Lock() defer m.mu.Unlock() m.tokenUsages = append(m.tokenUsages, req) return nil } -func (m *MockRecorder) RecordToolUsage(ctx context.Context, req *recorder.ToolUsageRecord) error { +func (m *MockRecorder) RecordToolUsage(_ context.Context, req *recorder.ToolUsageRecord) error { m.mu.Lock() defer m.mu.Unlock() m.toolUsages = append(m.toolUsages, req) return nil } -func (m *MockRecorder) RecordModelThought(ctx context.Context, req *recorder.ModelThoughtRecord) error { +func (m *MockRecorder) RecordModelThought(_ context.Context, req *recorder.ModelThoughtRecord) error { m.mu.Lock() defer m.mu.Unlock() m.modelThoughts = append(m.modelThoughts, req) diff --git a/aibridge/internal/testutil/mockprovider.go b/aibridge/internal/testutil/mockprovider.go index 2333d5fc7a..fd876daaff 100644 --- a/aibridge/internal/testutil/mockprovider.go +++ b/aibridge/internal/testutil/mockprovider.go @@ -11,26 +11,26 @@ import ( ) type MockProvider struct { - Name_ string + NameStr string URL string Bridged []string Passthrough []string InterceptorFunc func(w http.ResponseWriter, r *http.Request, tracer trace.Tracer) (intercept.Interceptor, error) } -func (m *MockProvider) Type() string { return m.Name_ } -func (m *MockProvider) Name() string { return m.Name_ } -func (m *MockProvider) BaseURL() string { return m.URL } -func (m *MockProvider) RoutePrefix() string { return fmt.Sprintf("/%s", m.Name_) } -func (m *MockProvider) BridgedRoutes() []string { return m.Bridged } -func (m *MockProvider) PassthroughRoutes() []string { return m.Passthrough } -func (m *MockProvider) AuthHeader() string { return "Authorization" } -func (m *MockProvider) InjectAuthHeader(h *http.Header) {} -func (m *MockProvider) CircuitBreakerConfig() *config.CircuitBreaker { return nil } -func (m *MockProvider) APIDumpDir() string { return "" } +func (m *MockProvider) Type() string { return m.NameStr } +func (m *MockProvider) Name() string { return m.NameStr } +func (m *MockProvider) BaseURL() string { return m.URL } +func (m *MockProvider) RoutePrefix() string { return fmt.Sprintf("/%s", m.NameStr) } +func (m *MockProvider) BridgedRoutes() []string { return m.Bridged } +func (m *MockProvider) PassthroughRoutes() []string { return m.Passthrough } +func (*MockProvider) AuthHeader() string { return "Authorization" } +func (*MockProvider) InjectAuthHeader(_ *http.Header) {} +func (*MockProvider) CircuitBreakerConfig() *config.CircuitBreaker { return nil } +func (*MockProvider) APIDumpDir() string { return "" } func (m *MockProvider) CreateInterceptor(w http.ResponseWriter, r *http.Request, tracer trace.Tracer) (intercept.Interceptor, error) { if m.InterceptorFunc != nil { return m.InterceptorFunc(w, r, tracer) } - return nil, nil + return nil, nil //nolint:nilnil // mock: no interceptor configured is not an error } diff --git a/aibridge/internal/testutil/timeout.go b/aibridge/internal/testutil/timeout.go new file mode 100644 index 0000000000..ef8b2b530d --- /dev/null +++ b/aibridge/internal/testutil/timeout.go @@ -0,0 +1,21 @@ +package testutil + +import "time" + +// Shared test timeout and interval constants. +// Using named constants avoids magic numbers and makes timeout policy +// easy to adjust across the entire test suite. +const ( + // WaitLong is the default timeout for test operations that may take a while + // (e.g. integration tests with HTTP round-trips). + WaitLong = 30 * time.Second + + // WaitMedium is a timeout for moderately slow operations. + WaitMedium = 10 * time.Second + + // WaitShort is a timeout for operations expected to complete quickly. + WaitShort = 5 * time.Second + + // IntervalFast is a short polling interval for require.Eventually and similar. + IntervalFast = 50 * time.Millisecond +) diff --git a/aibridge/mcp/client_info_test.go b/aibridge/mcp/client_info_test.go index 9dadc7a461..77f4ee7b0e 100644 --- a/aibridge/mcp/client_info_test.go +++ b/aibridge/mcp/client_info_test.go @@ -9,6 +9,8 @@ import ( ) func TestGetClientInfo(t *testing.T) { + t.Parallel() + info := mcp.GetClientInfo() assert.Equal(t, "coder/aibridge", info.Name) diff --git a/aibridge/mcp/mcp_test.go b/aibridge/mcp/mcp_test.go index 4e5c5f4ec6..aeea86e72d 100644 --- a/aibridge/mcp/mcp_test.go +++ b/aibridge/mcp/mcp_test.go @@ -9,7 +9,6 @@ import ( "slices" "strings" "testing" - "time" mcplib "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/server" @@ -19,6 +18,7 @@ import ( "cdr.dev/slog/v3" "cdr.dev/slog/v3/sloggers/slogtest" + "github.com/coder/coder/v2/aibridge/internal/testutil" "github.com/coder/coder/v2/aibridge/mcp" ) @@ -299,7 +299,7 @@ func TestToolInjectionOrder(t *testing.T) { // Setup. logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong) t.Cleanup(cancel) // Given: a MCP mock server offering a set of tools. diff --git a/aibridge/mcp/proxy_streamable_http.go b/aibridge/mcp/proxy_streamable_http.go index f926a8eb69..132c03965a 100644 --- a/aibridge/mcp/proxy_streamable_http.go +++ b/aibridge/mcp/proxy_streamable_http.go @@ -87,7 +87,7 @@ func (p *StreamableHTTPServerProxy) Init(ctx context.Context) (outErr error) { return xerrors.Errorf("MCP version negotiation failed; requested %q, accepts %q, received %q", version, strings.Join(mcp.ValidProtocolVersions, ","), result.ProtocolVersion) } - p.logger.Debug(ctx, "MCP client initialized", slog.F("name", result.ServerInfo.Name), slog.F("server_version", result.ServerInfo.Version)) + p.logger.Debug(ctx, "mcp client initialized", slog.F("name", result.ServerInfo.Name), slog.F("server_version", result.ServerInfo.Version)) tools, err := p.fetchTools(ctx) if err != nil { @@ -161,7 +161,7 @@ func (p *StreamableHTTPServerProxy) fetchTools(ctx context.Context) (_ map[strin return out, nil } -func (p *StreamableHTTPServerProxy) Shutdown(ctx context.Context) error { +func (p *StreamableHTTPServerProxy) Shutdown(_ context.Context) error { if p.client == nil { return nil } diff --git a/aibridge/mcp/tool.go b/aibridge/mcp/tool.go index c61a8c8f2b..8fbca9d224 100644 --- a/aibridge/mcp/tool.go +++ b/aibridge/mcp/tool.go @@ -59,15 +59,15 @@ func (t *Tool) Call(ctx context.Context, input any, tracer trace.Tracer) (_ *mcp ctx, span := tracer.Start(ctx, "Intercept.ProcessRequest.ToolCall", trace.WithAttributes(spanAttrs...)) defer tracing.EndSpanErr(span, &outErr) - inputJson, err := json.Marshal(input) + inputJSON, err := json.Marshal(input) if err != nil { t.Logger.Warn(ctx, "failed to marshal tool input, will be omitted from span attrs", slog.Error(err)) } else { - strJson := string(inputJson) - if len(strJson) > maxSpanInputAttrLen { - strJson = strJson[:maxSpanInputAttrLen] + strJSON := string(inputJSON) + if len(strJSON) > maxSpanInputAttrLen { + strJSON = strJSON[:maxSpanInputAttrLen] } - span.SetAttributes(attribute.String(tracing.MCPInput, strJson)) + span.SetAttributes(attribute.String(tracing.MCPInput, strJSON)) } start := time.Now() @@ -88,7 +88,7 @@ func (t *Tool) Call(ctx context.Context, input any, tracer trace.Tracer) (_ *mcp logFn(ctx, "injected tool invoked", slog.F("name", t.Name), slog.F("server", t.ServerName), - slog.F("input", inputJson), + slog.F("input", inputJSON), slog.F("duration_sec", time.Since(start).Seconds()), slog.Error(outErr), ) @@ -106,12 +106,13 @@ func (t *Tool) Call(ctx context.Context, input any, tracer trace.Tracer) (_ *mcp // - https://community.openai.com/t/function-call-description-max-length/529902 // - https://github.com/anthropics/claude-code/issues/2326 func EncodeToolID(server, tool string) string { + // strings.Builder writes to in-memory storage and never return errors. var sb strings.Builder - sb.WriteString(injectedToolPrefix) - sb.WriteString(injectedToolDelimiter) - sb.WriteString(server) - sb.WriteString(injectedToolDelimiter) - sb.WriteString(tool) + _, _ = sb.WriteString(injectedToolPrefix) + _, _ = sb.WriteString(injectedToolDelimiter) + _, _ = sb.WriteString(server) + _, _ = sb.WriteString(injectedToolDelimiter) + _, _ = sb.WriteString(tool) return sb.String() } diff --git a/aibridge/mcpmock/doc.go b/aibridge/mcpmock/doc.go index 1e3a73bbfc..0b615f2d69 100644 --- a/aibridge/mcpmock/doc.go +++ b/aibridge/mcpmock/doc.go @@ -1,3 +1,3 @@ package mcpmock -//go:generate mockgen -destination ./mcpmock.go -package mcpmock github.com/coder/coder/v2/aibridge/mcp ServerProxier +//go:generate mockgen -destination ./mcpmock.go -package mcpmock github.com/coder/aibridge/mcp ServerProxier diff --git a/aibridge/mcpmock/mcpmock.go b/aibridge/mcpmock/mcpmock.go index 81cf41f8a9..2678c73352 100644 --- a/aibridge/mcpmock/mcpmock.go +++ b/aibridge/mcpmock/mcpmock.go @@ -1,9 +1,9 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/coder/coder/v2/aibridge/mcp (interfaces: ServerProxier) +// Source: github.com/coder/aibridge/mcp (interfaces: ServerProxier) // // Generated by this command: // -// mockgen -destination ./mcpmock.go -package mcpmock github.com/coder/coder/v2/aibridge/mcp ServerProxier +// mockgen -destination ./mcpmock.go -package mcpmock github.com/coder/aibridge/mcp ServerProxier // // Package mcpmock is a generated GoMock package. diff --git a/aibridge/metrics/metrics.go b/aibridge/metrics/metrics.go index 6d14ab9d20..ec2d182fdf 100644 --- a/aibridge/metrics/metrics.go +++ b/aibridge/metrics/metrics.go @@ -5,7 +5,7 @@ import ( "github.com/prometheus/client_golang/prometheus/promauto" ) -var baseLabels []string = []string{"provider", "model"} +var baseLabels = []string{"provider", "model"} const ( InterceptionCountStatusFailed = "failed" diff --git a/aibridge/passthrough.go b/aibridge/passthrough.go index 284f59a902..2802c66e17 100644 --- a/aibridge/passthrough.go +++ b/aibridge/passthrough.go @@ -21,10 +21,10 @@ import ( // newPassthroughRouter returns a simple reverse-proxy implementation which will be used when a route is not handled specifically // by a [intercept.Provider]. -func newPassthroughRouter(provider provider.Provider, logger slog.Logger, m *metrics.Metrics, tracer trace.Tracer) http.HandlerFunc { +func newPassthroughRouter(prov provider.Provider, logger slog.Logger, m *metrics.Metrics, tracer trace.Tracer) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { if m != nil { - m.PassthroughCount.WithLabelValues(provider.Name(), r.URL.Path, r.Method).Add(1) + m.PassthroughCount.WithLabelValues(prov.Name(), r.URL.Path, r.Method).Add(1) } ctx, span := tracer.Start(r.Context(), "Passthrough", trace.WithAttributes( @@ -33,7 +33,7 @@ func newPassthroughRouter(provider provider.Provider, logger slog.Logger, m *met )) defer span.End() - upURL, err := url.Parse(provider.BaseURL()) + upURL, err := url.Parse(prov.BaseURL()) if err != nil { logger.Warn(ctx, "failed to parse provider base URL", slog.Error(err)) http.Error(w, "request error", http.StatusBadGateway) @@ -44,7 +44,7 @@ func newPassthroughRouter(provider provider.Provider, logger slog.Logger, m *met // Append the request path to the upstream base path. reqPath, err := url.JoinPath(upURL.Path, r.URL.Path) if err != nil { - logger.Warn(ctx, "failed to join upstream path", slog.Error(err), slog.F("upstreamPath", upURL.Path), slog.F("requestPath", r.URL.Path)) + logger.Warn(ctx, "failed to join upstream path", slog.Error(err), slog.F("upstream_path", upURL.Path), slog.F("request_path", r.URL.Path)) http.Error(w, "failed to join upstream path", http.StatusInternalServerError) span.SetStatus(codes.Error, "failed to join upstream path: "+err.Error()) return @@ -96,7 +96,7 @@ func newPassthroughRouter(provider provider.Provider, logger slog.Logger, m *met } // Inject provider auth. - provider.InjectAuthHeader(&req.Header) + prov.InjectAuthHeader(&req.Header) }, ErrorHandler: func(rw http.ResponseWriter, req *http.Request, e error) { logger.Warn(req.Context(), "reverse proxy error", slog.Error(e), slog.F("path", req.URL.Path)) @@ -113,7 +113,7 @@ func newPassthroughRouter(provider provider.Provider, logger slog.Logger, m *met TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second, } - proxy.Transport = apidump.NewPassthroughMiddleware(t, provider.APIDumpDir(), provider.Name(), logger, quartz.NewReal()) + proxy.Transport = apidump.NewPassthroughMiddleware(t, prov.APIDumpDir(), prov.Name(), logger, quartz.NewReal()) proxy.ServeHTTP(w, r) } diff --git a/aibridge/passthrough_test.go b/aibridge/passthrough_test.go index 90c7f1155d..d4b664d503 100644 --- a/aibridge/passthrough_test.go +++ b/aibridge/passthrough_test.go @@ -1,4 +1,4 @@ -package aibridge +package aibridge //nolint:testpackage // tests unexported newPassthroughRouter import ( "net/http" diff --git a/aibridge/provider/anthropic.go b/aibridge/provider/anthropic.go index 603b416d70..03b214807d 100644 --- a/aibridge/provider/anthropic.go +++ b/aibridge/provider/anthropic.go @@ -73,7 +73,7 @@ func NewAnthropic(cfg config.Anthropic, bedrockCfg *config.AWSBedrock) *Anthropi } } -func (p *Anthropic) Type() string { +func (*Anthropic) Type() string { return config.ProviderAnthropic } @@ -85,11 +85,11 @@ func (p *Anthropic) RoutePrefix() string { return fmt.Sprintf("/%s", p.Name()) } -func (p *Anthropic) BridgedRoutes() []string { +func (*Anthropic) BridgedRoutes() []string { return []string{routeMessages} } -func (p *Anthropic) PassthroughRoutes() []string { +func (*Anthropic) PassthroughRoutes() []string { return []string{ "/v1/models", "/v1/models/", // See https://pkg.go.dev/net/http#hdr-Trailing_slash_redirection-ServeMux. @@ -98,77 +98,76 @@ func (p *Anthropic) PassthroughRoutes() []string { } } -func (p *Anthropic) CreateInterceptor(w http.ResponseWriter, r *http.Request, tracer trace.Tracer) (_ intercept.Interceptor, outErr error) { +func (p *Anthropic) CreateInterceptor(_ http.ResponseWriter, r *http.Request, tracer trace.Tracer) (_ intercept.Interceptor, outErr error) { id := uuid.New() _, span := tracer.Start(r.Context(), "Intercept.CreateInterceptor") defer tracing.EndSpanErr(span, &outErr) path := strings.TrimPrefix(r.URL.Path, p.RoutePrefix()) - switch path { - case routeMessages: - payload, err := io.ReadAll(r.Body) - if err != nil { - return nil, xerrors.Errorf("read body: %w", err) - } - - reqPayload, err := messages.NewMessagesRequestPayload(payload) - if err != nil { - return nil, xerrors.Errorf("unmarshal request body: %w", err) - } - - cfg := p.cfg - cfg.ExtraHeaders = extractAnthropicHeaders(r) - - // At this point the request contains only LLM provider headers. - // Any Coder-specific authentication has already been stripped. - // - // In centralized mode neither Authorization nor X-Api-Key is - // present, so cfg keeps the centralized key unchanged. - // - // In BYOK mode the user's LLM credentials survive intact. - // If X-Api-Key is present the user has a personal API key; - // overwrite the centralized key with it. If Authorization is - // present the user authenticated directly with provider; - // set BYOKBearerToken and clear the centralized key. - // When both are present, X-Api-Key takes priority to match - // claude-code behavior. - credKind := intercept.CredentialKindCentralized - credSecret := cfg.Key - authHeaderName := p.AuthHeader() - if apiKey := r.Header.Get("X-Api-Key"); apiKey != "" { - cfg.Key = apiKey - authHeaderName = "X-Api-Key" - credKind = intercept.CredentialKindBYOK - credSecret = apiKey - } else if token := utils.ExtractBearerToken(r.Header.Get("Authorization")); token != "" { - cfg.BYOKBearerToken = token - cfg.Key = "" - authHeaderName = "Authorization" - credKind = intercept.CredentialKindBYOK - credSecret = token - } - - cred := intercept.NewCredentialInfo(credKind, credSecret) - - var interceptor intercept.Interceptor - if reqPayload.Stream() { - interceptor = messages.NewStreamingInterceptor(id, reqPayload, p.Name(), cfg, p.bedrockCfg, r.Header, authHeaderName, tracer, cred) - } else { - interceptor = messages.NewBlockingInterceptor(id, reqPayload, p.Name(), cfg, p.bedrockCfg, r.Header, authHeaderName, tracer, cred) - } - span.SetAttributes(interceptor.TraceAttributes(r)...) - return interceptor, nil + if path != routeMessages { + span.SetStatus(codes.Error, "unknown route: "+r.URL.Path) + return nil, ErrUnknownRoute } - span.SetStatus(codes.Error, "unknown route: "+r.URL.Path) - return nil, UnknownRoute + payload, err := io.ReadAll(r.Body) + if err != nil { + return nil, xerrors.Errorf("read body: %w", err) + } + + reqPayload, err := messages.NewRequestPayload(payload) + if err != nil { + return nil, xerrors.Errorf("unmarshal request body: %w", err) + } + + cfg := p.cfg + cfg.ExtraHeaders = extractAnthropicHeaders(r) + + // At this point the request contains only LLM provider headers. + // Any Coder-specific authentication has already been stripped. + // + // In centralized mode neither Authorization nor X-Api-Key is + // present, so cfg keeps the centralized key unchanged. + // + // In BYOK mode the user's LLM credentials survive intact. + // If X-Api-Key is present the user has a personal API key; + // overwrite the centralized key with it. If Authorization is + // present the user authenticated directly with provider; + // set BYOKBearerToken and clear the centralized key. + // When both are present, X-Api-Key takes priority to match + // claude-code behavior. + credKind := intercept.CredentialKindCentralized + credSecret := cfg.Key + authHeaderName := p.AuthHeader() + if apiKey := r.Header.Get("X-Api-Key"); apiKey != "" { + cfg.Key = apiKey + authHeaderName = "X-Api-Key" + credKind = intercept.CredentialKindBYOK + credSecret = apiKey + } else if token := utils.ExtractBearerToken(r.Header.Get("Authorization")); token != "" { + cfg.BYOKBearerToken = token + cfg.Key = "" + authHeaderName = "Authorization" + credKind = intercept.CredentialKindBYOK + credSecret = token + } + + cred := intercept.NewCredentialInfo(credKind, credSecret) + + var interceptor intercept.Interceptor + if reqPayload.Stream() { + interceptor = messages.NewStreamingInterceptor(id, reqPayload, p.Name(), cfg, p.bedrockCfg, r.Header, authHeaderName, tracer, cred) + } else { + interceptor = messages.NewBlockingInterceptor(id, reqPayload, p.Name(), cfg, p.bedrockCfg, r.Header, authHeaderName, tracer, cred) + } + span.SetAttributes(interceptor.TraceAttributes(r)...) + return interceptor, nil } func (p *Anthropic) BaseURL() string { return p.cfg.BaseURL } -func (p *Anthropic) AuthHeader() string { +func (*Anthropic) AuthHeader() string { return "X-Api-Key" } diff --git a/aibridge/provider/anthropic_test.go b/aibridge/provider/anthropic_test.go index ca8decfb60..b84be29b1b 100644 --- a/aibridge/provider/anthropic_test.go +++ b/aibridge/provider/anthropic_test.go @@ -1,4 +1,4 @@ -package provider +package provider //nolint:testpackage // tests unexported internals import ( "bytes" @@ -146,7 +146,7 @@ func TestAnthropic_CreateInterceptor(t *testing.T) { assert.Empty(t, receivedHeaders.Get("Authorization"), "client Authorization header must not reach upstream") }) - t.Run("UnknownRoute", func(t *testing.T) { + t.Run("ErrUnknownRoute", func(t *testing.T) { t.Parallel() body := `{"model": "claude-opus-4-5", "max_tokens": 1024, "messages": [{"role": "user", "content": "hello"}]}` @@ -155,7 +155,7 @@ func TestAnthropic_CreateInterceptor(t *testing.T) { interceptor, err := provider.CreateInterceptor(w, req, testTracer) - require.ErrorIs(t, err, UnknownRoute) + require.ErrorIs(t, err, ErrUnknownRoute) require.Nil(t, interceptor) }) } diff --git a/aibridge/provider/copilot.go b/aibridge/provider/copilot.go index a311ece415..4943f9cb55 100644 --- a/aibridge/provider/copilot.go +++ b/aibridge/provider/copilot.go @@ -72,7 +72,7 @@ func NewCopilot(cfg config.Copilot) *Copilot { } } -func (p *Copilot) Type() string { +func (*Copilot) Type() string { return config.ProviderCopilot } @@ -88,14 +88,14 @@ func (p *Copilot) RoutePrefix() string { return fmt.Sprintf("/%s", p.Name()) } -func (p *Copilot) BridgedRoutes() []string { +func (*Copilot) BridgedRoutes() []string { return []string{ routeCopilotChatCompletions, routeCopilotResponses, } } -func (p *Copilot) PassthroughRoutes() []string { +func (*Copilot) PassthroughRoutes() []string { return []string{ "/models", "/models/", @@ -105,7 +105,7 @@ func (p *Copilot) PassthroughRoutes() []string { } } -func (p *Copilot) AuthHeader() string { +func (*Copilot) AuthHeader() string { return "Authorization" } @@ -113,7 +113,7 @@ func (p *Copilot) AuthHeader() string { // Copilot uses per-user tokens passed in the original Authorization header, // rather than a global key configured at the provider level. // The original Authorization header flows through untouched from the client. -func (p *Copilot) InjectAuthHeader(_ *http.Header) {} +func (*Copilot) InjectAuthHeader(_ *http.Header) {} func (p *Copilot) CircuitBreakerConfig() *config.CircuitBreaker { return p.circuitBreaker @@ -170,7 +170,7 @@ func (p *Copilot) CreateInterceptor(_ http.ResponseWriter, r *http.Request, trac if err != nil { return nil, xerrors.Errorf("read body: %w", err) } - reqPayload, err := responses.NewResponsesRequestPayload(payload) + reqPayload, err := responses.NewRequestPayload(payload) if err != nil { return nil, xerrors.Errorf("unmarshal request body: %w", err) } @@ -183,7 +183,7 @@ func (p *Copilot) CreateInterceptor(_ http.ResponseWriter, r *http.Request, trac default: span.SetStatus(codes.Error, "unknown route: "+r.URL.Path) - return nil, UnknownRoute + return nil, ErrUnknownRoute } span.SetAttributes(interceptor.TraceAttributes(r)...) diff --git a/aibridge/provider/copilot_test.go b/aibridge/provider/copilot_test.go index 936fd4217c..cd30a83350 100644 --- a/aibridge/provider/copilot_test.go +++ b/aibridge/provider/copilot_test.go @@ -1,4 +1,4 @@ -package provider +package provider //nolint:testpackage // tests unexported internals import ( "bytes" @@ -300,7 +300,7 @@ func TestCopilot_CreateInterceptor(t *testing.T) { assert.Empty(t, receivedHeaders.Get("X-Api-Key"), "X-Api-Key must not be set upstream") }) - t.Run("UnknownRoute", func(t *testing.T) { + t.Run("ErrUnknownRoute", func(t *testing.T) { t.Parallel() body := `{"model": "gpt-4.1", "messages": [{"role": "user", "content": "hello"}]}` @@ -310,7 +310,7 @@ func TestCopilot_CreateInterceptor(t *testing.T) { interceptor, err := provider.CreateInterceptor(w, req, testTracer) - require.ErrorIs(t, err, UnknownRoute) + require.ErrorIs(t, err, ErrUnknownRoute) require.Nil(t, interceptor) }) } diff --git a/aibridge/provider/openai.go b/aibridge/provider/openai.go index f1136abf82..8b8527a64f 100644 --- a/aibridge/provider/openai.go +++ b/aibridge/provider/openai.go @@ -61,7 +61,7 @@ func NewOpenAI(cfg config.OpenAI) *OpenAI { } } -func (p *OpenAI) Type() string { +func (*OpenAI) Type() string { return config.ProviderOpenAI } @@ -75,7 +75,7 @@ func (p *OpenAI) RoutePrefix() string { return fmt.Sprintf("/%s/v1", p.Name()) } -func (p *OpenAI) BridgedRoutes() []string { +func (*OpenAI) BridgedRoutes() []string { return []string{ routeChatCompletions, routeResponses, @@ -86,7 +86,7 @@ func (p *OpenAI) BridgedRoutes() []string { // but must be passed through to the upstream. // The /v1/completions legacy API is deprecated and will not be passed through. // See https://platform.openai.com/docs/api-reference/completions. -func (p *OpenAI) PassthroughRoutes() []string { +func (*OpenAI) PassthroughRoutes() []string { return []string{ // See https://pkg.go.dev/net/http#hdr-Trailing_slash_redirection-ServeMux. // but without non trailing slash route requests to `/v1/conversations` are going to catch all @@ -98,7 +98,7 @@ func (p *OpenAI) PassthroughRoutes() []string { } } -func (p *OpenAI) CreateInterceptor(w http.ResponseWriter, r *http.Request, tracer trace.Tracer) (_ intercept.Interceptor, outErr error) { +func (p *OpenAI) CreateInterceptor(_ http.ResponseWriter, r *http.Request, tracer trace.Tracer) (_ intercept.Interceptor, outErr error) { id := uuid.New() _, span := tracer.Start(r.Context(), "Intercept.CreateInterceptor") @@ -141,7 +141,7 @@ func (p *OpenAI) CreateInterceptor(w http.ResponseWriter, r *http.Request, trace if err != nil { return nil, xerrors.Errorf("read body: %w", err) } - reqPayload, err := responses.NewResponsesRequestPayload(payload) + reqPayload, err := responses.NewRequestPayload(payload) if err != nil { return nil, xerrors.Errorf("unmarshal request body: %w", err) } @@ -153,7 +153,7 @@ func (p *OpenAI) CreateInterceptor(w http.ResponseWriter, r *http.Request, trace default: span.SetStatus(codes.Error, "unknown route: "+r.URL.Path) - return nil, UnknownRoute + return nil, ErrUnknownRoute } span.SetAttributes(interceptor.TraceAttributes(r)...) return interceptor, nil @@ -163,7 +163,7 @@ func (p *OpenAI) BaseURL() string { return p.cfg.BaseURL } -func (p *OpenAI) AuthHeader() string { +func (*OpenAI) AuthHeader() string { return "Authorization" } diff --git a/aibridge/provider/openai_test.go b/aibridge/provider/openai_test.go index 96695a2c2f..d739a2dc20 100644 --- a/aibridge/provider/openai_test.go +++ b/aibridge/provider/openai_test.go @@ -1,4 +1,4 @@ -package provider +package provider //nolint:testpackage // tests unexported internals import ( "bytes" diff --git a/aibridge/provider/provider.go b/aibridge/provider/provider.go index a05fe6e24d..cd09b6fc31 100644 --- a/aibridge/provider/provider.go +++ b/aibridge/provider/provider.go @@ -10,7 +10,7 @@ import ( "github.com/coder/coder/v2/aibridge/intercept" ) -var UnknownRoute = xerrors.New("unknown route") +var ErrUnknownRoute = xerrors.New("unknown route") // Provider defines routes (bridged and passed through) for given provider. // Bridged routes are processed by dedicated interceptors. diff --git a/aibridge/recorder/recorder.go b/aibridge/recorder/recorder.go index 3e79bdadb8..26a9f24b5d 100644 --- a/aibridge/recorder/recorder.go +++ b/aibridge/recorder/recorder.go @@ -14,19 +14,19 @@ import ( ) var ( - _ Recorder = &RecorderWrapper{} + _ Recorder = &WrappedRecorder{} _ Recorder = &AsyncRecorder{} ) -// RecorderWrapper is a convenience struct which implements RecorderClient and resolves a client before calling each method. +// WrappedRecorder is a convenience struct which implements RecorderClient and resolves a client before calling each method. // It also sets the start/creation time of each record. -type RecorderWrapper struct { +type WrappedRecorder struct { logger slog.Logger tracer trace.Tracer clientFn func() (Recorder, error) } -func (r *RecorderWrapper) RecordInterception(ctx context.Context, req *InterceptionRecord) (outErr error) { +func (r *WrappedRecorder) RecordInterception(ctx context.Context, req *InterceptionRecord) (outErr error) { ctx, span := r.tracer.Start(ctx, "Intercept.RecordInterception", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...)) defer tracing.EndSpanErr(span, &outErr) @@ -44,7 +44,7 @@ func (r *RecorderWrapper) RecordInterception(ctx context.Context, req *Intercept return err } -func (r *RecorderWrapper) RecordInterceptionEnded(ctx context.Context, req *InterceptionRecordEnded) (outErr error) { +func (r *WrappedRecorder) RecordInterceptionEnded(ctx context.Context, req *InterceptionRecordEnded) (outErr error) { ctx, span := r.tracer.Start(ctx, "Intercept.RecordInterceptionEnded", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...)) defer tracing.EndSpanErr(span, &outErr) @@ -62,7 +62,7 @@ func (r *RecorderWrapper) RecordInterceptionEnded(ctx context.Context, req *Inte return err } -func (r *RecorderWrapper) RecordPromptUsage(ctx context.Context, req *PromptUsageRecord) (outErr error) { +func (r *WrappedRecorder) RecordPromptUsage(ctx context.Context, req *PromptUsageRecord) (outErr error) { ctx, span := r.tracer.Start(ctx, "Intercept.RecordPromptUsage", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...)) defer tracing.EndSpanErr(span, &outErr) @@ -80,7 +80,7 @@ func (r *RecorderWrapper) RecordPromptUsage(ctx context.Context, req *PromptUsag return err } -func (r *RecorderWrapper) RecordTokenUsage(ctx context.Context, req *TokenUsageRecord) (outErr error) { +func (r *WrappedRecorder) RecordTokenUsage(ctx context.Context, req *TokenUsageRecord) (outErr error) { ctx, span := r.tracer.Start(ctx, "Intercept.RecordTokenUsage", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...)) defer tracing.EndSpanErr(span, &outErr) @@ -98,7 +98,7 @@ func (r *RecorderWrapper) RecordTokenUsage(ctx context.Context, req *TokenUsageR return err } -func (r *RecorderWrapper) RecordToolUsage(ctx context.Context, req *ToolUsageRecord) (outErr error) { +func (r *WrappedRecorder) RecordToolUsage(ctx context.Context, req *ToolUsageRecord) (outErr error) { ctx, span := r.tracer.Start(ctx, "Intercept.RecordToolUsage", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...)) defer tracing.EndSpanErr(span, &outErr) @@ -116,7 +116,7 @@ func (r *RecorderWrapper) RecordToolUsage(ctx context.Context, req *ToolUsageRec return err } -func (r *RecorderWrapper) RecordModelThought(ctx context.Context, req *ModelThoughtRecord) (outErr error) { +func (r *WrappedRecorder) RecordModelThought(ctx context.Context, req *ModelThoughtRecord) (outErr error) { ctx, span := r.tracer.Start(ctx, "Intercept.RecordModelThought", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...)) defer tracing.EndSpanErr(span, &outErr) @@ -134,8 +134,8 @@ func (r *RecorderWrapper) RecordModelThought(ctx context.Context, req *ModelThou return err } -func NewRecorder(logger slog.Logger, tracer trace.Tracer, clientFn func() (Recorder, error)) *RecorderWrapper { - return &RecorderWrapper{ +func NewWrappedRecorder(logger slog.Logger, tracer trace.Tracer, clientFn func() (Recorder, error)) *WrappedRecorder { + return &WrappedRecorder{ logger: logger, tracer: tracer, clientFn: clientFn, @@ -185,7 +185,7 @@ func (a *AsyncRecorder) WithClient(client string) { // RecordInterception must NOT be called asynchronously. // If an interception cannot be recorded, the whole request should fail. -func (a *AsyncRecorder) RecordInterception(ctx context.Context, req *InterceptionRecord) error { +func (*AsyncRecorder) RecordInterception(context.Context, *InterceptionRecord) error { panic("RecordInterception must not be called asynchronously") } diff --git a/aibridge/session.go b/aibridge/session.go index bc7b80bbf8..34c45d2158 100644 --- a/aibridge/session.go +++ b/aibridge/session.go @@ -14,10 +14,10 @@ import ( var claudeCodePattern = regexp.MustCompile(`_session_(.+)$`) // Legacy format: save compilation on each call. -// guessSessionID attempts to retrieve a session ID which may have been sent by +// GuessSessionID attempts to retrieve a session ID which may have been sent by // the client. We only attempt to retrieve sessions using methods recognized for // the given client. -func guessSessionID(client Client, r *http.Request) *string { +func GuessSessionID(client Client, r *http.Request) *string { switch client { case ClientClaudeCode: // Prefer the dedicated header (added in Claude Code v2.1.86+). diff --git a/aibridge/session_test.go b/aibridge/session_test.go index a73c195b49..7592dc5c54 100644 --- a/aibridge/session_test.go +++ b/aibridge/session_test.go @@ -1,4 +1,4 @@ -package aibridge +package aibridge_test import ( "io" @@ -8,6 +8,7 @@ import ( "github.com/stretchr/testify/require" + "github.com/coder/coder/v2/aibridge" "github.com/coder/coder/v2/aibridge/utils" ) @@ -16,7 +17,7 @@ func TestGuessSessionID(t *testing.T) { cases := []struct { name string - client Client + client aibridge.Client body string headers map[string]string sessionID *string @@ -24,177 +25,177 @@ func TestGuessSessionID(t *testing.T) { // Claude Code. { name: "claude_code_header_takes_precedence", - client: ClientClaudeCode, + client: aibridge.ClientClaudeCode, headers: map[string]string{"X-Claude-Code-Session-Id": "header-session-id"}, body: `{"metadata":{"user_id":"user_abc123_account_456_session_body-session-id"}}`, sessionID: utils.PtrTo("header-session-id"), }, { name: "claude_code_header_only", - client: ClientClaudeCode, + client: aibridge.ClientClaudeCode, headers: map[string]string{"X-Claude-Code-Session-Id": "aabb-ccdd"}, body: `{"model":"claude-3"}`, sessionID: utils.PtrTo("aabb-ccdd"), }, { name: "claude_code_empty_header_falls_back_to_body", - client: ClientClaudeCode, + client: aibridge.ClientClaudeCode, headers: map[string]string{"X-Claude-Code-Session-Id": ""}, body: `{"metadata":{"user_id":"user_abc123_account_456_session_f47ac10b-58cc-4372-a567-0e02b2c3d479"}}`, sessionID: utils.PtrTo("f47ac10b-58cc-4372-a567-0e02b2c3d479"), }, { name: "claude_code_whitespace_header_falls_back_to_body", - client: ClientClaudeCode, + client: aibridge.ClientClaudeCode, headers: map[string]string{"X-Claude-Code-Session-Id": " "}, body: `{"metadata":{"user_id":"user_abc123_account_456_session_f47ac10b-58cc-4372-a567-0e02b2c3d479"}}`, sessionID: utils.PtrTo("f47ac10b-58cc-4372-a567-0e02b2c3d479"), }, { name: "claude_code_with_valid_session", - client: ClientClaudeCode, + client: aibridge.ClientClaudeCode, body: `{"metadata":{"user_id":"user_abc123_account_456_session_f47ac10b-58cc-4372-a567-0e02b2c3d479"}}`, sessionID: utils.PtrTo("f47ac10b-58cc-4372-a567-0e02b2c3d479"), }, { name: "claude_code_with_valid_session_new_format", - client: ClientClaudeCode, + client: aibridge.ClientClaudeCode, body: `{"metadata":{"user_id":"{\"device_id\":\"45aa15c8c244ea2582f8144dde91a50ec3815851f6f648abef4ee15b173cc927\",\"account_uuid\":\"\",\"session_id\":\"54c1eb09-bc4c-4d2f-98eb-6d2ab2d5e2fe\"}"}}`, sessionID: utils.PtrTo("54c1eb09-bc4c-4d2f-98eb-6d2ab2d5e2fe"), }, { name: "claude_code_new_format_empty_session_id", - client: ClientClaudeCode, + client: aibridge.ClientClaudeCode, body: `{"metadata":{"user_id":"{\"device_id\":\"abc\",\"account_uuid\":\"\",\"session_id\":\"\"}"}}`, }, { name: "claude_code_new_format_no_session_id_field", - client: ClientClaudeCode, + client: aibridge.ClientClaudeCode, body: `{"metadata":{"user_id":"{\"device_id\":\"abc\",\"account_uuid\":\"\"}"}}`, }, { name: "claude_code_missing_metadata", - client: ClientClaudeCode, + client: aibridge.ClientClaudeCode, body: `{"model":"claude-3"}`, }, { name: "claude_code_missing_user_id", - client: ClientClaudeCode, + client: aibridge.ClientClaudeCode, body: `{"metadata":{}}`, }, { name: "claude_code_user_id_without_session", - client: ClientClaudeCode, + client: aibridge.ClientClaudeCode, body: `{"metadata":{"user_id":"user_abc123_account_456"}}`, }, { name: "claude_code_empty_body", - client: ClientClaudeCode, + client: aibridge.ClientClaudeCode, body: ``, }, { name: "claude_code_invalid_json", - client: ClientClaudeCode, + client: aibridge.ClientClaudeCode, body: `not json at all`, }, // Codex. { name: "codex_with_session_header", - client: ClientCodex, + client: aibridge.ClientCodex, headers: map[string]string{"session_id": "codex-session-123"}, sessionID: utils.PtrTo("codex-session-123"), }, { name: "codex_with_whitespace_in_header", - client: ClientCodex, + client: aibridge.ClientCodex, headers: map[string]string{"session_id": " codex-session-123 "}, sessionID: utils.PtrTo("codex-session-123"), }, { name: "codex_without_session_header", - client: ClientCodex, + client: aibridge.ClientCodex, }, // Other clients shouldn't use others' logic. { name: "unknown_client_returns_empty", - client: ClientUnknown, + client: aibridge.ClientUnknown, body: `{"metadata":{"user_id":"user_abc_account_456_session_some-id"}}`, }, { name: "zed_returns_empty", - client: ClientZed, + client: aibridge.ClientZed, headers: map[string]string{"session_id": "zed-session"}, body: `{"metadata":{"user_id":"user_abc_account_456_session_some-id"}}`, }, // Mux. { name: "mux_with_workspace_header", - client: ClientMux, + client: aibridge.ClientMux, headers: map[string]string{"X-Mux-Workspace-Id": "ws-abc-123"}, sessionID: utils.PtrTo("ws-abc-123"), }, { name: "mux_without_workspace_header", - client: ClientMux, + client: aibridge.ClientMux, }, // Copilot VS Code. { name: "copilot_vsc_with_interaction_id", - client: ClientCopilotVSC, + client: aibridge.ClientCopilotVSC, headers: map[string]string{"x-interaction-id": "interaction-xyz"}, sessionID: utils.PtrTo("interaction-xyz"), }, { name: "copilot_vsc_without_interaction_id", - client: ClientCopilotVSC, + client: aibridge.ClientCopilotVSC, }, // Copilot CLI. { name: "copilot_cli_with_session_header", - client: ClientCopilotCLI, + client: aibridge.ClientCopilotCLI, headers: map[string]string{"X-Client-Session-Id": "cli-sess-456"}, sessionID: utils.PtrTo("cli-sess-456"), }, { name: "copilot_cli_without_session_header", - client: ClientCopilotCLI, + client: aibridge.ClientCopilotCLI, }, // Kilo. { name: "kilo_with_task_id", - client: ClientKilo, + client: aibridge.ClientKilo, headers: map[string]string{"X-KILOCODE-TASKID": "task-789"}, sessionID: utils.PtrTo("task-789"), }, { name: "kilo_without_task_id", - client: ClientKilo, + client: aibridge.ClientKilo, }, // Coder Agents. { name: "coder_agents_with_chat_id", - client: ClientCoderAgents, + client: aibridge.ClientCoderAgents, headers: map[string]string{"X-Coder-Chat-Id": "a1b2c3d4-e5f6-7890-abcd-ef1234567890"}, sessionID: utils.PtrTo("a1b2c3d4-e5f6-7890-abcd-ef1234567890"), }, { name: "coder_agents_without_chat_id", - client: ClientCoderAgents, + client: aibridge.ClientCoderAgents, }, // Roo. { name: "roo_returns_empty", - client: ClientRoo, + client: aibridge.ClientRoo, }, // Cursor. { name: "cursor_returns_empty", - client: ClientCursor, + client: aibridge.ClientCursor, }, // Other cases. { name: "empty session ID value", - client: ClientKilo, + client: aibridge.ClientKilo, headers: map[string]string{"X-KILOCODE-TASKID": " "}, sessionID: nil, }, @@ -205,14 +206,14 @@ func TestGuessSessionID(t *testing.T) { t.Parallel() body := tc.body - req, err := http.NewRequest(http.MethodPost, "http://localhost", strings.NewReader(body)) + req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, "http://localhost", strings.NewReader(body)) require.NoError(t, err) for key, value := range tc.headers { req.Header.Set(key, value) } - got := guessSessionID(tc.client, req) + got := aibridge.GuessSessionID(tc.client, req) require.Equal(t, tc.sessionID, got) // Verify the body was restored and can be read again. @@ -226,16 +227,16 @@ func TestGuessSessionID(t *testing.T) { func TestUnreadableBody(t *testing.T) { t.Parallel() - req, err := http.NewRequest(http.MethodPost, "http://localhost", &errReader{}) + req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, "http://localhost", &errReader{}) require.NoError(t, err) - got := guessSessionID(ClientClaudeCode, req) + got := aibridge.GuessSessionID(aibridge.ClientClaudeCode, req) require.Nil(t, got) } // errReader is an io.Reader that always returns an error. type errReader struct{} -func (e *errReader) Read([]byte) (int, error) { +func (*errReader) Read([]byte) (int, error) { return 0, io.ErrUnexpectedEOF } diff --git a/aibridge/utils/concurrent_group_test.go b/aibridge/utils/concurrent_group_test.go index f6a67f7550..22b0cb93d7 100644 --- a/aibridge/utils/concurrent_group_test.go +++ b/aibridge/utils/concurrent_group_test.go @@ -18,11 +18,15 @@ func TestConcurrentGroup(t *testing.T) { t.Parallel() t.Run("no goroutines", func(t *testing.T) { + t.Parallel() + cg := utils.NewConcurrentGroup() require.NoError(t, cg.Wait()) }) t.Run("multiple goroutines, all ok", func(t *testing.T) { + t.Parallel() + cg := utils.NewConcurrentGroup() cg.Go(func() error { return nil @@ -34,6 +38,8 @@ func TestConcurrentGroup(t *testing.T) { }) t.Run("multiple goroutines, one err", func(t *testing.T) { + t.Parallel() + cg := utils.NewConcurrentGroup() oops := xerrors.New("oops") cg.Go(func() error { @@ -46,6 +52,8 @@ func TestConcurrentGroup(t *testing.T) { }) t.Run("multiple goroutines, multiple errs", func(t *testing.T) { + t.Parallel() + cg := utils.NewConcurrentGroup() oops := xerrors.New("oops") eek := xerrors.New("eek")