update aibridge fies with lint changes

This commit is contained in:
Paweł Banaszewski
2026-04-10 22:11:26 +00:00
parent 8b69f05cc0
commit e3ca654ae4
70 changed files with 1025 additions and 879 deletions
+1 -1
View File
@@ -62,5 +62,5 @@ func NewMetrics(reg prometheus.Registerer) *metrics.Metrics {
}
func NewRecorder(logger slog.Logger, tracer trace.Tracer, clientFn func() (Recorder, error)) Recorder {
return recorder.NewRecorder(logger, tracer, clientFn)
return recorder.NewWrappedRecorder(logger, tracer, clientFn)
}
+3 -3
View File
@@ -180,8 +180,8 @@ func newInterceptionProcessor(p provider.Provider, cbs *circuitbreaker.ProviderC
// We execute this before CreateInterceptor since the interceptors
// read the request body and don't reset them.
client := guessClient(r)
sessionID := guessSessionID(client, r)
client := GuessClient(r)
sessionID := GuessSessionID(client, r)
interceptor, err := p.CreateInterceptor(w, r.WithContext(ctx), tracer)
if err != nil {
@@ -276,7 +276,7 @@ func newInterceptionProcessor(p provider.Provider, cbs *circuitbreaker.ProviderC
log.Debug(ctx, "interception ended")
}
asyncRecorder.RecordInterceptionEnded(ctx, &recorder.InterceptionRecordEnded{ID: interceptor.ID().String()})
_ = asyncRecorder.RecordInterceptionEnded(ctx, &recorder.InterceptionRecordEnded{ID: interceptor.ID().String()})
// Ensure all recording have completed before completing request.
asyncRecorder.Wait()
+38 -57
View File
@@ -1,4 +1,4 @@
package aibridge
package aibridge_test
import (
"net/http"
@@ -7,16 +7,22 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/aibridge"
"github.com/coder/coder/v2/aibridge/config"
"github.com/coder/coder/v2/aibridge/internal/testutil"
"github.com/coder/coder/v2/aibridge/provider"
)
func TestValidateProvider_Names(t *testing.T) {
var bridgeTestTracer = otel.Tracer("bridge_test")
func TestValidateProviders(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
tests := []struct {
name string
providers []provider.Provider
@@ -25,94 +31,69 @@ func TestValidateProvider_Names(t *testing.T) {
{
name: "all_supported_providers",
providers: []provider.Provider{
NewOpenAIProvider(config.OpenAI{Name: "openai", BaseURL: "https://api.openai.com/v1/"}),
NewAnthropicProvider(config.Anthropic{Name: "anthropic", BaseURL: "https://api.anthropic.com/"}, nil),
NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.individual.githubcopilot.com"}),
NewCopilotProvider(config.Copilot{Name: "copilot-business", BaseURL: "https://api.business.githubcopilot.com"}),
NewCopilotProvider(config.Copilot{Name: "copilot-enterprise", BaseURL: "https://api.enterprise.githubcopilot.com"}),
aibridge.NewOpenAIProvider(config.OpenAI{Name: "openai", BaseURL: "https://api.openai.com/v1/"}),
aibridge.NewAnthropicProvider(config.Anthropic{Name: "anthropic", BaseURL: "https://api.anthropic.com/"}, nil),
aibridge.NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.individual.githubcopilot.com"}),
aibridge.NewCopilotProvider(config.Copilot{Name: "copilot-business", BaseURL: "https://api.business.githubcopilot.com"}),
aibridge.NewCopilotProvider(config.Copilot{Name: "copilot-enterprise", BaseURL: "https://api.enterprise.githubcopilot.com"}),
},
},
{
name: "default_names_and_base_urls",
providers: []provider.Provider{
NewOpenAIProvider(config.OpenAI{}),
NewAnthropicProvider(config.Anthropic{}, nil),
NewCopilotProvider(config.Copilot{}),
aibridge.NewOpenAIProvider(config.OpenAI{}),
aibridge.NewAnthropicProvider(config.Anthropic{}, nil),
aibridge.NewCopilotProvider(config.Copilot{}),
},
},
{
name: "multiple_copilot_instances",
providers: []provider.Provider{
NewCopilotProvider(config.Copilot{}),
NewCopilotProvider(config.Copilot{Name: "copilot-business", BaseURL: "https://api.business.githubcopilot.com"}),
NewCopilotProvider(config.Copilot{Name: "copilot-enterprise", BaseURL: "https://api.enterprise.githubcopilot.com"}),
aibridge.NewCopilotProvider(config.Copilot{}),
aibridge.NewCopilotProvider(config.Copilot{Name: "copilot-business", BaseURL: "https://api.business.githubcopilot.com"}),
aibridge.NewCopilotProvider(config.Copilot{Name: "copilot-enterprise", BaseURL: "https://api.enterprise.githubcopilot.com"}),
},
},
{
name: "name_with_slashes",
providers: []provider.Provider{
NewCopilotProvider(config.Copilot{Name: "copilot/business", BaseURL: "https://api.business.githubcopilot.com"}),
aibridge.NewCopilotProvider(config.Copilot{Name: "copilot/business", BaseURL: "https://api.business.githubcopilot.com"}),
},
expectErr: "invalid provider name",
},
{
name: "name_with_spaces",
providers: []provider.Provider{
NewCopilotProvider(config.Copilot{Name: "copilot business", BaseURL: "https://api.business.githubcopilot.com"}),
aibridge.NewCopilotProvider(config.Copilot{Name: "copilot business", BaseURL: "https://api.business.githubcopilot.com"}),
},
expectErr: "invalid provider name",
},
{
name: "name_with_uppercase",
providers: []provider.Provider{
NewCopilotProvider(config.Copilot{Name: "Copilot", BaseURL: "https://api.business.githubcopilot.com"}),
aibridge.NewCopilotProvider(config.Copilot{Name: "Copilot", BaseURL: "https://api.business.githubcopilot.com"}),
},
expectErr: "invalid provider name",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
err := validateProviders(tc.providers)
if tc.expectErr != "" {
require.Error(t, err)
assert.Contains(t, err.Error(), tc.expectErr)
} else {
require.NoError(t, err)
}
})
}
}
func TestValidateProvider_DuplicateNames(t *testing.T) {
t.Parallel()
tests := []struct {
name string
providers []provider.Provider
expectErr string
}{
{
name: "unique_names",
providers: []provider.Provider{
NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.individual.githubcopilot.com"}),
NewCopilotProvider(config.Copilot{Name: "copilot-business", BaseURL: "https://api.business.githubcopilot.com"}),
aibridge.NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.individual.githubcopilot.com"}),
aibridge.NewCopilotProvider(config.Copilot{Name: "copilot-business", BaseURL: "https://api.business.githubcopilot.com"}),
},
},
{
name: "duplicate_base_url_different_names",
providers: []provider.Provider{
NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.individual.githubcopilot.com"}),
NewCopilotProvider(config.Copilot{Name: "copilot-business", BaseURL: "https://api.individual.githubcopilot.com"}),
aibridge.NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.individual.githubcopilot.com"}),
aibridge.NewCopilotProvider(config.Copilot{Name: "copilot-business", BaseURL: "https://api.individual.githubcopilot.com"}),
},
},
{
name: "duplicate_name",
providers: []provider.Provider{
NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.individual.githubcopilot.com"}),
NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.business.githubcopilot.com"}),
aibridge.NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.individual.githubcopilot.com"}),
aibridge.NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.business.githubcopilot.com"}),
},
expectErr: "duplicate provider name",
},
@@ -122,7 +103,7 @@ func TestValidateProvider_DuplicateNames(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
err := validateProviders(tc.providers)
_, err := aibridge.NewRequestBridge(t.Context(), tc.providers, nil, nil, logger, nil, bridgeTestTracer)
if tc.expectErr != "" {
require.Error(t, err)
assert.Contains(t, err.Error(), tc.expectErr)
@@ -148,7 +129,7 @@ func TestPassthroughRoutesForProviders(t *testing.T) {
name: "openAI_no_base_path",
requestPath: "/openai/v1/conversations",
provider: func(baseURL string) provider.Provider {
return NewOpenAIProvider(config.OpenAI{BaseURL: baseURL})
return aibridge.NewOpenAIProvider(config.OpenAI{BaseURL: baseURL})
},
expectPath: "/conversations",
},
@@ -157,7 +138,7 @@ func TestPassthroughRoutesForProviders(t *testing.T) {
baseURLPath: "/v1",
requestPath: "/openai/v1/conversations",
provider: func(baseURL string) provider.Provider {
return NewOpenAIProvider(config.OpenAI{BaseURL: baseURL})
return aibridge.NewOpenAIProvider(config.OpenAI{BaseURL: baseURL})
},
expectPath: "/v1/conversations",
},
@@ -165,7 +146,7 @@ func TestPassthroughRoutesForProviders(t *testing.T) {
name: "anthropic_no_base_path",
requestPath: "/anthropic/v1/models",
provider: func(baseURL string) provider.Provider {
return NewAnthropicProvider(config.Anthropic{BaseURL: baseURL}, nil)
return aibridge.NewAnthropicProvider(config.Anthropic{BaseURL: baseURL}, nil)
},
expectPath: "/v1/models",
},
@@ -174,7 +155,7 @@ func TestPassthroughRoutesForProviders(t *testing.T) {
baseURLPath: "/v1",
requestPath: "/anthropic/v1/models",
provider: func(baseURL string) provider.Provider {
return NewAnthropicProvider(config.Anthropic{BaseURL: baseURL}, nil)
return aibridge.NewAnthropicProvider(config.Anthropic{BaseURL: baseURL}, nil)
},
expectPath: "/v1/v1/models",
},
@@ -182,7 +163,7 @@ func TestPassthroughRoutesForProviders(t *testing.T) {
name: "copilot_no_base_path",
requestPath: "/copilot/models",
provider: func(baseURL string) provider.Provider {
return NewCopilotProvider(config.Copilot{BaseURL: baseURL})
return aibridge.NewCopilotProvider(config.Copilot{BaseURL: baseURL})
},
expectPath: "/models",
},
@@ -191,7 +172,7 @@ func TestPassthroughRoutesForProviders(t *testing.T) {
baseURLPath: "/v1",
requestPath: "/copilot/models",
provider: func(baseURL string) provider.Provider {
return NewCopilotProvider(config.Copilot{BaseURL: baseURL})
return aibridge.NewCopilotProvider(config.Copilot{BaseURL: baseURL})
},
expectPath: "/v1/models",
},
@@ -210,14 +191,14 @@ func TestPassthroughRoutesForProviders(t *testing.T) {
}))
t.Cleanup(upstream.Close)
recorder := testutil.MockRecorder{}
rec := testutil.MockRecorder{}
prov := tc.provider(upstream.URL + tc.baseURLPath)
bridge, err := NewRequestBridge(t.Context(), []provider.Provider{prov}, &recorder, nil, logger, nil, testTracer)
bridge, err := aibridge.NewRequestBridge(t.Context(), []provider.Provider{prov}, &rec, nil, logger, nil, bridgeTestTracer)
require.NoError(t, err)
req := httptest.NewRequest("", tc.requestPath, nil)
resp := httptest.NewRecorder()
bridge.mux.ServeHTTP(resp, req)
bridge.ServeHTTP(resp, req)
assert.Equal(t, http.StatusOK, resp.Code)
assert.Contains(t, resp.Body.String(), upstreamRespBody)
+18 -7
View File
@@ -1,8 +1,10 @@
package circuitbreaker
import (
"bufio"
"errors"
"fmt"
"net"
"net/http"
"sync"
"time"
@@ -65,8 +67,8 @@ func (p *ProviderCircuitBreakers) isFailure(statusCode int) bool {
return DefaultIsFailure(statusCode)
}
// openErrorResponse returns the error response body when the circuit is open.
func (p *ProviderCircuitBreakers) openErrorResponse() []byte {
// openErrBody returns the error response body when the circuit is open.
func (p *ProviderCircuitBreakers) openErrBody() []byte {
if p.config.OpenErrorResponse != nil {
return p.config.OpenErrorResponse()
}
@@ -77,7 +79,7 @@ func (p *ProviderCircuitBreakers) openErrorResponse() []byte {
func (p *ProviderCircuitBreakers) Get(endpoint, model string) *gobreaker.CircuitBreaker[struct{}] {
key := endpoint + ":" + model
if v, ok := p.breakers.Load(key); ok {
return v.(*gobreaker.CircuitBreaker[struct{}])
return v.(*gobreaker.CircuitBreaker[struct{}]) //nolint:forcetypeassert // sync.Map always stores this type
}
settings := gobreaker.Settings{
@@ -97,11 +99,12 @@ func (p *ProviderCircuitBreakers) Get(endpoint, model string) *gobreaker.Circuit
cb := gobreaker.NewCircuitBreaker[struct{}](settings)
actual, _ := p.breakers.LoadOrStore(key, cb)
return actual.(*gobreaker.CircuitBreaker[struct{}])
return actual.(*gobreaker.CircuitBreaker[struct{}]) //nolint:forcetypeassert // sync.Map always stores this type
}
// statusCapturingWriter wraps http.ResponseWriter to capture the status code.
// It also implements http.Flusher to support streaming responses.
// It implements http.Flusher to support streaming and http.Hijacker to
// satisfy the FullResponseWriter lint rule.
type statusCapturingWriter struct {
http.ResponseWriter
statusCode int
@@ -130,6 +133,14 @@ func (w *statusCapturingWriter) Flush() {
}
}
func (w *statusCapturingWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
h, ok := w.ResponseWriter.(http.Hijacker)
if !ok {
return nil, nil, xerrors.New("upstream ResponseWriter does not support hijacking")
}
return h.Hijack()
}
// Unwrap returns the underlying ResponseWriter for interface checks.
func (w *statusCapturingWriter) Unwrap() http.ResponseWriter {
return w.ResponseWriter
@@ -167,7 +178,7 @@ func (p *ProviderCircuitBreakers) Execute(endpoint, model string, w http.Respons
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Retry-After", fmt.Sprintf("%d", int64(p.config.Timeout.Seconds())))
w.WriteHeader(http.StatusServiceUnavailable)
_, _ = w.Write(p.openErrorResponse())
_, _ = w.Write(p.openErrBody())
return ErrCircuitOpen
}
@@ -187,7 +198,7 @@ func (p *ProviderCircuitBreakers) Provider() string {
// OpenErrorResponse returns the error response body when the circuit is open.
// This is exposed for handlers to use when responding to rejected requests.
func (p *ProviderCircuitBreakers) OpenErrorResponse() []byte {
return p.openErrorResponse()
return p.openErrBody()
}
// StateToGaugeValue converts gobreaker.State to a gauge value.
+15 -13
View File
@@ -1,4 +1,4 @@
package circuitbreaker
package circuitbreaker_test
import (
"errors"
@@ -11,6 +11,7 @@ import (
"github.com/sony/gobreaker/v2"
"github.com/stretchr/testify/assert"
"github.com/coder/coder/v2/aibridge/circuitbreaker"
"github.com/coder/coder/v2/aibridge/config"
)
@@ -20,7 +21,7 @@ func TestExecute_PerModelIsolation(t *testing.T) {
sonnetCalls := atomic.Int32{}
haikuCalls := atomic.Int32{}
cbs := NewProviderCircuitBreakers("test", &config.CircuitBreaker{
cbs := circuitbreaker.NewProviderCircuitBreakers("test", &config.CircuitBreaker{
FailureThreshold: 1,
Interval: time.Minute,
Timeout: time.Minute,
@@ -48,7 +49,7 @@ func TestExecute_PerModelIsolation(t *testing.T) {
rw.WriteHeader(http.StatusOK)
return nil
})
assert.True(t, errors.Is(err, ErrCircuitOpen))
assert.True(t, errors.Is(err, circuitbreaker.ErrCircuitOpen))
assert.Equal(t, int32(1), sonnetCalls.Load()) // No new call
assert.Equal(t, http.StatusServiceUnavailable, w.Code)
@@ -69,7 +70,7 @@ func TestExecute_PerEndpointIsolation(t *testing.T) {
messagesCalls := atomic.Int32{}
completionsCalls := atomic.Int32{}
cbs := NewProviderCircuitBreakers("test", &config.CircuitBreaker{
cbs := circuitbreaker.NewProviderCircuitBreakers("test", &config.CircuitBreaker{
FailureThreshold: 1,
Interval: time.Minute,
Timeout: time.Minute,
@@ -95,7 +96,7 @@ func TestExecute_PerEndpointIsolation(t *testing.T) {
rw.WriteHeader(http.StatusOK)
return nil
})
assert.True(t, errors.Is(err, ErrCircuitOpen))
assert.True(t, errors.Is(err, circuitbreaker.ErrCircuitOpen))
assert.Equal(t, int32(1), messagesCalls.Load()) // No new call
assert.Equal(t, http.StatusServiceUnavailable, w.Code)
@@ -116,7 +117,7 @@ func TestExecute_CustomIsFailure(t *testing.T) {
var calls atomic.Int32
// Custom IsFailure that treats 502 as failure
cbs := NewProviderCircuitBreakers("test", &config.CircuitBreaker{
cbs := circuitbreaker.NewProviderCircuitBreakers("test", &config.CircuitBreaker{
FailureThreshold: 1,
Interval: time.Minute,
Timeout: time.Minute,
@@ -143,7 +144,7 @@ func TestExecute_CustomIsFailure(t *testing.T) {
rw.WriteHeader(http.StatusOK)
return nil
})
assert.True(t, errors.Is(err, ErrCircuitOpen))
assert.True(t, errors.Is(err, circuitbreaker.ErrCircuitOpen))
assert.Equal(t, int32(1), calls.Load()) // No new call
assert.Equal(t, http.StatusServiceUnavailable, w.Code)
}
@@ -158,7 +159,7 @@ func TestExecute_OnStateChange(t *testing.T) {
to gobreaker.State
}
cbs := NewProviderCircuitBreakers("test", &config.CircuitBreaker{
cbs := circuitbreaker.NewProviderCircuitBreakers("test", &config.CircuitBreaker{
FailureThreshold: 1,
Interval: time.Minute,
Timeout: time.Minute,
@@ -177,10 +178,11 @@ func TestExecute_OnStateChange(t *testing.T) {
// Trip circuit
w := httptest.NewRecorder()
cbs.Execute(endpoint, model, w, func(rw http.ResponseWriter) error {
err := cbs.Execute(endpoint, model, w, func(rw http.ResponseWriter) error {
rw.WriteHeader(http.StatusTooManyRequests)
return nil
})
assert.NoError(t, err)
// Verify state change callback was called with correct parameters
assert.Len(t, stateChanges, 1)
@@ -208,14 +210,14 @@ func TestDefaultIsFailure(t *testing.T) {
}
for _, tt := range tests {
assert.Equal(t, tt.isFailure, DefaultIsFailure(tt.statusCode), "status code %d", tt.statusCode)
assert.Equal(t, tt.isFailure, circuitbreaker.DefaultIsFailure(tt.statusCode), "status code %d", tt.statusCode)
}
}
func TestStateToGaugeValue(t *testing.T) {
t.Parallel()
assert.Equal(t, float64(0), StateToGaugeValue(gobreaker.StateClosed))
assert.Equal(t, float64(0.5), StateToGaugeValue(gobreaker.StateHalfOpen))
assert.Equal(t, float64(1), StateToGaugeValue(gobreaker.StateOpen))
assert.Equal(t, float64(0), circuitbreaker.StateToGaugeValue(gobreaker.StateClosed))
assert.Equal(t, float64(0.5), circuitbreaker.StateToGaugeValue(gobreaker.StateHalfOpen))
assert.Equal(t, float64(1), circuitbreaker.StateToGaugeValue(gobreaker.StateOpen))
}
+2 -2
View File
@@ -24,10 +24,10 @@ const (
ClientUnknown Client = "Unknown"
)
// guessClient attempts to guess the client application from the request headers.
// GuessClient attempts to guess the client application from the request headers.
// Not all clients set proper user agent headers, so this is a best-effort approach.
// Based on https://github.com/coder/aibridge/issues/20#issuecomment-3769444101.
func guessClient(r *http.Request) Client {
func GuessClient(r *http.Request) Client {
userAgent := strings.ToLower(r.UserAgent())
originator := r.Header.Get("originator")
+23 -21
View File
@@ -1,10 +1,12 @@
package aibridge
package aibridge_test
import (
"net/http"
"testing"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/aibridge"
)
func TestGuessClient(t *testing.T) {
@@ -14,93 +16,93 @@ func TestGuessClient(t *testing.T) {
name string
userAgent string
headers map[string]string
wantClient Client
wantClient aibridge.Client
}{
{
name: "mux",
userAgent: "mux/0.19.0-next.2.gcceff159 ai-sdk/openai/3.0.36 ai-sdk/provider-utils/4.0.15 runtime/node.js/22",
wantClient: ClientMux,
wantClient: aibridge.ClientMux,
},
{
name: "claude_code",
userAgent: "claude-cli/2.0.67 (external, cli)",
wantClient: ClientClaudeCode,
wantClient: aibridge.ClientClaudeCode,
},
{
name: "codex_cli",
userAgent: "codex_cli_rs/0.87.0 (Mac OS 26.2.0; arm64) ghostty/1.3.0-main_250877ef",
wantClient: ClientCodex,
wantClient: aibridge.ClientCodex,
},
{
name: "zed",
userAgent: "Zed/0.219.4+stable.119.abc123 (macos; aarch64)",
wantClient: ClientZed,
wantClient: aibridge.ClientZed,
},
{
name: "github_copilot_vsc",
userAgent: "GitHubCopilotChat/0.37.2026011603",
wantClient: ClientCopilotVSC,
wantClient: aibridge.ClientCopilotVSC,
},
{
name: "github_copilot_cli",
userAgent: "copilot/0.0.403 (client/cli linux v24.11.1)",
wantClient: ClientCopilotCLI,
wantClient: aibridge.ClientCopilotCLI,
},
{
name: "kilo_code_user_agent",
userAgent: "kilo-code/5.1.0 (darwin 25.2.0; arm64) node/22.21.1",
wantClient: ClientKilo,
wantClient: aibridge.ClientKilo,
},
{
name: "kilo_code_originator",
headers: map[string]string{"Originator": "kilo-code"},
wantClient: ClientKilo,
wantClient: aibridge.ClientKilo,
},
{
name: "roo_code_user_agent",
userAgent: "roo-code/3.45.0 (darwin 25.2.0; arm64) node/22.21.1",
wantClient: ClientRoo,
wantClient: aibridge.ClientRoo,
},
{
name: "roo_code_originator",
headers: map[string]string{"Originator": "roo-code"},
wantClient: ClientRoo,
wantClient: aibridge.ClientRoo,
},
{
name: "coder_agents",
userAgent: "coder-agents/v2.24.0 (linux/amd64)",
wantClient: ClientCoderAgents,
wantClient: aibridge.ClientCoderAgents,
},
{
name: "coder_agents_dev",
userAgent: "coder-agents/v0.0.0-devel (darwin/arm64)",
wantClient: ClientCoderAgents,
wantClient: aibridge.ClientCoderAgents,
},
{
name: "charm_crush",
userAgent: "Charm Crush/0.1.11",
wantClient: ClientCrush,
wantClient: aibridge.ClientCrush,
},
{
name: "cursor_x_cursor_client_version",
userAgent: "connect-es/1.6.1",
headers: map[string]string{"X-Cursor-client-version": "0.50.0"},
wantClient: ClientCursor,
wantClient: aibridge.ClientCursor,
},
{
name: "cursor_x_cursor_some_other_header",
headers: map[string]string{"x-cursor-client-version": "abc123"},
wantClient: ClientCursor,
wantClient: aibridge.ClientCursor,
},
{
name: "unknown_client",
userAgent: "ccclaude-cli/calude-with-wrong-prefix",
wantClient: ClientUnknown,
wantClient: aibridge.ClientUnknown,
},
{
name: "empty_user_agent",
userAgent: "",
wantClient: ClientUnknown,
wantClient: aibridge.ClientUnknown,
},
}
@@ -108,7 +110,7 @@ func TestGuessClient(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
req, err := http.NewRequest(http.MethodGet, "", nil)
req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, "", nil)
require.NoError(t, err)
req.Header.Set("User-Agent", tt.userAgent)
@@ -116,7 +118,7 @@ func TestGuessClient(t *testing.T) {
req.Header.Set(key, value)
}
got := guessClient(req)
got := aibridge.GuessClient(req)
require.Equal(t, tt.wantClient, got)
})
}
+10 -9
View File
@@ -1,4 +1,4 @@
package context
package context_test
import (
"context"
@@ -7,6 +7,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
aibcontext "github.com/coder/coder/v2/aibridge/context"
"github.com/coder/coder/v2/aibridge/recorder"
)
@@ -17,10 +18,10 @@ func TestAsActor(t *testing.T) {
metadata := recorder.Metadata{"key": "value"}
// When: storing an actor in the context
ctx := AsActor(context.Background(), "actor-123", metadata)
ctx := aibcontext.AsActor(context.Background(), "actor-123", metadata)
// Then: the actor should be retrievable with correct ID and metadata
actor := ActorFromContext(ctx)
actor := aibcontext.ActorFromContext(ctx)
require.NotNil(t, actor)
assert.Equal(t, "actor-123", actor.ID)
assert.Equal(t, "value", actor.Metadata["key"])
@@ -33,10 +34,10 @@ func TestActorFromContext(t *testing.T) {
t.Parallel()
// Given: a context with an actor
ctx := AsActor(context.Background(), "test-id", recorder.Metadata{})
ctx := aibcontext.AsActor(context.Background(), "test-id", recorder.Metadata{})
// When: extracting the actor from context
actor := ActorFromContext(ctx)
actor := aibcontext.ActorFromContext(ctx)
// Then: the actor should be returned with correct ID
require.NotNil(t, actor)
@@ -50,7 +51,7 @@ func TestActorFromContext(t *testing.T) {
ctx := context.Background()
// When: extracting the actor from context
actor := ActorFromContext(ctx)
actor := aibcontext.ActorFromContext(ctx)
// Then: nil should be returned
assert.Nil(t, actor)
@@ -64,10 +65,10 @@ func TestActorIDFromContext(t *testing.T) {
t.Parallel()
// Given: a context with an actor
ctx := AsActor(context.Background(), "test-actor-id", recorder.Metadata{})
ctx := aibcontext.AsActor(context.Background(), "test-actor-id", recorder.Metadata{})
// When: extracting the actor ID from context
got := ActorIDFromContext(ctx)
got := aibcontext.ActorIDFromContext(ctx)
// Then: the actor ID should be returned
assert.Equal(t, "test-actor-id", got)
@@ -80,7 +81,7 @@ func TestActorIDFromContext(t *testing.T) {
ctx := context.Background()
// When: extracting the actor ID from context
got := ActorIDFromContext(ctx)
got := aibcontext.ActorIDFromContext(ctx)
// Then: an empty string should be returned
assert.Empty(t, got)
-25
View File
@@ -1,25 +0,0 @@
These fixtures were created by adding logging middleware to API calls to view the raw requests/responses.
```go
...
opts = append(opts, option.WithMiddleware(LoggingMiddleware))
...
func LoggingMiddleware(req *http.Request, next option.MiddlewareNext) (res *http.Response, err error) {
reqOut, _ := httputil.DumpRequest(req, true)
// Forward the request to the next handler
res, err = next(req)
fmt.Printf("[req] %s\n", reqOut)
// Handle stuff after the request
if err != nil {
return res, err
}
respOut, _ := httputil.DumpResponse(res, true)
fmt.Printf("[resp] %s\n", respOut)
return res, err
}
```
+2 -2
View File
@@ -92,7 +92,7 @@ var (
OaiResponsesBlockingConversation []byte
//go:embed openai/responses/blocking/http_error.txtar
OaiResponsesBlockingHttpErr []byte
OaiResponsesBlockingHTTPErr []byte
//go:embed openai/responses/blocking/prev_response_id.txtar
OaiResponsesBlockingPrevResponseID []byte
@@ -139,7 +139,7 @@ var (
OaiResponsesStreamingConversation []byte
//go:embed openai/responses/streaming/http_error.txtar
OaiResponsesStreamingHttpErr []byte
OaiResponsesStreamingHTTPErr []byte
//go:embed openai/responses/streaming/prev_response_id.txtar
OaiResponsesStreamingPrevResponseID []byte
+8 -7
View File
@@ -1,4 +1,4 @@
package intercept
package intercept_test
import (
"testing"
@@ -7,14 +7,15 @@ import (
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/aibridge/context"
"github.com/coder/coder/v2/aibridge/intercept"
"github.com/coder/coder/v2/aibridge/recorder"
)
func TestNilActor(t *testing.T) {
t.Parallel()
require.Nil(t, ActorHeadersAsOpenAIOpts(nil))
require.Nil(t, ActorHeadersAsAnthropicOpts(nil))
require.Nil(t, intercept.ActorHeadersAsOpenAIOpts(nil))
require.Nil(t, intercept.ActorHeadersAsAnthropicOpts(nil))
}
func TestBasic(t *testing.T) {
@@ -28,9 +29,9 @@ func TestBasic(t *testing.T) {
// We can't peek inside since these opts require an internal type to apply onto.
// All we can do is check the length.
// See TestActorHeaders for an integration test.
oaiOpts := ActorHeadersAsOpenAIOpts(actor)
oaiOpts := intercept.ActorHeadersAsOpenAIOpts(actor)
require.Len(t, oaiOpts, 1)
antOpts := ActorHeadersAsAnthropicOpts(actor)
antOpts := intercept.ActorHeadersAsAnthropicOpts(actor)
require.Len(t, antOpts, 1)
}
@@ -49,8 +50,8 @@ func TestBasicAndMetadata(t *testing.T) {
// We can't peek inside since these opts require an internal type to apply onto.
// All we can do is check the length.
// See TestActorHeaders for an integration test.
oaiOpts := ActorHeadersAsOpenAIOpts(actor)
oaiOpts := intercept.ActorHeadersAsOpenAIOpts(actor)
require.Len(t, oaiOpts, 1+len(actor.Metadata))
antOpts := ActorHeadersAsAnthropicOpts(actor)
antOpts := intercept.ActorHeadersAsAnthropicOpts(actor)
require.Len(t, antOpts, 1+len(actor.Metadata))
}
+17 -16
View File
@@ -105,10 +105,11 @@ func (d *dumper) dumpRequest(req *http.Request) error {
if err != nil {
return xerrors.Errorf("write request header terminator: %w", err)
}
buf.Write(prettyBody)
buf.WriteByte('\n')
// bytes.Buffer writes to in-memory storage and never return errors.
_, _ = buf.Write(prettyBody)
_ = buf.WriteByte('\n')
return os.WriteFile(dumpPath, buf.Bytes(), 0o644)
return os.WriteFile(dumpPath, buf.Bytes(), 0o600)
}
func (d *dumper) dumpResponse(resp *http.Response) error {
@@ -129,19 +130,19 @@ func (d *dumper) dumpResponse(resp *http.Response) error {
return xerrors.Errorf("write response header terminator: %w", err)
}
// Wrap the response body to capture it as it streams
if resp.Body != nil {
resp.Body = &streamingBodyDumper{
body: resp.Body,
dumpPath: dumpPath,
headerData: headerBuf.Bytes(),
logger: func(err error) {
d.logger.Named("apidump").Warn(context.Background(), "failed to initialize response dump", slog.Error(err))
},
}
} else {
if resp.Body == nil {
// No body, just write headers
return os.WriteFile(dumpPath, headerBuf.Bytes(), 0o644)
return os.WriteFile(dumpPath, headerBuf.Bytes(), 0o600)
}
// Wrap the response body to capture it as it streams
resp.Body = &streamingBodyDumper{
body: resp.Body,
dumpPath: dumpPath,
headerData: headerBuf.Bytes(),
logger: func(err error) {
d.logger.Named("apidump").Warn(context.Background(), "failed to initialize response dump", slog.Error(err))
},
}
return nil
@@ -152,7 +153,7 @@ func (d *dumper) dumpResponse(resp *http.Response) error {
// for deterministic output.
// `sensitive` and `overrides` must both supply keys in canoncialized form.
// See [textproto.MIMEHeader].
func (d *dumper) writeRedactedHeaders(w io.Writer, headers http.Header, sensitive map[string]struct{}, overrides map[string]string) error {
func (*dumper) writeRedactedHeaders(w io.Writer, headers http.Header, sensitive map[string]struct{}, overrides map[string]string) error {
// Collect all header keys including overrides.
headerKeys := make([]string, 0, len(headers)+len(overrides))
seen := make(map[string]struct{}, len(headers)+len(overrides))
+18 -14
View File
@@ -1,4 +1,4 @@
package apidump
package apidump //nolint:testpackage // tests unexported internals
import (
"bytes"
@@ -39,7 +39,7 @@ func TestBridgedMiddleware_RedactsSensitiveRequestHeaders(t *testing.T) {
middleware := NewBridgeMiddleware(tmpDir, "openai", "gpt-4", interceptionID, logger, clk)
require.NotNil(t, middleware)
req, err := http.NewRequest(http.MethodPost, "https://api.openai.com/v1/chat/completions", bytes.NewReader([]byte(`{"test": true}`)))
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, "https://api.openai.com/v1/chat/completions", bytes.NewReader([]byte(`{"test": true}`)))
require.NoError(t, err)
// Add sensitive headers that should be redacted
@@ -52,7 +52,7 @@ func TestBridgedMiddleware_RedactsSensitiveRequestHeaders(t *testing.T) {
req.Header.Set("User-Agent", "test-client")
// Call middleware with a mock next function
_, err = middleware(req, func(r *http.Request) (*http.Response, error) {
resp, err := middleware(req, func(r *http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusOK,
Status: "200 OK",
@@ -62,6 +62,7 @@ func TestBridgedMiddleware_RedactsSensitiveRequestHeaders(t *testing.T) {
}, nil
})
require.NoError(t, err)
defer resp.Body.Close()
// Read the request dump file
modelDir := filepath.Join(tmpDir, "openai", "gpt-4")
@@ -96,7 +97,7 @@ func TestBridgedMiddleware_RedactsSensitiveResponseHeaders(t *testing.T) {
middleware := NewBridgeMiddleware(tmpDir, "openai", "gpt-4", interceptionID, logger, clk)
require.NotNil(t, middleware)
req, err := http.NewRequest(http.MethodPost, "https://api.openai.com/v1/chat/completions", bytes.NewReader([]byte(`{}`)))
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, "https://api.openai.com/v1/chat/completions", bytes.NewReader([]byte(`{}`)))
require.NoError(t, err)
// Call middleware with a response containing sensitive headers
@@ -166,11 +167,11 @@ func TestBridgedMiddleware_PreservesRequestBody(t *testing.T) {
require.NotNil(t, middleware)
originalBody := `{"messages": [{"role": "user", "content": "hello"}]}`
req, err := http.NewRequest(http.MethodPost, "https://api.openai.com/v1/chat/completions", bytes.NewReader([]byte(originalBody)))
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, "https://api.openai.com/v1/chat/completions", bytes.NewReader([]byte(originalBody)))
require.NoError(t, err)
var capturedBody []byte
_, err = middleware(req, func(r *http.Request) (*http.Response, error) {
resp2, err := middleware(req, func(r *http.Request) (*http.Response, error) {
// Read the body in the next handler to verify it's still available
capturedBody, _ = io.ReadAll(r.Body)
return &http.Response{
@@ -182,6 +183,7 @@ func TestBridgedMiddleware_PreservesRequestBody(t *testing.T) {
}, nil
})
require.NoError(t, err)
defer resp2.Body.Close()
// Verify the body was preserved for the next handler
require.Equal(t, originalBody, string(capturedBody))
@@ -199,10 +201,10 @@ func TestBridgedMiddleware_ModelWithSlash(t *testing.T) {
middleware := NewBridgeMiddleware(tmpDir, "google", "gemini/1.5-pro", interceptionID, logger, clk)
require.NotNil(t, middleware)
req, err := http.NewRequest(http.MethodPost, "https://api.google.com/v1/chat", bytes.NewReader([]byte(`{}`)))
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, "https://api.google.com/v1/chat", bytes.NewReader([]byte(`{}`)))
require.NoError(t, err)
_, err = middleware(req, func(r *http.Request) (*http.Response, error) {
resp3, err := middleware(req, func(r *http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusOK,
Status: "200 OK",
@@ -212,6 +214,7 @@ func TestBridgedMiddleware_ModelWithSlash(t *testing.T) {
}, nil
})
require.NoError(t, err)
defer resp3.Body.Close()
// Verify files are created with sanitized model name
modelDir := filepath.Join(tmpDir, "google", "gemini-1.5-pro")
@@ -278,7 +281,7 @@ func TestBridgedMiddleware_AllSensitiveRequestHeaders(t *testing.T) {
middleware := NewBridgeMiddleware(tmpDir, "openai", "gpt-4", interceptionID, logger, clk)
require.NotNil(t, middleware)
req, err := http.NewRequest(http.MethodPost, "https://api.openai.com/v1/chat/completions", bytes.NewReader([]byte(`{}`)))
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, "https://api.openai.com/v1/chat/completions", bytes.NewReader([]byte(`{}`)))
require.NoError(t, err)
// Set all sensitive headers
@@ -290,7 +293,7 @@ func TestBridgedMiddleware_AllSensitiveRequestHeaders(t *testing.T) {
req.Header.Set("Proxy-Authorization", "Basic proxy-creds")
req.Header.Set("X-Amz-Security-Token", "aws-security-token")
_, err = middleware(req, func(r *http.Request) (*http.Response, error) {
resp4, err := middleware(req, func(r *http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusOK,
Status: "200 OK",
@@ -300,6 +303,7 @@ func TestBridgedMiddleware_AllSensitiveRequestHeaders(t *testing.T) {
}, nil
})
require.NoError(t, err)
defer resp4.Body.Close()
modelDir := filepath.Join(tmpDir, "openai", "gpt-4")
reqDumpPath := findDumpFile(t, modelDir, SuffixRequest)
@@ -355,10 +359,10 @@ func TestPassthroughMiddleware(t *testing.T) {
rt := NewPassthroughMiddleware(inner, tmpDir, "openai", logger, clk)
req, err := http.NewRequest(http.MethodGet, "https://api.openai.com/v1/models", nil)
req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, "https://api.openai.com/v1/models", nil)
require.NoError(t, err)
resp, err := rt.RoundTrip(req)
resp, err := rt.RoundTrip(req) //nolint:bodyclose // resp is nil on error
require.ErrorIs(t, err, innerErr)
require.Nil(t, resp)
})
@@ -399,7 +403,7 @@ func TestPassthroughMiddleware(t *testing.T) {
rt := NewPassthroughMiddleware(inner, tmpDir, "openai", logger, clk)
req, err := http.NewRequest(http.MethodPost, "/v1/models", bytes.NewReader([]byte(req1Body)))
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, "/v1/models", bytes.NewReader([]byte(req1Body)))
require.NoError(t, err)
req.Header.Set("Authorization", "Bearer sk-secret-key-12345")
resp, err := rt.RoundTrip(req)
@@ -409,7 +413,7 @@ func TestPassthroughMiddleware(t *testing.T) {
require.NoError(t, resp.Body.Close())
// Second request should create new req/resp files
req2, err := http.NewRequest(http.MethodPost, "/v1/conversations", bytes.NewReader([]byte(req2Body)))
req2, err := http.NewRequestWithContext(t.Context(), http.MethodPost, "/v1/conversations", bytes.NewReader([]byte(req2Body)))
require.NoError(t, err)
resp2, err := rt.RoundTrip(req2)
require.NoError(t, err)
+1 -1
View File
@@ -1,4 +1,4 @@
package apidump
package apidump //nolint:testpackage // tests unexported internals
import (
"bytes"
+1 -1
View File
@@ -37,7 +37,7 @@ func (s *streamingBodyDumper) init() {
// Write headers first.
if _, err := s.file.Write(s.headerData); err != nil {
s.initErr = xerrors.Errorf("write headers: %w", err)
s.file.Close()
_ = s.file.Close() // best-effort cleanup on header write failure
s.file = nil
}
})
+9 -6
View File
@@ -1,4 +1,4 @@
package apidump
package apidump //nolint:testpackage // shares test helpers with apidump_test.go
import (
"bytes"
@@ -28,7 +28,7 @@ func TestMiddleware_StreamingResponse(t *testing.T) {
middleware := NewBridgeMiddleware(tmpDir, "openai", "gpt-4", interceptionID, logger, clk)
require.NotNil(t, middleware)
req, err := http.NewRequest(http.MethodPost, "https://api.openai.com/v1/chat/completions", bytes.NewReader([]byte(`{}`)))
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, "https://api.openai.com/v1/chat/completions", bytes.NewReader([]byte(`{}`)))
require.NoError(t, err)
// Simulate a streaming response with multiple chunks
@@ -42,10 +42,12 @@ func TestMiddleware_StreamingResponse(t *testing.T) {
// Create a pipe to simulate streaming
pr, pw := io.Pipe()
go func() {
defer pw.Close() //nolint:revive // error handled via pipe read side
for _, chunk := range chunks {
pw.Write([]byte(chunk))
if _, err := pw.Write([]byte(chunk)); err != nil {
return
}
}
pw.Close()
}()
resp, err := middleware(req, func(r *http.Request) (*http.Response, error) {
@@ -65,7 +67,7 @@ func TestMiddleware_StreamingResponse(t *testing.T) {
for {
n, err := resp.Body.Read(buf)
if n > 0 {
receivedData.Write(buf[:n])
_, _ = receivedData.Write(buf[:n]) // bytes.Buffer.Write never fails
}
if err == io.EOF {
break
@@ -104,7 +106,7 @@ func TestMiddleware_PreservesResponseBody(t *testing.T) {
middleware := NewBridgeMiddleware(tmpDir, "openai", "gpt-4", interceptionID, logger, clk)
require.NotNil(t, middleware)
req, err := http.NewRequest(http.MethodPost, "https://api.openai.com/v1/chat/completions", bytes.NewReader([]byte(`{}`)))
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, "https://api.openai.com/v1/chat/completions", bytes.NewReader([]byte(`{}`)))
require.NoError(t, err)
originalRespBody := `{"choices": [{"message": {"content": "hi"}}]}`
@@ -118,6 +120,7 @@ func TestMiddleware_PreservesResponseBody(t *testing.T) {
}, nil
})
require.NoError(t, err)
defer resp.Body.Close()
// Verify the response body is still readable after middleware
capturedBody, err := io.ReadAll(resp.Body)
+17 -17
View File
@@ -79,9 +79,9 @@ func (i *interceptionBase) Credential() intercept.CredentialInfo {
return i.credential
}
func (i *interceptionBase) Setup(logger slog.Logger, recorder recorder.Recorder, mcpProxy mcp.ServerProxier) {
func (i *interceptionBase) Setup(logger slog.Logger, rec recorder.Recorder, mcpProxy mcp.ServerProxier) {
i.logger = logger
i.recorder = recorder
i.recorder = rec
i.mcpProxy = mcpProxy
}
@@ -98,13 +98,13 @@ func (i *interceptionBase) CorrelatingToolCallID() *string {
return &msg.OfTool.ToolCallID
}
func (s *interceptionBase) baseTraceAttributes(r *http.Request, streaming bool) []attribute.KeyValue {
func (i *interceptionBase) baseTraceAttributes(r *http.Request, streaming bool) []attribute.KeyValue {
return []attribute.KeyValue{
attribute.String(tracing.RequestPath, r.URL.Path),
attribute.String(tracing.InterceptionID, s.id.String()),
attribute.String(tracing.InterceptionID, i.id.String()),
attribute.String(tracing.InitiatorID, aibcontext.ActorIDFromContext(r.Context())),
attribute.String(tracing.Provider, s.providerName),
attribute.String(tracing.Model, s.Model()),
attribute.String(tracing.Provider, i.providerName),
attribute.String(tracing.Model, i.Model()),
attribute.Bool(tracing.Streaming, streaming),
}
}
@@ -114,10 +114,10 @@ func (i *interceptionBase) Model() string {
return "coder-aibridge-unknown"
}
return string(i.req.Model)
return i.req.Model
}
func (i *interceptionBase) newErrorResponse(err error) map[string]any {
func (*interceptionBase) newErrorResponse(err error) map[string]any {
return map[string]any{
"error": true,
"message": err.Error(),
@@ -172,7 +172,7 @@ func (i *interceptionBase) unmarshalArgs(in string) (args recorder.ToolArgs) {
}
// writeUpstreamError marshals and writes a given error.
func (i *interceptionBase) writeUpstreamError(w http.ResponseWriter, oaiErr *errorResponse) {
func (i *interceptionBase) writeUpstreamError(w http.ResponseWriter, oaiErr *chatCompletionResponseError) {
if oaiErr == nil {
return
}
@@ -182,7 +182,7 @@ func (i *interceptionBase) writeUpstreamError(w http.ResponseWriter, oaiErr *err
out, err := json.Marshal(oaiErr)
if err != nil {
i.logger.Warn(context.Background(), "failed to marshal upstream error", slog.Error(err), slog.F("error_payload", slog.F("%+v", oaiErr)))
i.logger.Warn(context.Background(), "failed to marshal upstream error", slog.Error(err), slog.F("error_payload", oaiErr))
// Response has to match expected format.
_, _ = w.Write([]byte(`{
"error": {
@@ -227,13 +227,13 @@ func calculateActualInputTokenUsage(in openai.CompletionUsage) int64 {
in.PromptTokensDetails.CachedTokens /* The aggregated number of text input tokens that has been cached from previous requests. */
}
func getErrorResponse(err error) *errorResponse {
func getErrorResponse(err error) *chatCompletionResponseError {
var apiErr *openai.Error
if !errors.As(err, &apiErr) {
return nil
}
return &errorResponse{
return &chatCompletionResponseError{
ErrorObject: &shared.ErrorObject{
Code: apiErr.Code,
Message: apiErr.Message,
@@ -243,15 +243,15 @@ func getErrorResponse(err error) *errorResponse {
}
}
var _ error = &errorResponse{}
var _ error = &chatCompletionResponseError{}
type errorResponse struct {
type chatCompletionResponseError struct {
ErrorObject *shared.ErrorObject `json:"error"`
StatusCode int `json:"-"`
}
func newErrorResponse(msg error) *errorResponse {
return &errorResponse{
func newErrorResponse(msg error) *chatCompletionResponseError {
return &chatCompletionResponseError{
ErrorObject: &shared.ErrorObject{
Code: "error",
Message: msg.Error(),
@@ -260,7 +260,7 @@ func newErrorResponse(msg error) *errorResponse {
}
}
func (a *errorResponse) Error() string {
func (a *chatCompletionResponseError) Error() string {
if a.ErrorObject == nil {
return ""
}
@@ -1,4 +1,4 @@
package chatcompletions
package chatcompletions //nolint:testpackage // tests unexported internals
import (
"testing"
@@ -50,16 +50,16 @@ func NewBlockingInterceptor(
}}
}
func (s *BlockingInterception) Setup(logger slog.Logger, recorder recorder.Recorder, mcpProxy mcp.ServerProxier) {
s.interceptionBase.Setup(logger.Named("blocking"), recorder, mcpProxy)
func (i *BlockingInterception) Setup(logger slog.Logger, rec recorder.Recorder, mcpProxy mcp.ServerProxier) {
i.interceptionBase.Setup(logger.Named("blocking"), rec, mcpProxy)
}
func (s *BlockingInterception) Streaming() bool {
func (*BlockingInterception) Streaming() bool {
return false
}
func (s *BlockingInterception) TraceAttributes(r *http.Request) []attribute.KeyValue {
return s.interceptionBase.baseTraceAttributes(r, false)
func (i *BlockingInterception) TraceAttributes(r *http.Request) []attribute.KeyValue {
return i.interceptionBase.baseTraceAttributes(r, false)
}
func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Request) (outErr error) {
@@ -52,7 +52,7 @@ func (c *ChatCompletionNewParamsWrapper) lastUserPrompt() (*string, error) {
// We only care if the last message was issued by a user.
msg := c.Messages[len(c.Messages)-1]
if msg.OfUser == nil {
return nil, nil
return nil, nil //nolint:nilnil // no user prompt found is not an error
}
if msg.OfUser.Content.OfString.String() != "" {
@@ -69,5 +69,5 @@ func (c *ChatCompletionNewParamsWrapper) lastUserPrompt() (*string, error) {
}
}
return nil, nil
return nil, nil //nolint:nilnil // no text content found is not an error
}
@@ -1,4 +1,4 @@
package chatcompletions
package chatcompletions //nolint:testpackage // tests unexported internals
import (
"fmt"
@@ -114,6 +114,8 @@ func TestOpenAILastUserPrompt(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
result, err := tt.wrapper.lastUserPrompt()
if tt.expectError {
@@ -144,7 +146,7 @@ func generatePayload(messageCount int) []byte {
}
// Use realistic message content size
content := fmt.Sprintf("This is message number %d with some realistic content that might appear in a conversation.", i+1)
messages = append(messages, fmt.Sprintf(`{"role": "%s", "content": "%s"}`, role, content))
messages = append(messages, fmt.Sprintf(`{"role": %q, "content": %q}`, role, content))
}
return []byte(fmt.Sprintf(`{
+44 -45
View File
@@ -54,16 +54,16 @@ func NewStreamingInterceptor(
}}
}
func (i *StreamingInterception) Setup(logger slog.Logger, recorder recorder.Recorder, mcpProxy mcp.ServerProxier) {
i.interceptionBase.Setup(logger.Named("streaming"), recorder, mcpProxy)
func (i *StreamingInterception) Setup(logger slog.Logger, rec recorder.Recorder, mcpProxy mcp.ServerProxier) {
i.interceptionBase.Setup(logger.Named("streaming"), rec, mcpProxy)
}
func (i *StreamingInterception) Streaming() bool {
func (*StreamingInterception) Streaming() bool {
return true
}
func (s *StreamingInterception) TraceAttributes(r *http.Request) []attribute.KeyValue {
return s.interceptionBase.baseTraceAttributes(r, true)
func (i *StreamingInterception) TraceAttributes(r *http.Request) []attribute.KeyValue {
return i.interceptionBase.baseTraceAttributes(r, true)
}
// ProcessRequest handles a request to /v1/chat/completions.
@@ -189,16 +189,14 @@ func (i *StreamingInterception) ProcessRequest(w http.ResponseWriter, r *http.Re
})
toolCall = nil
} else {
} else if stream.Err() == nil {
// When the provider responds with only tool calls (no text content),
// no chunks are relayed to the client, so the stream is not yet
// initiated. Initiate it here so the SSE headers are sent and the
// ping ticker is started, preventing client timeout during tool invocation.
// Only initiate if no stream error, if there's an error, we'll return
// an HTTP error response instead of starting an SSE stream.
if stream.Err() == nil {
events.InitiateStream(w)
}
events.InitiateStream(w)
}
}
@@ -231,43 +229,43 @@ func (i *StreamingInterception) ProcessRequest(w http.ResponseWriter, r *http.Re
})
}
if events.IsStreaming() {
// Check if the stream encountered any errors.
if streamErr := stream.Err(); streamErr != nil {
if eventstream.IsUnrecoverableError(streamErr) {
logger.Debug(ctx, "stream terminated", slog.Error(streamErr))
// We can't reflect an error back if there's a connection error or the request context was canceled.
} else if oaiErr := getErrorResponse(streamErr); oaiErr != nil {
logger.Warn(ctx, "openai stream error", slog.Error(streamErr))
interceptionErr = oaiErr
} else {
logger.Warn(ctx, "unknown error", slog.Error(streamErr))
// Unfortunately, the OpenAI SDK does not support parsing errors received in the stream
// into known types (i.e. [shared.OverloadedError]).
// See https://github.com/openai/openai-go/blob/v2.7.0/packages/ssestream/ssestream.go#L171
// All it does is wrap the payload in an error - which is all we can return, currently.
interceptionErr = newErrorResponse(xerrors.Errorf("unknown stream error: %w", streamErr))
}
} else if lastErr != nil {
// Otherwise check if any logical errors occurred during processing.
logger.Warn(ctx, "stream failed", slog.Error(lastErr))
interceptionErr = newErrorResponse(xerrors.Errorf("processing error: %w", lastErr))
}
if interceptionErr != nil {
payload, err := i.marshalErr(interceptionErr)
if err != nil {
logger.Warn(ctx, "failed to marshal error", slog.Error(err), slog.F("error_payload", slog.F("%+v", interceptionErr)))
} else if err := events.Send(streamCtx, payload); err != nil {
logger.Warn(ctx, "failed to relay error", slog.Error(err), slog.F("payload", payload))
}
}
} else {
if !events.IsStreaming() {
// response/downstream Stream has not started yet; write error response and exit.
i.writeUpstreamError(w, getErrorResponse(stream.Err()))
return stream.Err()
}
// Check if the stream encountered any errors.
if streamErr := stream.Err(); streamErr != nil {
if eventstream.IsUnrecoverableError(streamErr) {
logger.Debug(ctx, "stream terminated", slog.Error(streamErr))
// We can't reflect an error back if there's a connection error or the request context was canceled.
} else if oaiErr := getErrorResponse(streamErr); oaiErr != nil {
logger.Warn(ctx, "openai stream error", slog.Error(streamErr))
interceptionErr = oaiErr
} else {
logger.Warn(ctx, "unknown stream error encountered", slog.Error(streamErr))
// Unfortunately, the OpenAI SDK does not support parsing errors received in the stream
// into known types (i.e. [shared.OverloadedError]).
// See https://github.com/openai/openai-go/blob/v2.7.0/packages/ssestream/ssestream.go#L171
// All it does is wrap the payload in an error - which is all we can return, currently.
interceptionErr = newErrorResponse(xerrors.Errorf("unknown stream error: %w", streamErr))
}
} else if lastErr != nil {
// Otherwise check if any logical errors occurred during processing.
logger.Warn(ctx, "stream processing failed", slog.Error(lastErr))
interceptionErr = newErrorResponse(xerrors.Errorf("processing error: %w", lastErr))
}
if interceptionErr != nil {
payload, err := i.marshalErr(interceptionErr)
if err != nil {
logger.Warn(ctx, "failed to marshal error", slog.Error(err), slog.F("error_payload", interceptionErr.Error()))
} else if err := events.Send(streamCtx, payload); err != nil {
logger.Warn(ctx, "failed to relay error", slog.Error(err), slog.F("payload", payload))
}
}
// No tool call, nothing more to do.
if toolCall == nil {
break
@@ -390,11 +388,12 @@ func (i *StreamingInterception) marshalErr(err error) ([]byte, error) {
return i.encodeForStream(data), nil
}
func (i *StreamingInterception) encodeForStream(payload []byte) []byte {
func (*StreamingInterception) encodeForStream(payload []byte) []byte {
// bytes.Buffer writes to in-memory storage and never return errors.
var buf bytes.Buffer
buf.WriteString("data: ")
buf.Write(payload)
buf.WriteString("\n\n")
_, _ = buf.WriteString("data: ")
_, _ = buf.Write(payload)
_, _ = buf.WriteString("\n\n")
return buf.Bytes()
}
@@ -1,4 +1,4 @@
package chatcompletions
package chatcompletions_test
import (
"net/http"
@@ -16,6 +16,7 @@ import (
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/aibridge/config"
"github.com/coder/coder/v2/aibridge/intercept"
"github.com/coder/coder/v2/aibridge/intercept/chatcompletions"
"github.com/coder/coder/v2/aibridge/internal/testutil"
)
@@ -73,7 +74,7 @@ func TestStreamingInterception_RelaysUpstreamErrorToClient(t *testing.T) {
Key: "test-key",
}
req := &ChatCompletionNewParamsWrapper{
req := &chatcompletions.ChatCompletionNewParamsWrapper{
ChatCompletionNewParams: openai.ChatCompletionNewParams{
Model: "gpt-4",
Messages: []openai.ChatCompletionMessageParamUnion{
@@ -88,7 +89,7 @@ func TestStreamingInterception_RelaysUpstreamErrorToClient(t *testing.T) {
httpReq := httptest.NewRequest(http.MethodPost, "/chat/completions", nil)
tracer := otel.Tracer("test")
interceptor := NewStreamingInterceptor(uuid.New(), req, config.ProviderOpenAI, cfg, httpReq.Header, "Authorization", tracer, intercept.CredentialInfo{})
interceptor := chatcompletions.NewStreamingInterceptor(uuid.New(), req, config.ProviderOpenAI, cfg, httpReq.Header, "Authorization", tracer, intercept.CredentialInfo{})
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
interceptor.Setup(logger, &testutil.MockRecorder{}, nil)
+15 -13
View File
@@ -1,4 +1,4 @@
package intercept
package intercept_test
import (
"net/http"
@@ -6,6 +6,8 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/aibridge/intercept"
)
func TestPrepareClientHeaders(t *testing.T) {
@@ -14,7 +16,7 @@ func TestPrepareClientHeaders(t *testing.T) {
t.Run("nil input returns empty header", func(t *testing.T) {
t.Parallel()
result := PrepareClientHeaders(nil)
result := intercept.PrepareClientHeaders(nil)
require.Empty(t, result)
})
@@ -29,7 +31,7 @@ func TestPrepareClientHeaders(t *testing.T) {
"X-Custom": {"preserved"},
}
result := PrepareClientHeaders(input)
result := intercept.PrepareClientHeaders(input)
assert.Empty(t, result.Get("Connection"))
assert.Empty(t, result.Get("Keep-Alive"))
@@ -48,7 +50,7 @@ func TestPrepareClientHeaders(t *testing.T) {
"X-Custom": {"preserved"},
}
result := PrepareClientHeaders(input)
result := intercept.PrepareClientHeaders(input)
assert.Empty(t, result.Get("Host"))
assert.Empty(t, result.Get("Accept-Encoding"))
@@ -65,7 +67,7 @@ func TestPrepareClientHeaders(t *testing.T) {
"X-Custom": {"preserved"},
}
result := PrepareClientHeaders(input)
result := intercept.PrepareClientHeaders(input)
assert.Empty(t, result.Get("Authorization"))
assert.Empty(t, result.Get("X-Api-Key"))
@@ -79,7 +81,7 @@ func TestPrepareClientHeaders(t *testing.T) {
"X-Custom": {"value-1", "value-2"},
}
result := PrepareClientHeaders(input)
result := intercept.PrepareClientHeaders(input)
require.Equal(t, []string{"value-1", "value-2"}, result["X-Custom"])
})
@@ -93,7 +95,7 @@ func TestPrepareClientHeaders(t *testing.T) {
}
originalCopy := input.Clone()
_ = PrepareClientHeaders(input)
_ = intercept.PrepareClientHeaders(input)
require.Equal(t, originalCopy, input)
})
@@ -113,7 +115,7 @@ func TestBuildUpstreamHeaders(t *testing.T) {
"User-Agent": {"claude-code/1.0"},
}
result := BuildUpstreamHeaders(sdkHeader, clientHeaders, "Authorization")
result := intercept.BuildUpstreamHeaders(sdkHeader, clientHeaders, "Authorization")
assert.Equal(t, "Bearer sk-provider-key", result.Get("Authorization"))
assert.Equal(t, "claude-code/1.0", result.Get("User-Agent"))
@@ -131,7 +133,7 @@ func TestBuildUpstreamHeaders(t *testing.T) {
"Anthropic-Beta": {"prompt-caching-2024-07-31"},
}
result := BuildUpstreamHeaders(sdkHeader, clientHeaders, "X-Api-Key")
result := intercept.BuildUpstreamHeaders(sdkHeader, clientHeaders, "X-Api-Key")
assert.Equal(t, "sk-ant-provider-key", result.Get("X-Api-Key"))
assert.Empty(t, result.Get("Authorization"))
@@ -151,7 +153,7 @@ func TestBuildUpstreamHeaders(t *testing.T) {
"User-Agent": {"claude-code/1.0"},
}
result := BuildUpstreamHeaders(sdkHeader, clientHeaders, "Authorization")
result := intercept.BuildUpstreamHeaders(sdkHeader, clientHeaders, "Authorization")
assert.Equal(t, "Bearer sk-key", result.Get("Authorization"))
assert.Equal(t, "user-123", result.Get("X-Ai-Bridge-Actor-Id"))
@@ -174,7 +176,7 @@ func TestBuildUpstreamHeaders(t *testing.T) {
"User-Agent": {"claude-code/1.0"},
}
result := BuildUpstreamHeaders(sdkHeader, clientHeaders, "Authorization")
result := intercept.BuildUpstreamHeaders(sdkHeader, clientHeaders, "Authorization")
assert.Empty(t, result.Get("Connection"))
assert.Empty(t, result.Get("Host"))
@@ -192,7 +194,7 @@ func TestBuildUpstreamHeaders(t *testing.T) {
"User-Agent": {"claude-code/1.0"},
}
result := BuildUpstreamHeaders(sdkHeader, clientHeaders, "Authorization")
result := intercept.BuildUpstreamHeaders(sdkHeader, clientHeaders, "Authorization")
assert.Empty(t, result.Get("Authorization"))
assert.Equal(t, "claude-code/1.0", result.Get("User-Agent"))
@@ -211,7 +213,7 @@ func TestBuildUpstreamHeaders(t *testing.T) {
sdkCopy := sdkHeader.Clone()
clientCopy := clientHeaders.Clone()
_ = BuildUpstreamHeaders(sdkHeader, clientHeaders, "Authorization")
_ = intercept.BuildUpstreamHeaders(sdkHeader, clientHeaders, "Authorization")
require.Equal(t, sdkCopy, sdkHeader)
require.Equal(t, clientCopy, clientHeaders)
@@ -32,7 +32,6 @@ type EventStream struct {
initiated atomic.Bool
initiateOnce sync.Once
closeOnce sync.Once
shutdownOnce sync.Once
eventsCh chan event
@@ -133,7 +132,7 @@ func (s *EventStream) Start(w http.ResponseWriter, r *http.Request) {
return
}
if err := flush(w); err != nil {
s.logger.Warn(ctx, "failed to flush", slog.Error(err))
s.logger.Warn(ctx, "failed to flush event stream", slog.Error(err))
return
}
@@ -240,8 +239,7 @@ func flush(w http.ResponseWriter) (err error) {
}
defer func() {
if r := recover(); r != nil {
// Likely a broken connection, don't spam the logs.
if r := recover(); r != nil { //nolint:revive,staticcheck // Intentionally swallowed; likely a broken connection.
}
}()
+1 -1
View File
@@ -17,7 +17,7 @@ type Interceptor interface {
ID() uuid.UUID
// Setup injects some required dependencies. This MUST be called before using the interceptor
// to process requests.
Setup(logger slog.Logger, recorder recorder.Recorder, mcpProxy mcp.ServerProxier)
Setup(logger slog.Logger, rec recorder.Recorder, mcpProxy mcp.ServerProxier)
// Model returns the model in use for this [Interceptor].
Model() string
// ProcessRequest handles the HTTP request.
+27 -29
View File
@@ -65,7 +65,7 @@ var bedrockSupportedBetaFlags = map[string]bool{
type interceptionBase struct {
id uuid.UUID
providerName string
reqPayload MessagesRequestPayload
reqPayload RequestPayload
cfg aibconfig.Anthropic
bedrockCfg *aibconfig.AWSBedrock
@@ -90,9 +90,9 @@ func (i *interceptionBase) Credential() intercept.CredentialInfo {
return i.credential
}
func (i *interceptionBase) Setup(logger slog.Logger, recorder recorder.Recorder, mcpProxy mcp.ServerProxier) {
func (i *interceptionBase) Setup(logger slog.Logger, rec recorder.Recorder, mcpProxy mcp.ServerProxier) {
i.logger = logger
i.recorder = recorder
i.recorder = rec
i.mcpProxy = mcpProxy
}
@@ -116,15 +116,15 @@ func (i *interceptionBase) Model() string {
return i.reqPayload.model()
}
func (s *interceptionBase) baseTraceAttributes(r *http.Request, streaming bool) []attribute.KeyValue {
func (i *interceptionBase) baseTraceAttributes(r *http.Request, streaming bool) []attribute.KeyValue {
return []attribute.KeyValue{
attribute.String(tracing.RequestPath, r.URL.Path),
attribute.String(tracing.InterceptionID, s.id.String()),
attribute.String(tracing.InterceptionID, i.id.String()),
attribute.String(tracing.InitiatorID, aibcontext.ActorIDFromContext(r.Context())),
attribute.String(tracing.Provider, s.providerName),
attribute.String(tracing.Model, s.Model()),
attribute.String(tracing.Provider, i.providerName),
attribute.String(tracing.Model, i.Model()),
attribute.Bool(tracing.Streaming, streaming),
attribute.Bool(tracing.IsBedrock, s.bedrockCfg != nil),
attribute.Bool(tracing.IsBedrock, i.bedrockCfg != nil),
}
}
@@ -174,24 +174,22 @@ func (i *interceptionBase) disableParallelToolCalls() {
}
// extractModelThoughts returns any thinking blocks that were returned in the response.
func (i *interceptionBase) extractModelThoughts(msg *anthropic.Message) []*recorder.ModelThoughtRecord {
func (*interceptionBase) extractModelThoughts(msg *anthropic.Message) []*recorder.ModelThoughtRecord {
if msg == nil {
return nil
}
var thoughtRecords []*recorder.ModelThoughtRecord
for _, block := range msg.Content {
switch variant := block.AsAny().(type) {
case anthropic.ThinkingBlock:
if variant.Thinking == "" {
continue
}
thoughtRecords = append(thoughtRecords, &recorder.ModelThoughtRecord{
Content: variant.Thinking,
Metadata: recorder.Metadata{"source": recorder.ThoughtSourceThinking},
})
}
// anthropic.RedactedThinkingBlock also exists, but there's nothing useful we can capture.
variant, ok := block.AsAny().(anthropic.ThinkingBlock)
if !ok || variant.Thinking == "" {
continue
}
thoughtRecords = append(thoughtRecords, &recorder.ModelThoughtRecord{
Content: variant.Thinking,
Metadata: recorder.Metadata{"source": recorder.ThoughtSourceThinking},
})
}
return thoughtRecords
}
@@ -264,7 +262,7 @@ func (i *interceptionBase) withBody() option.RequestOption {
return option.WithRequestBody("application/json", []byte(i.reqPayload))
}
func (i *interceptionBase) withAWSBedrockOptions(ctx context.Context, cfg *aibconfig.AWSBedrock) ([]option.RequestOption, error) {
func (*interceptionBase) withAWSBedrockOptions(ctx context.Context, cfg *aibconfig.AWSBedrock) ([]option.RequestOption, error) {
if cfg == nil {
return nil, xerrors.New("nil config given")
}
@@ -405,7 +403,7 @@ func filterBedrockBetaFlags(headers http.Header, model string) {
}
// writeUpstreamError marshals and writes a given error.
func (i *interceptionBase) writeUpstreamError(w http.ResponseWriter, antErr *ErrorResponse) {
func (i *interceptionBase) writeUpstreamError(w http.ResponseWriter, antErr *messagesResponseError) {
if antErr == nil {
return
}
@@ -415,7 +413,7 @@ func (i *interceptionBase) writeUpstreamError(w http.ResponseWriter, antErr *Err
out, err := json.Marshal(antErr)
if err != nil {
i.logger.Warn(context.Background(), "failed to marshal upstream error", slog.Error(err), slog.F("error_payload", slog.F("%+v", antErr)))
i.logger.Warn(context.Background(), "failed to marshal upstream error", slog.Error(err), slog.F("error_payload", antErr))
// Response has to match expected format.
// See https://docs.claude.com/en/api/errors#error-shapes.
_, _ = w.Write([]byte(fmt.Sprintf(`{
@@ -487,7 +485,7 @@ func accumulateUsage(dest, src any) {
}
}
func getErrorResponse(err error) *ErrorResponse {
func getErrorResponse(err error) *messagesResponseError {
var apierr *anthropic.Error
if !errors.As(err, &apierr) {
return nil
@@ -505,7 +503,7 @@ func getErrorResponse(err error) *ErrorResponse {
typ = string(detail.Type)
}
return &ErrorResponse{
return &messagesResponseError{
ErrorResponse: &anthropic.ErrorResponse{
Error: anthropic.ErrorObjectUnion{
Message: msg,
@@ -517,16 +515,16 @@ func getErrorResponse(err error) *ErrorResponse {
}
}
var _ error = &ErrorResponse{}
var _ error = &messagesResponseError{}
type ErrorResponse struct {
type messagesResponseError struct {
*anthropic.ErrorResponse
StatusCode int `json:"-"`
}
func newErrorResponse(msg error) *ErrorResponse {
return &ErrorResponse{
func newErrorResponse(msg error) *messagesResponseError {
return &messagesResponseError{
ErrorResponse: &shared.ErrorResponse{
Error: shared.ErrorObjectUnion{
Message: msg.Error(),
@@ -536,7 +534,7 @@ func newErrorResponse(msg error) *ErrorResponse {
}
}
func (a *ErrorResponse) Error() string {
func (a *messagesResponseError) Error() string {
if a.ErrorResponse == nil {
return ""
}
+20 -7
View File
@@ -1,4 +1,4 @@
package messages
package messages //nolint:testpackage // tests unexported internals
import (
"context"
@@ -197,6 +197,8 @@ func TestAWSBedrockValidation(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
base := &interceptionBase{}
opts, err := base.withAWSBedrockOptions(context.Background(), tt.cfg)
@@ -212,7 +214,10 @@ func TestAWSBedrockValidation(t *testing.T) {
}
func TestAccumulateUsage(t *testing.T) {
t.Parallel()
t.Run("Usage to Usage", func(t *testing.T) {
t.Parallel()
dest := &anthropic.Usage{
InputTokens: 10,
OutputTokens: 20,
@@ -253,6 +258,8 @@ func TestAccumulateUsage(t *testing.T) {
})
t.Run("MessageDeltaUsage to MessageDeltaUsage", func(t *testing.T) {
t.Parallel()
dest := &anthropic.MessageDeltaUsage{
InputTokens: 10,
OutputTokens: 20,
@@ -283,6 +290,8 @@ func TestAccumulateUsage(t *testing.T) {
})
t.Run("Usage to MessageDeltaUsage", func(t *testing.T) {
t.Parallel()
dest := &anthropic.MessageDeltaUsage{
InputTokens: 10,
OutputTokens: 20,
@@ -317,6 +326,8 @@ func TestAccumulateUsage(t *testing.T) {
})
t.Run("MessageDeltaUsage to Usage", func(t *testing.T) {
t.Parallel()
dest := &anthropic.Usage{
InputTokens: 10,
OutputTokens: 20,
@@ -354,6 +365,8 @@ func TestAccumulateUsage(t *testing.T) {
})
t.Run("Nil or unsupported types", func(t *testing.T) {
t.Parallel()
// Test with nil dest
var nilUsage *anthropic.Usage
source := anthropic.Usage{InputTokens: 10}
@@ -763,10 +776,10 @@ func TestAugmentRequestForBedrock_AdaptiveThinking(t *testing.T) {
}
}
func mustMessagesPayload(t *testing.T, requestBody string) MessagesRequestPayload {
func mustMessagesPayload(t *testing.T, requestBody string) RequestPayload {
t.Helper()
payload, err := NewMessagesRequestPayload([]byte(requestBody))
payload, err := NewRequestPayload([]byte(requestBody))
require.NoError(t, err)
return payload
@@ -777,11 +790,11 @@ type mockServerProxier struct {
tools []*mcp.Tool
}
func (m *mockServerProxier) Init(context.Context) error {
func (*mockServerProxier) Init(context.Context) error {
return nil
}
func (m *mockServerProxier) Shutdown(context.Context) error {
func (*mockServerProxier) Shutdown(context.Context) error {
return nil
}
@@ -798,8 +811,8 @@ func (m *mockServerProxier) GetTool(id string) *mcp.Tool {
return nil
}
func (m *mockServerProxier) CallTool(context.Context, string, any) (*mcpgo.CallToolResult, error) {
return nil, nil
func (*mockServerProxier) CallTool(context.Context, string, any) (*mcpgo.CallToolResult, error) {
return nil, nil //nolint:nilnil // mock: no-op implementation
}
func TestFilterBedrockBetaFlags(t *testing.T) {
+4 -4
View File
@@ -31,7 +31,7 @@ type BlockingInterception struct {
func NewBlockingInterceptor(
id uuid.UUID,
reqPayload MessagesRequestPayload,
reqPayload RequestPayload,
providerName string,
cfg config.Anthropic,
bedrockCfg *config.AWSBedrock,
@@ -53,15 +53,15 @@ func NewBlockingInterceptor(
}}
}
func (i *BlockingInterception) Setup(logger slog.Logger, recorder recorder.Recorder, mcpProxy mcp.ServerProxier) {
i.interceptionBase.Setup(logger.Named("blocking"), recorder, mcpProxy)
func (i *BlockingInterception) Setup(logger slog.Logger, rec recorder.Recorder, mcpProxy mcp.ServerProxier) {
i.interceptionBase.Setup(logger.Named("blocking"), rec, mcpProxy)
}
func (i *BlockingInterception) TraceAttributes(r *http.Request) []attribute.KeyValue {
return i.interceptionBase.baseTraceAttributes(r, false)
}
func (s *BlockingInterception) Streaming() bool {
func (*BlockingInterception) Streaming() bool {
return false
}
+20 -20
View File
@@ -82,12 +82,12 @@ var (
}
)
// MessagesRequestPayload is raw JSON bytes of an Anthropic Messages API request.
// RequestPayload is raw JSON bytes of an Anthropic Messages API request.
// Methods provide package-specific reads and rewrites while preserving the
// original body for upstream pass-through.
type MessagesRequestPayload []byte
type RequestPayload []byte
func NewMessagesRequestPayload(raw []byte) (MessagesRequestPayload, error) {
func NewRequestPayload(raw []byte) (RequestPayload, error) {
if len(bytes.TrimSpace(raw)) == 0 {
return nil, xerrors.New("messages empty request body")
}
@@ -95,10 +95,10 @@ func NewMessagesRequestPayload(raw []byte) (MessagesRequestPayload, error) {
return nil, xerrors.New("messages invalid JSON request body")
}
return MessagesRequestPayload(raw), nil
return RequestPayload(raw), nil
}
func (p MessagesRequestPayload) Stream() bool {
func (p RequestPayload) Stream() bool {
v := gjson.GetBytes(p, messagesReqPathStream)
if !v.IsBool() {
return false
@@ -106,11 +106,11 @@ func (p MessagesRequestPayload) Stream() bool {
return v.Bool()
}
func (p MessagesRequestPayload) model() string {
func (p RequestPayload) model() string {
return gjson.GetBytes(p, messagesReqPathModel).Str
}
func (p MessagesRequestPayload) correlatingToolCallID() *string {
func (p RequestPayload) correlatingToolCallID() *string {
messages := gjson.GetBytes(p, messagesReqPathMessages)
if !messages.IsArray() {
return nil
@@ -147,7 +147,7 @@ func (p MessagesRequestPayload) correlatingToolCallID() *string {
// lastUserPrompt returns the prompt text from the last user message. If no prompt
// is found, it returns empty string, false, nil. Unexpected shapes are treated as
// unsupported and do not fail the request path.
func (p MessagesRequestPayload) lastUserPrompt() (string, bool, error) {
func (p RequestPayload) lastUserPrompt() (string, bool, error) {
messages := gjson.GetBytes(p, messagesReqPathMessages)
if !messages.Exists() || messages.Type == gjson.Null {
return "", false, nil
@@ -195,7 +195,7 @@ func (p MessagesRequestPayload) lastUserPrompt() (string, bool, error) {
return "", false, nil
}
func (p MessagesRequestPayload) injectTools(injected []anthropic.ToolUnionParam) (MessagesRequestPayload, error) {
func (p RequestPayload) injectTools(injected []anthropic.ToolUnionParam) (RequestPayload, error) {
if len(injected) == 0 {
return p, nil
}
@@ -221,7 +221,7 @@ func (p MessagesRequestPayload) injectTools(injected []anthropic.ToolUnionParam)
return p.set(messagesReqPathTools, allTools)
}
func (p MessagesRequestPayload) disableParallelToolCalls() (MessagesRequestPayload, error) {
func (p RequestPayload) disableParallelToolCalls() (RequestPayload, error) {
toolChoice := gjson.GetBytes(p, messagesReqPathToolChoice)
// If no tool_choice was defined, assume auto.
@@ -258,7 +258,7 @@ func (p MessagesRequestPayload) disableParallelToolCalls() (MessagesRequestPaylo
}
}
func (p MessagesRequestPayload) appendedMessages(newMessages []anthropic.MessageParam) (MessagesRequestPayload, error) {
func (p RequestPayload) appendedMessages(newMessages []anthropic.MessageParam) (RequestPayload, error) {
if len(newMessages) == 0 {
return p, nil
}
@@ -285,11 +285,11 @@ func (p MessagesRequestPayload) appendedMessages(newMessages []anthropic.Message
return p.set(messagesReqPathMessages, allMessages)
}
func (p MessagesRequestPayload) withModel(model string) (MessagesRequestPayload, error) {
func (p RequestPayload) withModel(model string) (RequestPayload, error) {
return p.set(messagesReqPathModel, model)
}
func (p MessagesRequestPayload) messages() ([]json.RawMessage, error) {
func (p RequestPayload) messages() ([]json.RawMessage, error) {
messages := gjson.GetBytes(p, messagesReqPathMessages)
if !messages.Exists() || messages.Type == gjson.Null {
return nil, nil
@@ -301,7 +301,7 @@ func (p MessagesRequestPayload) messages() ([]json.RawMessage, error) {
return p.resultToRawMessage(messages.Array()), nil
}
func (p MessagesRequestPayload) tools() ([]json.RawMessage, error) {
func (p RequestPayload) tools() ([]json.RawMessage, error) {
tools := gjson.GetBytes(p, messagesReqPathTools)
if !tools.Exists() || tools.Type == gjson.Null {
return nil, nil
@@ -313,7 +313,7 @@ func (p MessagesRequestPayload) tools() ([]json.RawMessage, error) {
return p.resultToRawMessage(tools.Array()), nil
}
func (p MessagesRequestPayload) resultToRawMessage(items []gjson.Result) []json.RawMessage {
func (RequestPayload) resultToRawMessage(items []gjson.Result) []json.RawMessage {
// gjson.Result conversion to json.RawMessage is needed because
// gjson.Result does not implement json.Marshaler — would
// serialize its struct fields instead of the raw JSON it represents.
@@ -326,7 +326,7 @@ func (p MessagesRequestPayload) resultToRawMessage(items []gjson.Result) []json.
// convertAdaptiveThinkingForBedrock converts thinking.type "adaptive" to "enabled" with a calculated budget_tokens
// conversion is needed for Bedrock models that does not support the "adaptive" thinking.type
func (p MessagesRequestPayload) convertAdaptiveThinkingForBedrock() (MessagesRequestPayload, error) {
func (p RequestPayload) convertAdaptiveThinkingForBedrock() (RequestPayload, error) {
thinkingType := gjson.GetBytes(p, messagesReqPathThinkingType)
if thinkingType.String() != constAdaptive {
return p, nil
@@ -377,7 +377,7 @@ func (p MessagesRequestPayload) convertAdaptiveThinkingForBedrock() (MessagesReq
// removed when the corresponding flag is absent from the Anthropic-Beta header.
// Model-specific beta flags must already be filtered from the header before
// calling this method (see filterBedrockBetaFlags).
func (p MessagesRequestPayload) removeUnsupportedBedrockFields(headers http.Header) (MessagesRequestPayload, error) {
func (p RequestPayload) removeUnsupportedBedrockFields(headers http.Header) (RequestPayload, error) {
var payloadMap map[string]any
if err := json.Unmarshal(p, &payloadMap); err != nil {
return p, xerrors.Errorf("failed to unmarshal request payload when removing unsupported Bedrock fields: %w", err)
@@ -400,13 +400,13 @@ func (p MessagesRequestPayload) removeUnsupportedBedrockFields(headers http.Head
if err != nil {
return p, xerrors.Errorf("failed to marshal request payload when removing unsupported Bedrock fields: %w", err)
}
return MessagesRequestPayload(result), nil
return RequestPayload(result), nil
}
func (p MessagesRequestPayload) set(path string, value any) (MessagesRequestPayload, error) {
func (p RequestPayload) set(path string, value any) (RequestPayload, error) {
out, err := sjson.SetBytes(p, path, value)
if err != nil {
return p, xerrors.Errorf("set %s: %w", path, err)
}
return MessagesRequestPayload(out), nil
return RequestPayload(out), nil
}
+12 -12
View File
@@ -1,4 +1,4 @@
package messages
package messages //nolint:testpackage // tests unexported internals
import (
"testing"
@@ -11,7 +11,7 @@ import (
"github.com/coder/coder/v2/aibridge/utils"
)
func TestNewMessagesRequestPayload(t *testing.T) {
func TestNewRequestPayload(t *testing.T) {
t.Parallel()
testCases := []struct {
@@ -42,7 +42,7 @@ func TestNewMessagesRequestPayload(t *testing.T) {
t.Run(testCase.name, func(t *testing.T) {
t.Parallel()
payload, err := NewMessagesRequestPayload(testCase.requestBody)
payload, err := NewRequestPayload(testCase.requestBody)
if testCase.expectError {
require.Error(t, err)
require.Nil(t, payload)
@@ -50,12 +50,12 @@ func TestNewMessagesRequestPayload(t *testing.T) {
}
require.NoError(t, err)
require.Equal(t, MessagesRequestPayload(testCase.requestBody), payload)
require.Equal(t, RequestPayload(testCase.requestBody), payload)
})
}
}
func TestMessagesRequestPayloadStream(t *testing.T) {
func TestRequestPayloadStream(t *testing.T) {
t.Parallel()
testCases := []struct {
@@ -97,7 +97,7 @@ func TestMessagesRequestPayloadStream(t *testing.T) {
}
}
func TestMessagesRequestPayloadModel(t *testing.T) {
func TestRequestPayloadModel(t *testing.T) {
t.Parallel()
testCases := []struct {
@@ -132,7 +132,7 @@ func TestMessagesRequestPayloadModel(t *testing.T) {
}
}
func TestMessagesRequestPayloadLastUserPrompt(t *testing.T) {
func TestRequestPayloadLastUserPrompt(t *testing.T) {
t.Parallel()
testCases := []struct {
@@ -229,7 +229,7 @@ func TestMessagesRequestPayloadLastUserPrompt(t *testing.T) {
}
}
func TestMessagesRequestPayloadCorrelatingToolCallID(t *testing.T) {
func TestRequestPayloadCorrelatingToolCallID(t *testing.T) {
t.Parallel()
testCases := []struct {
@@ -266,7 +266,7 @@ func TestMessagesRequestPayloadCorrelatingToolCallID(t *testing.T) {
}
}
func TestMessagesRequestPayloadInjectTools(t *testing.T) {
func TestRequestPayloadInjectTools(t *testing.T) {
t.Parallel()
payload := mustMessagesPayload(t, `{"model":"claude-opus-4-5","max_tokens":1024,"messages":[{"role":"user","content":"hello"}],"tools":[{"name":"existing_tool","type":"custom","input_schema":{"type":"object","properties":{}},"cache_control":{"type":"ephemeral"}}]}`)
@@ -291,7 +291,7 @@ func TestMessagesRequestPayloadInjectTools(t *testing.T) {
require.Equal(t, "ephemeral", toolItems[1].Get("cache_control.type").String())
}
func TestMessagesRequestPayloadConvertAdaptiveThinkingForBedrock(t *testing.T) {
func TestRequestPayloadConvertAdaptiveThinkingForBedrock(t *testing.T) {
t.Parallel()
testCases := []struct {
@@ -361,7 +361,7 @@ func TestMessagesRequestPayloadConvertAdaptiveThinkingForBedrock(t *testing.T) {
}
}
func TestMessagesRequestPayloadDisableParallelToolCalls(t *testing.T) {
func TestRequestPayloadDisableParallelToolCalls(t *testing.T) {
t.Parallel()
testCases := []struct {
@@ -451,7 +451,7 @@ func TestMessagesRequestPayloadDisableParallelToolCalls(t *testing.T) {
}
}
func TestMessagesRequestPayloadAppendedMessages(t *testing.T) {
func TestRequestPayloadAppendedMessages(t *testing.T) {
t.Parallel()
payload := mustMessagesPayload(t, `{"model":"claude-opus-4-5","max_tokens":1024,"messages":[{"role":"user","content":"hello"}]}`)
+51 -54
View File
@@ -36,7 +36,7 @@ type StreamingInterception struct {
func NewStreamingInterceptor(
id uuid.UUID,
reqPayload MessagesRequestPayload,
reqPayload RequestPayload,
providerName string,
cfg config.Anthropic,
bedrockCfg *config.AWSBedrock,
@@ -58,16 +58,16 @@ func NewStreamingInterceptor(
}}
}
func (s *StreamingInterception) Setup(logger slog.Logger, recorder recorder.Recorder, mcpProxy mcp.ServerProxier) {
s.interceptionBase.Setup(logger.Named("streaming"), recorder, mcpProxy)
func (i *StreamingInterception) Setup(logger slog.Logger, rec recorder.Recorder, mcpProxy mcp.ServerProxier) {
i.interceptionBase.Setup(logger.Named("streaming"), rec, mcpProxy)
}
func (s *StreamingInterception) Streaming() bool {
func (*StreamingInterception) Streaming() bool {
return true
}
func (s *StreamingInterception) TraceAttributes(r *http.Request) []attribute.KeyValue {
return s.interceptionBase.baseTraceAttributes(r, true)
func (i *StreamingInterception) TraceAttributes(r *http.Request) []attribute.KeyValue {
return i.interceptionBase.baseTraceAttributes(r, true)
}
// ProcessRequest handles a request to /v1/messages.
@@ -156,7 +156,7 @@ newStream:
for {
// TODO add outer loop span (https://github.com/coder/aibridge/issues/67)
if err := streamCtx.Err(); err != nil {
lastErr = xerrors.Errorf("stream exit: %w", err)
interceptionErr = xerrors.Errorf("stream exit: %w", err)
break
}
@@ -178,8 +178,7 @@ newStream:
// Tool-related handling.
switch event.Type {
case string(constant.ValueOf[constant.ContentBlockStart]()):
switch block := event.AsContentBlockStart().ContentBlock.AsAny().(type) {
case anthropic.ToolUseBlock:
if block, ok := event.AsContentBlockStart().ContentBlock.AsAny().(anthropic.ToolUseBlock); ok {
lastToolName = block.Name
if i.mcpProxy != nil && i.mcpProxy.GetTool(block.Name) != nil {
@@ -306,8 +305,7 @@ newStream:
foundTools int
)
for _, block := range message.Content {
switch variant := block.AsAny().(type) {
case anthropic.ToolUseBlock:
if variant, ok := block.AsAny().(anthropic.ToolUseBlock); ok {
foundTools++
if variant.Name == name {
input = variant.Input
@@ -430,24 +428,23 @@ newStream:
// Causes a new stream to be run with updated messages.
isFirst = false
continue newStream
} else {
// Find all the non-injected tools and track their uses.
for _, block := range message.Content {
switch variant := block.AsAny().(type) {
case anthropic.ToolUseBlock:
if i.mcpProxy != nil && i.mcpProxy.GetTool(variant.Name) != nil {
continue
}
}
_ = i.recorder.RecordToolUsage(streamCtx, &recorder.ToolUsageRecord{
InterceptionID: i.ID().String(),
MsgID: message.ID,
ToolCallID: variant.ID,
Tool: variant.Name,
Args: variant.Input,
Injected: false,
})
// Find all the non-injected tools and track their uses.
for _, block := range message.Content {
if variant, ok := block.AsAny().(anthropic.ToolUseBlock); ok {
if i.mcpProxy != nil && i.mcpProxy.GetTool(variant.Name) != nil {
continue
}
_ = i.recorder.RecordToolUsage(streamCtx, &recorder.ToolUsageRecord{
InterceptionID: i.ID().String(),
MsgID: message.ID,
ToolCallID: variant.ID,
Tool: variant.Name,
Args: variant.Input,
Injected: false,
})
}
}
}
@@ -463,11 +460,10 @@ newStream:
if eventstream.IsUnrecoverableError(err) {
logger.Debug(ctx, "processing terminated", slog.Error(err))
break // Stop processing if client disconnected or context canceled.
} else {
logger.Warn(ctx, "failed to relay event", slog.Error(err))
lastErr = xerrors.Errorf("relay event: %w", err)
break
}
logger.Warn(ctx, "failed to relay event", slog.Error(err))
lastErr = xerrors.Errorf("relay event: %w", err)
break
}
}
@@ -477,8 +473,8 @@ newStream:
MsgID: message.ID,
Prompt: prompt,
})
prompt = ""
promptFound = false
prompt = "" //nolint:ineffassign // reset to prevent double-recording across newStream iterations
promptFound = false //nolint:ineffassign // reset to prevent double-recording across newStream iterations
}
if events.IsStreaming() {
@@ -491,7 +487,7 @@ newStream:
logger.Warn(ctx, "anthropic stream error", slog.Error(streamErr))
interceptionErr = antErr
} else {
logger.Warn(ctx, "unknown error", slog.Error(streamErr))
logger.Warn(ctx, "unknown stream error encountered", slog.Error(streamErr))
// Unfortunately, the Anthropic SDK does not support parsing errors received in the stream
// into known types (i.e. [shared.OverloadedError]).
// See https://github.com/anthropics/anthropic-sdk-go/blob/v1.12.0/packages/ssestream/ssestream.go#L172-L174
@@ -500,14 +496,14 @@ newStream:
}
} else if lastErr != nil {
// Otherwise check if any logical errors occurred during processing.
logger.Warn(ctx, "stream failed", slog.Error(lastErr))
logger.Warn(ctx, "stream processing failed", slog.Error(lastErr))
interceptionErr = newErrorResponse(xerrors.Errorf("processing error: %w", lastErr))
}
if interceptionErr != nil {
payload, err := i.marshal(interceptionErr)
if err != nil {
logger.Warn(ctx, "failed to marshal error", slog.Error(err), slog.F("error_payload", slog.F("%+v", interceptionErr)))
logger.Warn(ctx, "failed to marshal error", slog.Error(err), slog.F("error_payload", interceptionErr.Error()))
} else if err := events.Send(streamCtx, payload); err != nil {
logger.Warn(ctx, "failed to relay error", slog.Error(err), slog.F("payload", payload))
}
@@ -518,11 +514,11 @@ newStream:
}
shutdownCtx, shutdownCancel := context.WithTimeout(ctx, time.Second*30)
defer shutdownCancel()
// Give the events stream 30 seconds (TODO: configurable) to gracefully shutdown.
if err := events.Shutdown(shutdownCtx); err != nil {
logger.Warn(ctx, "event stream shutdown", slog.Error(err))
}
shutdownCancel()
// Cancel the stream context, we're now done.
if interceptionErr != nil {
@@ -537,8 +533,8 @@ newStream:
return interceptionErr
}
func (s *StreamingInterception) marshalEvent(event anthropic.MessageStreamEventUnion) ([]byte, error) {
sj, err := sjson.Set(event.RawJSON(), "message.id", s.ID().String())
func (i *StreamingInterception) marshalEvent(event anthropic.MessageStreamEventUnion) ([]byte, error) {
sj, err := sjson.Set(event.RawJSON(), "message.id", i.ID().String())
if err != nil {
return nil, xerrors.Errorf("marshal event id failed: %w", err)
}
@@ -548,10 +544,10 @@ func (s *StreamingInterception) marshalEvent(event anthropic.MessageStreamEventU
return nil, xerrors.Errorf("marshal event usage failed: %w", err)
}
return s.encodeForStream([]byte(sj), event.Type), nil
return i.encodeForStream([]byte(sj), event.Type), nil
}
func (s *StreamingInterception) marshal(payload any) ([]byte, error) {
func (i *StreamingInterception) marshal(payload any) ([]byte, error) {
data, err := json.Marshal(payload)
if err != nil {
return nil, xerrors.Errorf("marshal payload: %w", err)
@@ -567,29 +563,30 @@ func (s *StreamingInterception) marshal(payload any) ([]byte, error) {
return nil, xerrors.Errorf("could not determine type from payload %q", data)
}
return s.encodeForStream(data, eventType), nil
return i.encodeForStream(data, eventType), nil
}
// https://docs.anthropic.com/en/docs/build-with-claude/streaming#basic-streaming-request
func (s *StreamingInterception) pingPayload() []byte {
return s.encodeForStream([]byte(`{"type": "ping"}`), "ping")
func (i *StreamingInterception) pingPayload() []byte {
return i.encodeForStream([]byte(`{"type": "ping"}`), "ping")
}
func (s *StreamingInterception) encodeForStream(payload []byte, typ string) []byte {
func (*StreamingInterception) encodeForStream(payload []byte, typ string) []byte {
// bytes.Buffer writes to in-memory storage and never return errors.
var buf bytes.Buffer
buf.WriteString("event: ")
buf.WriteString(typ)
buf.WriteString("\n")
buf.WriteString("data: ")
buf.Write(payload)
buf.WriteString("\n\n")
_, _ = buf.WriteString("event: ")
_, _ = buf.WriteString(typ)
_, _ = buf.WriteString("\n")
_, _ = buf.WriteString("data: ")
_, _ = buf.Write(payload)
_, _ = buf.WriteString("\n\n")
return buf.Bytes()
}
// newStream traces svc.NewStreaming() call.
func (s *StreamingInterception) newStream(ctx context.Context, svc anthropic.MessageService) *ssestream.Stream[anthropic.MessageStreamEventUnion] {
_, span := s.tracer.Start(ctx, "Intercept.ProcessRequest.Upstream", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...))
func (i *StreamingInterception) newStream(ctx context.Context, svc anthropic.MessageService) *ssestream.Stream[anthropic.MessageStreamEventUnion] {
_, span := i.tracer.Start(ctx, "Intercept.ProcessRequest.Upstream", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...))
defer span.End()
return svc.NewStreaming(ctx, anthropic.MessageNewParams{}, s.withBody())
return svc.NewStreaming(ctx, anthropic.MessageNewParams{}, i.withBody())
}
+19 -13
View File
@@ -42,7 +42,7 @@ type responsesInterceptionBase struct {
// clientHeaders are the original HTTP headers from the client request.
clientHeaders http.Header
authHeaderName string
reqPayload ResponsesRequestPayload
reqPayload RequestPayload
cfg config.OpenAI
recorder recorder.Recorder
@@ -89,9 +89,9 @@ func (i *responsesInterceptionBase) Credential() intercept.CredentialInfo {
return i.credential
}
func (i *responsesInterceptionBase) Setup(logger slog.Logger, recorder recorder.Recorder, mcpProxy mcp.ServerProxier) {
func (i *responsesInterceptionBase) Setup(logger slog.Logger, rec recorder.Recorder, mcpProxy mcp.ServerProxier) {
i.logger = logger.With(slog.F("model", i.Model()))
i.recorder = recorder
i.recorder = rec
i.mcpProxy = mcpProxy
}
@@ -127,7 +127,13 @@ func (i *responsesInterceptionBase) validateRequest(ctx context.Context, w http.
// sendCustomErr sends custom responses.Error error to the client
// it should only be called before any data is sent back to the client
func (i *responsesInterceptionBase) sendCustomErr(ctx context.Context, w http.ResponseWriter, code int, err error) {
respErr := responses.Error{
// Same JSON shape as responses.Error but using a plain struct because
// responses.Error embeds *http.Request whose GetBody func field
// is not JSON-marshalable (SA1026).
respErr := struct {
Code string `json:"code"`
Message string `json:"message"`
}{
Code: strconv.Itoa(code),
Message: err.Error(),
}
@@ -222,15 +228,15 @@ func (i *responsesInterceptionBase) recordNonInjectedToolUsage(ctx context.Conte
func (i *responsesInterceptionBase) parseFunctionCallJSONArgs(ctx context.Context, raw string) recorder.ToolArgs {
trimmed := strings.TrimSpace(raw)
if trimmed != "" {
var args recorder.ToolArgs
if err := json.Unmarshal([]byte(trimmed), &args); err != nil {
i.logger.Warn(ctx, "failed to unmarshal tool args", slog.Error(err))
} else {
return args
}
if trimmed == "" {
return trimmed
}
return trimmed
var args recorder.ToolArgs
if err := json.Unmarshal([]byte(trimmed), &args); err != nil {
i.logger.Warn(ctx, "failed to unmarshal tool args", slog.Error(err))
return trimmed
}
return args
}
func (i *responsesInterceptionBase) recordTokenUsage(ctx context.Context, response *responses.Response) {
@@ -264,7 +270,7 @@ func (i *responsesInterceptionBase) recordTokenUsage(ctx context.Context, respon
// extractModelThoughts extracts model thoughts from response output items.
// It captures both reasoning summary items and commentary messages (message
// output items with "phase": "commentary") as model thoughts.
func (i *responsesInterceptionBase) extractModelThoughts(response *responses.Response) []*recorder.ModelThoughtRecord {
func (*responsesInterceptionBase) extractModelThoughts(response *responses.Response) []*recorder.ModelThoughtRecord {
if response == nil {
return nil
}
+8 -4
View File
@@ -1,4 +1,4 @@
package responses
package responses //nolint:testpackage // tests unexported internals
import (
"net/http"
@@ -359,13 +359,16 @@ func (mrw *mockResponseWriter) WriteHeader(statusCode int) {
}
func TestResponseCopierDoesntSendIfNoResponseReceived(t *testing.T) {
t.Parallel()
mrw := mockResponseWriter{}
respCopy := responseCopier{}
body := "test_body"
respCopy.buff.Write([]byte(body))
_, _ = respCopy.buff.Write([]byte(body)) // bytes.Buffer.Write never fails
respCopy.forwardResp(&mrw)
err := respCopy.forwardResp(&mrw)
require.NoError(t, err)
require.False(t, mrw.headerCalled)
require.False(t, mrw.writeCalled)
require.False(t, mrw.writeHeaderCalled)
@@ -373,7 +376,8 @@ func TestResponseCopierDoesntSendIfNoResponseReceived(t *testing.T) {
// after response is received data is forwarded
respCopy.responseReceived.Store(true)
respCopy.forwardResp(&mrw)
err = respCopy.forwardResp(&mrw)
require.NoError(t, err)
require.True(t, mrw.headerCalled)
require.True(t, mrw.writeCalled)
require.True(t, mrw.writeHeaderCalled)
+4 -4
View File
@@ -28,7 +28,7 @@ type BlockingResponsesInterceptor struct {
func NewBlockingInterceptor(
id uuid.UUID,
reqPayload ResponsesRequestPayload,
reqPayload RequestPayload,
providerName string,
cfg config.OpenAI,
clientHeaders http.Header,
@@ -50,11 +50,11 @@ func NewBlockingInterceptor(
}
}
func (i *BlockingResponsesInterceptor) Setup(logger slog.Logger, recorder recorder.Recorder, mcpProxy mcp.ServerProxier) {
i.responsesInterceptionBase.Setup(logger.Named("blocking"), recorder, mcpProxy)
func (i *BlockingResponsesInterceptor) Setup(logger slog.Logger, rec recorder.Recorder, mcpProxy mcp.ServerProxier) {
i.responsesInterceptionBase.Setup(logger.Named("blocking"), rec, mcpProxy)
}
func (i *BlockingResponsesInterceptor) Streaming() bool {
func (*BlockingResponsesInterceptor) Streaming() bool {
return false
}
+17 -17
View File
@@ -37,13 +37,13 @@ var (
reqPathType = string(constant.ValueOf[constant.Type]())
)
// ResponsesRequestPayload is raw JSON bytes of a Responses API request.
// RequestPayload is raw JSON bytes of a Responses API request.
// Methods provide package-specific reads and rewrites while preserving the
// original body for upstream pass-through.
// Note: No changes are made on schema error.
type ResponsesRequestPayload []byte
type RequestPayload []byte
func NewResponsesRequestPayload(raw []byte) (ResponsesRequestPayload, error) {
func NewRequestPayload(raw []byte) (RequestPayload, error) {
if len(bytes.TrimSpace(raw)) == 0 {
return nil, xerrors.New("empty request body")
}
@@ -51,22 +51,22 @@ func NewResponsesRequestPayload(raw []byte) (ResponsesRequestPayload, error) {
return nil, xerrors.New("invalid JSON payload")
}
return ResponsesRequestPayload(raw), nil
return RequestPayload(raw), nil
}
func (p ResponsesRequestPayload) Stream() bool {
func (p RequestPayload) Stream() bool {
return gjson.GetBytes(p, reqPathStream).Bool()
}
func (p ResponsesRequestPayload) model() string {
func (p RequestPayload) model() string {
return gjson.GetBytes(p, reqPathModel).String()
}
func (p ResponsesRequestPayload) background() bool {
func (p RequestPayload) background() bool {
return gjson.GetBytes(p, reqPathBackground).Bool()
}
func (p ResponsesRequestPayload) correlatingToolCallID() *string {
func (p RequestPayload) correlatingToolCallID() *string {
items := gjson.GetBytes(p, reqPathInput)
if !items.IsArray() {
return nil
@@ -94,7 +94,7 @@ func (p ResponsesRequestPayload) correlatingToolCallID() *string {
// item, or the string input value if present. If no prompt is found, it returns
// empty string, false, nil. Unexpected shapes are treated as unsupported and do
// not fail the request path.
func (p ResponsesRequestPayload) lastUserPrompt(ctx context.Context, logger slog.Logger) (string, bool, error) {
func (p RequestPayload) lastUserPrompt(ctx context.Context, logger slog.Logger) (string, bool, error) {
inputItems := gjson.GetBytes(p, reqPathInput)
if !inputItems.Exists() || inputItems.Type == gjson.Null {
return "", false, nil
@@ -155,10 +155,10 @@ func (p ResponsesRequestPayload) lastUserPrompt(ctx context.Context, logger slog
}
if promptExists {
sb.WriteByte('\n')
_ = sb.WriteByte('\n') // strings.Builder.WriteByte never fails
}
promptExists = true
sb.WriteString(text.Str)
_, _ = sb.WriteString(text.Str) // strings.Builder.WriteString never fails
}
if !promptExists {
@@ -168,7 +168,7 @@ func (p ResponsesRequestPayload) lastUserPrompt(ctx context.Context, logger slog
return sb.String(), true, nil
}
func (p ResponsesRequestPayload) injectTools(injected []responses.ToolUnionParam) (ResponsesRequestPayload, error) {
func (p RequestPayload) injectTools(injected []responses.ToolUnionParam) (RequestPayload, error) {
if len(injected) == 0 {
return p, nil
}
@@ -189,11 +189,11 @@ func (p ResponsesRequestPayload) injectTools(injected []responses.ToolUnionParam
return p.set(reqPathTools, allTools)
}
func (p ResponsesRequestPayload) disableParallelToolCalls() (ResponsesRequestPayload, error) {
func (p RequestPayload) disableParallelToolCalls() (RequestPayload, error) {
return p.set(reqPathParallelToolCalls, false)
}
func (p ResponsesRequestPayload) appendInputItems(items []responses.ResponseInputItemUnionParam) (ResponsesRequestPayload, error) {
func (p RequestPayload) appendInputItems(items []responses.ResponseInputItemUnionParam) (RequestPayload, error) {
if len(items) == 0 {
return p, nil
}
@@ -212,7 +212,7 @@ func (p ResponsesRequestPayload) appendInputItems(items []responses.ResponseInpu
return p.set(reqPathInput, allInput)
}
func (p ResponsesRequestPayload) inputItems() ([]any, error) {
func (p RequestPayload) inputItems() ([]any, error) {
input := gjson.GetBytes(p, reqPathInput)
if !input.Exists() || input.Type == gjson.Null {
return []any{}, nil
@@ -235,7 +235,7 @@ func (p ResponsesRequestPayload) inputItems() ([]any, error) {
return existing, nil
}
func (p ResponsesRequestPayload) toolItems() ([]json.RawMessage, error) {
func (p RequestPayload) toolItems() ([]json.RawMessage, error) {
tools := gjson.GetBytes(p, reqPathTools)
if !tools.Exists() {
return nil, nil
@@ -253,7 +253,7 @@ func (p ResponsesRequestPayload) toolItems() ([]json.RawMessage, error) {
return existing, nil
}
func (p ResponsesRequestPayload) set(path string, value any) (ResponsesRequestPayload, error) {
func (p RequestPayload) set(path string, value any) (RequestPayload, error) {
updated, err := sjson.SetBytes(p, path, value)
if err != nil {
return p, xerrors.Errorf("failed to set value at path %s: %w", path, err)
@@ -1,4 +1,4 @@
package responses
package responses //nolint:testpackage // tests unexported internals
import (
"encoding/json"
@@ -16,7 +16,7 @@ import (
"github.com/coder/coder/v2/aibridge/utils"
)
func TestNewResponsesRequestPayload(t *testing.T) {
func TestNewRequestPayload(t *testing.T) {
t.Parallel()
payloadWithWrongTypes := []byte(`{"model":123,"stream":"yes","input":42,"background":"nope"}`)
@@ -42,7 +42,7 @@ func TestNewResponsesRequestPayload(t *testing.T) {
err: "invalid JSON payload",
},
{
// ResponsesRequestPayload just checks for JSON validity,
// RequestPayload just checks for JSON validity,
// schema errors are not surfaced here and
// the original body is preserved for upstream handling
// similar to how reverse proxy would behave.
@@ -59,7 +59,7 @@ func TestNewResponsesRequestPayload(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
payload, err := NewResponsesRequestPayload(tc.raw)
payload, err := NewRequestPayload(tc.raw)
if tc.err != "" {
require.ErrorContains(t, err, tc.err)
@@ -518,10 +518,10 @@ func injectedFunctionTool(name string) responses.ToolUnionParam {
}
}
func mustPayload(t *testing.T, raw []byte) ResponsesRequestPayload {
func mustPayload(t *testing.T, raw []byte) RequestPayload {
t.Helper()
payload, err := NewResponsesRequestPayload(raw)
payload, err := NewRequestPayload(raw)
require.NoError(t, err)
return payload
}
+4 -4
View File
@@ -35,7 +35,7 @@ type StreamingResponsesInterceptor struct {
func NewStreamingInterceptor(
id uuid.UUID,
reqPayload ResponsesRequestPayload,
reqPayload RequestPayload,
providerName string,
cfg config.OpenAI,
clientHeaders http.Header,
@@ -57,11 +57,11 @@ func NewStreamingInterceptor(
}
}
func (i *StreamingResponsesInterceptor) Setup(logger slog.Logger, recorder recorder.Recorder, mcpProxy mcp.ServerProxier) {
i.responsesInterceptionBase.Setup(logger.Named("streaming"), recorder, mcpProxy)
func (i *StreamingResponsesInterceptor) Setup(logger slog.Logger, rec recorder.Recorder, mcpProxy mcp.ServerProxier) {
i.responsesInterceptionBase.Setup(logger.Named("streaming"), rec, mcpProxy)
}
func (i *StreamingResponsesInterceptor) Streaming() bool {
func (*StreamingResponsesInterceptor) Streaming() bool {
return true
}
@@ -1,4 +1,4 @@
package integrationtest
package integrationtest //nolint:testpackage // tests unexported internals
import (
"bufio"
@@ -11,7 +11,6 @@ import (
"path/filepath"
"strings"
"testing"
"time"
"github.com/stretchr/testify/require"
@@ -19,6 +18,7 @@ import (
"github.com/coder/coder/v2/aibridge/config"
"github.com/coder/coder/v2/aibridge/fixtures"
"github.com/coder/coder/v2/aibridge/intercept/apidump"
"github.com/coder/coder/v2/aibridge/internal/testutil"
"github.com/coder/coder/v2/aibridge/provider"
)
@@ -114,23 +114,25 @@ func TestAPIDump(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong)
t.Cleanup(cancel)
// Setup mock upstream server.
fix := fixtures.Parse(t, tc.fixture)
srv := newMockUpstream(t, ctx, newFixtureResponse(fix))
srv := newMockUpstream(ctx, t, newFixtureResponse(fix))
// Create temp dir for API dumps.
dumpDir := t.TempDir()
bridgeServer := newBridgeTestServer(t, ctx, srv.URL,
bridgeServer := newBridgeTestServer(ctx, t, srv.URL,
withCustomProvider(tc.providerFunc(srv.URL, dumpDir)),
)
resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, fix.Request(), tc.headers)
resp, err := bridgeServer.makeRequest(t, http.MethodPost, tc.path, fix.Request(), tc.headers)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusOK, resp.StatusCode)
_, err := io.ReadAll(resp.Body)
_, err = io.ReadAll(resp.Body)
require.NoError(t, err)
// Verify dump files were created.
@@ -187,6 +189,7 @@ func TestAPIDump(t *testing.T) {
// Parse the dumped HTTP response.
dumpResp, err := http.ReadResponse(bufio.NewReader(bytes.NewReader(respDumpData)), nil)
require.NoError(t, err)
defer dumpResp.Body.Close()
require.Equal(t, http.StatusOK, dumpResp.StatusCode)
dumpRespBody, err := io.ReadAll(dumpResp.Body)
require.NoError(t, err)
@@ -241,7 +244,7 @@ func TestAPIDumpPassthrough(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong)
t.Cleanup(cancel)
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -252,16 +255,18 @@ func TestAPIDumpPassthrough(t *testing.T) {
dumpDir := t.TempDir()
bridgeServer := newBridgeTestServer(t, ctx, upstream.URL,
bridgeServer := newBridgeTestServer(ctx, t, upstream.URL,
withCustomProvider(tc.providerFunc(upstream.URL, dumpDir)),
)
bridgeServer.makeRequest(t, http.MethodGet, tc.requestPath, nil)
resp, err := bridgeServer.makeRequest(t, http.MethodGet, tc.requestPath, nil)
require.NoError(t, err)
defer resp.Body.Close()
// Find dump files in the passthrough directory.
passthroughDir := filepath.Join(dumpDir, tc.name, "passthrough")
var reqDumpFile, respDumpFile string
err := filepath.Walk(passthroughDir, func(path string, info os.FileInfo, err error) error {
err = filepath.Walk(passthroughDir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
@@ -299,6 +304,7 @@ func TestAPIDumpPassthrough(t *testing.T) {
require.NoError(t, err)
dumpResp, err := http.ReadResponse(bufio.NewReader(bytes.NewReader(respDumpData)), nil)
require.NoError(t, err)
defer dumpResp.Body.Close()
require.Equal(t, http.StatusOK, dumpResp.StatusCode)
dumpRespBody, err := io.ReadAll(dumpResp.Body)
require.NoError(t, err)
+113 -73
View File
@@ -1,4 +1,4 @@
package integrationtest
package integrationtest //nolint:testpackage // tests unexported internals
import (
"bytes"
@@ -10,7 +10,6 @@ import (
"slices"
"strings"
"testing"
"time"
"github.com/anthropics/anthropic-sdk-go"
"github.com/anthropics/anthropic-sdk-go/packages/ssestream"
@@ -29,6 +28,7 @@ import (
"github.com/coder/coder/v2/aibridge/config"
"github.com/coder/coder/v2/aibridge/fixtures"
"github.com/coder/coder/v2/aibridge/intercept"
"github.com/coder/coder/v2/aibridge/internal/testutil"
"github.com/coder/coder/v2/aibridge/mcp"
"github.com/coder/coder/v2/aibridge/provider"
"github.com/coder/coder/v2/aibridge/recorder"
@@ -78,18 +78,20 @@ func TestAnthropicMessages(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong)
t.Cleanup(cancel)
fix := fixtures.Parse(t, fixtures.AntSingleBuiltinTool)
upstream := newMockUpstream(t, ctx, newFixtureResponse(fix))
upstream := newMockUpstream(ctx, t, newFixtureResponse(fix))
bridgeServer := newBridgeTestServer(t, ctx, upstream.URL)
bridgeServer := newBridgeTestServer(ctx, t, upstream.URL)
// Make API call to aibridge for Anthropic /v1/messages
reqBody, err := sjson.SetBytes(fix.Request(), "stream", tc.streaming)
require.NoError(t, err)
resp := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody)
resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusOK, resp.StatusCode)
// Response-specific checks.
@@ -210,17 +212,19 @@ func TestAnthropicMessagesModelThoughts(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong)
t.Cleanup(cancel)
fix := fixtures.Parse(t, tc.fixture)
upstream := newMockUpstream(t, ctx, newFixtureResponse(fix))
upstream := newMockUpstream(ctx, t, newFixtureResponse(fix))
bridgeServer := newBridgeTestServer(t, ctx, upstream.URL)
bridgeServer := newBridgeTestServer(ctx, t, upstream.URL)
reqBody, err := sjson.SetBytes(fix.Request(), "stream", tc.streaming)
require.NoError(t, err)
resp := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody)
resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusOK, resp.StatusCode)
if tc.streaming {
@@ -242,7 +246,7 @@ func TestAWSBedrockIntegration(t *testing.T) {
t.Run("invalid config", func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong)
t.Cleanup(cancel)
// Invalid bedrock config - missing region & base url
@@ -254,11 +258,13 @@ func TestAWSBedrockIntegration(t *testing.T) {
SmallFastModel: "test-haiku",
}
bridgeServer := newBridgeTestServer(t, ctx, "http://unused",
bridgeServer := newBridgeTestServer(ctx, t, "http://unused",
withCustomProvider(provider.NewAnthropic(anthropicCfg("http://unused", apiKey), bedrockCfg)),
)
resp := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, fixtures.Request(t, fixtures.AntSingleBuiltinTool))
resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, fixtures.Request(t, fixtures.AntSingleBuiltinTool))
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusInternalServerError, resp.StatusCode)
body, err := io.ReadAll(resp.Body)
@@ -272,11 +278,11 @@ func TestAWSBedrockIntegration(t *testing.T) {
t.Run(fmt.Sprintf("%s/streaming=%v", t.Name(), streaming), func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong)
t.Cleanup(cancel)
fix := fixtures.Parse(t, fixtures.AntSingleBuiltinTool)
upstream := newMockUpstream(t, ctx, newFixtureResponse(fix))
upstream := newMockUpstream(ctx, t, newFixtureResponse(fix))
// We define region here to validate that with Region & BaseURL defined, the latter takes precedence.
bedrockCfg := &config.AWSBedrock{
@@ -288,7 +294,7 @@ func TestAWSBedrockIntegration(t *testing.T) {
BaseURL: upstream.URL, // Use the mock server.
}
bridgeServer := newBridgeTestServer(t, ctx, upstream.URL,
bridgeServer := newBridgeTestServer(ctx, t, upstream.URL,
withCustomProvider(provider.NewAnthropic(anthropicCfg(upstream.URL, apiKey), bedrockCfg)),
)
@@ -296,7 +302,9 @@ func TestAWSBedrockIntegration(t *testing.T) {
// We override the AWS Bedrock client to route requests through our mock server.
reqBody, err := sjson.SetBytes(fix.Request(), "stream", streaming)
require.NoError(t, err)
resp := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody)
resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody)
require.NoError(t, err)
defer resp.Body.Close()
// For streaming responses, consume the body to allow the stream to complete.
if streaming {
@@ -396,11 +404,11 @@ func TestAWSBedrockIntegration(t *testing.T) {
t.Run(fmt.Sprintf("%s/streaming=%v", tc.name, streaming), func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong)
t.Cleanup(cancel)
fix := fixtures.Parse(t, fixtures.AntSimpleBedrock)
upstream := newMockUpstream(t, ctx, newFixtureResponse(fix))
upstream := newMockUpstream(ctx, t, newFixtureResponse(fix))
bCfg := &config.AWSBedrock{
Region: "us-west-2",
@@ -411,7 +419,7 @@ func TestAWSBedrockIntegration(t *testing.T) {
BaseURL: upstream.URL,
}
bridgeServer := newBridgeTestServer(t, ctx, upstream.URL,
bridgeServer := newBridgeTestServer(ctx, t, upstream.URL,
withCustomProvider(provider.NewAnthropic(anthropicCfg(upstream.URL, apiKey), bCfg)),
)
@@ -419,9 +427,11 @@ func TestAWSBedrockIntegration(t *testing.T) {
require.NoError(t, err)
// Send with Anthropic-Beta header containing flags that should be filtered.
resp := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody, http.Header{
resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody, http.Header{
"Anthropic-Beta": {"interleaved-thinking-2025-05-14,effort-2025-11-24,context-management-2025-06-27,prompt-caching-scope-2026-01-05"},
})
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusOK, resp.StatusCode)
_, err = io.ReadAll(resp.Body)
require.NoError(t, err)
@@ -491,18 +501,20 @@ func TestOpenAIChatCompletions(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong)
t.Cleanup(cancel)
fix := fixtures.Parse(t, fixtures.OaiChatSingleBuiltinTool)
upstream := newMockUpstream(t, ctx, newFixtureResponse(fix))
upstream := newMockUpstream(ctx, t, newFixtureResponse(fix))
bridgeServer := newBridgeTestServer(t, ctx, upstream.URL)
bridgeServer := newBridgeTestServer(ctx, t, upstream.URL)
// Make API call to aibridge for OpenAI /v1/chat/completions
reqBody, err := sjson.SetBytes(fix.Request(), "stream", tc.streaming)
require.NoError(t, err)
resp := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIChatCompletions, reqBody)
resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIChatCompletions, reqBody)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusOK, resp.StatusCode)
// Response-specific checks.
@@ -565,25 +577,27 @@ func TestOpenAIChatCompletions(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong)
t.Cleanup(cancel)
// Setup mock server for multi-turn interaction.
// First request → tool call response, second → tool response.
fix := fixtures.Parse(t, tc.fixture)
upstream := newMockUpstream(t, ctx, newFixtureResponse(fix), newFixtureToolResponse(fix))
upstream := newMockUpstream(ctx, t, newFixtureResponse(fix), newFixtureToolResponse(fix))
// Setup MCP proxies with the tool from the fixture
mockMCP := setupMCPForTest(t, defaultTracer)
bridgeServer := newBridgeTestServer(t, ctx, upstream.URL,
bridgeServer := newBridgeTestServer(ctx, t, upstream.URL,
withMCP(mockMCP),
)
// Add the stream param to the request.
reqBody, err := sjson.SetBytes(fix.Request(), "stream", true)
require.NoError(t, err)
resp := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIChatCompletions, reqBody)
resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIChatCompletions, reqBody)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusOK, resp.StatusCode)
// Verify SSE headers are sent correctly
@@ -756,18 +770,20 @@ func TestSimple(t *testing.T) {
t.Run(fmt.Sprintf("streaming=%v", streaming), func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong)
t.Cleanup(cancel)
fix := fixtures.Parse(t, tc.fixture)
upstream := newMockUpstream(t, ctx, newFixtureResponse(fix))
upstream := newMockUpstream(ctx, t, newFixtureResponse(fix))
bridgeServer := newBridgeTestServer(t, ctx, upstream.URL+tc.basePath)
bridgeServer := newBridgeTestServer(ctx, t, upstream.URL+tc.basePath)
// When: calling the "API server" with the fixture's request body.
reqBody, err := sjson.SetBytes(fix.Request(), "stream", streaming)
require.NoError(t, err)
resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, reqBody, http.Header{"User-Agent": {tc.userAgent}})
resp, err := bridgeServer.makeRequest(t, http.MethodPost, tc.path, reqBody, http.Header{"User-Agent": {tc.userAgent}})
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusOK, resp.StatusCode)
// Then: I expect the upstream request to have the correct path.
@@ -861,12 +877,12 @@ func TestSessionIDTracking(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong)
t.Cleanup(cancel)
fix := fixtures.Parse(t, tc.fixture)
upstream := newMockUpstream(t, ctx, newFixtureResponse(fix))
bridgeServer := newBridgeTestServer(t, ctx, upstream.URL, withProvider(config.ProviderAnthropic))
upstream := newMockUpstream(ctx, t, newFixtureResponse(fix))
bridgeServer := newBridgeTestServer(ctx, t, upstream.URL, withProvider(config.ProviderAnthropic))
reqBody := fix.Request()
if tc.metadataSessionID != "" {
@@ -875,11 +891,13 @@ func TestSessionIDTracking(t *testing.T) {
require.NoError(t, err)
}
resp := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody, tc.header)
resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody, tc.header)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusOK, resp.StatusCode)
// Drain the body to let the stream complete.
_, err := io.ReadAll(resp.Body)
_, err = io.ReadAll(resp.Body)
require.NoError(t, err)
interceptions := bridgeServer.Recorder.RecordedInterceptions()
@@ -948,10 +966,12 @@ func TestFallthrough(t *testing.T) {
t.Parallel()
fix := fixtures.Parse(t, tc.fixture)
upstream := newMockUpstream(t, t.Context(), newFixtureResponse(fix))
bridgeServer := newBridgeTestServer(t, t.Context(), upstream.URL+tc.basePath)
upstream := newMockUpstream(t.Context(), t, newFixtureResponse(fix))
bridgeServer := newBridgeTestServer(t.Context(), t, upstream.URL+tc.basePath)
resp := bridgeServer.makeRequest(t, http.MethodGet, tc.requestPath, nil)
resp, err := bridgeServer.makeRequest(t, http.MethodGet, tc.requestPath, nil)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusOK, resp.StatusCode)
@@ -984,6 +1004,7 @@ func TestAnthropicInjectedTools(t *testing.T) {
// Build the requirements & make the assertions which are common to all providers.
bridgeServer, mockMCP, resp := setupInjectedToolTest(t, fixtures.AntSingleInjectedTool, streaming, defaultTracer, pathAnthropicMessages, anthropicToolResultValidator(t))
defer resp.Body.Close()
// Ensure expected tool was invoked with expected input.
toolUsages := bridgeServer.Recorder.RecordedToolUsages()
@@ -1067,6 +1088,7 @@ func TestOpenAIInjectedTools(t *testing.T) {
// Build the requirements & make the assertions which are common to all providers.
bridgeServer, mockMCP, resp := setupInjectedToolTest(t, fixtures.OaiChatSingleInjectedTool, streaming, defaultTracer, pathOpenAIChatCompletions, openaiChatToolResultValidator(t))
defer resp.Body.Close()
// Ensure expected tool was invoked with expected input.
toolUsages := bridgeServer.Recorder.RecordedToolUsages()
@@ -1234,6 +1256,8 @@ func TestErrorHandling(t *testing.T) {
// Tests that errors which occur *before* a streaming response begins, or in non-streaming requests, are handled as expected.
t.Run("non-stream error", func(t *testing.T) {
t.Parallel()
cases := []struct {
name string
fixture []byte
@@ -1276,21 +1300,23 @@ func TestErrorHandling(t *testing.T) {
t.Run(fmt.Sprintf("streaming=%v", streaming), func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong)
t.Cleanup(cancel)
// Setup mock server. Error fixtures contain raw HTTP
// responses that may cause the bridge to retry.
fix := fixtures.Parse(t, tc.fixture)
upstream := newMockUpstream(t, ctx, newFixtureResponse(fix))
upstream := newMockUpstream(ctx, t, newFixtureResponse(fix))
bridgeServer := newBridgeTestServer(t, ctx, upstream.URL)
bridgeServer := newBridgeTestServer(ctx, t, upstream.URL)
// Add the stream param to the request.
reqBody, err := sjson.SetBytes(fix.Request(), "stream", streaming)
require.NoError(t, err)
resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, reqBody)
resp, err := bridgeServer.makeRequest(t, http.MethodPost, tc.path, reqBody)
require.NoError(t, err)
defer resp.Body.Close()
tc.responseHandlerFn(resp)
bridgeServer.Recorder.VerifyAllInterceptionsEnded(t)
@@ -1347,17 +1373,19 @@ func TestErrorHandling(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong)
t.Cleanup(cancel)
// Setup mock server.
fix := fixtures.Parse(t, tc.fixture)
upstream := newMockUpstream(t, ctx, newFixtureResponse(fix))
upstream := newMockUpstream(ctx, t, newFixtureResponse(fix))
upstream.StatusCode = http.StatusInternalServerError
bridgeServer := newBridgeTestServer(t, ctx, upstream.URL)
bridgeServer := newBridgeTestServer(ctx, t, upstream.URL)
resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, fix.Request())
resp, err := bridgeServer.makeRequest(t, http.MethodPost, tc.path, fix.Request())
require.NoError(t, err)
defer resp.Body.Close()
tc.responseHandlerFn(resp)
bridgeServer.Recorder.VerifyAllInterceptionsEnded(t)
@@ -1394,7 +1422,7 @@ func TestStableRequestEncoding(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong)
t.Cleanup(cancel)
// Setup MCP tools.
@@ -1408,15 +1436,17 @@ func TestStableRequestEncoding(t *testing.T) {
for i := range count {
responses[i] = newFixtureResponse(fix)
}
upstream := newMockUpstream(t, ctx, responses...)
upstream := newMockUpstream(ctx, t, responses...)
bridgeServer := newBridgeTestServer(t, ctx, upstream.URL,
bridgeServer := newBridgeTestServer(ctx, t, upstream.URL,
withMCP(mockMCP),
)
// Make multiple requests and verify they all have identical payloads.
for range count {
resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, fix.Request())
resp, err := bridgeServer.makeRequest(t, http.MethodPost, tc.path, fix.Request())
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusOK, resp.StatusCode)
}
@@ -1657,7 +1687,7 @@ func TestAnthropicToolChoiceParallelDisabled(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong)
t.Cleanup(cancel)
// Setup MCP tools conditionally.
@@ -1669,9 +1699,9 @@ func TestAnthropicToolChoiceParallelDisabled(t *testing.T) {
}
fix := fixtures.Parse(t, tc.fixture)
upstream := newMockUpstream(t, ctx, newFixtureResponse(fix))
upstream := newMockUpstream(ctx, t, newFixtureResponse(fix))
bridgeServer := newBridgeTestServer(t, ctx, upstream.URL,
bridgeServer := newBridgeTestServer(ctx, t, upstream.URL,
withMCP(mockMCP),
)
@@ -1679,7 +1709,9 @@ func TestAnthropicToolChoiceParallelDisabled(t *testing.T) {
reqBody, err := sjson.SetBytes(fix.Request(), "tool_choice", tc.toolChoice)
require.NoError(t, err)
resp := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody)
resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusOK, resp.StatusCode)
// Verify tool_choice in the upstream request.
@@ -1819,17 +1851,17 @@ func TestChatCompletionsParallelToolCallsDisabled(t *testing.T) {
t.Run(fmt.Sprintf("%s/streaming=%v", tc.name, streaming), func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong)
t.Cleanup(cancel)
fix := fixtures.Parse(t, tc.fixture)
upstream := newMockUpstream(t, ctx, newFixtureResponse(fix))
upstream := newMockUpstream(ctx, t, newFixtureResponse(fix))
var opts []bridgeOption
if tc.withInjectedTools {
opts = append(opts, withMCP(setupMCPForTest(t, defaultTracer)))
}
bridgeServer := newBridgeTestServer(t, ctx, upstream.URL, opts...)
bridgeServer := newBridgeTestServer(ctx, t, upstream.URL, opts...)
var (
reqBody = fix.Request()
@@ -1842,7 +1874,9 @@ func TestChatCompletionsParallelToolCallsDisabled(t *testing.T) {
reqBody, err = sjson.SetBytes(reqBody, "stream", streaming)
require.NoError(t, err)
resp := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIChatCompletions, reqBody)
resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIChatCompletions, reqBody)
require.NoError(t, err)
defer resp.Body.Close()
_, err = io.ReadAll(resp.Body)
require.NoError(t, err)
@@ -1872,13 +1906,13 @@ func TestThinkingAdaptiveIsPreserved(t *testing.T) {
t.Run(fmt.Sprintf("streaming=%v", streaming), func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong)
t.Cleanup(cancel)
// Create a mock server that captures the request body sent upstream.
upstream := newMockUpstream(t, ctx, newFixtureResponse(fix))
upstream := newMockUpstream(ctx, t, newFixtureResponse(fix))
bridgeServer := newBridgeTestServer(t, ctx, upstream.URL)
bridgeServer := newBridgeTestServer(ctx, t, upstream.URL)
// Inject adaptive thinking into the fixture request.
reqBody, err := sjson.SetBytes(fix.Request(), "thinking", map[string]string{"type": "adaptive"})
@@ -1886,7 +1920,9 @@ func TestThinkingAdaptiveIsPreserved(t *testing.T) {
reqBody, err = sjson.SetBytes(reqBody, "stream", streaming)
require.NoError(t, err)
resp := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody)
resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusOK, resp.StatusCode)
_, err = io.ReadAll(resp.Body)
require.NoError(t, err)
@@ -1935,11 +1971,11 @@ func TestEnvironmentDoNotLeak(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
// NOTE: Cannot use t.Parallel() here because t.Setenv requires sequential execution.
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong)
t.Cleanup(cancel)
fix := fixtures.Parse(t, tc.fixture)
upstream := newMockUpstream(t, ctx, newFixtureResponse(fix))
upstream := newMockUpstream(ctx, t, newFixtureResponse(fix))
// Set environment variables that the SDK would automatically read.
// These should NOT leak into upstream requests.
@@ -1947,9 +1983,11 @@ func TestEnvironmentDoNotLeak(t *testing.T) {
t.Setenv(key, val)
}
bridgeServer := newBridgeTestServer(t, ctx, upstream.URL)
bridgeServer := newBridgeTestServer(ctx, t, upstream.URL)
resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, fix.Request())
resp, err := bridgeServer.makeRequest(t, http.MethodPost, tc.path, fix.Request())
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusOK, resp.StatusCode)
// Verify that environment values did not leak.
@@ -2045,14 +2083,14 @@ func TestActorHeaders(t *testing.T) {
t.Run(fmt.Sprintf("%s/streaming=%v/send-headers=%v", tc.name, tc.streaming, send), func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong)
t.Cleanup(cancel)
fix := fixtures.Parse(t, tc.fixture)
upstream := newMockUpstream(t, ctx, newFixtureResponse(fix))
upstream := newMockUpstream(ctx, t, newFixtureResponse(fix))
metadataKey := "Username"
bridgeServer := newBridgeTestServer(t, ctx, upstream.URL,
bridgeServer := newBridgeTestServer(ctx, t, upstream.URL,
withCustomProvider(tc.createProviderFn(upstream.URL, apiKey, send)),
withActor(defaultActorID, recorder.Metadata{
metadataKey: actorUsername,
@@ -2063,7 +2101,9 @@ func TestActorHeaders(t *testing.T) {
reqBody, err := sjson.SetBytes(fix.Request(), "stream", tc.streaming)
require.NoError(t, err)
resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, reqBody)
resp, err := bridgeServer.makeRequest(t, http.MethodPost, tc.path, reqBody)
require.NoError(t, err)
defer resp.Body.Close()
// Drain the body so streaming responses complete without
// a "connection reset" error in the mock upstream.
_, err = io.ReadAll(resp.Body)
@@ -1,4 +1,4 @@
package integrationtest
package integrationtest //nolint:testpackage // tests unexported internals
import (
"fmt"
@@ -28,11 +28,11 @@ const (
)
func anthropicSuccessResponse(model string) string {
return fmt.Sprintf(`{"id":"msg_01","type":"message","role":"assistant","content":[{"type":"text","text":"Hello!"}],"model":"%s","stop_reason":"end_turn","usage":{"input_tokens":10,"output_tokens":5}}`, model)
return fmt.Sprintf(`{"id":"msg_01","type":"message","role":"assistant","content":[{"type":"text","text":"Hello!"}],"model":%q,"stop_reason":"end_turn","usage":{"input_tokens":10,"output_tokens":5}}`, model)
}
func openAISuccessResponse(model string) string {
return fmt.Sprintf(`{"id":"chatcmpl-123","object":"chat.completion","created":1677652288,"model":"%s","choices":[{"index":0,"message":{"role":"assistant","content":"Hello!"},"finish_reason":"stop"}],"usage":{"prompt_tokens":9,"completion_tokens":12,"total_tokens":21}}`, model)
return fmt.Sprintf(`{"id":"chatcmpl-123","object":"chat.completion","created":1677652288,"model":%q,"choices":[{"index":0,"message":{"role":"assistant","content":"Hello!"},"finish_reason":"stop"}],"usage":{"prompt_tokens":9,"completion_tokens":12,"total_tokens":21}}`, model)
}
// TestCircuitBreaker_FullRecoveryCycle tests the complete circuit breaker lifecycle:
@@ -130,31 +130,35 @@ func TestCircuitBreaker_FullRecoveryCycle(t *testing.T) {
}
ctx := t.Context()
bridgeServer := newBridgeTestServer(t, ctx, mockUpstream.URL,
bridgeServer := newBridgeTestServer(ctx, t, mockUpstream.URL,
withCustomProvider(tc.createProvider(mockUpstream.URL, cbConfig)),
withMetrics(m),
withActor("test-user-id", nil),
)
doRequest := func() *http.Response {
resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, []byte(tc.requestBody), tc.headers)
_, err := io.ReadAll(resp.Body)
doRequest := func() int {
resp, err := bridgeServer.makeRequest(t, http.MethodPost, tc.path, []byte(tc.requestBody), tc.headers)
require.NoError(t, err)
return resp
_, err = io.ReadAll(resp.Body)
require.NoError(t, err)
require.NoError(t, resp.Body.Close())
return resp.StatusCode
}
// Phase 1: Trip the circuit breaker
// First FailureThreshold requests hit upstream, get 429
for i := uint32(0); i < cbConfig.FailureThreshold; i++ {
resp := doRequest()
assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode)
status := doRequest()
assert.Equal(t, http.StatusTooManyRequests, status)
}
//nolint:gosec // G115: test constant, no overflow risk
assert.Equal(t, int32(cbConfig.FailureThreshold), upstreamCalls.Load())
// Phase 2: Verify circuit is open
// Request should be blocked by circuit breaker (no upstream call)
resp := doRequest()
assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode)
status := doRequest()
assert.Equal(t, http.StatusServiceUnavailable, status)
//nolint:gosec // G115: test constant, no overflow risk
assert.Equal(t, int32(cbConfig.FailureThreshold), upstreamCalls.Load(), "No new upstream call when circuit is open")
// Verify metrics show circuit is open
@@ -175,8 +179,8 @@ func TestCircuitBreaker_FullRecoveryCycle(t *testing.T) {
// Phase 4: Recovery - request in half-open state should succeed and close circuit
upstreamCallsBefore := upstreamCalls.Load()
resp = doRequest()
assert.Equal(t, http.StatusOK, resp.StatusCode, "Request should succeed in half-open state")
status = doRequest()
assert.Equal(t, http.StatusOK, status, "Request should succeed in half-open state")
assert.Equal(t, upstreamCallsBefore+1, upstreamCalls.Load(), "Request should reach upstream in half-open state")
// Verify circuit is now closed
@@ -186,8 +190,8 @@ func TestCircuitBreaker_FullRecoveryCycle(t *testing.T) {
// Phase 5: Verify circuit is fully functional again
// Multiple requests should all succeed and reach upstream
for i := 0; i < 3; i++ {
resp = doRequest()
assert.Equal(t, http.StatusOK, resp.StatusCode, "Request should succeed after circuit closes")
status = doRequest()
assert.Equal(t, http.StatusOK, status, "Request should succeed after circuit closes")
}
// All requests should have reached upstream
@@ -283,28 +287,30 @@ func TestCircuitBreaker_HalfOpenFailure(t *testing.T) {
}
ctx := t.Context()
bridgeServer := newBridgeTestServer(t, ctx, mockUpstream.URL,
bridgeServer := newBridgeTestServer(ctx, t, mockUpstream.URL,
withCustomProvider(tc.createProvider(mockUpstream.URL, cbConfig)),
withMetrics(m),
withActor("test-user-id", nil),
)
doRequest := func() *http.Response {
resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, []byte(tc.requestBody), tc.headers)
_, err := io.ReadAll(resp.Body)
doRequest := func() int {
resp, err := bridgeServer.makeRequest(t, http.MethodPost, tc.path, []byte(tc.requestBody), tc.headers)
require.NoError(t, err)
return resp
_, err = io.ReadAll(resp.Body)
require.NoError(t, err)
require.NoError(t, resp.Body.Close())
return resp.StatusCode
}
// Phase 1: Trip the circuit
for i := uint32(0); i < cbConfig.FailureThreshold; i++ {
resp := doRequest()
assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode)
status := doRequest()
assert.Equal(t, http.StatusTooManyRequests, status)
}
// Verify circuit is open
resp := doRequest()
assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode)
status := doRequest()
assert.Equal(t, http.StatusServiceUnavailable, status)
trips := promtest.ToFloat64(m.CircuitBreakerTrips.WithLabelValues(tc.expectProvider, tc.expectEndpoint, tc.expectModel))
assert.Equal(t, 1.0, trips, "CircuitBreakerTrips should be 1")
@@ -314,13 +320,13 @@ func TestCircuitBreaker_HalfOpenFailure(t *testing.T) {
// Phase 3: Request in half-open state fails, circuit should re-open
upstreamCallsBefore := upstreamCalls.Load()
resp = doRequest()
assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode, "Request should fail in half-open state")
status = doRequest()
assert.Equal(t, http.StatusTooManyRequests, status, "Request should fail in half-open state")
assert.Equal(t, upstreamCallsBefore+1, upstreamCalls.Load(), "Request should reach upstream in half-open state")
// Circuit should be open again - next request should be rejected immediately
resp = doRequest()
assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode, "Circuit should be open again after half-open failure")
status = doRequest()
assert.Equal(t, http.StatusServiceUnavailable, status, "Circuit should be open again after half-open failure")
assert.Equal(t, upstreamCallsBefore+1, upstreamCalls.Load(), "Request should NOT reach upstream when circuit re-opens")
// Verify metrics: trips should be 2 now (tripped twice)
@@ -429,28 +435,30 @@ func TestCircuitBreaker_HalfOpenMaxRequests(t *testing.T) {
}
ctx := t.Context()
bridgeServer := newBridgeTestServer(t, ctx, mockUpstream.URL,
bridgeServer := newBridgeTestServer(ctx, t, mockUpstream.URL,
withCustomProvider(tc.createProvider(mockUpstream.URL, cbConfig)),
withMetrics(m),
withActor("test-user-id", nil),
)
doRequest := func() *http.Response {
resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, []byte(tc.requestBody), tc.headers)
_, err := io.ReadAll(resp.Body)
doRequest := func() int {
resp, err := bridgeServer.makeRequest(t, http.MethodPost, tc.path, []byte(tc.requestBody), tc.headers)
require.NoError(t, err)
return resp
_, err = io.ReadAll(resp.Body)
require.NoError(t, err)
require.NoError(t, resp.Body.Close())
return resp.StatusCode
}
// Phase 1: Trip the circuit
for i := uint32(0); i < cbConfig.FailureThreshold; i++ {
resp := doRequest()
assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode)
status := doRequest()
assert.Equal(t, http.StatusTooManyRequests, status)
}
// Verify circuit is open
resp := doRequest()
assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode)
status := doRequest()
assert.Equal(t, http.StatusServiceUnavailable, status)
// Phase 2: Wait for half-open state and switch upstream to success
time.Sleep(cbConfig.Timeout + 10*time.Millisecond)
@@ -466,8 +474,8 @@ func TestCircuitBreaker_HalfOpenMaxRequests(t *testing.T) {
wg.Add(1)
go func() {
defer wg.Done()
resp := doRequest()
responses <- resp.StatusCode
status := doRequest()
responses <- status
}()
}
@@ -544,7 +552,7 @@ func TestCircuitBreaker_PerModelIsolation(t *testing.T) {
MaxRequests: 1,
}
ctx := t.Context()
bridgeServer := newBridgeTestServer(t, ctx, mockUpstream.URL,
bridgeServer := newBridgeTestServer(ctx, t, mockUpstream.URL,
withCustomProvider(provider.NewAnthropic(config.Anthropic{
BaseURL: mockUpstream.URL,
Key: "test-key",
@@ -554,27 +562,31 @@ func TestCircuitBreaker_PerModelIsolation(t *testing.T) {
withActor("test-user-id", nil),
)
doRequest := func(model string) *http.Response {
body := fmt.Sprintf(`{"model":"%s","max_tokens":1024,"messages":[{"role":"user","content":"hi"}]}`, model)
resp := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, []byte(body), http.Header{
doRequest := func(model string) int {
body := fmt.Sprintf(`{"model":%q,"max_tokens":1024,"messages":[{"role":"user","content":"hi"}]}`, model)
resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, []byte(body), http.Header{
"x-api-key": {"test"},
"anthropic-version": {"2023-06-01"},
})
_, err := io.ReadAll(resp.Body)
require.NoError(t, err)
return resp
_, err = io.ReadAll(resp.Body)
require.NoError(t, err)
require.NoError(t, resp.Body.Close())
return resp.StatusCode
}
// Phase 1: Trip the circuit for sonnet model
for i := uint32(0); i < cbConfig.FailureThreshold; i++ {
resp := doRequest("claude-sonnet-4-20250514")
assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode)
status := doRequest("claude-sonnet-4-20250514")
assert.Equal(t, http.StatusTooManyRequests, status)
}
//nolint:gosec // G115: test constant, no overflow risk
assert.Equal(t, int32(cbConfig.FailureThreshold), sonnetCalls.Load())
// Verify sonnet circuit is open
resp := doRequest("claude-sonnet-4-20250514")
assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode, "Sonnet circuit should be open")
status := doRequest("claude-sonnet-4-20250514")
assert.Equal(t, http.StatusServiceUnavailable, status, "Sonnet circuit should be open")
//nolint:gosec // G115: test constant, no overflow risk
assert.Equal(t, int32(cbConfig.FailureThreshold), sonnetCalls.Load(), "No new sonnet calls when circuit is open")
// Verify sonnet metrics show circuit is open
@@ -585,14 +597,14 @@ func TestCircuitBreaker_PerModelIsolation(t *testing.T) {
assert.Equal(t, 1.0, sonnetState, "Sonnet CircuitBreakerState should be 1 (open)")
// Phase 2: Haiku model should still work (independent circuit)
resp = doRequest("claude-3-5-haiku-20241022")
assert.Equal(t, http.StatusOK, resp.StatusCode, "Haiku should succeed while sonnet circuit is open")
status = doRequest("claude-3-5-haiku-20241022")
assert.Equal(t, http.StatusOK, status, "Haiku should succeed while sonnet circuit is open")
assert.Equal(t, int32(1), haikuCalls.Load(), "Haiku call should reach upstream")
// Make multiple haiku requests - all should succeed
for i := 0; i < 3; i++ {
resp = doRequest("claude-3-5-haiku-20241022")
assert.Equal(t, http.StatusOK, resp.StatusCode, "Haiku should continue to succeed")
status = doRequest("claude-3-5-haiku-20241022")
assert.Equal(t, http.StatusOK, status, "Haiku should continue to succeed")
}
assert.Equal(t, int32(4), haikuCalls.Load(), "All haiku calls should reach upstream")
@@ -607,8 +619,8 @@ func TestCircuitBreaker_PerModelIsolation(t *testing.T) {
time.Sleep(cbConfig.Timeout + 10*time.Millisecond)
sonnetShouldFail.Store(false)
resp = doRequest("claude-sonnet-4-20250514")
assert.Equal(t, http.StatusOK, resp.StatusCode, "Sonnet should recover after timeout")
status = doRequest("claude-sonnet-4-20250514")
assert.Equal(t, http.StatusOK, status, "Sonnet should recover after timeout")
// Verify sonnet circuit is now closed
sonnetState = promtest.ToFloat64(m.CircuitBreakerState.WithLabelValues(config.ProviderAnthropic, "/v1/messages", "claude-sonnet-4-20250514"))
@@ -1,4 +1,4 @@
package integrationtest
package integrationtest //nolint:testpackage // tests unexported internals
import (
"bytes"
@@ -7,7 +7,6 @@ import (
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/prometheus/client_golang/prometheus"
promtest "github.com/prometheus/client_golang/prometheus/testutil"
@@ -17,6 +16,7 @@ import (
"github.com/coder/coder/v2/aibridge"
"github.com/coder/coder/v2/aibridge/config"
"github.com/coder/coder/v2/aibridge/fixtures"
"github.com/coder/coder/v2/aibridge/internal/testutil"
"github.com/coder/coder/v2/aibridge/metrics"
)
@@ -104,7 +104,7 @@ func TestMetrics_Interception(t *testing.T) {
},
{
name: "oai_responses_blocking_error",
fixture: fixtures.OaiResponsesBlockingHttpErr,
fixture: fixtures.OaiResponsesBlockingHTTPErr,
path: pathOpenAIResponses,
headers: http.Header{"User-Agent": []string{"codex/1.0.0"}},
expectStatus: metrics.InterceptionCountStatusFailed,
@@ -127,7 +127,7 @@ func TestMetrics_Interception(t *testing.T) {
},
{
name: "oai_responses_streaming_error",
fixture: fixtures.OaiResponsesStreamingHttpErr,
fixture: fixtures.OaiResponsesStreamingHTTPErr,
path: pathOpenAIResponses,
headers: http.Header{"Originator": []string{"roo-code"}},
expectStatus: metrics.InterceptionCountStatusFailed,
@@ -143,20 +143,22 @@ func TestMetrics_Interception(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong)
t.Cleanup(cancel)
fix := fixtures.Parse(t, tc.fixture)
upstream := newMockUpstream(t, ctx, newFixtureResponse(fix))
upstream := newMockUpstream(ctx, t, newFixtureResponse(fix))
upstream.AllowOverflow = tc.allowOverflow
m := aibridge.NewMetrics(prometheus.NewRegistry())
bridgeServer := newBridgeTestServer(t, ctx, upstream.URL,
bridgeServer := newBridgeTestServer(ctx, t, upstream.URL,
withMetrics(m),
)
resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, fix.Request(), tc.headers)
_, err := io.ReadAll(resp.Body)
resp, err := bridgeServer.makeRequest(t, http.MethodPost, tc.path, fix.Request(), tc.headers)
require.NoError(t, err)
defer resp.Body.Close()
_, err = io.ReadAll(resp.Body)
require.NoError(t, err)
count := promtest.ToFloat64(m.InterceptionCount.WithLabelValues(
@@ -173,7 +175,7 @@ func TestMetrics_InterceptionsInflight(t *testing.T) {
fix := fixtures.Parse(t, fixtures.AntSimple)
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong)
t.Cleanup(cancel)
blockCh := make(chan struct{})
@@ -185,7 +187,7 @@ func TestMetrics_InterceptionsInflight(t *testing.T) {
t.Cleanup(srv.Close)
m := aibridge.NewMetrics(prometheus.NewRegistry())
bridgeServer := newBridgeTestServer(t, ctx, srv.URL,
bridgeServer := newBridgeTestServer(ctx, t, srv.URL,
withMetrics(m),
)
@@ -208,7 +210,7 @@ func TestMetrics_InterceptionsInflight(t *testing.T) {
return promtest.ToFloat64(
m.InterceptionsInflight.WithLabelValues(config.ProviderAnthropic, "claude-sonnet-4-0", "/v1/messages"),
) == 1
}, time.Second*10, time.Millisecond*50)
}, testutil.WaitMedium, testutil.IntervalFast)
// Unblock request, await completion.
close(blockCh)
@@ -223,7 +225,7 @@ func TestMetrics_InterceptionsInflight(t *testing.T) {
return promtest.ToFloat64(
m.InterceptionsInflight.WithLabelValues(config.ProviderAnthropic, "claude-sonnet-4-0", "/v1/messages"),
) == 0
}, time.Second*10, time.Millisecond*50)
}, testutil.WaitMedium, testutil.IntervalFast)
}
func TestMetrics_PassthroughCount(t *testing.T) {
@@ -233,11 +235,13 @@ func TestMetrics_PassthroughCount(t *testing.T) {
t.Cleanup(upstream.Close)
m := aibridge.NewMetrics(prometheus.NewRegistry())
bridgeServer := newBridgeTestServer(t, t.Context(), upstream.URL,
bridgeServer := newBridgeTestServer(t.Context(), t, upstream.URL,
withMetrics(m),
)
resp := bridgeServer.makeRequest(t, http.MethodGet, "/openai/v1/models", nil)
resp, err := bridgeServer.makeRequest(t, http.MethodGet, "/openai/v1/models", nil)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusOK, resp.StatusCode)
count := promtest.ToFloat64(m.PassthroughCount.WithLabelValues(
@@ -248,20 +252,22 @@ func TestMetrics_PassthroughCount(t *testing.T) {
func TestMetrics_PromptCount(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong)
t.Cleanup(cancel)
fix := fixtures.Parse(t, fixtures.OaiChatSimple)
upstream := newMockUpstream(t, ctx, newFixtureResponse(fix))
upstream := newMockUpstream(ctx, t, newFixtureResponse(fix))
m := aibridge.NewMetrics(prometheus.NewRegistry())
bridgeServer := newBridgeTestServer(t, ctx, upstream.URL,
bridgeServer := newBridgeTestServer(ctx, t, upstream.URL,
withMetrics(m),
)
resp := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIChatCompletions, fix.Request(), http.Header{"User-Agent": []string{"claude-code/1.0.0"}})
resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIChatCompletions, fix.Request(), http.Header{"User-Agent": []string{"claude-code/1.0.0"}})
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusOK, resp.StatusCode)
_, err := io.ReadAll(resp.Body)
_, err = io.ReadAll(resp.Body)
require.NoError(t, err)
prompts := promtest.ToFloat64(m.PromptCount.WithLabelValues(
@@ -336,14 +342,14 @@ func TestMetrics_TokenUseCount(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong)
t.Cleanup(cancel)
fix := fixtures.Parse(t, tc.fixture)
upstream := newMockUpstream(t, ctx, newFixtureResponse(fix))
upstream := newMockUpstream(ctx, t, newFixtureResponse(fix))
m := aibridge.NewMetrics(prometheus.NewRegistry())
bridgeServer := newBridgeTestServer(t, ctx, upstream.URL,
bridgeServer := newBridgeTestServer(ctx, t, upstream.URL,
withMetrics(m),
)
@@ -353,7 +359,9 @@ func TestMetrics_TokenUseCount(t *testing.T) {
reqBody, err = sjson.SetBytes(reqBody, "stream", true)
require.NoError(t, err)
}
resp := bridgeServer.makeRequest(t, http.MethodPost, tc.reqPath, reqBody, nil)
resp, err := bridgeServer.makeRequest(t, http.MethodPost, tc.reqPath, reqBody, nil)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusOK, resp.StatusCode)
_, _ = io.ReadAll(resp.Body)
@@ -361,7 +369,7 @@ func TestMetrics_TokenUseCount(t *testing.T) {
require.Eventually(t, func() bool {
return promtest.ToFloat64(m.TokenUseCount.WithLabelValues(
tc.expectProvider, tc.expectModel, "input", defaultActorID, string(aibridge.ClientUnknown))) > 0
}, time.Second*10, time.Millisecond*50)
}, testutil.WaitMedium, testutil.IntervalFast)
for label, expected := range tc.expectedLabels {
require.Equal(t, expected, promtest.ToFloat64(m.TokenUseCount.WithLabelValues(
@@ -375,20 +383,22 @@ func TestMetrics_TokenUseCount(t *testing.T) {
func TestMetrics_NonInjectedToolUseCount(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong)
t.Cleanup(cancel)
fix := fixtures.Parse(t, fixtures.OaiChatSingleBuiltinTool)
upstream := newMockUpstream(t, ctx, newFixtureResponse(fix))
upstream := newMockUpstream(ctx, t, newFixtureResponse(fix))
m := aibridge.NewMetrics(prometheus.NewRegistry())
bridgeServer := newBridgeTestServer(t, ctx, upstream.URL,
bridgeServer := newBridgeTestServer(ctx, t, upstream.URL,
withMetrics(m),
)
resp := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIChatCompletions, fix.Request())
resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIChatCompletions, fix.Request())
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusOK, resp.StatusCode)
_, err := io.ReadAll(resp.Body)
_, err = io.ReadAll(resp.Body)
require.NoError(t, err)
count := promtest.ToFloat64(m.NonInjectedToolUseCount.WithLabelValues(
@@ -399,32 +409,34 @@ func TestMetrics_NonInjectedToolUseCount(t *testing.T) {
func TestMetrics_InjectedToolUseCount(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong)
t.Cleanup(cancel)
// First request returns the tool invocation, the second returns the mocked response to the tool result.
fix := fixtures.Parse(t, fixtures.AntSingleInjectedTool)
upstream := newMockUpstream(t, ctx, newFixtureResponse(fix), newFixtureToolResponse(fix))
upstream := newMockUpstream(ctx, t, newFixtureResponse(fix), newFixtureToolResponse(fix))
m := aibridge.NewMetrics(prometheus.NewRegistry())
// Setup mocked MCP server & tools.
mockMCP := setupMCPForTest(t, defaultTracer)
bridgeServer := newBridgeTestServer(t, ctx, upstream.URL,
bridgeServer := newBridgeTestServer(ctx, t, upstream.URL,
withMetrics(m),
withMCP(mockMCP),
)
resp := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, fix.Request())
resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, fix.Request())
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusOK, resp.StatusCode)
_, err := io.ReadAll(resp.Body)
_, err = io.ReadAll(resp.Body)
require.NoError(t, err)
// Wait until full roundtrip has completed.
require.Eventually(t, func() bool {
return upstream.Calls.Load() == 2
}, time.Second*10, time.Millisecond*50)
}, testutil.WaitMedium, testutil.IntervalFast)
recorder := bridgeServer.Recorder
require.Len(t, recorder.ToolUsages(), 1)
+4 -4
View File
@@ -7,7 +7,6 @@ import (
"net/http/httptest"
"sync"
"testing"
"time"
"github.com/mark3labs/mcp-go/client/transport"
mcplib "github.com/mark3labs/mcp-go/mcp"
@@ -19,6 +18,7 @@ import (
"cdr.dev/slog/v3"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/aibridge/internal/testutil"
"github.com/coder/coder/v2/aibridge/mcp"
)
@@ -68,12 +68,12 @@ func setupMCPForTestWithName(t *testing.T, name string, tracer trace.Tracer) *mo
mgr := mcp.NewServerProxyManager(map[string]mcp.ServerProxier{proxy.Name(): proxy}, tracer)
t.Cleanup(func() {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
require.NoError(t, mgr.Shutdown(ctx))
})
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong)
t.Cleanup(cancel)
require.NoError(t, mgr.Init(ctx))
require.NotEmpty(t, mgr.ListTools(), "mock MCP server should expose tools after init")
@@ -141,7 +141,7 @@ func createMockMCPSrv(t *testing.T) (http.Handler, *callAccumulator) {
tool := mcplib.NewTool(name,
mcplib.WithDescription(fmt.Sprintf("Mock of the %s tool", name)),
)
s.AddTool(tool, func(ctx context.Context, request mcplib.CallToolRequest) (*mcplib.CallToolResult, error) {
s.AddTool(tool, func(_ context.Context, request mcplib.CallToolRequest) (*mcplib.CallToolResult, error) {
acc.addCall(request.Params.Name, request.Params.Arguments)
if errMsg, ok := acc.getToolError(request.Params.Name); ok {
return nil, xerrors.New(errMsg)
@@ -111,9 +111,9 @@ func (ms *mockUpstream) receivedRequests() []receivedRequest {
// The test fails if the number of requests doesn't match the number of
// responses (when AllowOverflow is not set, default).
//
// srv := newMockUpstream(t, ctx, newFixtureResponse(fix)) // simple
// srv := newMockUpstream(t, ctx, newFixtureResponse(fix), newFixtureToolResponse(fix)) // multi-turn
func newMockUpstream(t *testing.T, ctx context.Context, responses ...upstreamResponse) *mockUpstream {
// srv := newMockUpstream(ctx, t, newFixtureResponse(fix)) // simple
// srv := newMockUpstream(ctx, t, newFixtureResponse(fix), newFixtureToolResponse(fix)) // multi-turn
func newMockUpstream(ctx context.Context, t *testing.T, responses ...upstreamResponse) *mockUpstream {
t.Helper()
require.NotEmpty(t, responses, "at least one upstreamResponse required")
@@ -1,4 +1,4 @@
package integrationtest
package integrationtest //nolint:testpackage // tests unexported internals
import (
"context"
@@ -22,6 +22,7 @@ import (
"github.com/coder/coder/v2/aibridge"
"github.com/coder/coder/v2/aibridge/config"
"github.com/coder/coder/v2/aibridge/fixtures"
"github.com/coder/coder/v2/aibridge/internal/testutil"
"github.com/coder/coder/v2/aibridge/provider"
"github.com/coder/coder/v2/aibridge/recorder"
"github.com/coder/coder/v2/aibridge/utils"
@@ -335,15 +336,17 @@ func TestResponsesOutputMatchesUpstream(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong)
t.Cleanup(cancel)
fix := fixtures.Parse(t, tc.fixture)
upstream := newMockUpstream(t, ctx, newFixtureResponse(fix))
upstream := newMockUpstream(ctx, t, newFixtureResponse(fix))
bridgeServer := newBridgeTestServer(t, ctx, upstream.URL)
bridgeServer := newBridgeTestServer(ctx, t, upstream.URL)
resp := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIResponses, fix.Request(), http.Header{"User-Agent": {tc.userAgent}})
resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIResponses, fix.Request(), http.Header{"User-Agent": {tc.userAgent}})
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusOK, resp.StatusCode)
got, err := io.ReadAll(resp.Body)
@@ -416,7 +419,7 @@ func TestResponsesBackgroundModeForbidden(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong)
t.Cleanup(cancel)
// request with Background mode should be rejected before it reaches upstream
@@ -426,11 +429,13 @@ func TestResponsesBackgroundModeForbidden(t *testing.T) {
}))
t.Cleanup(upstream.Close)
bridgeServer := newBridgeTestServer(t, ctx, upstream.URL)
bridgeServer := newBridgeTestServer(ctx, t, upstream.URL)
// Create a request with background mode enabled
reqBytes := responsesRequestBytes(t, tc.streaming, keyVal{"background", true})
resp := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIResponses, reqBytes)
resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIResponses, reqBytes)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, "application/json", resp.Header.Get("Content-Type"))
require.Equal(t, http.StatusNotImplemented, resp.StatusCode)
@@ -547,17 +552,17 @@ func TestResponsesParallelToolsOverwritten(t *testing.T) {
t.Run(fmt.Sprintf("%s/streaming=%v", tc.name, streaming), func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong)
t.Cleanup(cancel)
fix := fixtures.Parse(t, tc.fixture[i])
upstream := newMockUpstream(t, ctx, newFixtureResponse(fix))
upstream := newMockUpstream(ctx, t, newFixtureResponse(fix))
var opts []bridgeOption
if tc.withInjectedTools {
opts = append(opts, withMCP(setupMCPForTest(t, defaultTracer)))
}
bridgeServer := newBridgeTestServer(t, ctx, upstream.URL, opts...)
bridgeServer := newBridgeTestServer(ctx, t, upstream.URL, opts...)
var (
reqBody = fix.Request()
@@ -568,7 +573,9 @@ func TestResponsesParallelToolsOverwritten(t *testing.T) {
require.NoError(t, err)
}
resp := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIResponses, reqBody)
resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIResponses, reqBody)
require.NoError(t, err)
defer resp.Body.Close()
_, err = io.ReadAll(resp.Body)
require.NoError(t, err)
@@ -631,14 +638,16 @@ func TestClientAndConnectionError(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong)
t.Cleanup(cancel)
// tc.addr may be an intentionally invalid URL; use withCustomProvider.
bridgeServer := newBridgeTestServer(t, ctx, tc.addr, withCustomProvider(provider.NewOpenAI(openAICfg(tc.addr, apiKey))))
bridgeServer := newBridgeTestServer(ctx, t, tc.addr, withCustomProvider(provider.NewOpenAI(openAICfg(tc.addr, apiKey))))
reqBytes := responsesRequestBytes(t, tc.streaming)
resp := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIResponses, reqBytes)
resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIResponses, reqBytes)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, "application/json", resp.Header.Get("Content-Type"))
require.Equal(t, http.StatusInternalServerError, resp.StatusCode)
@@ -701,7 +710,7 @@ func TestUpstreamError(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong)
t.Cleanup(cancel)
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -712,10 +721,12 @@ func TestUpstreamError(t *testing.T) {
}))
t.Cleanup(upstream.Close)
bridgeServer := newBridgeTestServer(t, ctx, upstream.URL)
bridgeServer := newBridgeTestServer(ctx, t, upstream.URL)
reqBytes := responsesRequestBytes(t, tc.streaming)
resp := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIResponses, reqBytes)
resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIResponses, reqBytes)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, tc.statusCode, resp.StatusCode)
require.Equal(t, tc.contentType, resp.Header.Get("Content-Type"))
@@ -880,13 +891,13 @@ func TestResponsesInjectedTool(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong)
t.Cleanup(cancel)
// Setup mock server for multi-turn interaction.
// First request → tool call response, second → tool response.
fix := fixtures.Parse(t, tc.fixture)
upstream := newMockUpstream(t, ctx, newFixtureResponse(fix), newFixtureToolResponse(fix))
upstream := newMockUpstream(ctx, t, newFixtureResponse(fix), newFixtureToolResponse(fix))
// Setup MCP server proxies (with mock tools).
mockMCP := setupMCPForTest(t, defaultTracer)
@@ -894,9 +905,11 @@ func TestResponsesInjectedTool(t *testing.T) {
mockMCP.setToolError(tc.mcpToolName, tc.expectToolError)
}
bridgeServer := newBridgeTestServer(t, ctx, upstream.URL, withMCP(mockMCP))
bridgeServer := newBridgeTestServer(ctx, t, upstream.URL, withMCP(mockMCP))
resp := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIResponses, fix.Request())
resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIResponses, fix.Request())
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusOK, resp.StatusCode)
body, err := io.ReadAll(resp.Body)
@@ -905,7 +918,7 @@ func TestResponsesInjectedTool(t *testing.T) {
// Wait for both requests to be made (inner agentic loop).
require.Eventually(t, func() bool {
return upstream.Calls.Load() == 2
}, time.Second*10, time.Millisecond*50)
}, testutil.WaitMedium, testutil.IntervalFast)
// Verify the injected tool was invoked via MCP.
invocations := mockMCP.getCallsByTool(tc.mcpToolName)
@@ -1025,18 +1038,20 @@ func TestResponsesModelThoughts(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong)
t.Cleanup(cancel)
fix := fixtures.Parse(t, tc.fixture)
upstream := newMockUpstream(t, ctx, newFixtureResponse(fix))
upstream := newMockUpstream(ctx, t, newFixtureResponse(fix))
bridgeServer := newBridgeTestServer(t, ctx, upstream.URL)
bridgeServer := newBridgeTestServer(ctx, t, upstream.URL)
resp := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIResponses, fix.Request())
resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIResponses, fix.Request())
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusOK, resp.StatusCode)
_, err := io.ReadAll(resp.Body)
_, err = io.ReadAll(resp.Body)
require.NoError(t, err)
bridgeServer.Recorder.VerifyModelThoughtsRecorded(t, tc.expectedThoughts)
@@ -7,7 +7,6 @@ import (
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/tidwall/sjson"
@@ -63,11 +62,13 @@ type bridgeTestServer struct {
// makeRequest builds and executes an HTTP request against this server.
// Optional headers are applied after the default Content-Type.
func (s *bridgeTestServer) makeRequest(t *testing.T, method string, path string, body []byte, header ...http.Header) *http.Response {
func (s *bridgeTestServer) makeRequest(t *testing.T, method string, path string, body []byte, header ...http.Header) (*http.Response, error) {
t.Helper()
req, err := http.NewRequestWithContext(t.Context(), method, s.URL+path, bytes.NewReader(body))
require.NoError(t, err)
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
for _, h := range header {
for k, vals := range h {
@@ -76,10 +77,7 @@ func (s *bridgeTestServer) makeRequest(t *testing.T, method string, path string,
}
}
}
resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
t.Cleanup(func() { _ = resp.Body.Close() })
return resp
return http.DefaultClient.Do(req)
}
type bridgeOption func(*bridgeConfig)
@@ -133,8 +131,8 @@ func withActor(id string, md recorder.Metadata) bridgeOption {
// - defaultTracer (unless withTracer)
// - defaultActorID (unless withActor)
func newBridgeTestServer(
t *testing.T,
ctx context.Context,
t *testing.T,
upstreamURL string,
opts ...bridgeOption,
) *bridgeTestServer {
@@ -209,7 +207,7 @@ func setupInjectedToolTest(
) (*bridgeTestServer, *mockMCP, *http.Response) {
t.Helper()
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong)
t.Cleanup(cancel)
fix := fixtures.Parse(t, fixture)
@@ -220,7 +218,7 @@ func setupInjectedToolTest(
firstResp := newFixtureResponse(fix)
toolResp := newFixtureToolResponse(fix)
toolResp.OnRequest = toolRequestValidatorFn
upstream := newMockUpstream(t, ctx, firstResp, toolResp)
upstream := newMockUpstream(ctx, t, firstResp, toolResp)
mockMCP := setupMCPForTest(t, tracer)
@@ -230,19 +228,20 @@ func setupInjectedToolTest(
withActor(defaultActorID, nil),
}
allOpts = append(allOpts, opts...)
bridgeServer := newBridgeTestServer(t, ctx, upstream.URL, allOpts...)
bridgeServer := newBridgeTestServer(ctx, t, upstream.URL, allOpts...)
// Add the stream param to the request.
reqBody, err := sjson.SetBytes(fix.Request(), "stream", streaming)
require.NoError(t, err)
resp := bridgeServer.makeRequest(t, http.MethodPost, path, reqBody)
resp, err := bridgeServer.makeRequest(t, http.MethodPost, path, reqBody)
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
// Wait both requests (initial + tool call result)
require.Eventually(t, func() bool {
return upstream.Calls.Load() == 2
}, time.Second*10, time.Millisecond*50)
}, testutil.WaitMedium, testutil.IntervalFast)
return bridgeServer, mockMCP, resp
}
+53 -24
View File
@@ -1,4 +1,4 @@
package integrationtest
package integrationtest //nolint:testpackage // tests unexported internals
import (
"context"
@@ -6,7 +6,6 @@ import (
"slices"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -20,6 +19,7 @@ import (
"github.com/coder/coder/v2/aibridge/config"
"github.com/coder/coder/v2/aibridge/fixtures"
"github.com/coder/coder/v2/aibridge/internal/testutil"
"github.com/coder/coder/v2/aibridge/tracing"
)
@@ -43,6 +43,8 @@ func setupTracer(t *testing.T) (*tracetest.SpanRecorder, oteltrace.Tracer) {
}
func TestTraceAnthropic(t *testing.T) {
t.Parallel()
expectNonStreaming := []expectTrace{
{"Intercept", 1, codes.Unset},
{"Intercept.CreateInterceptor", 1, codes.Unset},
@@ -137,13 +139,15 @@ func TestTraceAnthropic(t *testing.T) {
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
t.Parallel()
ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong)
t.Cleanup(cancel)
sr, tracer := setupTracer(t)
fix := fixtures.Parse(t, tc.fixture)
upstream := newMockUpstream(t, ctx, newFixtureResponse(fix))
upstream := newMockUpstream(ctx, t, newFixtureResponse(fix))
opts := []bridgeOption{
withTracer(tracer),
@@ -151,11 +155,13 @@ func TestTraceAnthropic(t *testing.T) {
if tc.bedrock {
opts = append(opts, withProvider(providerBedrock))
}
bridgeServer := newBridgeTestServer(t, ctx, upstream.URL, opts...)
bridgeServer := newBridgeTestServer(ctx, t, upstream.URL, opts...)
reqBody, err := sjson.SetBytes(fix.Request(), "stream", tc.streaming)
require.NoError(t, err)
resp := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody)
resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusOK, resp.StatusCode)
bridgeServer.Close()
@@ -189,6 +195,8 @@ func TestTraceAnthropic(t *testing.T) {
}
func TestTraceAnthropicErr(t *testing.T) {
t.Parallel()
expectNonStream := []expectTrace{
{"Intercept", 1, codes.Error},
{"Intercept.CreateInterceptor", 1, codes.Unset},
@@ -247,13 +255,15 @@ func TestTraceAnthropicErr(t *testing.T) {
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
t.Parallel()
ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong)
t.Cleanup(cancel)
sr, tracer := setupTracer(t)
fix := fixtures.Parse(t, tc.fixture)
upstream := newMockUpstream(t, ctx, newFixtureResponse(fix))
upstream := newMockUpstream(ctx, t, newFixtureResponse(fix))
opts := []bridgeOption{
withTracer(tracer),
@@ -261,11 +271,13 @@ func TestTraceAnthropicErr(t *testing.T) {
if tc.bedrock {
opts = append(opts, withProvider(providerBedrock))
}
bridgeServer := newBridgeTestServer(t, ctx, upstream.URL, opts...)
bridgeServer := newBridgeTestServer(ctx, t, upstream.URL, opts...)
reqBody, err := sjson.SetBytes(fix.Request(), "stream", tc.streaming)
require.NoError(t, err)
resp := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody)
resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody)
require.NoError(t, err)
defer resp.Body.Close()
if tc.streaming {
require.Equal(t, http.StatusOK, resp.StatusCode)
} else {
@@ -385,10 +397,11 @@ func TestInjectedToolsTrace(t *testing.T) {
validatorFn = openaiChatToolResultValidator(t)
}
bridgeServer, mockMCP, _ := setupInjectedToolTest(
bridgeServer, mockMCP, resp := setupInjectedToolTest(
t, tc.fixture, tc.streaming, tracer,
tc.path, validatorFn, tc.opts...,
)
defer resp.Body.Close()
require.Len(t, bridgeServer.Recorder.RecordedInterceptions(), 1)
intcID := bridgeServer.Recorder.RecordedInterceptions()[0].ID
@@ -417,6 +430,8 @@ func TestInjectedToolsTrace(t *testing.T) {
}
func TestTraceOpenAI(t *testing.T) {
t.Parallel()
cases := []struct {
name string
fixture []byte
@@ -529,20 +544,24 @@ func TestTraceOpenAI(t *testing.T) {
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
t.Parallel()
ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong)
t.Cleanup(cancel)
sr, tracer := setupTracer(t)
fix := fixtures.Parse(t, tc.fixture)
upstream := newMockUpstream(t, ctx, newFixtureResponse(fix))
bridgeServer := newBridgeTestServer(t, ctx, upstream.URL,
upstream := newMockUpstream(ctx, t, newFixtureResponse(fix))
bridgeServer := newBridgeTestServer(ctx, t, upstream.URL,
withTracer(tracer),
)
reqBody, err := sjson.SetBytes(fix.Request(), "stream", tc.streaming)
require.NoError(t, err)
resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, reqBody)
resp, err := bridgeServer.makeRequest(t, http.MethodPost, tc.path, reqBody)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusOK, resp.StatusCode)
bridgeServer.Close()
@@ -569,6 +588,8 @@ func TestTraceOpenAI(t *testing.T) {
}
func TestTraceOpenAIErr(t *testing.T) {
t.Parallel()
cases := []struct {
name string
fixture []byte
@@ -647,7 +668,7 @@ func TestTraceOpenAIErr(t *testing.T) {
},
{
name: "trace_openai_responses_streaming_http_error",
fixture: fixtures.OaiResponsesStreamingHttpErr,
fixture: fixtures.OaiResponsesStreamingHTTPErr,
streaming: true,
allowOverflow: true, // 429 error causes retries
@@ -664,7 +685,7 @@ func TestTraceOpenAIErr(t *testing.T) {
},
{
name: "trace_openai_responses_blocking_http_error",
fixture: fixtures.OaiResponsesBlockingHttpErr,
fixture: fixtures.OaiResponsesBlockingHTTPErr,
streaming: false,
path: pathOpenAIResponses,
@@ -682,22 +703,26 @@ func TestTraceOpenAIErr(t *testing.T) {
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
t.Parallel()
ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong)
t.Cleanup(cancel)
sr, tracer := setupTracer(t)
fix := fixtures.Parse(t, tc.fixture)
mockAPI := newMockUpstream(t, ctx, newFixtureResponse(fix))
mockAPI := newMockUpstream(ctx, t, newFixtureResponse(fix))
mockAPI.AllowOverflow = tc.allowOverflow
bridgeServer := newBridgeTestServer(t, ctx, mockAPI.URL,
bridgeServer := newBridgeTestServer(ctx, t, mockAPI.URL,
withTracer(tracer),
)
reqBody, err := sjson.SetBytes(fix.Request(), "stream", tc.streaming)
require.NoError(t, err)
resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, reqBody)
resp, err := bridgeServer.makeRequest(t, http.MethodPost, tc.path, reqBody)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, tc.expectCode, resp.StatusCode)
bridgeServer.Close()
@@ -729,15 +754,17 @@ func TestTracePassthrough(t *testing.T) {
fix := fixtures.Parse(t, fixtures.OaiChatFallthrough)
upstream := newMockUpstream(t, t.Context(), newFixtureResponse(fix))
upstream := newMockUpstream(t.Context(), t, newFixtureResponse(fix))
sr, tracer := setupTracer(t)
bridgeServer := newBridgeTestServer(t, t.Context(), upstream.URL,
bridgeServer := newBridgeTestServer(t.Context(), t, upstream.URL,
withTracer(tracer),
)
resp := bridgeServer.makeRequest(t, http.MethodGet, "/openai/v1/models", nil)
resp, err := bridgeServer.makeRequest(t, http.MethodGet, "/openai/v1/models", nil)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusOK, resp.StatusCode)
bridgeServer.Close()
@@ -755,6 +782,8 @@ func TestTracePassthrough(t *testing.T) {
}
func TestNewServerProxyManagerTraces(t *testing.T) {
t.Parallel()
sr, tracer := setupTracer(t)
serverName := "serverName"
+6 -6
View File
@@ -26,14 +26,14 @@ type MockRecorder struct {
interceptionsEnd map[string]*recorder.InterceptionRecordEnded
}
func (m *MockRecorder) RecordInterception(ctx context.Context, req *recorder.InterceptionRecord) error {
func (m *MockRecorder) RecordInterception(_ context.Context, req *recorder.InterceptionRecord) error {
m.mu.Lock()
defer m.mu.Unlock()
m.interceptions = append(m.interceptions, req)
return nil
}
func (m *MockRecorder) RecordInterceptionEnded(ctx context.Context, req *recorder.InterceptionRecordEnded) error {
func (m *MockRecorder) RecordInterceptionEnded(_ context.Context, req *recorder.InterceptionRecordEnded) error {
m.mu.Lock()
defer m.mu.Unlock()
if m.interceptionsEnd == nil {
@@ -46,28 +46,28 @@ func (m *MockRecorder) RecordInterceptionEnded(ctx context.Context, req *recorde
return nil
}
func (m *MockRecorder) RecordPromptUsage(ctx context.Context, req *recorder.PromptUsageRecord) error {
func (m *MockRecorder) RecordPromptUsage(_ context.Context, req *recorder.PromptUsageRecord) error {
m.mu.Lock()
defer m.mu.Unlock()
m.userPrompts = append(m.userPrompts, req)
return nil
}
func (m *MockRecorder) RecordTokenUsage(ctx context.Context, req *recorder.TokenUsageRecord) error {
func (m *MockRecorder) RecordTokenUsage(_ context.Context, req *recorder.TokenUsageRecord) error {
m.mu.Lock()
defer m.mu.Unlock()
m.tokenUsages = append(m.tokenUsages, req)
return nil
}
func (m *MockRecorder) RecordToolUsage(ctx context.Context, req *recorder.ToolUsageRecord) error {
func (m *MockRecorder) RecordToolUsage(_ context.Context, req *recorder.ToolUsageRecord) error {
m.mu.Lock()
defer m.mu.Unlock()
m.toolUsages = append(m.toolUsages, req)
return nil
}
func (m *MockRecorder) RecordModelThought(ctx context.Context, req *recorder.ModelThoughtRecord) error {
func (m *MockRecorder) RecordModelThought(_ context.Context, req *recorder.ModelThoughtRecord) error {
m.mu.Lock()
defer m.mu.Unlock()
m.modelThoughts = append(m.modelThoughts, req)
+12 -12
View File
@@ -11,26 +11,26 @@ import (
)
type MockProvider struct {
Name_ string
NameStr string
URL string
Bridged []string
Passthrough []string
InterceptorFunc func(w http.ResponseWriter, r *http.Request, tracer trace.Tracer) (intercept.Interceptor, error)
}
func (m *MockProvider) Type() string { return m.Name_ }
func (m *MockProvider) Name() string { return m.Name_ }
func (m *MockProvider) BaseURL() string { return m.URL }
func (m *MockProvider) RoutePrefix() string { return fmt.Sprintf("/%s", m.Name_) }
func (m *MockProvider) BridgedRoutes() []string { return m.Bridged }
func (m *MockProvider) PassthroughRoutes() []string { return m.Passthrough }
func (m *MockProvider) AuthHeader() string { return "Authorization" }
func (m *MockProvider) InjectAuthHeader(h *http.Header) {}
func (m *MockProvider) CircuitBreakerConfig() *config.CircuitBreaker { return nil }
func (m *MockProvider) APIDumpDir() string { return "" }
func (m *MockProvider) Type() string { return m.NameStr }
func (m *MockProvider) Name() string { return m.NameStr }
func (m *MockProvider) BaseURL() string { return m.URL }
func (m *MockProvider) RoutePrefix() string { return fmt.Sprintf("/%s", m.NameStr) }
func (m *MockProvider) BridgedRoutes() []string { return m.Bridged }
func (m *MockProvider) PassthroughRoutes() []string { return m.Passthrough }
func (*MockProvider) AuthHeader() string { return "Authorization" }
func (*MockProvider) InjectAuthHeader(_ *http.Header) {}
func (*MockProvider) CircuitBreakerConfig() *config.CircuitBreaker { return nil }
func (*MockProvider) APIDumpDir() string { return "" }
func (m *MockProvider) CreateInterceptor(w http.ResponseWriter, r *http.Request, tracer trace.Tracer) (intercept.Interceptor, error) {
if m.InterceptorFunc != nil {
return m.InterceptorFunc(w, r, tracer)
}
return nil, nil
return nil, nil //nolint:nilnil // mock: no interceptor configured is not an error
}
+21
View File
@@ -0,0 +1,21 @@
package testutil
import "time"
// Shared test timeout and interval constants.
// Using named constants avoids magic numbers and makes timeout policy
// easy to adjust across the entire test suite.
const (
// WaitLong is the default timeout for test operations that may take a while
// (e.g. integration tests with HTTP round-trips).
WaitLong = 30 * time.Second
// WaitMedium is a timeout for moderately slow operations.
WaitMedium = 10 * time.Second
// WaitShort is a timeout for operations expected to complete quickly.
WaitShort = 5 * time.Second
// IntervalFast is a short polling interval for require.Eventually and similar.
IntervalFast = 50 * time.Millisecond
)
+2
View File
@@ -9,6 +9,8 @@ import (
)
func TestGetClientInfo(t *testing.T) {
t.Parallel()
info := mcp.GetClientInfo()
assert.Equal(t, "coder/aibridge", info.Name)
+2 -2
View File
@@ -9,7 +9,6 @@ import (
"slices"
"strings"
"testing"
"time"
mcplib "github.com/mark3labs/mcp-go/mcp"
"github.com/mark3labs/mcp-go/server"
@@ -19,6 +18,7 @@ import (
"cdr.dev/slog/v3"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/aibridge/internal/testutil"
"github.com/coder/coder/v2/aibridge/mcp"
)
@@ -299,7 +299,7 @@ func TestToolInjectionOrder(t *testing.T) {
// Setup.
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong)
t.Cleanup(cancel)
// Given: a MCP mock server offering a set of tools.
+2 -2
View File
@@ -87,7 +87,7 @@ func (p *StreamableHTTPServerProxy) Init(ctx context.Context) (outErr error) {
return xerrors.Errorf("MCP version negotiation failed; requested %q, accepts %q, received %q", version, strings.Join(mcp.ValidProtocolVersions, ","), result.ProtocolVersion)
}
p.logger.Debug(ctx, "MCP client initialized", slog.F("name", result.ServerInfo.Name), slog.F("server_version", result.ServerInfo.Version))
p.logger.Debug(ctx, "mcp client initialized", slog.F("name", result.ServerInfo.Name), slog.F("server_version", result.ServerInfo.Version))
tools, err := p.fetchTools(ctx)
if err != nil {
@@ -161,7 +161,7 @@ func (p *StreamableHTTPServerProxy) fetchTools(ctx context.Context) (_ map[strin
return out, nil
}
func (p *StreamableHTTPServerProxy) Shutdown(ctx context.Context) error {
func (p *StreamableHTTPServerProxy) Shutdown(_ context.Context) error {
if p.client == nil {
return nil
}
+12 -11
View File
@@ -59,15 +59,15 @@ func (t *Tool) Call(ctx context.Context, input any, tracer trace.Tracer) (_ *mcp
ctx, span := tracer.Start(ctx, "Intercept.ProcessRequest.ToolCall", trace.WithAttributes(spanAttrs...))
defer tracing.EndSpanErr(span, &outErr)
inputJson, err := json.Marshal(input)
inputJSON, err := json.Marshal(input)
if err != nil {
t.Logger.Warn(ctx, "failed to marshal tool input, will be omitted from span attrs", slog.Error(err))
} else {
strJson := string(inputJson)
if len(strJson) > maxSpanInputAttrLen {
strJson = strJson[:maxSpanInputAttrLen]
strJSON := string(inputJSON)
if len(strJSON) > maxSpanInputAttrLen {
strJSON = strJSON[:maxSpanInputAttrLen]
}
span.SetAttributes(attribute.String(tracing.MCPInput, strJson))
span.SetAttributes(attribute.String(tracing.MCPInput, strJSON))
}
start := time.Now()
@@ -88,7 +88,7 @@ func (t *Tool) Call(ctx context.Context, input any, tracer trace.Tracer) (_ *mcp
logFn(ctx, "injected tool invoked",
slog.F("name", t.Name),
slog.F("server", t.ServerName),
slog.F("input", inputJson),
slog.F("input", inputJSON),
slog.F("duration_sec", time.Since(start).Seconds()),
slog.Error(outErr),
)
@@ -106,12 +106,13 @@ func (t *Tool) Call(ctx context.Context, input any, tracer trace.Tracer) (_ *mcp
// - https://community.openai.com/t/function-call-description-max-length/529902
// - https://github.com/anthropics/claude-code/issues/2326
func EncodeToolID(server, tool string) string {
// strings.Builder writes to in-memory storage and never return errors.
var sb strings.Builder
sb.WriteString(injectedToolPrefix)
sb.WriteString(injectedToolDelimiter)
sb.WriteString(server)
sb.WriteString(injectedToolDelimiter)
sb.WriteString(tool)
_, _ = sb.WriteString(injectedToolPrefix)
_, _ = sb.WriteString(injectedToolDelimiter)
_, _ = sb.WriteString(server)
_, _ = sb.WriteString(injectedToolDelimiter)
_, _ = sb.WriteString(tool)
return sb.String()
}
+1 -1
View File
@@ -1,3 +1,3 @@
package mcpmock
//go:generate mockgen -destination ./mcpmock.go -package mcpmock github.com/coder/coder/v2/aibridge/mcp ServerProxier
//go:generate mockgen -destination ./mcpmock.go -package mcpmock github.com/coder/aibridge/mcp ServerProxier
+2 -2
View File
@@ -1,9 +1,9 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/coder/coder/v2/aibridge/mcp (interfaces: ServerProxier)
// Source: github.com/coder/aibridge/mcp (interfaces: ServerProxier)
//
// Generated by this command:
//
// mockgen -destination ./mcpmock.go -package mcpmock github.com/coder/coder/v2/aibridge/mcp ServerProxier
// mockgen -destination ./mcpmock.go -package mcpmock github.com/coder/aibridge/mcp ServerProxier
//
// Package mcpmock is a generated GoMock package.
+1 -1
View File
@@ -5,7 +5,7 @@ import (
"github.com/prometheus/client_golang/prometheus/promauto"
)
var baseLabels []string = []string{"provider", "model"}
var baseLabels = []string{"provider", "model"}
const (
InterceptionCountStatusFailed = "failed"
+6 -6
View File
@@ -21,10 +21,10 @@ import (
// newPassthroughRouter returns a simple reverse-proxy implementation which will be used when a route is not handled specifically
// by a [intercept.Provider].
func newPassthroughRouter(provider provider.Provider, logger slog.Logger, m *metrics.Metrics, tracer trace.Tracer) http.HandlerFunc {
func newPassthroughRouter(prov provider.Provider, logger slog.Logger, m *metrics.Metrics, tracer trace.Tracer) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if m != nil {
m.PassthroughCount.WithLabelValues(provider.Name(), r.URL.Path, r.Method).Add(1)
m.PassthroughCount.WithLabelValues(prov.Name(), r.URL.Path, r.Method).Add(1)
}
ctx, span := tracer.Start(r.Context(), "Passthrough", trace.WithAttributes(
@@ -33,7 +33,7 @@ func newPassthroughRouter(provider provider.Provider, logger slog.Logger, m *met
))
defer span.End()
upURL, err := url.Parse(provider.BaseURL())
upURL, err := url.Parse(prov.BaseURL())
if err != nil {
logger.Warn(ctx, "failed to parse provider base URL", slog.Error(err))
http.Error(w, "request error", http.StatusBadGateway)
@@ -44,7 +44,7 @@ func newPassthroughRouter(provider provider.Provider, logger slog.Logger, m *met
// Append the request path to the upstream base path.
reqPath, err := url.JoinPath(upURL.Path, r.URL.Path)
if err != nil {
logger.Warn(ctx, "failed to join upstream path", slog.Error(err), slog.F("upstreamPath", upURL.Path), slog.F("requestPath", r.URL.Path))
logger.Warn(ctx, "failed to join upstream path", slog.Error(err), slog.F("upstream_path", upURL.Path), slog.F("request_path", r.URL.Path))
http.Error(w, "failed to join upstream path", http.StatusInternalServerError)
span.SetStatus(codes.Error, "failed to join upstream path: "+err.Error())
return
@@ -96,7 +96,7 @@ func newPassthroughRouter(provider provider.Provider, logger slog.Logger, m *met
}
// Inject provider auth.
provider.InjectAuthHeader(&req.Header)
prov.InjectAuthHeader(&req.Header)
},
ErrorHandler: func(rw http.ResponseWriter, req *http.Request, e error) {
logger.Warn(req.Context(), "reverse proxy error", slog.Error(e), slog.F("path", req.URL.Path))
@@ -113,7 +113,7 @@ func newPassthroughRouter(provider provider.Provider, logger slog.Logger, m *met
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
proxy.Transport = apidump.NewPassthroughMiddleware(t, provider.APIDumpDir(), provider.Name(), logger, quartz.NewReal())
proxy.Transport = apidump.NewPassthroughMiddleware(t, prov.APIDumpDir(), prov.Name(), logger, quartz.NewReal())
proxy.ServeHTTP(w, r)
}
+1 -1
View File
@@ -1,4 +1,4 @@
package aibridge
package aibridge //nolint:testpackage // tests unexported newPassthroughRouter
import (
"net/http"
+60 -61
View File
@@ -73,7 +73,7 @@ func NewAnthropic(cfg config.Anthropic, bedrockCfg *config.AWSBedrock) *Anthropi
}
}
func (p *Anthropic) Type() string {
func (*Anthropic) Type() string {
return config.ProviderAnthropic
}
@@ -85,11 +85,11 @@ func (p *Anthropic) RoutePrefix() string {
return fmt.Sprintf("/%s", p.Name())
}
func (p *Anthropic) BridgedRoutes() []string {
func (*Anthropic) BridgedRoutes() []string {
return []string{routeMessages}
}
func (p *Anthropic) PassthroughRoutes() []string {
func (*Anthropic) PassthroughRoutes() []string {
return []string{
"/v1/models",
"/v1/models/", // See https://pkg.go.dev/net/http#hdr-Trailing_slash_redirection-ServeMux.
@@ -98,77 +98,76 @@ func (p *Anthropic) PassthroughRoutes() []string {
}
}
func (p *Anthropic) CreateInterceptor(w http.ResponseWriter, r *http.Request, tracer trace.Tracer) (_ intercept.Interceptor, outErr error) {
func (p *Anthropic) CreateInterceptor(_ http.ResponseWriter, r *http.Request, tracer trace.Tracer) (_ intercept.Interceptor, outErr error) {
id := uuid.New()
_, span := tracer.Start(r.Context(), "Intercept.CreateInterceptor")
defer tracing.EndSpanErr(span, &outErr)
path := strings.TrimPrefix(r.URL.Path, p.RoutePrefix())
switch path {
case routeMessages:
payload, err := io.ReadAll(r.Body)
if err != nil {
return nil, xerrors.Errorf("read body: %w", err)
}
reqPayload, err := messages.NewMessagesRequestPayload(payload)
if err != nil {
return nil, xerrors.Errorf("unmarshal request body: %w", err)
}
cfg := p.cfg
cfg.ExtraHeaders = extractAnthropicHeaders(r)
// At this point the request contains only LLM provider headers.
// Any Coder-specific authentication has already been stripped.
//
// In centralized mode neither Authorization nor X-Api-Key is
// present, so cfg keeps the centralized key unchanged.
//
// In BYOK mode the user's LLM credentials survive intact.
// If X-Api-Key is present the user has a personal API key;
// overwrite the centralized key with it. If Authorization is
// present the user authenticated directly with provider;
// set BYOKBearerToken and clear the centralized key.
// When both are present, X-Api-Key takes priority to match
// claude-code behavior.
credKind := intercept.CredentialKindCentralized
credSecret := cfg.Key
authHeaderName := p.AuthHeader()
if apiKey := r.Header.Get("X-Api-Key"); apiKey != "" {
cfg.Key = apiKey
authHeaderName = "X-Api-Key"
credKind = intercept.CredentialKindBYOK
credSecret = apiKey
} else if token := utils.ExtractBearerToken(r.Header.Get("Authorization")); token != "" {
cfg.BYOKBearerToken = token
cfg.Key = ""
authHeaderName = "Authorization"
credKind = intercept.CredentialKindBYOK
credSecret = token
}
cred := intercept.NewCredentialInfo(credKind, credSecret)
var interceptor intercept.Interceptor
if reqPayload.Stream() {
interceptor = messages.NewStreamingInterceptor(id, reqPayload, p.Name(), cfg, p.bedrockCfg, r.Header, authHeaderName, tracer, cred)
} else {
interceptor = messages.NewBlockingInterceptor(id, reqPayload, p.Name(), cfg, p.bedrockCfg, r.Header, authHeaderName, tracer, cred)
}
span.SetAttributes(interceptor.TraceAttributes(r)...)
return interceptor, nil
if path != routeMessages {
span.SetStatus(codes.Error, "unknown route: "+r.URL.Path)
return nil, ErrUnknownRoute
}
span.SetStatus(codes.Error, "unknown route: "+r.URL.Path)
return nil, UnknownRoute
payload, err := io.ReadAll(r.Body)
if err != nil {
return nil, xerrors.Errorf("read body: %w", err)
}
reqPayload, err := messages.NewRequestPayload(payload)
if err != nil {
return nil, xerrors.Errorf("unmarshal request body: %w", err)
}
cfg := p.cfg
cfg.ExtraHeaders = extractAnthropicHeaders(r)
// At this point the request contains only LLM provider headers.
// Any Coder-specific authentication has already been stripped.
//
// In centralized mode neither Authorization nor X-Api-Key is
// present, so cfg keeps the centralized key unchanged.
//
// In BYOK mode the user's LLM credentials survive intact.
// If X-Api-Key is present the user has a personal API key;
// overwrite the centralized key with it. If Authorization is
// present the user authenticated directly with provider;
// set BYOKBearerToken and clear the centralized key.
// When both are present, X-Api-Key takes priority to match
// claude-code behavior.
credKind := intercept.CredentialKindCentralized
credSecret := cfg.Key
authHeaderName := p.AuthHeader()
if apiKey := r.Header.Get("X-Api-Key"); apiKey != "" {
cfg.Key = apiKey
authHeaderName = "X-Api-Key"
credKind = intercept.CredentialKindBYOK
credSecret = apiKey
} else if token := utils.ExtractBearerToken(r.Header.Get("Authorization")); token != "" {
cfg.BYOKBearerToken = token
cfg.Key = ""
authHeaderName = "Authorization"
credKind = intercept.CredentialKindBYOK
credSecret = token
}
cred := intercept.NewCredentialInfo(credKind, credSecret)
var interceptor intercept.Interceptor
if reqPayload.Stream() {
interceptor = messages.NewStreamingInterceptor(id, reqPayload, p.Name(), cfg, p.bedrockCfg, r.Header, authHeaderName, tracer, cred)
} else {
interceptor = messages.NewBlockingInterceptor(id, reqPayload, p.Name(), cfg, p.bedrockCfg, r.Header, authHeaderName, tracer, cred)
}
span.SetAttributes(interceptor.TraceAttributes(r)...)
return interceptor, nil
}
func (p *Anthropic) BaseURL() string {
return p.cfg.BaseURL
}
func (p *Anthropic) AuthHeader() string {
func (*Anthropic) AuthHeader() string {
return "X-Api-Key"
}
+3 -3
View File
@@ -1,4 +1,4 @@
package provider
package provider //nolint:testpackage // tests unexported internals
import (
"bytes"
@@ -146,7 +146,7 @@ func TestAnthropic_CreateInterceptor(t *testing.T) {
assert.Empty(t, receivedHeaders.Get("Authorization"), "client Authorization header must not reach upstream")
})
t.Run("UnknownRoute", func(t *testing.T) {
t.Run("ErrUnknownRoute", func(t *testing.T) {
t.Parallel()
body := `{"model": "claude-opus-4-5", "max_tokens": 1024, "messages": [{"role": "user", "content": "hello"}]}`
@@ -155,7 +155,7 @@ func TestAnthropic_CreateInterceptor(t *testing.T) {
interceptor, err := provider.CreateInterceptor(w, req, testTracer)
require.ErrorIs(t, err, UnknownRoute)
require.ErrorIs(t, err, ErrUnknownRoute)
require.Nil(t, interceptor)
})
}
+7 -7
View File
@@ -72,7 +72,7 @@ func NewCopilot(cfg config.Copilot) *Copilot {
}
}
func (p *Copilot) Type() string {
func (*Copilot) Type() string {
return config.ProviderCopilot
}
@@ -88,14 +88,14 @@ func (p *Copilot) RoutePrefix() string {
return fmt.Sprintf("/%s", p.Name())
}
func (p *Copilot) BridgedRoutes() []string {
func (*Copilot) BridgedRoutes() []string {
return []string{
routeCopilotChatCompletions,
routeCopilotResponses,
}
}
func (p *Copilot) PassthroughRoutes() []string {
func (*Copilot) PassthroughRoutes() []string {
return []string{
"/models",
"/models/",
@@ -105,7 +105,7 @@ func (p *Copilot) PassthroughRoutes() []string {
}
}
func (p *Copilot) AuthHeader() string {
func (*Copilot) AuthHeader() string {
return "Authorization"
}
@@ -113,7 +113,7 @@ func (p *Copilot) AuthHeader() string {
// Copilot uses per-user tokens passed in the original Authorization header,
// rather than a global key configured at the provider level.
// The original Authorization header flows through untouched from the client.
func (p *Copilot) InjectAuthHeader(_ *http.Header) {}
func (*Copilot) InjectAuthHeader(_ *http.Header) {}
func (p *Copilot) CircuitBreakerConfig() *config.CircuitBreaker {
return p.circuitBreaker
@@ -170,7 +170,7 @@ func (p *Copilot) CreateInterceptor(_ http.ResponseWriter, r *http.Request, trac
if err != nil {
return nil, xerrors.Errorf("read body: %w", err)
}
reqPayload, err := responses.NewResponsesRequestPayload(payload)
reqPayload, err := responses.NewRequestPayload(payload)
if err != nil {
return nil, xerrors.Errorf("unmarshal request body: %w", err)
}
@@ -183,7 +183,7 @@ func (p *Copilot) CreateInterceptor(_ http.ResponseWriter, r *http.Request, trac
default:
span.SetStatus(codes.Error, "unknown route: "+r.URL.Path)
return nil, UnknownRoute
return nil, ErrUnknownRoute
}
span.SetAttributes(interceptor.TraceAttributes(r)...)
+3 -3
View File
@@ -1,4 +1,4 @@
package provider
package provider //nolint:testpackage // tests unexported internals
import (
"bytes"
@@ -300,7 +300,7 @@ func TestCopilot_CreateInterceptor(t *testing.T) {
assert.Empty(t, receivedHeaders.Get("X-Api-Key"), "X-Api-Key must not be set upstream")
})
t.Run("UnknownRoute", func(t *testing.T) {
t.Run("ErrUnknownRoute", func(t *testing.T) {
t.Parallel()
body := `{"model": "gpt-4.1", "messages": [{"role": "user", "content": "hello"}]}`
@@ -310,7 +310,7 @@ func TestCopilot_CreateInterceptor(t *testing.T) {
interceptor, err := provider.CreateInterceptor(w, req, testTracer)
require.ErrorIs(t, err, UnknownRoute)
require.ErrorIs(t, err, ErrUnknownRoute)
require.Nil(t, interceptor)
})
}
+7 -7
View File
@@ -61,7 +61,7 @@ func NewOpenAI(cfg config.OpenAI) *OpenAI {
}
}
func (p *OpenAI) Type() string {
func (*OpenAI) Type() string {
return config.ProviderOpenAI
}
@@ -75,7 +75,7 @@ func (p *OpenAI) RoutePrefix() string {
return fmt.Sprintf("/%s/v1", p.Name())
}
func (p *OpenAI) BridgedRoutes() []string {
func (*OpenAI) BridgedRoutes() []string {
return []string{
routeChatCompletions,
routeResponses,
@@ -86,7 +86,7 @@ func (p *OpenAI) BridgedRoutes() []string {
// but must be passed through to the upstream.
// The /v1/completions legacy API is deprecated and will not be passed through.
// See https://platform.openai.com/docs/api-reference/completions.
func (p *OpenAI) PassthroughRoutes() []string {
func (*OpenAI) PassthroughRoutes() []string {
return []string{
// See https://pkg.go.dev/net/http#hdr-Trailing_slash_redirection-ServeMux.
// but without non trailing slash route requests to `/v1/conversations` are going to catch all
@@ -98,7 +98,7 @@ func (p *OpenAI) PassthroughRoutes() []string {
}
}
func (p *OpenAI) CreateInterceptor(w http.ResponseWriter, r *http.Request, tracer trace.Tracer) (_ intercept.Interceptor, outErr error) {
func (p *OpenAI) CreateInterceptor(_ http.ResponseWriter, r *http.Request, tracer trace.Tracer) (_ intercept.Interceptor, outErr error) {
id := uuid.New()
_, span := tracer.Start(r.Context(), "Intercept.CreateInterceptor")
@@ -141,7 +141,7 @@ func (p *OpenAI) CreateInterceptor(w http.ResponseWriter, r *http.Request, trace
if err != nil {
return nil, xerrors.Errorf("read body: %w", err)
}
reqPayload, err := responses.NewResponsesRequestPayload(payload)
reqPayload, err := responses.NewRequestPayload(payload)
if err != nil {
return nil, xerrors.Errorf("unmarshal request body: %w", err)
}
@@ -153,7 +153,7 @@ func (p *OpenAI) CreateInterceptor(w http.ResponseWriter, r *http.Request, trace
default:
span.SetStatus(codes.Error, "unknown route: "+r.URL.Path)
return nil, UnknownRoute
return nil, ErrUnknownRoute
}
span.SetAttributes(interceptor.TraceAttributes(r)...)
return interceptor, nil
@@ -163,7 +163,7 @@ func (p *OpenAI) BaseURL() string {
return p.cfg.BaseURL
}
func (p *OpenAI) AuthHeader() string {
func (*OpenAI) AuthHeader() string {
return "Authorization"
}
+1 -1
View File
@@ -1,4 +1,4 @@
package provider
package provider //nolint:testpackage // tests unexported internals
import (
"bytes"
+1 -1
View File
@@ -10,7 +10,7 @@ import (
"github.com/coder/coder/v2/aibridge/intercept"
)
var UnknownRoute = xerrors.New("unknown route")
var ErrUnknownRoute = xerrors.New("unknown route")
// Provider defines routes (bridged and passed through) for given provider.
// Bridged routes are processed by dedicated interceptors.
+12 -12
View File
@@ -14,19 +14,19 @@ import (
)
var (
_ Recorder = &RecorderWrapper{}
_ Recorder = &WrappedRecorder{}
_ Recorder = &AsyncRecorder{}
)
// RecorderWrapper is a convenience struct which implements RecorderClient and resolves a client before calling each method.
// WrappedRecorder is a convenience struct which implements RecorderClient and resolves a client before calling each method.
// It also sets the start/creation time of each record.
type RecorderWrapper struct {
type WrappedRecorder struct {
logger slog.Logger
tracer trace.Tracer
clientFn func() (Recorder, error)
}
func (r *RecorderWrapper) RecordInterception(ctx context.Context, req *InterceptionRecord) (outErr error) {
func (r *WrappedRecorder) RecordInterception(ctx context.Context, req *InterceptionRecord) (outErr error) {
ctx, span := r.tracer.Start(ctx, "Intercept.RecordInterception", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...))
defer tracing.EndSpanErr(span, &outErr)
@@ -44,7 +44,7 @@ func (r *RecorderWrapper) RecordInterception(ctx context.Context, req *Intercept
return err
}
func (r *RecorderWrapper) RecordInterceptionEnded(ctx context.Context, req *InterceptionRecordEnded) (outErr error) {
func (r *WrappedRecorder) RecordInterceptionEnded(ctx context.Context, req *InterceptionRecordEnded) (outErr error) {
ctx, span := r.tracer.Start(ctx, "Intercept.RecordInterceptionEnded", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...))
defer tracing.EndSpanErr(span, &outErr)
@@ -62,7 +62,7 @@ func (r *RecorderWrapper) RecordInterceptionEnded(ctx context.Context, req *Inte
return err
}
func (r *RecorderWrapper) RecordPromptUsage(ctx context.Context, req *PromptUsageRecord) (outErr error) {
func (r *WrappedRecorder) RecordPromptUsage(ctx context.Context, req *PromptUsageRecord) (outErr error) {
ctx, span := r.tracer.Start(ctx, "Intercept.RecordPromptUsage", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...))
defer tracing.EndSpanErr(span, &outErr)
@@ -80,7 +80,7 @@ func (r *RecorderWrapper) RecordPromptUsage(ctx context.Context, req *PromptUsag
return err
}
func (r *RecorderWrapper) RecordTokenUsage(ctx context.Context, req *TokenUsageRecord) (outErr error) {
func (r *WrappedRecorder) RecordTokenUsage(ctx context.Context, req *TokenUsageRecord) (outErr error) {
ctx, span := r.tracer.Start(ctx, "Intercept.RecordTokenUsage", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...))
defer tracing.EndSpanErr(span, &outErr)
@@ -98,7 +98,7 @@ func (r *RecorderWrapper) RecordTokenUsage(ctx context.Context, req *TokenUsageR
return err
}
func (r *RecorderWrapper) RecordToolUsage(ctx context.Context, req *ToolUsageRecord) (outErr error) {
func (r *WrappedRecorder) RecordToolUsage(ctx context.Context, req *ToolUsageRecord) (outErr error) {
ctx, span := r.tracer.Start(ctx, "Intercept.RecordToolUsage", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...))
defer tracing.EndSpanErr(span, &outErr)
@@ -116,7 +116,7 @@ func (r *RecorderWrapper) RecordToolUsage(ctx context.Context, req *ToolUsageRec
return err
}
func (r *RecorderWrapper) RecordModelThought(ctx context.Context, req *ModelThoughtRecord) (outErr error) {
func (r *WrappedRecorder) RecordModelThought(ctx context.Context, req *ModelThoughtRecord) (outErr error) {
ctx, span := r.tracer.Start(ctx, "Intercept.RecordModelThought", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...))
defer tracing.EndSpanErr(span, &outErr)
@@ -134,8 +134,8 @@ func (r *RecorderWrapper) RecordModelThought(ctx context.Context, req *ModelThou
return err
}
func NewRecorder(logger slog.Logger, tracer trace.Tracer, clientFn func() (Recorder, error)) *RecorderWrapper {
return &RecorderWrapper{
func NewWrappedRecorder(logger slog.Logger, tracer trace.Tracer, clientFn func() (Recorder, error)) *WrappedRecorder {
return &WrappedRecorder{
logger: logger,
tracer: tracer,
clientFn: clientFn,
@@ -185,7 +185,7 @@ func (a *AsyncRecorder) WithClient(client string) {
// RecordInterception must NOT be called asynchronously.
// If an interception cannot be recorded, the whole request should fail.
func (a *AsyncRecorder) RecordInterception(ctx context.Context, req *InterceptionRecord) error {
func (*AsyncRecorder) RecordInterception(context.Context, *InterceptionRecord) error {
panic("RecordInterception must not be called asynchronously")
}
+2 -2
View File
@@ -14,10 +14,10 @@ import (
var claudeCodePattern = regexp.MustCompile(`_session_(.+)$`) // Legacy format: save compilation on each call.
// guessSessionID attempts to retrieve a session ID which may have been sent by
// GuessSessionID attempts to retrieve a session ID which may have been sent by
// the client. We only attempt to retrieve sessions using methods recognized for
// the given client.
func guessSessionID(client Client, r *http.Request) *string {
func GuessSessionID(client Client, r *http.Request) *string {
switch client {
case ClientClaudeCode:
// Prefer the dedicated header (added in Claude Code v2.1.86+).
+39 -38
View File
@@ -1,4 +1,4 @@
package aibridge
package aibridge_test
import (
"io"
@@ -8,6 +8,7 @@ import (
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/aibridge"
"github.com/coder/coder/v2/aibridge/utils"
)
@@ -16,7 +17,7 @@ func TestGuessSessionID(t *testing.T) {
cases := []struct {
name string
client Client
client aibridge.Client
body string
headers map[string]string
sessionID *string
@@ -24,177 +25,177 @@ func TestGuessSessionID(t *testing.T) {
// Claude Code.
{
name: "claude_code_header_takes_precedence",
client: ClientClaudeCode,
client: aibridge.ClientClaudeCode,
headers: map[string]string{"X-Claude-Code-Session-Id": "header-session-id"},
body: `{"metadata":{"user_id":"user_abc123_account_456_session_body-session-id"}}`,
sessionID: utils.PtrTo("header-session-id"),
},
{
name: "claude_code_header_only",
client: ClientClaudeCode,
client: aibridge.ClientClaudeCode,
headers: map[string]string{"X-Claude-Code-Session-Id": "aabb-ccdd"},
body: `{"model":"claude-3"}`,
sessionID: utils.PtrTo("aabb-ccdd"),
},
{
name: "claude_code_empty_header_falls_back_to_body",
client: ClientClaudeCode,
client: aibridge.ClientClaudeCode,
headers: map[string]string{"X-Claude-Code-Session-Id": ""},
body: `{"metadata":{"user_id":"user_abc123_account_456_session_f47ac10b-58cc-4372-a567-0e02b2c3d479"}}`,
sessionID: utils.PtrTo("f47ac10b-58cc-4372-a567-0e02b2c3d479"),
},
{
name: "claude_code_whitespace_header_falls_back_to_body",
client: ClientClaudeCode,
client: aibridge.ClientClaudeCode,
headers: map[string]string{"X-Claude-Code-Session-Id": " "},
body: `{"metadata":{"user_id":"user_abc123_account_456_session_f47ac10b-58cc-4372-a567-0e02b2c3d479"}}`,
sessionID: utils.PtrTo("f47ac10b-58cc-4372-a567-0e02b2c3d479"),
},
{
name: "claude_code_with_valid_session",
client: ClientClaudeCode,
client: aibridge.ClientClaudeCode,
body: `{"metadata":{"user_id":"user_abc123_account_456_session_f47ac10b-58cc-4372-a567-0e02b2c3d479"}}`,
sessionID: utils.PtrTo("f47ac10b-58cc-4372-a567-0e02b2c3d479"),
},
{
name: "claude_code_with_valid_session_new_format",
client: ClientClaudeCode,
client: aibridge.ClientClaudeCode,
body: `{"metadata":{"user_id":"{\"device_id\":\"45aa15c8c244ea2582f8144dde91a50ec3815851f6f648abef4ee15b173cc927\",\"account_uuid\":\"\",\"session_id\":\"54c1eb09-bc4c-4d2f-98eb-6d2ab2d5e2fe\"}"}}`,
sessionID: utils.PtrTo("54c1eb09-bc4c-4d2f-98eb-6d2ab2d5e2fe"),
},
{
name: "claude_code_new_format_empty_session_id",
client: ClientClaudeCode,
client: aibridge.ClientClaudeCode,
body: `{"metadata":{"user_id":"{\"device_id\":\"abc\",\"account_uuid\":\"\",\"session_id\":\"\"}"}}`,
},
{
name: "claude_code_new_format_no_session_id_field",
client: ClientClaudeCode,
client: aibridge.ClientClaudeCode,
body: `{"metadata":{"user_id":"{\"device_id\":\"abc\",\"account_uuid\":\"\"}"}}`,
},
{
name: "claude_code_missing_metadata",
client: ClientClaudeCode,
client: aibridge.ClientClaudeCode,
body: `{"model":"claude-3"}`,
},
{
name: "claude_code_missing_user_id",
client: ClientClaudeCode,
client: aibridge.ClientClaudeCode,
body: `{"metadata":{}}`,
},
{
name: "claude_code_user_id_without_session",
client: ClientClaudeCode,
client: aibridge.ClientClaudeCode,
body: `{"metadata":{"user_id":"user_abc123_account_456"}}`,
},
{
name: "claude_code_empty_body",
client: ClientClaudeCode,
client: aibridge.ClientClaudeCode,
body: ``,
},
{
name: "claude_code_invalid_json",
client: ClientClaudeCode,
client: aibridge.ClientClaudeCode,
body: `not json at all`,
},
// Codex.
{
name: "codex_with_session_header",
client: ClientCodex,
client: aibridge.ClientCodex,
headers: map[string]string{"session_id": "codex-session-123"},
sessionID: utils.PtrTo("codex-session-123"),
},
{
name: "codex_with_whitespace_in_header",
client: ClientCodex,
client: aibridge.ClientCodex,
headers: map[string]string{"session_id": " codex-session-123 "},
sessionID: utils.PtrTo("codex-session-123"),
},
{
name: "codex_without_session_header",
client: ClientCodex,
client: aibridge.ClientCodex,
},
// Other clients shouldn't use others' logic.
{
name: "unknown_client_returns_empty",
client: ClientUnknown,
client: aibridge.ClientUnknown,
body: `{"metadata":{"user_id":"user_abc_account_456_session_some-id"}}`,
},
{
name: "zed_returns_empty",
client: ClientZed,
client: aibridge.ClientZed,
headers: map[string]string{"session_id": "zed-session"},
body: `{"metadata":{"user_id":"user_abc_account_456_session_some-id"}}`,
},
// Mux.
{
name: "mux_with_workspace_header",
client: ClientMux,
client: aibridge.ClientMux,
headers: map[string]string{"X-Mux-Workspace-Id": "ws-abc-123"},
sessionID: utils.PtrTo("ws-abc-123"),
},
{
name: "mux_without_workspace_header",
client: ClientMux,
client: aibridge.ClientMux,
},
// Copilot VS Code.
{
name: "copilot_vsc_with_interaction_id",
client: ClientCopilotVSC,
client: aibridge.ClientCopilotVSC,
headers: map[string]string{"x-interaction-id": "interaction-xyz"},
sessionID: utils.PtrTo("interaction-xyz"),
},
{
name: "copilot_vsc_without_interaction_id",
client: ClientCopilotVSC,
client: aibridge.ClientCopilotVSC,
},
// Copilot CLI.
{
name: "copilot_cli_with_session_header",
client: ClientCopilotCLI,
client: aibridge.ClientCopilotCLI,
headers: map[string]string{"X-Client-Session-Id": "cli-sess-456"},
sessionID: utils.PtrTo("cli-sess-456"),
},
{
name: "copilot_cli_without_session_header",
client: ClientCopilotCLI,
client: aibridge.ClientCopilotCLI,
},
// Kilo.
{
name: "kilo_with_task_id",
client: ClientKilo,
client: aibridge.ClientKilo,
headers: map[string]string{"X-KILOCODE-TASKID": "task-789"},
sessionID: utils.PtrTo("task-789"),
},
{
name: "kilo_without_task_id",
client: ClientKilo,
client: aibridge.ClientKilo,
},
// Coder Agents.
{
name: "coder_agents_with_chat_id",
client: ClientCoderAgents,
client: aibridge.ClientCoderAgents,
headers: map[string]string{"X-Coder-Chat-Id": "a1b2c3d4-e5f6-7890-abcd-ef1234567890"},
sessionID: utils.PtrTo("a1b2c3d4-e5f6-7890-abcd-ef1234567890"),
},
{
name: "coder_agents_without_chat_id",
client: ClientCoderAgents,
client: aibridge.ClientCoderAgents,
},
// Roo.
{
name: "roo_returns_empty",
client: ClientRoo,
client: aibridge.ClientRoo,
},
// Cursor.
{
name: "cursor_returns_empty",
client: ClientCursor,
client: aibridge.ClientCursor,
},
// Other cases.
{
name: "empty session ID value",
client: ClientKilo,
client: aibridge.ClientKilo,
headers: map[string]string{"X-KILOCODE-TASKID": " "},
sessionID: nil,
},
@@ -205,14 +206,14 @@ func TestGuessSessionID(t *testing.T) {
t.Parallel()
body := tc.body
req, err := http.NewRequest(http.MethodPost, "http://localhost", strings.NewReader(body))
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, "http://localhost", strings.NewReader(body))
require.NoError(t, err)
for key, value := range tc.headers {
req.Header.Set(key, value)
}
got := guessSessionID(tc.client, req)
got := aibridge.GuessSessionID(tc.client, req)
require.Equal(t, tc.sessionID, got)
// Verify the body was restored and can be read again.
@@ -226,16 +227,16 @@ func TestGuessSessionID(t *testing.T) {
func TestUnreadableBody(t *testing.T) {
t.Parallel()
req, err := http.NewRequest(http.MethodPost, "http://localhost", &errReader{})
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, "http://localhost", &errReader{})
require.NoError(t, err)
got := guessSessionID(ClientClaudeCode, req)
got := aibridge.GuessSessionID(aibridge.ClientClaudeCode, req)
require.Nil(t, got)
}
// errReader is an io.Reader that always returns an error.
type errReader struct{}
func (e *errReader) Read([]byte) (int, error) {
func (*errReader) Read([]byte) (int, error) {
return 0, io.ErrUnexpectedEOF
}
+8
View File
@@ -18,11 +18,15 @@ func TestConcurrentGroup(t *testing.T) {
t.Parallel()
t.Run("no goroutines", func(t *testing.T) {
t.Parallel()
cg := utils.NewConcurrentGroup()
require.NoError(t, cg.Wait())
})
t.Run("multiple goroutines, all ok", func(t *testing.T) {
t.Parallel()
cg := utils.NewConcurrentGroup()
cg.Go(func() error {
return nil
@@ -34,6 +38,8 @@ func TestConcurrentGroup(t *testing.T) {
})
t.Run("multiple goroutines, one err", func(t *testing.T) {
t.Parallel()
cg := utils.NewConcurrentGroup()
oops := xerrors.New("oops")
cg.Go(func() error {
@@ -46,6 +52,8 @@ func TestConcurrentGroup(t *testing.T) {
})
t.Run("multiple goroutines, multiple errs", func(t *testing.T) {
t.Parallel()
cg := utils.NewConcurrentGroup()
oops := xerrors.New("oops")
eek := xerrors.New("eek")