update aibridge fies with lint changes
This commit is contained in:
+1
-1
@@ -62,5 +62,5 @@ func NewMetrics(reg prometheus.Registerer) *metrics.Metrics {
|
||||
}
|
||||
|
||||
func NewRecorder(logger slog.Logger, tracer trace.Tracer, clientFn func() (Recorder, error)) Recorder {
|
||||
return recorder.NewRecorder(logger, tracer, clientFn)
|
||||
return recorder.NewWrappedRecorder(logger, tracer, clientFn)
|
||||
}
|
||||
|
||||
+3
-3
@@ -180,8 +180,8 @@ func newInterceptionProcessor(p provider.Provider, cbs *circuitbreaker.ProviderC
|
||||
|
||||
// We execute this before CreateInterceptor since the interceptors
|
||||
// read the request body and don't reset them.
|
||||
client := guessClient(r)
|
||||
sessionID := guessSessionID(client, r)
|
||||
client := GuessClient(r)
|
||||
sessionID := GuessSessionID(client, r)
|
||||
|
||||
interceptor, err := p.CreateInterceptor(w, r.WithContext(ctx), tracer)
|
||||
if err != nil {
|
||||
@@ -276,7 +276,7 @@ func newInterceptionProcessor(p provider.Provider, cbs *circuitbreaker.ProviderC
|
||||
log.Debug(ctx, "interception ended")
|
||||
}
|
||||
|
||||
asyncRecorder.RecordInterceptionEnded(ctx, &recorder.InterceptionRecordEnded{ID: interceptor.ID().String()})
|
||||
_ = asyncRecorder.RecordInterceptionEnded(ctx, &recorder.InterceptionRecordEnded{ID: interceptor.ID().String()})
|
||||
|
||||
// Ensure all recording have completed before completing request.
|
||||
asyncRecorder.Wait()
|
||||
|
||||
+38
-57
@@ -1,4 +1,4 @@
|
||||
package aibridge
|
||||
package aibridge_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
@@ -7,16 +7,22 @@ import (
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.opentelemetry.io/otel"
|
||||
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/aibridge"
|
||||
"github.com/coder/coder/v2/aibridge/config"
|
||||
"github.com/coder/coder/v2/aibridge/internal/testutil"
|
||||
"github.com/coder/coder/v2/aibridge/provider"
|
||||
)
|
||||
|
||||
func TestValidateProvider_Names(t *testing.T) {
|
||||
var bridgeTestTracer = otel.Tracer("bridge_test")
|
||||
|
||||
func TestValidateProviders(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
providers []provider.Provider
|
||||
@@ -25,94 +31,69 @@ func TestValidateProvider_Names(t *testing.T) {
|
||||
{
|
||||
name: "all_supported_providers",
|
||||
providers: []provider.Provider{
|
||||
NewOpenAIProvider(config.OpenAI{Name: "openai", BaseURL: "https://api.openai.com/v1/"}),
|
||||
NewAnthropicProvider(config.Anthropic{Name: "anthropic", BaseURL: "https://api.anthropic.com/"}, nil),
|
||||
NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.individual.githubcopilot.com"}),
|
||||
NewCopilotProvider(config.Copilot{Name: "copilot-business", BaseURL: "https://api.business.githubcopilot.com"}),
|
||||
NewCopilotProvider(config.Copilot{Name: "copilot-enterprise", BaseURL: "https://api.enterprise.githubcopilot.com"}),
|
||||
aibridge.NewOpenAIProvider(config.OpenAI{Name: "openai", BaseURL: "https://api.openai.com/v1/"}),
|
||||
aibridge.NewAnthropicProvider(config.Anthropic{Name: "anthropic", BaseURL: "https://api.anthropic.com/"}, nil),
|
||||
aibridge.NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.individual.githubcopilot.com"}),
|
||||
aibridge.NewCopilotProvider(config.Copilot{Name: "copilot-business", BaseURL: "https://api.business.githubcopilot.com"}),
|
||||
aibridge.NewCopilotProvider(config.Copilot{Name: "copilot-enterprise", BaseURL: "https://api.enterprise.githubcopilot.com"}),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "default_names_and_base_urls",
|
||||
providers: []provider.Provider{
|
||||
NewOpenAIProvider(config.OpenAI{}),
|
||||
NewAnthropicProvider(config.Anthropic{}, nil),
|
||||
NewCopilotProvider(config.Copilot{}),
|
||||
aibridge.NewOpenAIProvider(config.OpenAI{}),
|
||||
aibridge.NewAnthropicProvider(config.Anthropic{}, nil),
|
||||
aibridge.NewCopilotProvider(config.Copilot{}),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple_copilot_instances",
|
||||
providers: []provider.Provider{
|
||||
NewCopilotProvider(config.Copilot{}),
|
||||
NewCopilotProvider(config.Copilot{Name: "copilot-business", BaseURL: "https://api.business.githubcopilot.com"}),
|
||||
NewCopilotProvider(config.Copilot{Name: "copilot-enterprise", BaseURL: "https://api.enterprise.githubcopilot.com"}),
|
||||
aibridge.NewCopilotProvider(config.Copilot{}),
|
||||
aibridge.NewCopilotProvider(config.Copilot{Name: "copilot-business", BaseURL: "https://api.business.githubcopilot.com"}),
|
||||
aibridge.NewCopilotProvider(config.Copilot{Name: "copilot-enterprise", BaseURL: "https://api.enterprise.githubcopilot.com"}),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "name_with_slashes",
|
||||
providers: []provider.Provider{
|
||||
NewCopilotProvider(config.Copilot{Name: "copilot/business", BaseURL: "https://api.business.githubcopilot.com"}),
|
||||
aibridge.NewCopilotProvider(config.Copilot{Name: "copilot/business", BaseURL: "https://api.business.githubcopilot.com"}),
|
||||
},
|
||||
expectErr: "invalid provider name",
|
||||
},
|
||||
{
|
||||
name: "name_with_spaces",
|
||||
providers: []provider.Provider{
|
||||
NewCopilotProvider(config.Copilot{Name: "copilot business", BaseURL: "https://api.business.githubcopilot.com"}),
|
||||
aibridge.NewCopilotProvider(config.Copilot{Name: "copilot business", BaseURL: "https://api.business.githubcopilot.com"}),
|
||||
},
|
||||
expectErr: "invalid provider name",
|
||||
},
|
||||
{
|
||||
name: "name_with_uppercase",
|
||||
providers: []provider.Provider{
|
||||
NewCopilotProvider(config.Copilot{Name: "Copilot", BaseURL: "https://api.business.githubcopilot.com"}),
|
||||
aibridge.NewCopilotProvider(config.Copilot{Name: "Copilot", BaseURL: "https://api.business.githubcopilot.com"}),
|
||||
},
|
||||
expectErr: "invalid provider name",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
err := validateProviders(tc.providers)
|
||||
if tc.expectErr != "" {
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tc.expectErr)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateProvider_DuplicateNames(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
providers []provider.Provider
|
||||
expectErr string
|
||||
}{
|
||||
{
|
||||
name: "unique_names",
|
||||
providers: []provider.Provider{
|
||||
NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.individual.githubcopilot.com"}),
|
||||
NewCopilotProvider(config.Copilot{Name: "copilot-business", BaseURL: "https://api.business.githubcopilot.com"}),
|
||||
aibridge.NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.individual.githubcopilot.com"}),
|
||||
aibridge.NewCopilotProvider(config.Copilot{Name: "copilot-business", BaseURL: "https://api.business.githubcopilot.com"}),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "duplicate_base_url_different_names",
|
||||
providers: []provider.Provider{
|
||||
NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.individual.githubcopilot.com"}),
|
||||
NewCopilotProvider(config.Copilot{Name: "copilot-business", BaseURL: "https://api.individual.githubcopilot.com"}),
|
||||
aibridge.NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.individual.githubcopilot.com"}),
|
||||
aibridge.NewCopilotProvider(config.Copilot{Name: "copilot-business", BaseURL: "https://api.individual.githubcopilot.com"}),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "duplicate_name",
|
||||
providers: []provider.Provider{
|
||||
NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.individual.githubcopilot.com"}),
|
||||
NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.business.githubcopilot.com"}),
|
||||
aibridge.NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.individual.githubcopilot.com"}),
|
||||
aibridge.NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.business.githubcopilot.com"}),
|
||||
},
|
||||
expectErr: "duplicate provider name",
|
||||
},
|
||||
@@ -122,7 +103,7 @@ func TestValidateProvider_DuplicateNames(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
err := validateProviders(tc.providers)
|
||||
_, err := aibridge.NewRequestBridge(t.Context(), tc.providers, nil, nil, logger, nil, bridgeTestTracer)
|
||||
if tc.expectErr != "" {
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tc.expectErr)
|
||||
@@ -148,7 +129,7 @@ func TestPassthroughRoutesForProviders(t *testing.T) {
|
||||
name: "openAI_no_base_path",
|
||||
requestPath: "/openai/v1/conversations",
|
||||
provider: func(baseURL string) provider.Provider {
|
||||
return NewOpenAIProvider(config.OpenAI{BaseURL: baseURL})
|
||||
return aibridge.NewOpenAIProvider(config.OpenAI{BaseURL: baseURL})
|
||||
},
|
||||
expectPath: "/conversations",
|
||||
},
|
||||
@@ -157,7 +138,7 @@ func TestPassthroughRoutesForProviders(t *testing.T) {
|
||||
baseURLPath: "/v1",
|
||||
requestPath: "/openai/v1/conversations",
|
||||
provider: func(baseURL string) provider.Provider {
|
||||
return NewOpenAIProvider(config.OpenAI{BaseURL: baseURL})
|
||||
return aibridge.NewOpenAIProvider(config.OpenAI{BaseURL: baseURL})
|
||||
},
|
||||
expectPath: "/v1/conversations",
|
||||
},
|
||||
@@ -165,7 +146,7 @@ func TestPassthroughRoutesForProviders(t *testing.T) {
|
||||
name: "anthropic_no_base_path",
|
||||
requestPath: "/anthropic/v1/models",
|
||||
provider: func(baseURL string) provider.Provider {
|
||||
return NewAnthropicProvider(config.Anthropic{BaseURL: baseURL}, nil)
|
||||
return aibridge.NewAnthropicProvider(config.Anthropic{BaseURL: baseURL}, nil)
|
||||
},
|
||||
expectPath: "/v1/models",
|
||||
},
|
||||
@@ -174,7 +155,7 @@ func TestPassthroughRoutesForProviders(t *testing.T) {
|
||||
baseURLPath: "/v1",
|
||||
requestPath: "/anthropic/v1/models",
|
||||
provider: func(baseURL string) provider.Provider {
|
||||
return NewAnthropicProvider(config.Anthropic{BaseURL: baseURL}, nil)
|
||||
return aibridge.NewAnthropicProvider(config.Anthropic{BaseURL: baseURL}, nil)
|
||||
},
|
||||
expectPath: "/v1/v1/models",
|
||||
},
|
||||
@@ -182,7 +163,7 @@ func TestPassthroughRoutesForProviders(t *testing.T) {
|
||||
name: "copilot_no_base_path",
|
||||
requestPath: "/copilot/models",
|
||||
provider: func(baseURL string) provider.Provider {
|
||||
return NewCopilotProvider(config.Copilot{BaseURL: baseURL})
|
||||
return aibridge.NewCopilotProvider(config.Copilot{BaseURL: baseURL})
|
||||
},
|
||||
expectPath: "/models",
|
||||
},
|
||||
@@ -191,7 +172,7 @@ func TestPassthroughRoutesForProviders(t *testing.T) {
|
||||
baseURLPath: "/v1",
|
||||
requestPath: "/copilot/models",
|
||||
provider: func(baseURL string) provider.Provider {
|
||||
return NewCopilotProvider(config.Copilot{BaseURL: baseURL})
|
||||
return aibridge.NewCopilotProvider(config.Copilot{BaseURL: baseURL})
|
||||
},
|
||||
expectPath: "/v1/models",
|
||||
},
|
||||
@@ -210,14 +191,14 @@ func TestPassthroughRoutesForProviders(t *testing.T) {
|
||||
}))
|
||||
t.Cleanup(upstream.Close)
|
||||
|
||||
recorder := testutil.MockRecorder{}
|
||||
rec := testutil.MockRecorder{}
|
||||
prov := tc.provider(upstream.URL + tc.baseURLPath)
|
||||
bridge, err := NewRequestBridge(t.Context(), []provider.Provider{prov}, &recorder, nil, logger, nil, testTracer)
|
||||
bridge, err := aibridge.NewRequestBridge(t.Context(), []provider.Provider{prov}, &rec, nil, logger, nil, bridgeTestTracer)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest("", tc.requestPath, nil)
|
||||
resp := httptest.NewRecorder()
|
||||
bridge.mux.ServeHTTP(resp, req)
|
||||
bridge.ServeHTTP(resp, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.Code)
|
||||
assert.Contains(t, resp.Body.String(), upstreamRespBody)
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
package circuitbreaker
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -65,8 +67,8 @@ func (p *ProviderCircuitBreakers) isFailure(statusCode int) bool {
|
||||
return DefaultIsFailure(statusCode)
|
||||
}
|
||||
|
||||
// openErrorResponse returns the error response body when the circuit is open.
|
||||
func (p *ProviderCircuitBreakers) openErrorResponse() []byte {
|
||||
// openErrBody returns the error response body when the circuit is open.
|
||||
func (p *ProviderCircuitBreakers) openErrBody() []byte {
|
||||
if p.config.OpenErrorResponse != nil {
|
||||
return p.config.OpenErrorResponse()
|
||||
}
|
||||
@@ -77,7 +79,7 @@ func (p *ProviderCircuitBreakers) openErrorResponse() []byte {
|
||||
func (p *ProviderCircuitBreakers) Get(endpoint, model string) *gobreaker.CircuitBreaker[struct{}] {
|
||||
key := endpoint + ":" + model
|
||||
if v, ok := p.breakers.Load(key); ok {
|
||||
return v.(*gobreaker.CircuitBreaker[struct{}])
|
||||
return v.(*gobreaker.CircuitBreaker[struct{}]) //nolint:forcetypeassert // sync.Map always stores this type
|
||||
}
|
||||
|
||||
settings := gobreaker.Settings{
|
||||
@@ -97,11 +99,12 @@ func (p *ProviderCircuitBreakers) Get(endpoint, model string) *gobreaker.Circuit
|
||||
|
||||
cb := gobreaker.NewCircuitBreaker[struct{}](settings)
|
||||
actual, _ := p.breakers.LoadOrStore(key, cb)
|
||||
return actual.(*gobreaker.CircuitBreaker[struct{}])
|
||||
return actual.(*gobreaker.CircuitBreaker[struct{}]) //nolint:forcetypeassert // sync.Map always stores this type
|
||||
}
|
||||
|
||||
// statusCapturingWriter wraps http.ResponseWriter to capture the status code.
|
||||
// It also implements http.Flusher to support streaming responses.
|
||||
// It implements http.Flusher to support streaming and http.Hijacker to
|
||||
// satisfy the FullResponseWriter lint rule.
|
||||
type statusCapturingWriter struct {
|
||||
http.ResponseWriter
|
||||
statusCode int
|
||||
@@ -130,6 +133,14 @@ func (w *statusCapturingWriter) Flush() {
|
||||
}
|
||||
}
|
||||
|
||||
func (w *statusCapturingWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
h, ok := w.ResponseWriter.(http.Hijacker)
|
||||
if !ok {
|
||||
return nil, nil, xerrors.New("upstream ResponseWriter does not support hijacking")
|
||||
}
|
||||
return h.Hijack()
|
||||
}
|
||||
|
||||
// Unwrap returns the underlying ResponseWriter for interface checks.
|
||||
func (w *statusCapturingWriter) Unwrap() http.ResponseWriter {
|
||||
return w.ResponseWriter
|
||||
@@ -167,7 +178,7 @@ func (p *ProviderCircuitBreakers) Execute(endpoint, model string, w http.Respons
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Retry-After", fmt.Sprintf("%d", int64(p.config.Timeout.Seconds())))
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
_, _ = w.Write(p.openErrorResponse())
|
||||
_, _ = w.Write(p.openErrBody())
|
||||
return ErrCircuitOpen
|
||||
}
|
||||
|
||||
@@ -187,7 +198,7 @@ func (p *ProviderCircuitBreakers) Provider() string {
|
||||
// OpenErrorResponse returns the error response body when the circuit is open.
|
||||
// This is exposed for handlers to use when responding to rejected requests.
|
||||
func (p *ProviderCircuitBreakers) OpenErrorResponse() []byte {
|
||||
return p.openErrorResponse()
|
||||
return p.openErrBody()
|
||||
}
|
||||
|
||||
// StateToGaugeValue converts gobreaker.State to a gauge value.
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package circuitbreaker
|
||||
package circuitbreaker_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/sony/gobreaker/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/coder/coder/v2/aibridge/circuitbreaker"
|
||||
"github.com/coder/coder/v2/aibridge/config"
|
||||
)
|
||||
|
||||
@@ -20,7 +21,7 @@ func TestExecute_PerModelIsolation(t *testing.T) {
|
||||
sonnetCalls := atomic.Int32{}
|
||||
haikuCalls := atomic.Int32{}
|
||||
|
||||
cbs := NewProviderCircuitBreakers("test", &config.CircuitBreaker{
|
||||
cbs := circuitbreaker.NewProviderCircuitBreakers("test", &config.CircuitBreaker{
|
||||
FailureThreshold: 1,
|
||||
Interval: time.Minute,
|
||||
Timeout: time.Minute,
|
||||
@@ -48,7 +49,7 @@ func TestExecute_PerModelIsolation(t *testing.T) {
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
return nil
|
||||
})
|
||||
assert.True(t, errors.Is(err, ErrCircuitOpen))
|
||||
assert.True(t, errors.Is(err, circuitbreaker.ErrCircuitOpen))
|
||||
assert.Equal(t, int32(1), sonnetCalls.Load()) // No new call
|
||||
assert.Equal(t, http.StatusServiceUnavailable, w.Code)
|
||||
|
||||
@@ -69,7 +70,7 @@ func TestExecute_PerEndpointIsolation(t *testing.T) {
|
||||
messagesCalls := atomic.Int32{}
|
||||
completionsCalls := atomic.Int32{}
|
||||
|
||||
cbs := NewProviderCircuitBreakers("test", &config.CircuitBreaker{
|
||||
cbs := circuitbreaker.NewProviderCircuitBreakers("test", &config.CircuitBreaker{
|
||||
FailureThreshold: 1,
|
||||
Interval: time.Minute,
|
||||
Timeout: time.Minute,
|
||||
@@ -95,7 +96,7 @@ func TestExecute_PerEndpointIsolation(t *testing.T) {
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
return nil
|
||||
})
|
||||
assert.True(t, errors.Is(err, ErrCircuitOpen))
|
||||
assert.True(t, errors.Is(err, circuitbreaker.ErrCircuitOpen))
|
||||
assert.Equal(t, int32(1), messagesCalls.Load()) // No new call
|
||||
assert.Equal(t, http.StatusServiceUnavailable, w.Code)
|
||||
|
||||
@@ -116,7 +117,7 @@ func TestExecute_CustomIsFailure(t *testing.T) {
|
||||
var calls atomic.Int32
|
||||
|
||||
// Custom IsFailure that treats 502 as failure
|
||||
cbs := NewProviderCircuitBreakers("test", &config.CircuitBreaker{
|
||||
cbs := circuitbreaker.NewProviderCircuitBreakers("test", &config.CircuitBreaker{
|
||||
FailureThreshold: 1,
|
||||
Interval: time.Minute,
|
||||
Timeout: time.Minute,
|
||||
@@ -143,7 +144,7 @@ func TestExecute_CustomIsFailure(t *testing.T) {
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
return nil
|
||||
})
|
||||
assert.True(t, errors.Is(err, ErrCircuitOpen))
|
||||
assert.True(t, errors.Is(err, circuitbreaker.ErrCircuitOpen))
|
||||
assert.Equal(t, int32(1), calls.Load()) // No new call
|
||||
assert.Equal(t, http.StatusServiceUnavailable, w.Code)
|
||||
}
|
||||
@@ -158,7 +159,7 @@ func TestExecute_OnStateChange(t *testing.T) {
|
||||
to gobreaker.State
|
||||
}
|
||||
|
||||
cbs := NewProviderCircuitBreakers("test", &config.CircuitBreaker{
|
||||
cbs := circuitbreaker.NewProviderCircuitBreakers("test", &config.CircuitBreaker{
|
||||
FailureThreshold: 1,
|
||||
Interval: time.Minute,
|
||||
Timeout: time.Minute,
|
||||
@@ -177,10 +178,11 @@ func TestExecute_OnStateChange(t *testing.T) {
|
||||
|
||||
// Trip circuit
|
||||
w := httptest.NewRecorder()
|
||||
cbs.Execute(endpoint, model, w, func(rw http.ResponseWriter) error {
|
||||
err := cbs.Execute(endpoint, model, w, func(rw http.ResponseWriter) error {
|
||||
rw.WriteHeader(http.StatusTooManyRequests)
|
||||
return nil
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify state change callback was called with correct parameters
|
||||
assert.Len(t, stateChanges, 1)
|
||||
@@ -208,14 +210,14 @@ func TestDefaultIsFailure(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
assert.Equal(t, tt.isFailure, DefaultIsFailure(tt.statusCode), "status code %d", tt.statusCode)
|
||||
assert.Equal(t, tt.isFailure, circuitbreaker.DefaultIsFailure(tt.statusCode), "status code %d", tt.statusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStateToGaugeValue(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert.Equal(t, float64(0), StateToGaugeValue(gobreaker.StateClosed))
|
||||
assert.Equal(t, float64(0.5), StateToGaugeValue(gobreaker.StateHalfOpen))
|
||||
assert.Equal(t, float64(1), StateToGaugeValue(gobreaker.StateOpen))
|
||||
assert.Equal(t, float64(0), circuitbreaker.StateToGaugeValue(gobreaker.StateClosed))
|
||||
assert.Equal(t, float64(0.5), circuitbreaker.StateToGaugeValue(gobreaker.StateHalfOpen))
|
||||
assert.Equal(t, float64(1), circuitbreaker.StateToGaugeValue(gobreaker.StateOpen))
|
||||
}
|
||||
|
||||
+2
-2
@@ -24,10 +24,10 @@ const (
|
||||
ClientUnknown Client = "Unknown"
|
||||
)
|
||||
|
||||
// guessClient attempts to guess the client application from the request headers.
|
||||
// GuessClient attempts to guess the client application from the request headers.
|
||||
// Not all clients set proper user agent headers, so this is a best-effort approach.
|
||||
// Based on https://github.com/coder/aibridge/issues/20#issuecomment-3769444101.
|
||||
func guessClient(r *http.Request) Client {
|
||||
func GuessClient(r *http.Request) Client {
|
||||
userAgent := strings.ToLower(r.UserAgent())
|
||||
originator := r.Header.Get("originator")
|
||||
|
||||
|
||||
+23
-21
@@ -1,10 +1,12 @@
|
||||
package aibridge
|
||||
package aibridge_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/aibridge"
|
||||
)
|
||||
|
||||
func TestGuessClient(t *testing.T) {
|
||||
@@ -14,93 +16,93 @@ func TestGuessClient(t *testing.T) {
|
||||
name string
|
||||
userAgent string
|
||||
headers map[string]string
|
||||
wantClient Client
|
||||
wantClient aibridge.Client
|
||||
}{
|
||||
{
|
||||
name: "mux",
|
||||
userAgent: "mux/0.19.0-next.2.gcceff159 ai-sdk/openai/3.0.36 ai-sdk/provider-utils/4.0.15 runtime/node.js/22",
|
||||
wantClient: ClientMux,
|
||||
wantClient: aibridge.ClientMux,
|
||||
},
|
||||
{
|
||||
name: "claude_code",
|
||||
userAgent: "claude-cli/2.0.67 (external, cli)",
|
||||
wantClient: ClientClaudeCode,
|
||||
wantClient: aibridge.ClientClaudeCode,
|
||||
},
|
||||
{
|
||||
name: "codex_cli",
|
||||
userAgent: "codex_cli_rs/0.87.0 (Mac OS 26.2.0; arm64) ghostty/1.3.0-main_250877ef",
|
||||
wantClient: ClientCodex,
|
||||
wantClient: aibridge.ClientCodex,
|
||||
},
|
||||
{
|
||||
name: "zed",
|
||||
userAgent: "Zed/0.219.4+stable.119.abc123 (macos; aarch64)",
|
||||
wantClient: ClientZed,
|
||||
wantClient: aibridge.ClientZed,
|
||||
},
|
||||
{
|
||||
name: "github_copilot_vsc",
|
||||
userAgent: "GitHubCopilotChat/0.37.2026011603",
|
||||
wantClient: ClientCopilotVSC,
|
||||
wantClient: aibridge.ClientCopilotVSC,
|
||||
},
|
||||
{
|
||||
name: "github_copilot_cli",
|
||||
userAgent: "copilot/0.0.403 (client/cli linux v24.11.1)",
|
||||
wantClient: ClientCopilotCLI,
|
||||
wantClient: aibridge.ClientCopilotCLI,
|
||||
},
|
||||
{
|
||||
name: "kilo_code_user_agent",
|
||||
userAgent: "kilo-code/5.1.0 (darwin 25.2.0; arm64) node/22.21.1",
|
||||
wantClient: ClientKilo,
|
||||
wantClient: aibridge.ClientKilo,
|
||||
},
|
||||
{
|
||||
name: "kilo_code_originator",
|
||||
headers: map[string]string{"Originator": "kilo-code"},
|
||||
wantClient: ClientKilo,
|
||||
wantClient: aibridge.ClientKilo,
|
||||
},
|
||||
{
|
||||
name: "roo_code_user_agent",
|
||||
userAgent: "roo-code/3.45.0 (darwin 25.2.0; arm64) node/22.21.1",
|
||||
wantClient: ClientRoo,
|
||||
wantClient: aibridge.ClientRoo,
|
||||
},
|
||||
{
|
||||
name: "roo_code_originator",
|
||||
headers: map[string]string{"Originator": "roo-code"},
|
||||
wantClient: ClientRoo,
|
||||
wantClient: aibridge.ClientRoo,
|
||||
},
|
||||
{
|
||||
name: "coder_agents",
|
||||
userAgent: "coder-agents/v2.24.0 (linux/amd64)",
|
||||
wantClient: ClientCoderAgents,
|
||||
wantClient: aibridge.ClientCoderAgents,
|
||||
},
|
||||
{
|
||||
name: "coder_agents_dev",
|
||||
userAgent: "coder-agents/v0.0.0-devel (darwin/arm64)",
|
||||
wantClient: ClientCoderAgents,
|
||||
wantClient: aibridge.ClientCoderAgents,
|
||||
},
|
||||
{
|
||||
name: "charm_crush",
|
||||
userAgent: "Charm Crush/0.1.11",
|
||||
wantClient: ClientCrush,
|
||||
wantClient: aibridge.ClientCrush,
|
||||
},
|
||||
{
|
||||
name: "cursor_x_cursor_client_version",
|
||||
userAgent: "connect-es/1.6.1",
|
||||
headers: map[string]string{"X-Cursor-client-version": "0.50.0"},
|
||||
wantClient: ClientCursor,
|
||||
wantClient: aibridge.ClientCursor,
|
||||
},
|
||||
{
|
||||
name: "cursor_x_cursor_some_other_header",
|
||||
headers: map[string]string{"x-cursor-client-version": "abc123"},
|
||||
wantClient: ClientCursor,
|
||||
wantClient: aibridge.ClientCursor,
|
||||
},
|
||||
{
|
||||
name: "unknown_client",
|
||||
userAgent: "ccclaude-cli/calude-with-wrong-prefix",
|
||||
wantClient: ClientUnknown,
|
||||
wantClient: aibridge.ClientUnknown,
|
||||
},
|
||||
{
|
||||
name: "empty_user_agent",
|
||||
userAgent: "",
|
||||
wantClient: ClientUnknown,
|
||||
wantClient: aibridge.ClientUnknown,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -108,7 +110,7 @@ func TestGuessClient(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, "", nil)
|
||||
req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, "", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
req.Header.Set("User-Agent", tt.userAgent)
|
||||
@@ -116,7 +118,7 @@ func TestGuessClient(t *testing.T) {
|
||||
req.Header.Set(key, value)
|
||||
}
|
||||
|
||||
got := guessClient(req)
|
||||
got := aibridge.GuessClient(req)
|
||||
require.Equal(t, tt.wantClient, got)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package context
|
||||
package context_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
aibcontext "github.com/coder/coder/v2/aibridge/context"
|
||||
"github.com/coder/coder/v2/aibridge/recorder"
|
||||
)
|
||||
|
||||
@@ -17,10 +18,10 @@ func TestAsActor(t *testing.T) {
|
||||
metadata := recorder.Metadata{"key": "value"}
|
||||
|
||||
// When: storing an actor in the context
|
||||
ctx := AsActor(context.Background(), "actor-123", metadata)
|
||||
ctx := aibcontext.AsActor(context.Background(), "actor-123", metadata)
|
||||
|
||||
// Then: the actor should be retrievable with correct ID and metadata
|
||||
actor := ActorFromContext(ctx)
|
||||
actor := aibcontext.ActorFromContext(ctx)
|
||||
require.NotNil(t, actor)
|
||||
assert.Equal(t, "actor-123", actor.ID)
|
||||
assert.Equal(t, "value", actor.Metadata["key"])
|
||||
@@ -33,10 +34,10 @@ func TestActorFromContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Given: a context with an actor
|
||||
ctx := AsActor(context.Background(), "test-id", recorder.Metadata{})
|
||||
ctx := aibcontext.AsActor(context.Background(), "test-id", recorder.Metadata{})
|
||||
|
||||
// When: extracting the actor from context
|
||||
actor := ActorFromContext(ctx)
|
||||
actor := aibcontext.ActorFromContext(ctx)
|
||||
|
||||
// Then: the actor should be returned with correct ID
|
||||
require.NotNil(t, actor)
|
||||
@@ -50,7 +51,7 @@ func TestActorFromContext(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// When: extracting the actor from context
|
||||
actor := ActorFromContext(ctx)
|
||||
actor := aibcontext.ActorFromContext(ctx)
|
||||
|
||||
// Then: nil should be returned
|
||||
assert.Nil(t, actor)
|
||||
@@ -64,10 +65,10 @@ func TestActorIDFromContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Given: a context with an actor
|
||||
ctx := AsActor(context.Background(), "test-actor-id", recorder.Metadata{})
|
||||
ctx := aibcontext.AsActor(context.Background(), "test-actor-id", recorder.Metadata{})
|
||||
|
||||
// When: extracting the actor ID from context
|
||||
got := ActorIDFromContext(ctx)
|
||||
got := aibcontext.ActorIDFromContext(ctx)
|
||||
|
||||
// Then: the actor ID should be returned
|
||||
assert.Equal(t, "test-actor-id", got)
|
||||
@@ -80,7 +81,7 @@ func TestActorIDFromContext(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// When: extracting the actor ID from context
|
||||
got := ActorIDFromContext(ctx)
|
||||
got := aibcontext.ActorIDFromContext(ctx)
|
||||
|
||||
// Then: an empty string should be returned
|
||||
assert.Empty(t, got)
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
These fixtures were created by adding logging middleware to API calls to view the raw requests/responses.
|
||||
|
||||
```go
|
||||
...
|
||||
opts = append(opts, option.WithMiddleware(LoggingMiddleware))
|
||||
...
|
||||
|
||||
func LoggingMiddleware(req *http.Request, next option.MiddlewareNext) (res *http.Response, err error) {
|
||||
reqOut, _ := httputil.DumpRequest(req, true)
|
||||
|
||||
// Forward the request to the next handler
|
||||
res, err = next(req)
|
||||
fmt.Printf("[req] %s\n", reqOut)
|
||||
|
||||
// Handle stuff after the request
|
||||
if err != nil {
|
||||
return res, err
|
||||
}
|
||||
|
||||
respOut, _ := httputil.DumpResponse(res, true)
|
||||
fmt.Printf("[resp] %s\n", respOut)
|
||||
|
||||
return res, err
|
||||
}
|
||||
```
|
||||
@@ -92,7 +92,7 @@ var (
|
||||
OaiResponsesBlockingConversation []byte
|
||||
|
||||
//go:embed openai/responses/blocking/http_error.txtar
|
||||
OaiResponsesBlockingHttpErr []byte
|
||||
OaiResponsesBlockingHTTPErr []byte
|
||||
|
||||
//go:embed openai/responses/blocking/prev_response_id.txtar
|
||||
OaiResponsesBlockingPrevResponseID []byte
|
||||
@@ -139,7 +139,7 @@ var (
|
||||
OaiResponsesStreamingConversation []byte
|
||||
|
||||
//go:embed openai/responses/streaming/http_error.txtar
|
||||
OaiResponsesStreamingHttpErr []byte
|
||||
OaiResponsesStreamingHTTPErr []byte
|
||||
|
||||
//go:embed openai/responses/streaming/prev_response_id.txtar
|
||||
OaiResponsesStreamingPrevResponseID []byte
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package intercept
|
||||
package intercept_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
@@ -7,14 +7,15 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/aibridge/context"
|
||||
"github.com/coder/coder/v2/aibridge/intercept"
|
||||
"github.com/coder/coder/v2/aibridge/recorder"
|
||||
)
|
||||
|
||||
func TestNilActor(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
require.Nil(t, ActorHeadersAsOpenAIOpts(nil))
|
||||
require.Nil(t, ActorHeadersAsAnthropicOpts(nil))
|
||||
require.Nil(t, intercept.ActorHeadersAsOpenAIOpts(nil))
|
||||
require.Nil(t, intercept.ActorHeadersAsAnthropicOpts(nil))
|
||||
}
|
||||
|
||||
func TestBasic(t *testing.T) {
|
||||
@@ -28,9 +29,9 @@ func TestBasic(t *testing.T) {
|
||||
// We can't peek inside since these opts require an internal type to apply onto.
|
||||
// All we can do is check the length.
|
||||
// See TestActorHeaders for an integration test.
|
||||
oaiOpts := ActorHeadersAsOpenAIOpts(actor)
|
||||
oaiOpts := intercept.ActorHeadersAsOpenAIOpts(actor)
|
||||
require.Len(t, oaiOpts, 1)
|
||||
antOpts := ActorHeadersAsAnthropicOpts(actor)
|
||||
antOpts := intercept.ActorHeadersAsAnthropicOpts(actor)
|
||||
require.Len(t, antOpts, 1)
|
||||
}
|
||||
|
||||
@@ -49,8 +50,8 @@ func TestBasicAndMetadata(t *testing.T) {
|
||||
// We can't peek inside since these opts require an internal type to apply onto.
|
||||
// All we can do is check the length.
|
||||
// See TestActorHeaders for an integration test.
|
||||
oaiOpts := ActorHeadersAsOpenAIOpts(actor)
|
||||
oaiOpts := intercept.ActorHeadersAsOpenAIOpts(actor)
|
||||
require.Len(t, oaiOpts, 1+len(actor.Metadata))
|
||||
antOpts := ActorHeadersAsAnthropicOpts(actor)
|
||||
antOpts := intercept.ActorHeadersAsAnthropicOpts(actor)
|
||||
require.Len(t, antOpts, 1+len(actor.Metadata))
|
||||
}
|
||||
|
||||
@@ -105,10 +105,11 @@ func (d *dumper) dumpRequest(req *http.Request) error {
|
||||
if err != nil {
|
||||
return xerrors.Errorf("write request header terminator: %w", err)
|
||||
}
|
||||
buf.Write(prettyBody)
|
||||
buf.WriteByte('\n')
|
||||
// bytes.Buffer writes to in-memory storage and never return errors.
|
||||
_, _ = buf.Write(prettyBody)
|
||||
_ = buf.WriteByte('\n')
|
||||
|
||||
return os.WriteFile(dumpPath, buf.Bytes(), 0o644)
|
||||
return os.WriteFile(dumpPath, buf.Bytes(), 0o600)
|
||||
}
|
||||
|
||||
func (d *dumper) dumpResponse(resp *http.Response) error {
|
||||
@@ -129,19 +130,19 @@ func (d *dumper) dumpResponse(resp *http.Response) error {
|
||||
return xerrors.Errorf("write response header terminator: %w", err)
|
||||
}
|
||||
|
||||
// Wrap the response body to capture it as it streams
|
||||
if resp.Body != nil {
|
||||
resp.Body = &streamingBodyDumper{
|
||||
body: resp.Body,
|
||||
dumpPath: dumpPath,
|
||||
headerData: headerBuf.Bytes(),
|
||||
logger: func(err error) {
|
||||
d.logger.Named("apidump").Warn(context.Background(), "failed to initialize response dump", slog.Error(err))
|
||||
},
|
||||
}
|
||||
} else {
|
||||
if resp.Body == nil {
|
||||
// No body, just write headers
|
||||
return os.WriteFile(dumpPath, headerBuf.Bytes(), 0o644)
|
||||
return os.WriteFile(dumpPath, headerBuf.Bytes(), 0o600)
|
||||
}
|
||||
|
||||
// Wrap the response body to capture it as it streams
|
||||
resp.Body = &streamingBodyDumper{
|
||||
body: resp.Body,
|
||||
dumpPath: dumpPath,
|
||||
headerData: headerBuf.Bytes(),
|
||||
logger: func(err error) {
|
||||
d.logger.Named("apidump").Warn(context.Background(), "failed to initialize response dump", slog.Error(err))
|
||||
},
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -152,7 +153,7 @@ func (d *dumper) dumpResponse(resp *http.Response) error {
|
||||
// for deterministic output.
|
||||
// `sensitive` and `overrides` must both supply keys in canoncialized form.
|
||||
// See [textproto.MIMEHeader].
|
||||
func (d *dumper) writeRedactedHeaders(w io.Writer, headers http.Header, sensitive map[string]struct{}, overrides map[string]string) error {
|
||||
func (*dumper) writeRedactedHeaders(w io.Writer, headers http.Header, sensitive map[string]struct{}, overrides map[string]string) error {
|
||||
// Collect all header keys including overrides.
|
||||
headerKeys := make([]string, 0, len(headers)+len(overrides))
|
||||
seen := make(map[string]struct{}, len(headers)+len(overrides))
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package apidump
|
||||
package apidump //nolint:testpackage // tests unexported internals
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
@@ -39,7 +39,7 @@ func TestBridgedMiddleware_RedactsSensitiveRequestHeaders(t *testing.T) {
|
||||
middleware := NewBridgeMiddleware(tmpDir, "openai", "gpt-4", interceptionID, logger, clk)
|
||||
require.NotNil(t, middleware)
|
||||
|
||||
req, err := http.NewRequest(http.MethodPost, "https://api.openai.com/v1/chat/completions", bytes.NewReader([]byte(`{"test": true}`)))
|
||||
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, "https://api.openai.com/v1/chat/completions", bytes.NewReader([]byte(`{"test": true}`)))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add sensitive headers that should be redacted
|
||||
@@ -52,7 +52,7 @@ func TestBridgedMiddleware_RedactsSensitiveRequestHeaders(t *testing.T) {
|
||||
req.Header.Set("User-Agent", "test-client")
|
||||
|
||||
// Call middleware with a mock next function
|
||||
_, err = middleware(req, func(r *http.Request) (*http.Response, error) {
|
||||
resp, err := middleware(req, func(r *http.Request) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Status: "200 OK",
|
||||
@@ -62,6 +62,7 @@ func TestBridgedMiddleware_RedactsSensitiveRequestHeaders(t *testing.T) {
|
||||
}, nil
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Read the request dump file
|
||||
modelDir := filepath.Join(tmpDir, "openai", "gpt-4")
|
||||
@@ -96,7 +97,7 @@ func TestBridgedMiddleware_RedactsSensitiveResponseHeaders(t *testing.T) {
|
||||
middleware := NewBridgeMiddleware(tmpDir, "openai", "gpt-4", interceptionID, logger, clk)
|
||||
require.NotNil(t, middleware)
|
||||
|
||||
req, err := http.NewRequest(http.MethodPost, "https://api.openai.com/v1/chat/completions", bytes.NewReader([]byte(`{}`)))
|
||||
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, "https://api.openai.com/v1/chat/completions", bytes.NewReader([]byte(`{}`)))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Call middleware with a response containing sensitive headers
|
||||
@@ -166,11 +167,11 @@ func TestBridgedMiddleware_PreservesRequestBody(t *testing.T) {
|
||||
require.NotNil(t, middleware)
|
||||
|
||||
originalBody := `{"messages": [{"role": "user", "content": "hello"}]}`
|
||||
req, err := http.NewRequest(http.MethodPost, "https://api.openai.com/v1/chat/completions", bytes.NewReader([]byte(originalBody)))
|
||||
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, "https://api.openai.com/v1/chat/completions", bytes.NewReader([]byte(originalBody)))
|
||||
require.NoError(t, err)
|
||||
|
||||
var capturedBody []byte
|
||||
_, err = middleware(req, func(r *http.Request) (*http.Response, error) {
|
||||
resp2, err := middleware(req, func(r *http.Request) (*http.Response, error) {
|
||||
// Read the body in the next handler to verify it's still available
|
||||
capturedBody, _ = io.ReadAll(r.Body)
|
||||
return &http.Response{
|
||||
@@ -182,6 +183,7 @@ func TestBridgedMiddleware_PreservesRequestBody(t *testing.T) {
|
||||
}, nil
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer resp2.Body.Close()
|
||||
|
||||
// Verify the body was preserved for the next handler
|
||||
require.Equal(t, originalBody, string(capturedBody))
|
||||
@@ -199,10 +201,10 @@ func TestBridgedMiddleware_ModelWithSlash(t *testing.T) {
|
||||
middleware := NewBridgeMiddleware(tmpDir, "google", "gemini/1.5-pro", interceptionID, logger, clk)
|
||||
require.NotNil(t, middleware)
|
||||
|
||||
req, err := http.NewRequest(http.MethodPost, "https://api.google.com/v1/chat", bytes.NewReader([]byte(`{}`)))
|
||||
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, "https://api.google.com/v1/chat", bytes.NewReader([]byte(`{}`)))
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = middleware(req, func(r *http.Request) (*http.Response, error) {
|
||||
resp3, err := middleware(req, func(r *http.Request) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Status: "200 OK",
|
||||
@@ -212,6 +214,7 @@ func TestBridgedMiddleware_ModelWithSlash(t *testing.T) {
|
||||
}, nil
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer resp3.Body.Close()
|
||||
|
||||
// Verify files are created with sanitized model name
|
||||
modelDir := filepath.Join(tmpDir, "google", "gemini-1.5-pro")
|
||||
@@ -278,7 +281,7 @@ func TestBridgedMiddleware_AllSensitiveRequestHeaders(t *testing.T) {
|
||||
middleware := NewBridgeMiddleware(tmpDir, "openai", "gpt-4", interceptionID, logger, clk)
|
||||
require.NotNil(t, middleware)
|
||||
|
||||
req, err := http.NewRequest(http.MethodPost, "https://api.openai.com/v1/chat/completions", bytes.NewReader([]byte(`{}`)))
|
||||
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, "https://api.openai.com/v1/chat/completions", bytes.NewReader([]byte(`{}`)))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set all sensitive headers
|
||||
@@ -290,7 +293,7 @@ func TestBridgedMiddleware_AllSensitiveRequestHeaders(t *testing.T) {
|
||||
req.Header.Set("Proxy-Authorization", "Basic proxy-creds")
|
||||
req.Header.Set("X-Amz-Security-Token", "aws-security-token")
|
||||
|
||||
_, err = middleware(req, func(r *http.Request) (*http.Response, error) {
|
||||
resp4, err := middleware(req, func(r *http.Request) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Status: "200 OK",
|
||||
@@ -300,6 +303,7 @@ func TestBridgedMiddleware_AllSensitiveRequestHeaders(t *testing.T) {
|
||||
}, nil
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer resp4.Body.Close()
|
||||
|
||||
modelDir := filepath.Join(tmpDir, "openai", "gpt-4")
|
||||
reqDumpPath := findDumpFile(t, modelDir, SuffixRequest)
|
||||
@@ -355,10 +359,10 @@ func TestPassthroughMiddleware(t *testing.T) {
|
||||
|
||||
rt := NewPassthroughMiddleware(inner, tmpDir, "openai", logger, clk)
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, "https://api.openai.com/v1/models", nil)
|
||||
req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, "https://api.openai.com/v1/models", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := rt.RoundTrip(req)
|
||||
resp, err := rt.RoundTrip(req) //nolint:bodyclose // resp is nil on error
|
||||
require.ErrorIs(t, err, innerErr)
|
||||
require.Nil(t, resp)
|
||||
})
|
||||
@@ -399,7 +403,7 @@ func TestPassthroughMiddleware(t *testing.T) {
|
||||
|
||||
rt := NewPassthroughMiddleware(inner, tmpDir, "openai", logger, clk)
|
||||
|
||||
req, err := http.NewRequest(http.MethodPost, "/v1/models", bytes.NewReader([]byte(req1Body)))
|
||||
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, "/v1/models", bytes.NewReader([]byte(req1Body)))
|
||||
require.NoError(t, err)
|
||||
req.Header.Set("Authorization", "Bearer sk-secret-key-12345")
|
||||
resp, err := rt.RoundTrip(req)
|
||||
@@ -409,7 +413,7 @@ func TestPassthroughMiddleware(t *testing.T) {
|
||||
require.NoError(t, resp.Body.Close())
|
||||
|
||||
// Second request should create new req/resp files
|
||||
req2, err := http.NewRequest(http.MethodPost, "/v1/conversations", bytes.NewReader([]byte(req2Body)))
|
||||
req2, err := http.NewRequestWithContext(t.Context(), http.MethodPost, "/v1/conversations", bytes.NewReader([]byte(req2Body)))
|
||||
require.NoError(t, err)
|
||||
resp2, err := rt.RoundTrip(req2)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package apidump
|
||||
package apidump //nolint:testpackage // tests unexported internals
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
@@ -37,7 +37,7 @@ func (s *streamingBodyDumper) init() {
|
||||
// Write headers first.
|
||||
if _, err := s.file.Write(s.headerData); err != nil {
|
||||
s.initErr = xerrors.Errorf("write headers: %w", err)
|
||||
s.file.Close()
|
||||
_ = s.file.Close() // best-effort cleanup on header write failure
|
||||
s.file = nil
|
||||
}
|
||||
})
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package apidump
|
||||
package apidump //nolint:testpackage // shares test helpers with apidump_test.go
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
@@ -28,7 +28,7 @@ func TestMiddleware_StreamingResponse(t *testing.T) {
|
||||
middleware := NewBridgeMiddleware(tmpDir, "openai", "gpt-4", interceptionID, logger, clk)
|
||||
require.NotNil(t, middleware)
|
||||
|
||||
req, err := http.NewRequest(http.MethodPost, "https://api.openai.com/v1/chat/completions", bytes.NewReader([]byte(`{}`)))
|
||||
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, "https://api.openai.com/v1/chat/completions", bytes.NewReader([]byte(`{}`)))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Simulate a streaming response with multiple chunks
|
||||
@@ -42,10 +42,12 @@ func TestMiddleware_StreamingResponse(t *testing.T) {
|
||||
// Create a pipe to simulate streaming
|
||||
pr, pw := io.Pipe()
|
||||
go func() {
|
||||
defer pw.Close() //nolint:revive // error handled via pipe read side
|
||||
for _, chunk := range chunks {
|
||||
pw.Write([]byte(chunk))
|
||||
if _, err := pw.Write([]byte(chunk)); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
pw.Close()
|
||||
}()
|
||||
|
||||
resp, err := middleware(req, func(r *http.Request) (*http.Response, error) {
|
||||
@@ -65,7 +67,7 @@ func TestMiddleware_StreamingResponse(t *testing.T) {
|
||||
for {
|
||||
n, err := resp.Body.Read(buf)
|
||||
if n > 0 {
|
||||
receivedData.Write(buf[:n])
|
||||
_, _ = receivedData.Write(buf[:n]) // bytes.Buffer.Write never fails
|
||||
}
|
||||
if err == io.EOF {
|
||||
break
|
||||
@@ -104,7 +106,7 @@ func TestMiddleware_PreservesResponseBody(t *testing.T) {
|
||||
middleware := NewBridgeMiddleware(tmpDir, "openai", "gpt-4", interceptionID, logger, clk)
|
||||
require.NotNil(t, middleware)
|
||||
|
||||
req, err := http.NewRequest(http.MethodPost, "https://api.openai.com/v1/chat/completions", bytes.NewReader([]byte(`{}`)))
|
||||
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, "https://api.openai.com/v1/chat/completions", bytes.NewReader([]byte(`{}`)))
|
||||
require.NoError(t, err)
|
||||
|
||||
originalRespBody := `{"choices": [{"message": {"content": "hi"}}]}`
|
||||
@@ -118,6 +120,7 @@ func TestMiddleware_PreservesResponseBody(t *testing.T) {
|
||||
}, nil
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Verify the response body is still readable after middleware
|
||||
capturedBody, err := io.ReadAll(resp.Body)
|
||||
|
||||
@@ -79,9 +79,9 @@ func (i *interceptionBase) Credential() intercept.CredentialInfo {
|
||||
return i.credential
|
||||
}
|
||||
|
||||
func (i *interceptionBase) Setup(logger slog.Logger, recorder recorder.Recorder, mcpProxy mcp.ServerProxier) {
|
||||
func (i *interceptionBase) Setup(logger slog.Logger, rec recorder.Recorder, mcpProxy mcp.ServerProxier) {
|
||||
i.logger = logger
|
||||
i.recorder = recorder
|
||||
i.recorder = rec
|
||||
i.mcpProxy = mcpProxy
|
||||
}
|
||||
|
||||
@@ -98,13 +98,13 @@ func (i *interceptionBase) CorrelatingToolCallID() *string {
|
||||
return &msg.OfTool.ToolCallID
|
||||
}
|
||||
|
||||
func (s *interceptionBase) baseTraceAttributes(r *http.Request, streaming bool) []attribute.KeyValue {
|
||||
func (i *interceptionBase) baseTraceAttributes(r *http.Request, streaming bool) []attribute.KeyValue {
|
||||
return []attribute.KeyValue{
|
||||
attribute.String(tracing.RequestPath, r.URL.Path),
|
||||
attribute.String(tracing.InterceptionID, s.id.String()),
|
||||
attribute.String(tracing.InterceptionID, i.id.String()),
|
||||
attribute.String(tracing.InitiatorID, aibcontext.ActorIDFromContext(r.Context())),
|
||||
attribute.String(tracing.Provider, s.providerName),
|
||||
attribute.String(tracing.Model, s.Model()),
|
||||
attribute.String(tracing.Provider, i.providerName),
|
||||
attribute.String(tracing.Model, i.Model()),
|
||||
attribute.Bool(tracing.Streaming, streaming),
|
||||
}
|
||||
}
|
||||
@@ -114,10 +114,10 @@ func (i *interceptionBase) Model() string {
|
||||
return "coder-aibridge-unknown"
|
||||
}
|
||||
|
||||
return string(i.req.Model)
|
||||
return i.req.Model
|
||||
}
|
||||
|
||||
func (i *interceptionBase) newErrorResponse(err error) map[string]any {
|
||||
func (*interceptionBase) newErrorResponse(err error) map[string]any {
|
||||
return map[string]any{
|
||||
"error": true,
|
||||
"message": err.Error(),
|
||||
@@ -172,7 +172,7 @@ func (i *interceptionBase) unmarshalArgs(in string) (args recorder.ToolArgs) {
|
||||
}
|
||||
|
||||
// writeUpstreamError marshals and writes a given error.
|
||||
func (i *interceptionBase) writeUpstreamError(w http.ResponseWriter, oaiErr *errorResponse) {
|
||||
func (i *interceptionBase) writeUpstreamError(w http.ResponseWriter, oaiErr *chatCompletionResponseError) {
|
||||
if oaiErr == nil {
|
||||
return
|
||||
}
|
||||
@@ -182,7 +182,7 @@ func (i *interceptionBase) writeUpstreamError(w http.ResponseWriter, oaiErr *err
|
||||
|
||||
out, err := json.Marshal(oaiErr)
|
||||
if err != nil {
|
||||
i.logger.Warn(context.Background(), "failed to marshal upstream error", slog.Error(err), slog.F("error_payload", slog.F("%+v", oaiErr)))
|
||||
i.logger.Warn(context.Background(), "failed to marshal upstream error", slog.Error(err), slog.F("error_payload", oaiErr))
|
||||
// Response has to match expected format.
|
||||
_, _ = w.Write([]byte(`{
|
||||
"error": {
|
||||
@@ -227,13 +227,13 @@ func calculateActualInputTokenUsage(in openai.CompletionUsage) int64 {
|
||||
in.PromptTokensDetails.CachedTokens /* The aggregated number of text input tokens that has been cached from previous requests. */
|
||||
}
|
||||
|
||||
func getErrorResponse(err error) *errorResponse {
|
||||
func getErrorResponse(err error) *chatCompletionResponseError {
|
||||
var apiErr *openai.Error
|
||||
if !errors.As(err, &apiErr) {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &errorResponse{
|
||||
return &chatCompletionResponseError{
|
||||
ErrorObject: &shared.ErrorObject{
|
||||
Code: apiErr.Code,
|
||||
Message: apiErr.Message,
|
||||
@@ -243,15 +243,15 @@ func getErrorResponse(err error) *errorResponse {
|
||||
}
|
||||
}
|
||||
|
||||
var _ error = &errorResponse{}
|
||||
var _ error = &chatCompletionResponseError{}
|
||||
|
||||
type errorResponse struct {
|
||||
type chatCompletionResponseError struct {
|
||||
ErrorObject *shared.ErrorObject `json:"error"`
|
||||
StatusCode int `json:"-"`
|
||||
}
|
||||
|
||||
func newErrorResponse(msg error) *errorResponse {
|
||||
return &errorResponse{
|
||||
func newErrorResponse(msg error) *chatCompletionResponseError {
|
||||
return &chatCompletionResponseError{
|
||||
ErrorObject: &shared.ErrorObject{
|
||||
Code: "error",
|
||||
Message: msg.Error(),
|
||||
@@ -260,7 +260,7 @@ func newErrorResponse(msg error) *errorResponse {
|
||||
}
|
||||
}
|
||||
|
||||
func (a *errorResponse) Error() string {
|
||||
func (a *chatCompletionResponseError) Error() string {
|
||||
if a.ErrorObject == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package chatcompletions
|
||||
package chatcompletions //nolint:testpackage // tests unexported internals
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
@@ -50,16 +50,16 @@ func NewBlockingInterceptor(
|
||||
}}
|
||||
}
|
||||
|
||||
func (s *BlockingInterception) Setup(logger slog.Logger, recorder recorder.Recorder, mcpProxy mcp.ServerProxier) {
|
||||
s.interceptionBase.Setup(logger.Named("blocking"), recorder, mcpProxy)
|
||||
func (i *BlockingInterception) Setup(logger slog.Logger, rec recorder.Recorder, mcpProxy mcp.ServerProxier) {
|
||||
i.interceptionBase.Setup(logger.Named("blocking"), rec, mcpProxy)
|
||||
}
|
||||
|
||||
func (s *BlockingInterception) Streaming() bool {
|
||||
func (*BlockingInterception) Streaming() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *BlockingInterception) TraceAttributes(r *http.Request) []attribute.KeyValue {
|
||||
return s.interceptionBase.baseTraceAttributes(r, false)
|
||||
func (i *BlockingInterception) TraceAttributes(r *http.Request) []attribute.KeyValue {
|
||||
return i.interceptionBase.baseTraceAttributes(r, false)
|
||||
}
|
||||
|
||||
func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Request) (outErr error) {
|
||||
|
||||
@@ -52,7 +52,7 @@ func (c *ChatCompletionNewParamsWrapper) lastUserPrompt() (*string, error) {
|
||||
// We only care if the last message was issued by a user.
|
||||
msg := c.Messages[len(c.Messages)-1]
|
||||
if msg.OfUser == nil {
|
||||
return nil, nil
|
||||
return nil, nil //nolint:nilnil // no user prompt found is not an error
|
||||
}
|
||||
|
||||
if msg.OfUser.Content.OfString.String() != "" {
|
||||
@@ -69,5 +69,5 @@ func (c *ChatCompletionNewParamsWrapper) lastUserPrompt() (*string, error) {
|
||||
}
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
return nil, nil //nolint:nilnil // no text content found is not an error
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package chatcompletions
|
||||
package chatcompletions //nolint:testpackage // tests unexported internals
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
@@ -114,6 +114,8 @@ func TestOpenAILastUserPrompt(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
result, err := tt.wrapper.lastUserPrompt()
|
||||
|
||||
if tt.expectError {
|
||||
@@ -144,7 +146,7 @@ func generatePayload(messageCount int) []byte {
|
||||
}
|
||||
// Use realistic message content size
|
||||
content := fmt.Sprintf("This is message number %d with some realistic content that might appear in a conversation.", i+1)
|
||||
messages = append(messages, fmt.Sprintf(`{"role": "%s", "content": "%s"}`, role, content))
|
||||
messages = append(messages, fmt.Sprintf(`{"role": %q, "content": %q}`, role, content))
|
||||
}
|
||||
|
||||
return []byte(fmt.Sprintf(`{
|
||||
|
||||
@@ -54,16 +54,16 @@ func NewStreamingInterceptor(
|
||||
}}
|
||||
}
|
||||
|
||||
func (i *StreamingInterception) Setup(logger slog.Logger, recorder recorder.Recorder, mcpProxy mcp.ServerProxier) {
|
||||
i.interceptionBase.Setup(logger.Named("streaming"), recorder, mcpProxy)
|
||||
func (i *StreamingInterception) Setup(logger slog.Logger, rec recorder.Recorder, mcpProxy mcp.ServerProxier) {
|
||||
i.interceptionBase.Setup(logger.Named("streaming"), rec, mcpProxy)
|
||||
}
|
||||
|
||||
func (i *StreamingInterception) Streaming() bool {
|
||||
func (*StreamingInterception) Streaming() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *StreamingInterception) TraceAttributes(r *http.Request) []attribute.KeyValue {
|
||||
return s.interceptionBase.baseTraceAttributes(r, true)
|
||||
func (i *StreamingInterception) TraceAttributes(r *http.Request) []attribute.KeyValue {
|
||||
return i.interceptionBase.baseTraceAttributes(r, true)
|
||||
}
|
||||
|
||||
// ProcessRequest handles a request to /v1/chat/completions.
|
||||
@@ -189,16 +189,14 @@ func (i *StreamingInterception) ProcessRequest(w http.ResponseWriter, r *http.Re
|
||||
})
|
||||
|
||||
toolCall = nil
|
||||
} else {
|
||||
} else if stream.Err() == nil {
|
||||
// When the provider responds with only tool calls (no text content),
|
||||
// no chunks are relayed to the client, so the stream is not yet
|
||||
// initiated. Initiate it here so the SSE headers are sent and the
|
||||
// ping ticker is started, preventing client timeout during tool invocation.
|
||||
// Only initiate if no stream error, if there's an error, we'll return
|
||||
// an HTTP error response instead of starting an SSE stream.
|
||||
if stream.Err() == nil {
|
||||
events.InitiateStream(w)
|
||||
}
|
||||
events.InitiateStream(w)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -231,43 +229,43 @@ func (i *StreamingInterception) ProcessRequest(w http.ResponseWriter, r *http.Re
|
||||
})
|
||||
}
|
||||
|
||||
if events.IsStreaming() {
|
||||
// Check if the stream encountered any errors.
|
||||
if streamErr := stream.Err(); streamErr != nil {
|
||||
if eventstream.IsUnrecoverableError(streamErr) {
|
||||
logger.Debug(ctx, "stream terminated", slog.Error(streamErr))
|
||||
// We can't reflect an error back if there's a connection error or the request context was canceled.
|
||||
} else if oaiErr := getErrorResponse(streamErr); oaiErr != nil {
|
||||
logger.Warn(ctx, "openai stream error", slog.Error(streamErr))
|
||||
interceptionErr = oaiErr
|
||||
} else {
|
||||
logger.Warn(ctx, "unknown error", slog.Error(streamErr))
|
||||
// Unfortunately, the OpenAI SDK does not support parsing errors received in the stream
|
||||
// into known types (i.e. [shared.OverloadedError]).
|
||||
// See https://github.com/openai/openai-go/blob/v2.7.0/packages/ssestream/ssestream.go#L171
|
||||
// All it does is wrap the payload in an error - which is all we can return, currently.
|
||||
interceptionErr = newErrorResponse(xerrors.Errorf("unknown stream error: %w", streamErr))
|
||||
}
|
||||
} else if lastErr != nil {
|
||||
// Otherwise check if any logical errors occurred during processing.
|
||||
logger.Warn(ctx, "stream failed", slog.Error(lastErr))
|
||||
interceptionErr = newErrorResponse(xerrors.Errorf("processing error: %w", lastErr))
|
||||
}
|
||||
|
||||
if interceptionErr != nil {
|
||||
payload, err := i.marshalErr(interceptionErr)
|
||||
if err != nil {
|
||||
logger.Warn(ctx, "failed to marshal error", slog.Error(err), slog.F("error_payload", slog.F("%+v", interceptionErr)))
|
||||
} else if err := events.Send(streamCtx, payload); err != nil {
|
||||
logger.Warn(ctx, "failed to relay error", slog.Error(err), slog.F("payload", payload))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if !events.IsStreaming() {
|
||||
// response/downstream Stream has not started yet; write error response and exit.
|
||||
i.writeUpstreamError(w, getErrorResponse(stream.Err()))
|
||||
return stream.Err()
|
||||
}
|
||||
|
||||
// Check if the stream encountered any errors.
|
||||
if streamErr := stream.Err(); streamErr != nil {
|
||||
if eventstream.IsUnrecoverableError(streamErr) {
|
||||
logger.Debug(ctx, "stream terminated", slog.Error(streamErr))
|
||||
// We can't reflect an error back if there's a connection error or the request context was canceled.
|
||||
} else if oaiErr := getErrorResponse(streamErr); oaiErr != nil {
|
||||
logger.Warn(ctx, "openai stream error", slog.Error(streamErr))
|
||||
interceptionErr = oaiErr
|
||||
} else {
|
||||
logger.Warn(ctx, "unknown stream error encountered", slog.Error(streamErr))
|
||||
// Unfortunately, the OpenAI SDK does not support parsing errors received in the stream
|
||||
// into known types (i.e. [shared.OverloadedError]).
|
||||
// See https://github.com/openai/openai-go/blob/v2.7.0/packages/ssestream/ssestream.go#L171
|
||||
// All it does is wrap the payload in an error - which is all we can return, currently.
|
||||
interceptionErr = newErrorResponse(xerrors.Errorf("unknown stream error: %w", streamErr))
|
||||
}
|
||||
} else if lastErr != nil {
|
||||
// Otherwise check if any logical errors occurred during processing.
|
||||
logger.Warn(ctx, "stream processing failed", slog.Error(lastErr))
|
||||
interceptionErr = newErrorResponse(xerrors.Errorf("processing error: %w", lastErr))
|
||||
}
|
||||
|
||||
if interceptionErr != nil {
|
||||
payload, err := i.marshalErr(interceptionErr)
|
||||
if err != nil {
|
||||
logger.Warn(ctx, "failed to marshal error", slog.Error(err), slog.F("error_payload", interceptionErr.Error()))
|
||||
} else if err := events.Send(streamCtx, payload); err != nil {
|
||||
logger.Warn(ctx, "failed to relay error", slog.Error(err), slog.F("payload", payload))
|
||||
}
|
||||
}
|
||||
|
||||
// No tool call, nothing more to do.
|
||||
if toolCall == nil {
|
||||
break
|
||||
@@ -390,11 +388,12 @@ func (i *StreamingInterception) marshalErr(err error) ([]byte, error) {
|
||||
return i.encodeForStream(data), nil
|
||||
}
|
||||
|
||||
func (i *StreamingInterception) encodeForStream(payload []byte) []byte {
|
||||
func (*StreamingInterception) encodeForStream(payload []byte) []byte {
|
||||
// bytes.Buffer writes to in-memory storage and never return errors.
|
||||
var buf bytes.Buffer
|
||||
buf.WriteString("data: ")
|
||||
buf.Write(payload)
|
||||
buf.WriteString("\n\n")
|
||||
_, _ = buf.WriteString("data: ")
|
||||
_, _ = buf.Write(payload)
|
||||
_, _ = buf.WriteString("\n\n")
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package chatcompletions
|
||||
package chatcompletions_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/aibridge/config"
|
||||
"github.com/coder/coder/v2/aibridge/intercept"
|
||||
"github.com/coder/coder/v2/aibridge/intercept/chatcompletions"
|
||||
"github.com/coder/coder/v2/aibridge/internal/testutil"
|
||||
)
|
||||
|
||||
@@ -73,7 +74,7 @@ func TestStreamingInterception_RelaysUpstreamErrorToClient(t *testing.T) {
|
||||
Key: "test-key",
|
||||
}
|
||||
|
||||
req := &ChatCompletionNewParamsWrapper{
|
||||
req := &chatcompletions.ChatCompletionNewParamsWrapper{
|
||||
ChatCompletionNewParams: openai.ChatCompletionNewParams{
|
||||
Model: "gpt-4",
|
||||
Messages: []openai.ChatCompletionMessageParamUnion{
|
||||
@@ -88,7 +89,7 @@ func TestStreamingInterception_RelaysUpstreamErrorToClient(t *testing.T) {
|
||||
httpReq := httptest.NewRequest(http.MethodPost, "/chat/completions", nil)
|
||||
|
||||
tracer := otel.Tracer("test")
|
||||
interceptor := NewStreamingInterceptor(uuid.New(), req, config.ProviderOpenAI, cfg, httpReq.Header, "Authorization", tracer, intercept.CredentialInfo{})
|
||||
interceptor := chatcompletions.NewStreamingInterceptor(uuid.New(), req, config.ProviderOpenAI, cfg, httpReq.Header, "Authorization", tracer, intercept.CredentialInfo{})
|
||||
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
|
||||
interceptor.Setup(logger, &testutil.MockRecorder{}, nil)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package intercept
|
||||
package intercept_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
@@ -6,6 +6,8 @@ import (
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/aibridge/intercept"
|
||||
)
|
||||
|
||||
func TestPrepareClientHeaders(t *testing.T) {
|
||||
@@ -14,7 +16,7 @@ func TestPrepareClientHeaders(t *testing.T) {
|
||||
t.Run("nil input returns empty header", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
result := PrepareClientHeaders(nil)
|
||||
result := intercept.PrepareClientHeaders(nil)
|
||||
require.Empty(t, result)
|
||||
})
|
||||
|
||||
@@ -29,7 +31,7 @@ func TestPrepareClientHeaders(t *testing.T) {
|
||||
"X-Custom": {"preserved"},
|
||||
}
|
||||
|
||||
result := PrepareClientHeaders(input)
|
||||
result := intercept.PrepareClientHeaders(input)
|
||||
|
||||
assert.Empty(t, result.Get("Connection"))
|
||||
assert.Empty(t, result.Get("Keep-Alive"))
|
||||
@@ -48,7 +50,7 @@ func TestPrepareClientHeaders(t *testing.T) {
|
||||
"X-Custom": {"preserved"},
|
||||
}
|
||||
|
||||
result := PrepareClientHeaders(input)
|
||||
result := intercept.PrepareClientHeaders(input)
|
||||
|
||||
assert.Empty(t, result.Get("Host"))
|
||||
assert.Empty(t, result.Get("Accept-Encoding"))
|
||||
@@ -65,7 +67,7 @@ func TestPrepareClientHeaders(t *testing.T) {
|
||||
"X-Custom": {"preserved"},
|
||||
}
|
||||
|
||||
result := PrepareClientHeaders(input)
|
||||
result := intercept.PrepareClientHeaders(input)
|
||||
|
||||
assert.Empty(t, result.Get("Authorization"))
|
||||
assert.Empty(t, result.Get("X-Api-Key"))
|
||||
@@ -79,7 +81,7 @@ func TestPrepareClientHeaders(t *testing.T) {
|
||||
"X-Custom": {"value-1", "value-2"},
|
||||
}
|
||||
|
||||
result := PrepareClientHeaders(input)
|
||||
result := intercept.PrepareClientHeaders(input)
|
||||
|
||||
require.Equal(t, []string{"value-1", "value-2"}, result["X-Custom"])
|
||||
})
|
||||
@@ -93,7 +95,7 @@ func TestPrepareClientHeaders(t *testing.T) {
|
||||
}
|
||||
originalCopy := input.Clone()
|
||||
|
||||
_ = PrepareClientHeaders(input)
|
||||
_ = intercept.PrepareClientHeaders(input)
|
||||
|
||||
require.Equal(t, originalCopy, input)
|
||||
})
|
||||
@@ -113,7 +115,7 @@ func TestBuildUpstreamHeaders(t *testing.T) {
|
||||
"User-Agent": {"claude-code/1.0"},
|
||||
}
|
||||
|
||||
result := BuildUpstreamHeaders(sdkHeader, clientHeaders, "Authorization")
|
||||
result := intercept.BuildUpstreamHeaders(sdkHeader, clientHeaders, "Authorization")
|
||||
|
||||
assert.Equal(t, "Bearer sk-provider-key", result.Get("Authorization"))
|
||||
assert.Equal(t, "claude-code/1.0", result.Get("User-Agent"))
|
||||
@@ -131,7 +133,7 @@ func TestBuildUpstreamHeaders(t *testing.T) {
|
||||
"Anthropic-Beta": {"prompt-caching-2024-07-31"},
|
||||
}
|
||||
|
||||
result := BuildUpstreamHeaders(sdkHeader, clientHeaders, "X-Api-Key")
|
||||
result := intercept.BuildUpstreamHeaders(sdkHeader, clientHeaders, "X-Api-Key")
|
||||
|
||||
assert.Equal(t, "sk-ant-provider-key", result.Get("X-Api-Key"))
|
||||
assert.Empty(t, result.Get("Authorization"))
|
||||
@@ -151,7 +153,7 @@ func TestBuildUpstreamHeaders(t *testing.T) {
|
||||
"User-Agent": {"claude-code/1.0"},
|
||||
}
|
||||
|
||||
result := BuildUpstreamHeaders(sdkHeader, clientHeaders, "Authorization")
|
||||
result := intercept.BuildUpstreamHeaders(sdkHeader, clientHeaders, "Authorization")
|
||||
|
||||
assert.Equal(t, "Bearer sk-key", result.Get("Authorization"))
|
||||
assert.Equal(t, "user-123", result.Get("X-Ai-Bridge-Actor-Id"))
|
||||
@@ -174,7 +176,7 @@ func TestBuildUpstreamHeaders(t *testing.T) {
|
||||
"User-Agent": {"claude-code/1.0"},
|
||||
}
|
||||
|
||||
result := BuildUpstreamHeaders(sdkHeader, clientHeaders, "Authorization")
|
||||
result := intercept.BuildUpstreamHeaders(sdkHeader, clientHeaders, "Authorization")
|
||||
|
||||
assert.Empty(t, result.Get("Connection"))
|
||||
assert.Empty(t, result.Get("Host"))
|
||||
@@ -192,7 +194,7 @@ func TestBuildUpstreamHeaders(t *testing.T) {
|
||||
"User-Agent": {"claude-code/1.0"},
|
||||
}
|
||||
|
||||
result := BuildUpstreamHeaders(sdkHeader, clientHeaders, "Authorization")
|
||||
result := intercept.BuildUpstreamHeaders(sdkHeader, clientHeaders, "Authorization")
|
||||
|
||||
assert.Empty(t, result.Get("Authorization"))
|
||||
assert.Equal(t, "claude-code/1.0", result.Get("User-Agent"))
|
||||
@@ -211,7 +213,7 @@ func TestBuildUpstreamHeaders(t *testing.T) {
|
||||
sdkCopy := sdkHeader.Clone()
|
||||
clientCopy := clientHeaders.Clone()
|
||||
|
||||
_ = BuildUpstreamHeaders(sdkHeader, clientHeaders, "Authorization")
|
||||
_ = intercept.BuildUpstreamHeaders(sdkHeader, clientHeaders, "Authorization")
|
||||
|
||||
require.Equal(t, sdkCopy, sdkHeader)
|
||||
require.Equal(t, clientCopy, clientHeaders)
|
||||
|
||||
@@ -32,7 +32,6 @@ type EventStream struct {
|
||||
initiated atomic.Bool
|
||||
initiateOnce sync.Once
|
||||
|
||||
closeOnce sync.Once
|
||||
shutdownOnce sync.Once
|
||||
eventsCh chan event
|
||||
|
||||
@@ -133,7 +132,7 @@ func (s *EventStream) Start(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
if err := flush(w); err != nil {
|
||||
s.logger.Warn(ctx, "failed to flush", slog.Error(err))
|
||||
s.logger.Warn(ctx, "failed to flush event stream", slog.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -240,8 +239,7 @@ func flush(w http.ResponseWriter) (err error) {
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
// Likely a broken connection, don't spam the logs.
|
||||
if r := recover(); r != nil { //nolint:revive,staticcheck // Intentionally swallowed; likely a broken connection.
|
||||
}
|
||||
}()
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ type Interceptor interface {
|
||||
ID() uuid.UUID
|
||||
// Setup injects some required dependencies. This MUST be called before using the interceptor
|
||||
// to process requests.
|
||||
Setup(logger slog.Logger, recorder recorder.Recorder, mcpProxy mcp.ServerProxier)
|
||||
Setup(logger slog.Logger, rec recorder.Recorder, mcpProxy mcp.ServerProxier)
|
||||
// Model returns the model in use for this [Interceptor].
|
||||
Model() string
|
||||
// ProcessRequest handles the HTTP request.
|
||||
|
||||
@@ -65,7 +65,7 @@ var bedrockSupportedBetaFlags = map[string]bool{
|
||||
type interceptionBase struct {
|
||||
id uuid.UUID
|
||||
providerName string
|
||||
reqPayload MessagesRequestPayload
|
||||
reqPayload RequestPayload
|
||||
|
||||
cfg aibconfig.Anthropic
|
||||
bedrockCfg *aibconfig.AWSBedrock
|
||||
@@ -90,9 +90,9 @@ func (i *interceptionBase) Credential() intercept.CredentialInfo {
|
||||
return i.credential
|
||||
}
|
||||
|
||||
func (i *interceptionBase) Setup(logger slog.Logger, recorder recorder.Recorder, mcpProxy mcp.ServerProxier) {
|
||||
func (i *interceptionBase) Setup(logger slog.Logger, rec recorder.Recorder, mcpProxy mcp.ServerProxier) {
|
||||
i.logger = logger
|
||||
i.recorder = recorder
|
||||
i.recorder = rec
|
||||
i.mcpProxy = mcpProxy
|
||||
}
|
||||
|
||||
@@ -116,15 +116,15 @@ func (i *interceptionBase) Model() string {
|
||||
return i.reqPayload.model()
|
||||
}
|
||||
|
||||
func (s *interceptionBase) baseTraceAttributes(r *http.Request, streaming bool) []attribute.KeyValue {
|
||||
func (i *interceptionBase) baseTraceAttributes(r *http.Request, streaming bool) []attribute.KeyValue {
|
||||
return []attribute.KeyValue{
|
||||
attribute.String(tracing.RequestPath, r.URL.Path),
|
||||
attribute.String(tracing.InterceptionID, s.id.String()),
|
||||
attribute.String(tracing.InterceptionID, i.id.String()),
|
||||
attribute.String(tracing.InitiatorID, aibcontext.ActorIDFromContext(r.Context())),
|
||||
attribute.String(tracing.Provider, s.providerName),
|
||||
attribute.String(tracing.Model, s.Model()),
|
||||
attribute.String(tracing.Provider, i.providerName),
|
||||
attribute.String(tracing.Model, i.Model()),
|
||||
attribute.Bool(tracing.Streaming, streaming),
|
||||
attribute.Bool(tracing.IsBedrock, s.bedrockCfg != nil),
|
||||
attribute.Bool(tracing.IsBedrock, i.bedrockCfg != nil),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -174,24 +174,22 @@ func (i *interceptionBase) disableParallelToolCalls() {
|
||||
}
|
||||
|
||||
// extractModelThoughts returns any thinking blocks that were returned in the response.
|
||||
func (i *interceptionBase) extractModelThoughts(msg *anthropic.Message) []*recorder.ModelThoughtRecord {
|
||||
func (*interceptionBase) extractModelThoughts(msg *anthropic.Message) []*recorder.ModelThoughtRecord {
|
||||
if msg == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var thoughtRecords []*recorder.ModelThoughtRecord
|
||||
for _, block := range msg.Content {
|
||||
switch variant := block.AsAny().(type) {
|
||||
case anthropic.ThinkingBlock:
|
||||
if variant.Thinking == "" {
|
||||
continue
|
||||
}
|
||||
thoughtRecords = append(thoughtRecords, &recorder.ModelThoughtRecord{
|
||||
Content: variant.Thinking,
|
||||
Metadata: recorder.Metadata{"source": recorder.ThoughtSourceThinking},
|
||||
})
|
||||
}
|
||||
// anthropic.RedactedThinkingBlock also exists, but there's nothing useful we can capture.
|
||||
variant, ok := block.AsAny().(anthropic.ThinkingBlock)
|
||||
if !ok || variant.Thinking == "" {
|
||||
continue
|
||||
}
|
||||
thoughtRecords = append(thoughtRecords, &recorder.ModelThoughtRecord{
|
||||
Content: variant.Thinking,
|
||||
Metadata: recorder.Metadata{"source": recorder.ThoughtSourceThinking},
|
||||
})
|
||||
}
|
||||
return thoughtRecords
|
||||
}
|
||||
@@ -264,7 +262,7 @@ func (i *interceptionBase) withBody() option.RequestOption {
|
||||
return option.WithRequestBody("application/json", []byte(i.reqPayload))
|
||||
}
|
||||
|
||||
func (i *interceptionBase) withAWSBedrockOptions(ctx context.Context, cfg *aibconfig.AWSBedrock) ([]option.RequestOption, error) {
|
||||
func (*interceptionBase) withAWSBedrockOptions(ctx context.Context, cfg *aibconfig.AWSBedrock) ([]option.RequestOption, error) {
|
||||
if cfg == nil {
|
||||
return nil, xerrors.New("nil config given")
|
||||
}
|
||||
@@ -405,7 +403,7 @@ func filterBedrockBetaFlags(headers http.Header, model string) {
|
||||
}
|
||||
|
||||
// writeUpstreamError marshals and writes a given error.
|
||||
func (i *interceptionBase) writeUpstreamError(w http.ResponseWriter, antErr *ErrorResponse) {
|
||||
func (i *interceptionBase) writeUpstreamError(w http.ResponseWriter, antErr *messagesResponseError) {
|
||||
if antErr == nil {
|
||||
return
|
||||
}
|
||||
@@ -415,7 +413,7 @@ func (i *interceptionBase) writeUpstreamError(w http.ResponseWriter, antErr *Err
|
||||
|
||||
out, err := json.Marshal(antErr)
|
||||
if err != nil {
|
||||
i.logger.Warn(context.Background(), "failed to marshal upstream error", slog.Error(err), slog.F("error_payload", slog.F("%+v", antErr)))
|
||||
i.logger.Warn(context.Background(), "failed to marshal upstream error", slog.Error(err), slog.F("error_payload", antErr))
|
||||
// Response has to match expected format.
|
||||
// See https://docs.claude.com/en/api/errors#error-shapes.
|
||||
_, _ = w.Write([]byte(fmt.Sprintf(`{
|
||||
@@ -487,7 +485,7 @@ func accumulateUsage(dest, src any) {
|
||||
}
|
||||
}
|
||||
|
||||
func getErrorResponse(err error) *ErrorResponse {
|
||||
func getErrorResponse(err error) *messagesResponseError {
|
||||
var apierr *anthropic.Error
|
||||
if !errors.As(err, &apierr) {
|
||||
return nil
|
||||
@@ -505,7 +503,7 @@ func getErrorResponse(err error) *ErrorResponse {
|
||||
typ = string(detail.Type)
|
||||
}
|
||||
|
||||
return &ErrorResponse{
|
||||
return &messagesResponseError{
|
||||
ErrorResponse: &anthropic.ErrorResponse{
|
||||
Error: anthropic.ErrorObjectUnion{
|
||||
Message: msg,
|
||||
@@ -517,16 +515,16 @@ func getErrorResponse(err error) *ErrorResponse {
|
||||
}
|
||||
}
|
||||
|
||||
var _ error = &ErrorResponse{}
|
||||
var _ error = &messagesResponseError{}
|
||||
|
||||
type ErrorResponse struct {
|
||||
type messagesResponseError struct {
|
||||
*anthropic.ErrorResponse
|
||||
|
||||
StatusCode int `json:"-"`
|
||||
}
|
||||
|
||||
func newErrorResponse(msg error) *ErrorResponse {
|
||||
return &ErrorResponse{
|
||||
func newErrorResponse(msg error) *messagesResponseError {
|
||||
return &messagesResponseError{
|
||||
ErrorResponse: &shared.ErrorResponse{
|
||||
Error: shared.ErrorObjectUnion{
|
||||
Message: msg.Error(),
|
||||
@@ -536,7 +534,7 @@ func newErrorResponse(msg error) *ErrorResponse {
|
||||
}
|
||||
}
|
||||
|
||||
func (a *ErrorResponse) Error() string {
|
||||
func (a *messagesResponseError) Error() string {
|
||||
if a.ErrorResponse == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package messages
|
||||
package messages //nolint:testpackage // tests unexported internals
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -197,6 +197,8 @@ func TestAWSBedrockValidation(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
base := &interceptionBase{}
|
||||
opts, err := base.withAWSBedrockOptions(context.Background(), tt.cfg)
|
||||
|
||||
@@ -212,7 +214,10 @@ func TestAWSBedrockValidation(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAccumulateUsage(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("Usage to Usage", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
dest := &anthropic.Usage{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
@@ -253,6 +258,8 @@ func TestAccumulateUsage(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("MessageDeltaUsage to MessageDeltaUsage", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dest := &anthropic.MessageDeltaUsage{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
@@ -283,6 +290,8 @@ func TestAccumulateUsage(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("Usage to MessageDeltaUsage", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dest := &anthropic.MessageDeltaUsage{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
@@ -317,6 +326,8 @@ func TestAccumulateUsage(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("MessageDeltaUsage to Usage", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dest := &anthropic.Usage{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
@@ -354,6 +365,8 @@ func TestAccumulateUsage(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("Nil or unsupported types", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Test with nil dest
|
||||
var nilUsage *anthropic.Usage
|
||||
source := anthropic.Usage{InputTokens: 10}
|
||||
@@ -763,10 +776,10 @@ func TestAugmentRequestForBedrock_AdaptiveThinking(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func mustMessagesPayload(t *testing.T, requestBody string) MessagesRequestPayload {
|
||||
func mustMessagesPayload(t *testing.T, requestBody string) RequestPayload {
|
||||
t.Helper()
|
||||
|
||||
payload, err := NewMessagesRequestPayload([]byte(requestBody))
|
||||
payload, err := NewRequestPayload([]byte(requestBody))
|
||||
require.NoError(t, err)
|
||||
|
||||
return payload
|
||||
@@ -777,11 +790,11 @@ type mockServerProxier struct {
|
||||
tools []*mcp.Tool
|
||||
}
|
||||
|
||||
func (m *mockServerProxier) Init(context.Context) error {
|
||||
func (*mockServerProxier) Init(context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockServerProxier) Shutdown(context.Context) error {
|
||||
func (*mockServerProxier) Shutdown(context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -798,8 +811,8 @@ func (m *mockServerProxier) GetTool(id string) *mcp.Tool {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockServerProxier) CallTool(context.Context, string, any) (*mcpgo.CallToolResult, error) {
|
||||
return nil, nil
|
||||
func (*mockServerProxier) CallTool(context.Context, string, any) (*mcpgo.CallToolResult, error) {
|
||||
return nil, nil //nolint:nilnil // mock: no-op implementation
|
||||
}
|
||||
|
||||
func TestFilterBedrockBetaFlags(t *testing.T) {
|
||||
|
||||
@@ -31,7 +31,7 @@ type BlockingInterception struct {
|
||||
|
||||
func NewBlockingInterceptor(
|
||||
id uuid.UUID,
|
||||
reqPayload MessagesRequestPayload,
|
||||
reqPayload RequestPayload,
|
||||
providerName string,
|
||||
cfg config.Anthropic,
|
||||
bedrockCfg *config.AWSBedrock,
|
||||
@@ -53,15 +53,15 @@ func NewBlockingInterceptor(
|
||||
}}
|
||||
}
|
||||
|
||||
func (i *BlockingInterception) Setup(logger slog.Logger, recorder recorder.Recorder, mcpProxy mcp.ServerProxier) {
|
||||
i.interceptionBase.Setup(logger.Named("blocking"), recorder, mcpProxy)
|
||||
func (i *BlockingInterception) Setup(logger slog.Logger, rec recorder.Recorder, mcpProxy mcp.ServerProxier) {
|
||||
i.interceptionBase.Setup(logger.Named("blocking"), rec, mcpProxy)
|
||||
}
|
||||
|
||||
func (i *BlockingInterception) TraceAttributes(r *http.Request) []attribute.KeyValue {
|
||||
return i.interceptionBase.baseTraceAttributes(r, false)
|
||||
}
|
||||
|
||||
func (s *BlockingInterception) Streaming() bool {
|
||||
func (*BlockingInterception) Streaming() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
|
||||
@@ -82,12 +82,12 @@ var (
|
||||
}
|
||||
)
|
||||
|
||||
// MessagesRequestPayload is raw JSON bytes of an Anthropic Messages API request.
|
||||
// RequestPayload is raw JSON bytes of an Anthropic Messages API request.
|
||||
// Methods provide package-specific reads and rewrites while preserving the
|
||||
// original body for upstream pass-through.
|
||||
type MessagesRequestPayload []byte
|
||||
type RequestPayload []byte
|
||||
|
||||
func NewMessagesRequestPayload(raw []byte) (MessagesRequestPayload, error) {
|
||||
func NewRequestPayload(raw []byte) (RequestPayload, error) {
|
||||
if len(bytes.TrimSpace(raw)) == 0 {
|
||||
return nil, xerrors.New("messages empty request body")
|
||||
}
|
||||
@@ -95,10 +95,10 @@ func NewMessagesRequestPayload(raw []byte) (MessagesRequestPayload, error) {
|
||||
return nil, xerrors.New("messages invalid JSON request body")
|
||||
}
|
||||
|
||||
return MessagesRequestPayload(raw), nil
|
||||
return RequestPayload(raw), nil
|
||||
}
|
||||
|
||||
func (p MessagesRequestPayload) Stream() bool {
|
||||
func (p RequestPayload) Stream() bool {
|
||||
v := gjson.GetBytes(p, messagesReqPathStream)
|
||||
if !v.IsBool() {
|
||||
return false
|
||||
@@ -106,11 +106,11 @@ func (p MessagesRequestPayload) Stream() bool {
|
||||
return v.Bool()
|
||||
}
|
||||
|
||||
func (p MessagesRequestPayload) model() string {
|
||||
func (p RequestPayload) model() string {
|
||||
return gjson.GetBytes(p, messagesReqPathModel).Str
|
||||
}
|
||||
|
||||
func (p MessagesRequestPayload) correlatingToolCallID() *string {
|
||||
func (p RequestPayload) correlatingToolCallID() *string {
|
||||
messages := gjson.GetBytes(p, messagesReqPathMessages)
|
||||
if !messages.IsArray() {
|
||||
return nil
|
||||
@@ -147,7 +147,7 @@ func (p MessagesRequestPayload) correlatingToolCallID() *string {
|
||||
// lastUserPrompt returns the prompt text from the last user message. If no prompt
|
||||
// is found, it returns empty string, false, nil. Unexpected shapes are treated as
|
||||
// unsupported and do not fail the request path.
|
||||
func (p MessagesRequestPayload) lastUserPrompt() (string, bool, error) {
|
||||
func (p RequestPayload) lastUserPrompt() (string, bool, error) {
|
||||
messages := gjson.GetBytes(p, messagesReqPathMessages)
|
||||
if !messages.Exists() || messages.Type == gjson.Null {
|
||||
return "", false, nil
|
||||
@@ -195,7 +195,7 @@ func (p MessagesRequestPayload) lastUserPrompt() (string, bool, error) {
|
||||
return "", false, nil
|
||||
}
|
||||
|
||||
func (p MessagesRequestPayload) injectTools(injected []anthropic.ToolUnionParam) (MessagesRequestPayload, error) {
|
||||
func (p RequestPayload) injectTools(injected []anthropic.ToolUnionParam) (RequestPayload, error) {
|
||||
if len(injected) == 0 {
|
||||
return p, nil
|
||||
}
|
||||
@@ -221,7 +221,7 @@ func (p MessagesRequestPayload) injectTools(injected []anthropic.ToolUnionParam)
|
||||
return p.set(messagesReqPathTools, allTools)
|
||||
}
|
||||
|
||||
func (p MessagesRequestPayload) disableParallelToolCalls() (MessagesRequestPayload, error) {
|
||||
func (p RequestPayload) disableParallelToolCalls() (RequestPayload, error) {
|
||||
toolChoice := gjson.GetBytes(p, messagesReqPathToolChoice)
|
||||
|
||||
// If no tool_choice was defined, assume auto.
|
||||
@@ -258,7 +258,7 @@ func (p MessagesRequestPayload) disableParallelToolCalls() (MessagesRequestPaylo
|
||||
}
|
||||
}
|
||||
|
||||
func (p MessagesRequestPayload) appendedMessages(newMessages []anthropic.MessageParam) (MessagesRequestPayload, error) {
|
||||
func (p RequestPayload) appendedMessages(newMessages []anthropic.MessageParam) (RequestPayload, error) {
|
||||
if len(newMessages) == 0 {
|
||||
return p, nil
|
||||
}
|
||||
@@ -285,11 +285,11 @@ func (p MessagesRequestPayload) appendedMessages(newMessages []anthropic.Message
|
||||
return p.set(messagesReqPathMessages, allMessages)
|
||||
}
|
||||
|
||||
func (p MessagesRequestPayload) withModel(model string) (MessagesRequestPayload, error) {
|
||||
func (p RequestPayload) withModel(model string) (RequestPayload, error) {
|
||||
return p.set(messagesReqPathModel, model)
|
||||
}
|
||||
|
||||
func (p MessagesRequestPayload) messages() ([]json.RawMessage, error) {
|
||||
func (p RequestPayload) messages() ([]json.RawMessage, error) {
|
||||
messages := gjson.GetBytes(p, messagesReqPathMessages)
|
||||
if !messages.Exists() || messages.Type == gjson.Null {
|
||||
return nil, nil
|
||||
@@ -301,7 +301,7 @@ func (p MessagesRequestPayload) messages() ([]json.RawMessage, error) {
|
||||
return p.resultToRawMessage(messages.Array()), nil
|
||||
}
|
||||
|
||||
func (p MessagesRequestPayload) tools() ([]json.RawMessage, error) {
|
||||
func (p RequestPayload) tools() ([]json.RawMessage, error) {
|
||||
tools := gjson.GetBytes(p, messagesReqPathTools)
|
||||
if !tools.Exists() || tools.Type == gjson.Null {
|
||||
return nil, nil
|
||||
@@ -313,7 +313,7 @@ func (p MessagesRequestPayload) tools() ([]json.RawMessage, error) {
|
||||
return p.resultToRawMessage(tools.Array()), nil
|
||||
}
|
||||
|
||||
func (p MessagesRequestPayload) resultToRawMessage(items []gjson.Result) []json.RawMessage {
|
||||
func (RequestPayload) resultToRawMessage(items []gjson.Result) []json.RawMessage {
|
||||
// gjson.Result conversion to json.RawMessage is needed because
|
||||
// gjson.Result does not implement json.Marshaler — would
|
||||
// serialize its struct fields instead of the raw JSON it represents.
|
||||
@@ -326,7 +326,7 @@ func (p MessagesRequestPayload) resultToRawMessage(items []gjson.Result) []json.
|
||||
|
||||
// convertAdaptiveThinkingForBedrock converts thinking.type "adaptive" to "enabled" with a calculated budget_tokens
|
||||
// conversion is needed for Bedrock models that does not support the "adaptive" thinking.type
|
||||
func (p MessagesRequestPayload) convertAdaptiveThinkingForBedrock() (MessagesRequestPayload, error) {
|
||||
func (p RequestPayload) convertAdaptiveThinkingForBedrock() (RequestPayload, error) {
|
||||
thinkingType := gjson.GetBytes(p, messagesReqPathThinkingType)
|
||||
if thinkingType.String() != constAdaptive {
|
||||
return p, nil
|
||||
@@ -377,7 +377,7 @@ func (p MessagesRequestPayload) convertAdaptiveThinkingForBedrock() (MessagesReq
|
||||
// removed when the corresponding flag is absent from the Anthropic-Beta header.
|
||||
// Model-specific beta flags must already be filtered from the header before
|
||||
// calling this method (see filterBedrockBetaFlags).
|
||||
func (p MessagesRequestPayload) removeUnsupportedBedrockFields(headers http.Header) (MessagesRequestPayload, error) {
|
||||
func (p RequestPayload) removeUnsupportedBedrockFields(headers http.Header) (RequestPayload, error) {
|
||||
var payloadMap map[string]any
|
||||
if err := json.Unmarshal(p, &payloadMap); err != nil {
|
||||
return p, xerrors.Errorf("failed to unmarshal request payload when removing unsupported Bedrock fields: %w", err)
|
||||
@@ -400,13 +400,13 @@ func (p MessagesRequestPayload) removeUnsupportedBedrockFields(headers http.Head
|
||||
if err != nil {
|
||||
return p, xerrors.Errorf("failed to marshal request payload when removing unsupported Bedrock fields: %w", err)
|
||||
}
|
||||
return MessagesRequestPayload(result), nil
|
||||
return RequestPayload(result), nil
|
||||
}
|
||||
|
||||
func (p MessagesRequestPayload) set(path string, value any) (MessagesRequestPayload, error) {
|
||||
func (p RequestPayload) set(path string, value any) (RequestPayload, error) {
|
||||
out, err := sjson.SetBytes(p, path, value)
|
||||
if err != nil {
|
||||
return p, xerrors.Errorf("set %s: %w", path, err)
|
||||
}
|
||||
return MessagesRequestPayload(out), nil
|
||||
return RequestPayload(out), nil
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package messages
|
||||
package messages //nolint:testpackage // tests unexported internals
|
||||
|
||||
import (
|
||||
"testing"
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
"github.com/coder/coder/v2/aibridge/utils"
|
||||
)
|
||||
|
||||
func TestNewMessagesRequestPayload(t *testing.T) {
|
||||
func TestNewRequestPayload(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := []struct {
|
||||
@@ -42,7 +42,7 @@ func TestNewMessagesRequestPayload(t *testing.T) {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
payload, err := NewMessagesRequestPayload(testCase.requestBody)
|
||||
payload, err := NewRequestPayload(testCase.requestBody)
|
||||
if testCase.expectError {
|
||||
require.Error(t, err)
|
||||
require.Nil(t, payload)
|
||||
@@ -50,12 +50,12 @@ func TestNewMessagesRequestPayload(t *testing.T) {
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, MessagesRequestPayload(testCase.requestBody), payload)
|
||||
require.Equal(t, RequestPayload(testCase.requestBody), payload)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessagesRequestPayloadStream(t *testing.T) {
|
||||
func TestRequestPayloadStream(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := []struct {
|
||||
@@ -97,7 +97,7 @@ func TestMessagesRequestPayloadStream(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessagesRequestPayloadModel(t *testing.T) {
|
||||
func TestRequestPayloadModel(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := []struct {
|
||||
@@ -132,7 +132,7 @@ func TestMessagesRequestPayloadModel(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessagesRequestPayloadLastUserPrompt(t *testing.T) {
|
||||
func TestRequestPayloadLastUserPrompt(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := []struct {
|
||||
@@ -229,7 +229,7 @@ func TestMessagesRequestPayloadLastUserPrompt(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessagesRequestPayloadCorrelatingToolCallID(t *testing.T) {
|
||||
func TestRequestPayloadCorrelatingToolCallID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := []struct {
|
||||
@@ -266,7 +266,7 @@ func TestMessagesRequestPayloadCorrelatingToolCallID(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessagesRequestPayloadInjectTools(t *testing.T) {
|
||||
func TestRequestPayloadInjectTools(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
payload := mustMessagesPayload(t, `{"model":"claude-opus-4-5","max_tokens":1024,"messages":[{"role":"user","content":"hello"}],"tools":[{"name":"existing_tool","type":"custom","input_schema":{"type":"object","properties":{}},"cache_control":{"type":"ephemeral"}}]}`)
|
||||
@@ -291,7 +291,7 @@ func TestMessagesRequestPayloadInjectTools(t *testing.T) {
|
||||
require.Equal(t, "ephemeral", toolItems[1].Get("cache_control.type").String())
|
||||
}
|
||||
|
||||
func TestMessagesRequestPayloadConvertAdaptiveThinkingForBedrock(t *testing.T) {
|
||||
func TestRequestPayloadConvertAdaptiveThinkingForBedrock(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := []struct {
|
||||
@@ -361,7 +361,7 @@ func TestMessagesRequestPayloadConvertAdaptiveThinkingForBedrock(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessagesRequestPayloadDisableParallelToolCalls(t *testing.T) {
|
||||
func TestRequestPayloadDisableParallelToolCalls(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := []struct {
|
||||
@@ -451,7 +451,7 @@ func TestMessagesRequestPayloadDisableParallelToolCalls(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessagesRequestPayloadAppendedMessages(t *testing.T) {
|
||||
func TestRequestPayloadAppendedMessages(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
payload := mustMessagesPayload(t, `{"model":"claude-opus-4-5","max_tokens":1024,"messages":[{"role":"user","content":"hello"}]}`)
|
||||
|
||||
@@ -36,7 +36,7 @@ type StreamingInterception struct {
|
||||
|
||||
func NewStreamingInterceptor(
|
||||
id uuid.UUID,
|
||||
reqPayload MessagesRequestPayload,
|
||||
reqPayload RequestPayload,
|
||||
providerName string,
|
||||
cfg config.Anthropic,
|
||||
bedrockCfg *config.AWSBedrock,
|
||||
@@ -58,16 +58,16 @@ func NewStreamingInterceptor(
|
||||
}}
|
||||
}
|
||||
|
||||
func (s *StreamingInterception) Setup(logger slog.Logger, recorder recorder.Recorder, mcpProxy mcp.ServerProxier) {
|
||||
s.interceptionBase.Setup(logger.Named("streaming"), recorder, mcpProxy)
|
||||
func (i *StreamingInterception) Setup(logger slog.Logger, rec recorder.Recorder, mcpProxy mcp.ServerProxier) {
|
||||
i.interceptionBase.Setup(logger.Named("streaming"), rec, mcpProxy)
|
||||
}
|
||||
|
||||
func (s *StreamingInterception) Streaming() bool {
|
||||
func (*StreamingInterception) Streaming() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *StreamingInterception) TraceAttributes(r *http.Request) []attribute.KeyValue {
|
||||
return s.interceptionBase.baseTraceAttributes(r, true)
|
||||
func (i *StreamingInterception) TraceAttributes(r *http.Request) []attribute.KeyValue {
|
||||
return i.interceptionBase.baseTraceAttributes(r, true)
|
||||
}
|
||||
|
||||
// ProcessRequest handles a request to /v1/messages.
|
||||
@@ -156,7 +156,7 @@ newStream:
|
||||
for {
|
||||
// TODO add outer loop span (https://github.com/coder/aibridge/issues/67)
|
||||
if err := streamCtx.Err(); err != nil {
|
||||
lastErr = xerrors.Errorf("stream exit: %w", err)
|
||||
interceptionErr = xerrors.Errorf("stream exit: %w", err)
|
||||
break
|
||||
}
|
||||
|
||||
@@ -178,8 +178,7 @@ newStream:
|
||||
// Tool-related handling.
|
||||
switch event.Type {
|
||||
case string(constant.ValueOf[constant.ContentBlockStart]()):
|
||||
switch block := event.AsContentBlockStart().ContentBlock.AsAny().(type) {
|
||||
case anthropic.ToolUseBlock:
|
||||
if block, ok := event.AsContentBlockStart().ContentBlock.AsAny().(anthropic.ToolUseBlock); ok {
|
||||
lastToolName = block.Name
|
||||
|
||||
if i.mcpProxy != nil && i.mcpProxy.GetTool(block.Name) != nil {
|
||||
@@ -306,8 +305,7 @@ newStream:
|
||||
foundTools int
|
||||
)
|
||||
for _, block := range message.Content {
|
||||
switch variant := block.AsAny().(type) {
|
||||
case anthropic.ToolUseBlock:
|
||||
if variant, ok := block.AsAny().(anthropic.ToolUseBlock); ok {
|
||||
foundTools++
|
||||
if variant.Name == name {
|
||||
input = variant.Input
|
||||
@@ -430,24 +428,23 @@ newStream:
|
||||
// Causes a new stream to be run with updated messages.
|
||||
isFirst = false
|
||||
continue newStream
|
||||
} else {
|
||||
// Find all the non-injected tools and track their uses.
|
||||
for _, block := range message.Content {
|
||||
switch variant := block.AsAny().(type) {
|
||||
case anthropic.ToolUseBlock:
|
||||
if i.mcpProxy != nil && i.mcpProxy.GetTool(variant.Name) != nil {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
_ = i.recorder.RecordToolUsage(streamCtx, &recorder.ToolUsageRecord{
|
||||
InterceptionID: i.ID().String(),
|
||||
MsgID: message.ID,
|
||||
ToolCallID: variant.ID,
|
||||
Tool: variant.Name,
|
||||
Args: variant.Input,
|
||||
Injected: false,
|
||||
})
|
||||
// Find all the non-injected tools and track their uses.
|
||||
for _, block := range message.Content {
|
||||
if variant, ok := block.AsAny().(anthropic.ToolUseBlock); ok {
|
||||
if i.mcpProxy != nil && i.mcpProxy.GetTool(variant.Name) != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
_ = i.recorder.RecordToolUsage(streamCtx, &recorder.ToolUsageRecord{
|
||||
InterceptionID: i.ID().String(),
|
||||
MsgID: message.ID,
|
||||
ToolCallID: variant.ID,
|
||||
Tool: variant.Name,
|
||||
Args: variant.Input,
|
||||
Injected: false,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -463,11 +460,10 @@ newStream:
|
||||
if eventstream.IsUnrecoverableError(err) {
|
||||
logger.Debug(ctx, "processing terminated", slog.Error(err))
|
||||
break // Stop processing if client disconnected or context canceled.
|
||||
} else {
|
||||
logger.Warn(ctx, "failed to relay event", slog.Error(err))
|
||||
lastErr = xerrors.Errorf("relay event: %w", err)
|
||||
break
|
||||
}
|
||||
logger.Warn(ctx, "failed to relay event", slog.Error(err))
|
||||
lastErr = xerrors.Errorf("relay event: %w", err)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
@@ -477,8 +473,8 @@ newStream:
|
||||
MsgID: message.ID,
|
||||
Prompt: prompt,
|
||||
})
|
||||
prompt = ""
|
||||
promptFound = false
|
||||
prompt = "" //nolint:ineffassign // reset to prevent double-recording across newStream iterations
|
||||
promptFound = false //nolint:ineffassign // reset to prevent double-recording across newStream iterations
|
||||
}
|
||||
|
||||
if events.IsStreaming() {
|
||||
@@ -491,7 +487,7 @@ newStream:
|
||||
logger.Warn(ctx, "anthropic stream error", slog.Error(streamErr))
|
||||
interceptionErr = antErr
|
||||
} else {
|
||||
logger.Warn(ctx, "unknown error", slog.Error(streamErr))
|
||||
logger.Warn(ctx, "unknown stream error encountered", slog.Error(streamErr))
|
||||
// Unfortunately, the Anthropic SDK does not support parsing errors received in the stream
|
||||
// into known types (i.e. [shared.OverloadedError]).
|
||||
// See https://github.com/anthropics/anthropic-sdk-go/blob/v1.12.0/packages/ssestream/ssestream.go#L172-L174
|
||||
@@ -500,14 +496,14 @@ newStream:
|
||||
}
|
||||
} else if lastErr != nil {
|
||||
// Otherwise check if any logical errors occurred during processing.
|
||||
logger.Warn(ctx, "stream failed", slog.Error(lastErr))
|
||||
logger.Warn(ctx, "stream processing failed", slog.Error(lastErr))
|
||||
interceptionErr = newErrorResponse(xerrors.Errorf("processing error: %w", lastErr))
|
||||
}
|
||||
|
||||
if interceptionErr != nil {
|
||||
payload, err := i.marshal(interceptionErr)
|
||||
if err != nil {
|
||||
logger.Warn(ctx, "failed to marshal error", slog.Error(err), slog.F("error_payload", slog.F("%+v", interceptionErr)))
|
||||
logger.Warn(ctx, "failed to marshal error", slog.Error(err), slog.F("error_payload", interceptionErr.Error()))
|
||||
} else if err := events.Send(streamCtx, payload); err != nil {
|
||||
logger.Warn(ctx, "failed to relay error", slog.Error(err), slog.F("payload", payload))
|
||||
}
|
||||
@@ -518,11 +514,11 @@ newStream:
|
||||
}
|
||||
|
||||
shutdownCtx, shutdownCancel := context.WithTimeout(ctx, time.Second*30)
|
||||
defer shutdownCancel()
|
||||
// Give the events stream 30 seconds (TODO: configurable) to gracefully shutdown.
|
||||
if err := events.Shutdown(shutdownCtx); err != nil {
|
||||
logger.Warn(ctx, "event stream shutdown", slog.Error(err))
|
||||
}
|
||||
shutdownCancel()
|
||||
|
||||
// Cancel the stream context, we're now done.
|
||||
if interceptionErr != nil {
|
||||
@@ -537,8 +533,8 @@ newStream:
|
||||
return interceptionErr
|
||||
}
|
||||
|
||||
func (s *StreamingInterception) marshalEvent(event anthropic.MessageStreamEventUnion) ([]byte, error) {
|
||||
sj, err := sjson.Set(event.RawJSON(), "message.id", s.ID().String())
|
||||
func (i *StreamingInterception) marshalEvent(event anthropic.MessageStreamEventUnion) ([]byte, error) {
|
||||
sj, err := sjson.Set(event.RawJSON(), "message.id", i.ID().String())
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("marshal event id failed: %w", err)
|
||||
}
|
||||
@@ -548,10 +544,10 @@ func (s *StreamingInterception) marshalEvent(event anthropic.MessageStreamEventU
|
||||
return nil, xerrors.Errorf("marshal event usage failed: %w", err)
|
||||
}
|
||||
|
||||
return s.encodeForStream([]byte(sj), event.Type), nil
|
||||
return i.encodeForStream([]byte(sj), event.Type), nil
|
||||
}
|
||||
|
||||
func (s *StreamingInterception) marshal(payload any) ([]byte, error) {
|
||||
func (i *StreamingInterception) marshal(payload any) ([]byte, error) {
|
||||
data, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("marshal payload: %w", err)
|
||||
@@ -567,29 +563,30 @@ func (s *StreamingInterception) marshal(payload any) ([]byte, error) {
|
||||
return nil, xerrors.Errorf("could not determine type from payload %q", data)
|
||||
}
|
||||
|
||||
return s.encodeForStream(data, eventType), nil
|
||||
return i.encodeForStream(data, eventType), nil
|
||||
}
|
||||
|
||||
// https://docs.anthropic.com/en/docs/build-with-claude/streaming#basic-streaming-request
|
||||
func (s *StreamingInterception) pingPayload() []byte {
|
||||
return s.encodeForStream([]byte(`{"type": "ping"}`), "ping")
|
||||
func (i *StreamingInterception) pingPayload() []byte {
|
||||
return i.encodeForStream([]byte(`{"type": "ping"}`), "ping")
|
||||
}
|
||||
|
||||
func (s *StreamingInterception) encodeForStream(payload []byte, typ string) []byte {
|
||||
func (*StreamingInterception) encodeForStream(payload []byte, typ string) []byte {
|
||||
// bytes.Buffer writes to in-memory storage and never return errors.
|
||||
var buf bytes.Buffer
|
||||
buf.WriteString("event: ")
|
||||
buf.WriteString(typ)
|
||||
buf.WriteString("\n")
|
||||
buf.WriteString("data: ")
|
||||
buf.Write(payload)
|
||||
buf.WriteString("\n\n")
|
||||
_, _ = buf.WriteString("event: ")
|
||||
_, _ = buf.WriteString(typ)
|
||||
_, _ = buf.WriteString("\n")
|
||||
_, _ = buf.WriteString("data: ")
|
||||
_, _ = buf.Write(payload)
|
||||
_, _ = buf.WriteString("\n\n")
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
// newStream traces svc.NewStreaming() call.
|
||||
func (s *StreamingInterception) newStream(ctx context.Context, svc anthropic.MessageService) *ssestream.Stream[anthropic.MessageStreamEventUnion] {
|
||||
_, span := s.tracer.Start(ctx, "Intercept.ProcessRequest.Upstream", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...))
|
||||
func (i *StreamingInterception) newStream(ctx context.Context, svc anthropic.MessageService) *ssestream.Stream[anthropic.MessageStreamEventUnion] {
|
||||
_, span := i.tracer.Start(ctx, "Intercept.ProcessRequest.Upstream", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...))
|
||||
defer span.End()
|
||||
|
||||
return svc.NewStreaming(ctx, anthropic.MessageNewParams{}, s.withBody())
|
||||
return svc.NewStreaming(ctx, anthropic.MessageNewParams{}, i.withBody())
|
||||
}
|
||||
|
||||
@@ -42,7 +42,7 @@ type responsesInterceptionBase struct {
|
||||
// clientHeaders are the original HTTP headers from the client request.
|
||||
clientHeaders http.Header
|
||||
authHeaderName string
|
||||
reqPayload ResponsesRequestPayload
|
||||
reqPayload RequestPayload
|
||||
|
||||
cfg config.OpenAI
|
||||
recorder recorder.Recorder
|
||||
@@ -89,9 +89,9 @@ func (i *responsesInterceptionBase) Credential() intercept.CredentialInfo {
|
||||
return i.credential
|
||||
}
|
||||
|
||||
func (i *responsesInterceptionBase) Setup(logger slog.Logger, recorder recorder.Recorder, mcpProxy mcp.ServerProxier) {
|
||||
func (i *responsesInterceptionBase) Setup(logger slog.Logger, rec recorder.Recorder, mcpProxy mcp.ServerProxier) {
|
||||
i.logger = logger.With(slog.F("model", i.Model()))
|
||||
i.recorder = recorder
|
||||
i.recorder = rec
|
||||
i.mcpProxy = mcpProxy
|
||||
}
|
||||
|
||||
@@ -127,7 +127,13 @@ func (i *responsesInterceptionBase) validateRequest(ctx context.Context, w http.
|
||||
// sendCustomErr sends custom responses.Error error to the client
|
||||
// it should only be called before any data is sent back to the client
|
||||
func (i *responsesInterceptionBase) sendCustomErr(ctx context.Context, w http.ResponseWriter, code int, err error) {
|
||||
respErr := responses.Error{
|
||||
// Same JSON shape as responses.Error but using a plain struct because
|
||||
// responses.Error embeds *http.Request whose GetBody func field
|
||||
// is not JSON-marshalable (SA1026).
|
||||
respErr := struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}{
|
||||
Code: strconv.Itoa(code),
|
||||
Message: err.Error(),
|
||||
}
|
||||
@@ -222,15 +228,15 @@ func (i *responsesInterceptionBase) recordNonInjectedToolUsage(ctx context.Conte
|
||||
|
||||
func (i *responsesInterceptionBase) parseFunctionCallJSONArgs(ctx context.Context, raw string) recorder.ToolArgs {
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if trimmed != "" {
|
||||
var args recorder.ToolArgs
|
||||
if err := json.Unmarshal([]byte(trimmed), &args); err != nil {
|
||||
i.logger.Warn(ctx, "failed to unmarshal tool args", slog.Error(err))
|
||||
} else {
|
||||
return args
|
||||
}
|
||||
if trimmed == "" {
|
||||
return trimmed
|
||||
}
|
||||
return trimmed
|
||||
var args recorder.ToolArgs
|
||||
if err := json.Unmarshal([]byte(trimmed), &args); err != nil {
|
||||
i.logger.Warn(ctx, "failed to unmarshal tool args", slog.Error(err))
|
||||
return trimmed
|
||||
}
|
||||
return args
|
||||
}
|
||||
|
||||
func (i *responsesInterceptionBase) recordTokenUsage(ctx context.Context, response *responses.Response) {
|
||||
@@ -264,7 +270,7 @@ func (i *responsesInterceptionBase) recordTokenUsage(ctx context.Context, respon
|
||||
// extractModelThoughts extracts model thoughts from response output items.
|
||||
// It captures both reasoning summary items and commentary messages (message
|
||||
// output items with "phase": "commentary") as model thoughts.
|
||||
func (i *responsesInterceptionBase) extractModelThoughts(response *responses.Response) []*recorder.ModelThoughtRecord {
|
||||
func (*responsesInterceptionBase) extractModelThoughts(response *responses.Response) []*recorder.ModelThoughtRecord {
|
||||
if response == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package responses
|
||||
package responses //nolint:testpackage // tests unexported internals
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
@@ -359,13 +359,16 @@ func (mrw *mockResponseWriter) WriteHeader(statusCode int) {
|
||||
}
|
||||
|
||||
func TestResponseCopierDoesntSendIfNoResponseReceived(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mrw := mockResponseWriter{}
|
||||
|
||||
respCopy := responseCopier{}
|
||||
body := "test_body"
|
||||
respCopy.buff.Write([]byte(body))
|
||||
_, _ = respCopy.buff.Write([]byte(body)) // bytes.Buffer.Write never fails
|
||||
|
||||
respCopy.forwardResp(&mrw)
|
||||
err := respCopy.forwardResp(&mrw)
|
||||
require.NoError(t, err)
|
||||
require.False(t, mrw.headerCalled)
|
||||
require.False(t, mrw.writeCalled)
|
||||
require.False(t, mrw.writeHeaderCalled)
|
||||
@@ -373,7 +376,8 @@ func TestResponseCopierDoesntSendIfNoResponseReceived(t *testing.T) {
|
||||
// after response is received data is forwarded
|
||||
respCopy.responseReceived.Store(true)
|
||||
|
||||
respCopy.forwardResp(&mrw)
|
||||
err = respCopy.forwardResp(&mrw)
|
||||
require.NoError(t, err)
|
||||
require.True(t, mrw.headerCalled)
|
||||
require.True(t, mrw.writeCalled)
|
||||
require.True(t, mrw.writeHeaderCalled)
|
||||
|
||||
@@ -28,7 +28,7 @@ type BlockingResponsesInterceptor struct {
|
||||
|
||||
func NewBlockingInterceptor(
|
||||
id uuid.UUID,
|
||||
reqPayload ResponsesRequestPayload,
|
||||
reqPayload RequestPayload,
|
||||
providerName string,
|
||||
cfg config.OpenAI,
|
||||
clientHeaders http.Header,
|
||||
@@ -50,11 +50,11 @@ func NewBlockingInterceptor(
|
||||
}
|
||||
}
|
||||
|
||||
func (i *BlockingResponsesInterceptor) Setup(logger slog.Logger, recorder recorder.Recorder, mcpProxy mcp.ServerProxier) {
|
||||
i.responsesInterceptionBase.Setup(logger.Named("blocking"), recorder, mcpProxy)
|
||||
func (i *BlockingResponsesInterceptor) Setup(logger slog.Logger, rec recorder.Recorder, mcpProxy mcp.ServerProxier) {
|
||||
i.responsesInterceptionBase.Setup(logger.Named("blocking"), rec, mcpProxy)
|
||||
}
|
||||
|
||||
func (i *BlockingResponsesInterceptor) Streaming() bool {
|
||||
func (*BlockingResponsesInterceptor) Streaming() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
|
||||
@@ -37,13 +37,13 @@ var (
|
||||
reqPathType = string(constant.ValueOf[constant.Type]())
|
||||
)
|
||||
|
||||
// ResponsesRequestPayload is raw JSON bytes of a Responses API request.
|
||||
// RequestPayload is raw JSON bytes of a Responses API request.
|
||||
// Methods provide package-specific reads and rewrites while preserving the
|
||||
// original body for upstream pass-through.
|
||||
// Note: No changes are made on schema error.
|
||||
type ResponsesRequestPayload []byte
|
||||
type RequestPayload []byte
|
||||
|
||||
func NewResponsesRequestPayload(raw []byte) (ResponsesRequestPayload, error) {
|
||||
func NewRequestPayload(raw []byte) (RequestPayload, error) {
|
||||
if len(bytes.TrimSpace(raw)) == 0 {
|
||||
return nil, xerrors.New("empty request body")
|
||||
}
|
||||
@@ -51,22 +51,22 @@ func NewResponsesRequestPayload(raw []byte) (ResponsesRequestPayload, error) {
|
||||
return nil, xerrors.New("invalid JSON payload")
|
||||
}
|
||||
|
||||
return ResponsesRequestPayload(raw), nil
|
||||
return RequestPayload(raw), nil
|
||||
}
|
||||
|
||||
func (p ResponsesRequestPayload) Stream() bool {
|
||||
func (p RequestPayload) Stream() bool {
|
||||
return gjson.GetBytes(p, reqPathStream).Bool()
|
||||
}
|
||||
|
||||
func (p ResponsesRequestPayload) model() string {
|
||||
func (p RequestPayload) model() string {
|
||||
return gjson.GetBytes(p, reqPathModel).String()
|
||||
}
|
||||
|
||||
func (p ResponsesRequestPayload) background() bool {
|
||||
func (p RequestPayload) background() bool {
|
||||
return gjson.GetBytes(p, reqPathBackground).Bool()
|
||||
}
|
||||
|
||||
func (p ResponsesRequestPayload) correlatingToolCallID() *string {
|
||||
func (p RequestPayload) correlatingToolCallID() *string {
|
||||
items := gjson.GetBytes(p, reqPathInput)
|
||||
if !items.IsArray() {
|
||||
return nil
|
||||
@@ -94,7 +94,7 @@ func (p ResponsesRequestPayload) correlatingToolCallID() *string {
|
||||
// item, or the string input value if present. If no prompt is found, it returns
|
||||
// empty string, false, nil. Unexpected shapes are treated as unsupported and do
|
||||
// not fail the request path.
|
||||
func (p ResponsesRequestPayload) lastUserPrompt(ctx context.Context, logger slog.Logger) (string, bool, error) {
|
||||
func (p RequestPayload) lastUserPrompt(ctx context.Context, logger slog.Logger) (string, bool, error) {
|
||||
inputItems := gjson.GetBytes(p, reqPathInput)
|
||||
if !inputItems.Exists() || inputItems.Type == gjson.Null {
|
||||
return "", false, nil
|
||||
@@ -155,10 +155,10 @@ func (p ResponsesRequestPayload) lastUserPrompt(ctx context.Context, logger slog
|
||||
}
|
||||
|
||||
if promptExists {
|
||||
sb.WriteByte('\n')
|
||||
_ = sb.WriteByte('\n') // strings.Builder.WriteByte never fails
|
||||
}
|
||||
promptExists = true
|
||||
sb.WriteString(text.Str)
|
||||
_, _ = sb.WriteString(text.Str) // strings.Builder.WriteString never fails
|
||||
}
|
||||
|
||||
if !promptExists {
|
||||
@@ -168,7 +168,7 @@ func (p ResponsesRequestPayload) lastUserPrompt(ctx context.Context, logger slog
|
||||
return sb.String(), true, nil
|
||||
}
|
||||
|
||||
func (p ResponsesRequestPayload) injectTools(injected []responses.ToolUnionParam) (ResponsesRequestPayload, error) {
|
||||
func (p RequestPayload) injectTools(injected []responses.ToolUnionParam) (RequestPayload, error) {
|
||||
if len(injected) == 0 {
|
||||
return p, nil
|
||||
}
|
||||
@@ -189,11 +189,11 @@ func (p ResponsesRequestPayload) injectTools(injected []responses.ToolUnionParam
|
||||
return p.set(reqPathTools, allTools)
|
||||
}
|
||||
|
||||
func (p ResponsesRequestPayload) disableParallelToolCalls() (ResponsesRequestPayload, error) {
|
||||
func (p RequestPayload) disableParallelToolCalls() (RequestPayload, error) {
|
||||
return p.set(reqPathParallelToolCalls, false)
|
||||
}
|
||||
|
||||
func (p ResponsesRequestPayload) appendInputItems(items []responses.ResponseInputItemUnionParam) (ResponsesRequestPayload, error) {
|
||||
func (p RequestPayload) appendInputItems(items []responses.ResponseInputItemUnionParam) (RequestPayload, error) {
|
||||
if len(items) == 0 {
|
||||
return p, nil
|
||||
}
|
||||
@@ -212,7 +212,7 @@ func (p ResponsesRequestPayload) appendInputItems(items []responses.ResponseInpu
|
||||
return p.set(reqPathInput, allInput)
|
||||
}
|
||||
|
||||
func (p ResponsesRequestPayload) inputItems() ([]any, error) {
|
||||
func (p RequestPayload) inputItems() ([]any, error) {
|
||||
input := gjson.GetBytes(p, reqPathInput)
|
||||
if !input.Exists() || input.Type == gjson.Null {
|
||||
return []any{}, nil
|
||||
@@ -235,7 +235,7 @@ func (p ResponsesRequestPayload) inputItems() ([]any, error) {
|
||||
return existing, nil
|
||||
}
|
||||
|
||||
func (p ResponsesRequestPayload) toolItems() ([]json.RawMessage, error) {
|
||||
func (p RequestPayload) toolItems() ([]json.RawMessage, error) {
|
||||
tools := gjson.GetBytes(p, reqPathTools)
|
||||
if !tools.Exists() {
|
||||
return nil, nil
|
||||
@@ -253,7 +253,7 @@ func (p ResponsesRequestPayload) toolItems() ([]json.RawMessage, error) {
|
||||
return existing, nil
|
||||
}
|
||||
|
||||
func (p ResponsesRequestPayload) set(path string, value any) (ResponsesRequestPayload, error) {
|
||||
func (p RequestPayload) set(path string, value any) (RequestPayload, error) {
|
||||
updated, err := sjson.SetBytes(p, path, value)
|
||||
if err != nil {
|
||||
return p, xerrors.Errorf("failed to set value at path %s: %w", path, err)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package responses
|
||||
package responses //nolint:testpackage // tests unexported internals
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
@@ -16,7 +16,7 @@ import (
|
||||
"github.com/coder/coder/v2/aibridge/utils"
|
||||
)
|
||||
|
||||
func TestNewResponsesRequestPayload(t *testing.T) {
|
||||
func TestNewRequestPayload(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
payloadWithWrongTypes := []byte(`{"model":123,"stream":"yes","input":42,"background":"nope"}`)
|
||||
@@ -42,7 +42,7 @@ func TestNewResponsesRequestPayload(t *testing.T) {
|
||||
err: "invalid JSON payload",
|
||||
},
|
||||
{
|
||||
// ResponsesRequestPayload just checks for JSON validity,
|
||||
// RequestPayload just checks for JSON validity,
|
||||
// schema errors are not surfaced here and
|
||||
// the original body is preserved for upstream handling
|
||||
// similar to how reverse proxy would behave.
|
||||
@@ -59,7 +59,7 @@ func TestNewResponsesRequestPayload(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
payload, err := NewResponsesRequestPayload(tc.raw)
|
||||
payload, err := NewRequestPayload(tc.raw)
|
||||
|
||||
if tc.err != "" {
|
||||
require.ErrorContains(t, err, tc.err)
|
||||
@@ -518,10 +518,10 @@ func injectedFunctionTool(name string) responses.ToolUnionParam {
|
||||
}
|
||||
}
|
||||
|
||||
func mustPayload(t *testing.T, raw []byte) ResponsesRequestPayload {
|
||||
func mustPayload(t *testing.T, raw []byte) RequestPayload {
|
||||
t.Helper()
|
||||
|
||||
payload, err := NewResponsesRequestPayload(raw)
|
||||
payload, err := NewRequestPayload(raw)
|
||||
require.NoError(t, err)
|
||||
return payload
|
||||
}
|
||||
|
||||
@@ -35,7 +35,7 @@ type StreamingResponsesInterceptor struct {
|
||||
|
||||
func NewStreamingInterceptor(
|
||||
id uuid.UUID,
|
||||
reqPayload ResponsesRequestPayload,
|
||||
reqPayload RequestPayload,
|
||||
providerName string,
|
||||
cfg config.OpenAI,
|
||||
clientHeaders http.Header,
|
||||
@@ -57,11 +57,11 @@ func NewStreamingInterceptor(
|
||||
}
|
||||
}
|
||||
|
||||
func (i *StreamingResponsesInterceptor) Setup(logger slog.Logger, recorder recorder.Recorder, mcpProxy mcp.ServerProxier) {
|
||||
i.responsesInterceptionBase.Setup(logger.Named("streaming"), recorder, mcpProxy)
|
||||
func (i *StreamingResponsesInterceptor) Setup(logger slog.Logger, rec recorder.Recorder, mcpProxy mcp.ServerProxier) {
|
||||
i.responsesInterceptionBase.Setup(logger.Named("streaming"), rec, mcpProxy)
|
||||
}
|
||||
|
||||
func (i *StreamingResponsesInterceptor) Streaming() bool {
|
||||
func (*StreamingResponsesInterceptor) Streaming() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package integrationtest
|
||||
package integrationtest //nolint:testpackage // tests unexported internals
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
@@ -19,6 +18,7 @@ import (
|
||||
"github.com/coder/coder/v2/aibridge/config"
|
||||
"github.com/coder/coder/v2/aibridge/fixtures"
|
||||
"github.com/coder/coder/v2/aibridge/intercept/apidump"
|
||||
"github.com/coder/coder/v2/aibridge/internal/testutil"
|
||||
"github.com/coder/coder/v2/aibridge/provider"
|
||||
)
|
||||
|
||||
@@ -114,23 +114,25 @@ func TestAPIDump(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
// Setup mock upstream server.
|
||||
fix := fixtures.Parse(t, tc.fixture)
|
||||
srv := newMockUpstream(t, ctx, newFixtureResponse(fix))
|
||||
srv := newMockUpstream(ctx, t, newFixtureResponse(fix))
|
||||
|
||||
// Create temp dir for API dumps.
|
||||
dumpDir := t.TempDir()
|
||||
|
||||
bridgeServer := newBridgeTestServer(t, ctx, srv.URL,
|
||||
bridgeServer := newBridgeTestServer(ctx, t, srv.URL,
|
||||
withCustomProvider(tc.providerFunc(srv.URL, dumpDir)),
|
||||
)
|
||||
|
||||
resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, fix.Request(), tc.headers)
|
||||
resp, err := bridgeServer.makeRequest(t, http.MethodPost, tc.path, fix.Request(), tc.headers)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
_, err := io.ReadAll(resp.Body)
|
||||
_, err = io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify dump files were created.
|
||||
@@ -187,6 +189,7 @@ func TestAPIDump(t *testing.T) {
|
||||
// Parse the dumped HTTP response.
|
||||
dumpResp, err := http.ReadResponse(bufio.NewReader(bytes.NewReader(respDumpData)), nil)
|
||||
require.NoError(t, err)
|
||||
defer dumpResp.Body.Close()
|
||||
require.Equal(t, http.StatusOK, dumpResp.StatusCode)
|
||||
dumpRespBody, err := io.ReadAll(dumpResp.Body)
|
||||
require.NoError(t, err)
|
||||
@@ -241,7 +244,7 @@ func TestAPIDumpPassthrough(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -252,16 +255,18 @@ func TestAPIDumpPassthrough(t *testing.T) {
|
||||
|
||||
dumpDir := t.TempDir()
|
||||
|
||||
bridgeServer := newBridgeTestServer(t, ctx, upstream.URL,
|
||||
bridgeServer := newBridgeTestServer(ctx, t, upstream.URL,
|
||||
withCustomProvider(tc.providerFunc(upstream.URL, dumpDir)),
|
||||
)
|
||||
|
||||
bridgeServer.makeRequest(t, http.MethodGet, tc.requestPath, nil)
|
||||
resp, err := bridgeServer.makeRequest(t, http.MethodGet, tc.requestPath, nil)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Find dump files in the passthrough directory.
|
||||
passthroughDir := filepath.Join(dumpDir, tc.name, "passthrough")
|
||||
var reqDumpFile, respDumpFile string
|
||||
err := filepath.Walk(passthroughDir, func(path string, info os.FileInfo, err error) error {
|
||||
err = filepath.Walk(passthroughDir, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -299,6 +304,7 @@ func TestAPIDumpPassthrough(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
dumpResp, err := http.ReadResponse(bufio.NewReader(bytes.NewReader(respDumpData)), nil)
|
||||
require.NoError(t, err)
|
||||
defer dumpResp.Body.Close()
|
||||
require.Equal(t, http.StatusOK, dumpResp.StatusCode)
|
||||
dumpRespBody, err := io.ReadAll(dumpResp.Body)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package integrationtest
|
||||
package integrationtest //nolint:testpackage // tests unexported internals
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/anthropics/anthropic-sdk-go"
|
||||
"github.com/anthropics/anthropic-sdk-go/packages/ssestream"
|
||||
@@ -29,6 +28,7 @@ import (
|
||||
"github.com/coder/coder/v2/aibridge/config"
|
||||
"github.com/coder/coder/v2/aibridge/fixtures"
|
||||
"github.com/coder/coder/v2/aibridge/intercept"
|
||||
"github.com/coder/coder/v2/aibridge/internal/testutil"
|
||||
"github.com/coder/coder/v2/aibridge/mcp"
|
||||
"github.com/coder/coder/v2/aibridge/provider"
|
||||
"github.com/coder/coder/v2/aibridge/recorder"
|
||||
@@ -78,18 +78,20 @@ func TestAnthropicMessages(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
fix := fixtures.Parse(t, fixtures.AntSingleBuiltinTool)
|
||||
upstream := newMockUpstream(t, ctx, newFixtureResponse(fix))
|
||||
upstream := newMockUpstream(ctx, t, newFixtureResponse(fix))
|
||||
|
||||
bridgeServer := newBridgeTestServer(t, ctx, upstream.URL)
|
||||
bridgeServer := newBridgeTestServer(ctx, t, upstream.URL)
|
||||
|
||||
// Make API call to aibridge for Anthropic /v1/messages
|
||||
reqBody, err := sjson.SetBytes(fix.Request(), "stream", tc.streaming)
|
||||
require.NoError(t, err)
|
||||
resp := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody)
|
||||
resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
// Response-specific checks.
|
||||
@@ -210,17 +212,19 @@ func TestAnthropicMessagesModelThoughts(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
fix := fixtures.Parse(t, tc.fixture)
|
||||
upstream := newMockUpstream(t, ctx, newFixtureResponse(fix))
|
||||
upstream := newMockUpstream(ctx, t, newFixtureResponse(fix))
|
||||
|
||||
bridgeServer := newBridgeTestServer(t, ctx, upstream.URL)
|
||||
bridgeServer := newBridgeTestServer(ctx, t, upstream.URL)
|
||||
|
||||
reqBody, err := sjson.SetBytes(fix.Request(), "stream", tc.streaming)
|
||||
require.NoError(t, err)
|
||||
resp := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody)
|
||||
resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
if tc.streaming {
|
||||
@@ -242,7 +246,7 @@ func TestAWSBedrockIntegration(t *testing.T) {
|
||||
t.Run("invalid config", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
// Invalid bedrock config - missing region & base url
|
||||
@@ -254,11 +258,13 @@ func TestAWSBedrockIntegration(t *testing.T) {
|
||||
SmallFastModel: "test-haiku",
|
||||
}
|
||||
|
||||
bridgeServer := newBridgeTestServer(t, ctx, "http://unused",
|
||||
bridgeServer := newBridgeTestServer(ctx, t, "http://unused",
|
||||
withCustomProvider(provider.NewAnthropic(anthropicCfg("http://unused", apiKey), bedrockCfg)),
|
||||
)
|
||||
|
||||
resp := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, fixtures.Request(t, fixtures.AntSingleBuiltinTool))
|
||||
resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, fixtures.Request(t, fixtures.AntSingleBuiltinTool))
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
require.Equal(t, http.StatusInternalServerError, resp.StatusCode)
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
@@ -272,11 +278,11 @@ func TestAWSBedrockIntegration(t *testing.T) {
|
||||
t.Run(fmt.Sprintf("%s/streaming=%v", t.Name(), streaming), func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
fix := fixtures.Parse(t, fixtures.AntSingleBuiltinTool)
|
||||
upstream := newMockUpstream(t, ctx, newFixtureResponse(fix))
|
||||
upstream := newMockUpstream(ctx, t, newFixtureResponse(fix))
|
||||
|
||||
// We define region here to validate that with Region & BaseURL defined, the latter takes precedence.
|
||||
bedrockCfg := &config.AWSBedrock{
|
||||
@@ -288,7 +294,7 @@ func TestAWSBedrockIntegration(t *testing.T) {
|
||||
BaseURL: upstream.URL, // Use the mock server.
|
||||
}
|
||||
|
||||
bridgeServer := newBridgeTestServer(t, ctx, upstream.URL,
|
||||
bridgeServer := newBridgeTestServer(ctx, t, upstream.URL,
|
||||
withCustomProvider(provider.NewAnthropic(anthropicCfg(upstream.URL, apiKey), bedrockCfg)),
|
||||
)
|
||||
|
||||
@@ -296,7 +302,9 @@ func TestAWSBedrockIntegration(t *testing.T) {
|
||||
// We override the AWS Bedrock client to route requests through our mock server.
|
||||
reqBody, err := sjson.SetBytes(fix.Request(), "stream", streaming)
|
||||
require.NoError(t, err)
|
||||
resp := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody)
|
||||
resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
// For streaming responses, consume the body to allow the stream to complete.
|
||||
if streaming {
|
||||
@@ -396,11 +404,11 @@ func TestAWSBedrockIntegration(t *testing.T) {
|
||||
t.Run(fmt.Sprintf("%s/streaming=%v", tc.name, streaming), func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
fix := fixtures.Parse(t, fixtures.AntSimpleBedrock)
|
||||
upstream := newMockUpstream(t, ctx, newFixtureResponse(fix))
|
||||
upstream := newMockUpstream(ctx, t, newFixtureResponse(fix))
|
||||
|
||||
bCfg := &config.AWSBedrock{
|
||||
Region: "us-west-2",
|
||||
@@ -411,7 +419,7 @@ func TestAWSBedrockIntegration(t *testing.T) {
|
||||
BaseURL: upstream.URL,
|
||||
}
|
||||
|
||||
bridgeServer := newBridgeTestServer(t, ctx, upstream.URL,
|
||||
bridgeServer := newBridgeTestServer(ctx, t, upstream.URL,
|
||||
withCustomProvider(provider.NewAnthropic(anthropicCfg(upstream.URL, apiKey), bCfg)),
|
||||
)
|
||||
|
||||
@@ -419,9 +427,11 @@ func TestAWSBedrockIntegration(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Send with Anthropic-Beta header containing flags that should be filtered.
|
||||
resp := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody, http.Header{
|
||||
resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody, http.Header{
|
||||
"Anthropic-Beta": {"interleaved-thinking-2025-05-14,effort-2025-11-24,context-management-2025-06-27,prompt-caching-scope-2026-01-05"},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
_, err = io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
@@ -491,18 +501,20 @@ func TestOpenAIChatCompletions(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
fix := fixtures.Parse(t, fixtures.OaiChatSingleBuiltinTool)
|
||||
upstream := newMockUpstream(t, ctx, newFixtureResponse(fix))
|
||||
upstream := newMockUpstream(ctx, t, newFixtureResponse(fix))
|
||||
|
||||
bridgeServer := newBridgeTestServer(t, ctx, upstream.URL)
|
||||
bridgeServer := newBridgeTestServer(ctx, t, upstream.URL)
|
||||
|
||||
// Make API call to aibridge for OpenAI /v1/chat/completions
|
||||
reqBody, err := sjson.SetBytes(fix.Request(), "stream", tc.streaming)
|
||||
require.NoError(t, err)
|
||||
resp := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIChatCompletions, reqBody)
|
||||
resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIChatCompletions, reqBody)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
// Response-specific checks.
|
||||
@@ -565,25 +577,27 @@ func TestOpenAIChatCompletions(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
// Setup mock server for multi-turn interaction.
|
||||
// First request → tool call response, second → tool response.
|
||||
fix := fixtures.Parse(t, tc.fixture)
|
||||
upstream := newMockUpstream(t, ctx, newFixtureResponse(fix), newFixtureToolResponse(fix))
|
||||
upstream := newMockUpstream(ctx, t, newFixtureResponse(fix), newFixtureToolResponse(fix))
|
||||
|
||||
// Setup MCP proxies with the tool from the fixture
|
||||
mockMCP := setupMCPForTest(t, defaultTracer)
|
||||
|
||||
bridgeServer := newBridgeTestServer(t, ctx, upstream.URL,
|
||||
bridgeServer := newBridgeTestServer(ctx, t, upstream.URL,
|
||||
withMCP(mockMCP),
|
||||
)
|
||||
|
||||
// Add the stream param to the request.
|
||||
reqBody, err := sjson.SetBytes(fix.Request(), "stream", true)
|
||||
require.NoError(t, err)
|
||||
resp := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIChatCompletions, reqBody)
|
||||
resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIChatCompletions, reqBody)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
// Verify SSE headers are sent correctly
|
||||
@@ -756,18 +770,20 @@ func TestSimple(t *testing.T) {
|
||||
t.Run(fmt.Sprintf("streaming=%v", streaming), func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
fix := fixtures.Parse(t, tc.fixture)
|
||||
upstream := newMockUpstream(t, ctx, newFixtureResponse(fix))
|
||||
upstream := newMockUpstream(ctx, t, newFixtureResponse(fix))
|
||||
|
||||
bridgeServer := newBridgeTestServer(t, ctx, upstream.URL+tc.basePath)
|
||||
bridgeServer := newBridgeTestServer(ctx, t, upstream.URL+tc.basePath)
|
||||
|
||||
// When: calling the "API server" with the fixture's request body.
|
||||
reqBody, err := sjson.SetBytes(fix.Request(), "stream", streaming)
|
||||
require.NoError(t, err)
|
||||
resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, reqBody, http.Header{"User-Agent": {tc.userAgent}})
|
||||
resp, err := bridgeServer.makeRequest(t, http.MethodPost, tc.path, reqBody, http.Header{"User-Agent": {tc.userAgent}})
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
// Then: I expect the upstream request to have the correct path.
|
||||
@@ -861,12 +877,12 @@ func TestSessionIDTracking(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
fix := fixtures.Parse(t, tc.fixture)
|
||||
upstream := newMockUpstream(t, ctx, newFixtureResponse(fix))
|
||||
bridgeServer := newBridgeTestServer(t, ctx, upstream.URL, withProvider(config.ProviderAnthropic))
|
||||
upstream := newMockUpstream(ctx, t, newFixtureResponse(fix))
|
||||
bridgeServer := newBridgeTestServer(ctx, t, upstream.URL, withProvider(config.ProviderAnthropic))
|
||||
|
||||
reqBody := fix.Request()
|
||||
if tc.metadataSessionID != "" {
|
||||
@@ -875,11 +891,13 @@ func TestSessionIDTracking(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
resp := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody, tc.header)
|
||||
resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody, tc.header)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
// Drain the body to let the stream complete.
|
||||
_, err := io.ReadAll(resp.Body)
|
||||
_, err = io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
|
||||
interceptions := bridgeServer.Recorder.RecordedInterceptions()
|
||||
@@ -948,10 +966,12 @@ func TestFallthrough(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
fix := fixtures.Parse(t, tc.fixture)
|
||||
upstream := newMockUpstream(t, t.Context(), newFixtureResponse(fix))
|
||||
bridgeServer := newBridgeTestServer(t, t.Context(), upstream.URL+tc.basePath)
|
||||
upstream := newMockUpstream(t.Context(), t, newFixtureResponse(fix))
|
||||
bridgeServer := newBridgeTestServer(t.Context(), t, upstream.URL+tc.basePath)
|
||||
|
||||
resp := bridgeServer.makeRequest(t, http.MethodGet, tc.requestPath, nil)
|
||||
resp, err := bridgeServer.makeRequest(t, http.MethodGet, tc.requestPath, nil)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
@@ -984,6 +1004,7 @@ func TestAnthropicInjectedTools(t *testing.T) {
|
||||
|
||||
// Build the requirements & make the assertions which are common to all providers.
|
||||
bridgeServer, mockMCP, resp := setupInjectedToolTest(t, fixtures.AntSingleInjectedTool, streaming, defaultTracer, pathAnthropicMessages, anthropicToolResultValidator(t))
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Ensure expected tool was invoked with expected input.
|
||||
toolUsages := bridgeServer.Recorder.RecordedToolUsages()
|
||||
@@ -1067,6 +1088,7 @@ func TestOpenAIInjectedTools(t *testing.T) {
|
||||
|
||||
// Build the requirements & make the assertions which are common to all providers.
|
||||
bridgeServer, mockMCP, resp := setupInjectedToolTest(t, fixtures.OaiChatSingleInjectedTool, streaming, defaultTracer, pathOpenAIChatCompletions, openaiChatToolResultValidator(t))
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Ensure expected tool was invoked with expected input.
|
||||
toolUsages := bridgeServer.Recorder.RecordedToolUsages()
|
||||
@@ -1234,6 +1256,8 @@ func TestErrorHandling(t *testing.T) {
|
||||
|
||||
// Tests that errors which occur *before* a streaming response begins, or in non-streaming requests, are handled as expected.
|
||||
t.Run("non-stream error", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
fixture []byte
|
||||
@@ -1276,21 +1300,23 @@ func TestErrorHandling(t *testing.T) {
|
||||
t.Run(fmt.Sprintf("streaming=%v", streaming), func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
// Setup mock server. Error fixtures contain raw HTTP
|
||||
// responses that may cause the bridge to retry.
|
||||
fix := fixtures.Parse(t, tc.fixture)
|
||||
upstream := newMockUpstream(t, ctx, newFixtureResponse(fix))
|
||||
upstream := newMockUpstream(ctx, t, newFixtureResponse(fix))
|
||||
|
||||
bridgeServer := newBridgeTestServer(t, ctx, upstream.URL)
|
||||
bridgeServer := newBridgeTestServer(ctx, t, upstream.URL)
|
||||
|
||||
// Add the stream param to the request.
|
||||
reqBody, err := sjson.SetBytes(fix.Request(), "stream", streaming)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, reqBody)
|
||||
resp, err := bridgeServer.makeRequest(t, http.MethodPost, tc.path, reqBody)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
tc.responseHandlerFn(resp)
|
||||
bridgeServer.Recorder.VerifyAllInterceptionsEnded(t)
|
||||
@@ -1347,17 +1373,19 @@ func TestErrorHandling(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
// Setup mock server.
|
||||
fix := fixtures.Parse(t, tc.fixture)
|
||||
upstream := newMockUpstream(t, ctx, newFixtureResponse(fix))
|
||||
upstream := newMockUpstream(ctx, t, newFixtureResponse(fix))
|
||||
upstream.StatusCode = http.StatusInternalServerError
|
||||
|
||||
bridgeServer := newBridgeTestServer(t, ctx, upstream.URL)
|
||||
bridgeServer := newBridgeTestServer(ctx, t, upstream.URL)
|
||||
|
||||
resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, fix.Request())
|
||||
resp, err := bridgeServer.makeRequest(t, http.MethodPost, tc.path, fix.Request())
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
tc.responseHandlerFn(resp)
|
||||
bridgeServer.Recorder.VerifyAllInterceptionsEnded(t)
|
||||
@@ -1394,7 +1422,7 @@ func TestStableRequestEncoding(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
// Setup MCP tools.
|
||||
@@ -1408,15 +1436,17 @@ func TestStableRequestEncoding(t *testing.T) {
|
||||
for i := range count {
|
||||
responses[i] = newFixtureResponse(fix)
|
||||
}
|
||||
upstream := newMockUpstream(t, ctx, responses...)
|
||||
upstream := newMockUpstream(ctx, t, responses...)
|
||||
|
||||
bridgeServer := newBridgeTestServer(t, ctx, upstream.URL,
|
||||
bridgeServer := newBridgeTestServer(ctx, t, upstream.URL,
|
||||
withMCP(mockMCP),
|
||||
)
|
||||
|
||||
// Make multiple requests and verify they all have identical payloads.
|
||||
for range count {
|
||||
resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, fix.Request())
|
||||
resp, err := bridgeServer.makeRequest(t, http.MethodPost, tc.path, fix.Request())
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
}
|
||||
|
||||
@@ -1657,7 +1687,7 @@ func TestAnthropicToolChoiceParallelDisabled(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
// Setup MCP tools conditionally.
|
||||
@@ -1669,9 +1699,9 @@ func TestAnthropicToolChoiceParallelDisabled(t *testing.T) {
|
||||
}
|
||||
|
||||
fix := fixtures.Parse(t, tc.fixture)
|
||||
upstream := newMockUpstream(t, ctx, newFixtureResponse(fix))
|
||||
upstream := newMockUpstream(ctx, t, newFixtureResponse(fix))
|
||||
|
||||
bridgeServer := newBridgeTestServer(t, ctx, upstream.URL,
|
||||
bridgeServer := newBridgeTestServer(ctx, t, upstream.URL,
|
||||
withMCP(mockMCP),
|
||||
)
|
||||
|
||||
@@ -1679,7 +1709,9 @@ func TestAnthropicToolChoiceParallelDisabled(t *testing.T) {
|
||||
reqBody, err := sjson.SetBytes(fix.Request(), "tool_choice", tc.toolChoice)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody)
|
||||
resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
// Verify tool_choice in the upstream request.
|
||||
@@ -1819,17 +1851,17 @@ func TestChatCompletionsParallelToolCallsDisabled(t *testing.T) {
|
||||
t.Run(fmt.Sprintf("%s/streaming=%v", tc.name, streaming), func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
fix := fixtures.Parse(t, tc.fixture)
|
||||
upstream := newMockUpstream(t, ctx, newFixtureResponse(fix))
|
||||
upstream := newMockUpstream(ctx, t, newFixtureResponse(fix))
|
||||
|
||||
var opts []bridgeOption
|
||||
if tc.withInjectedTools {
|
||||
opts = append(opts, withMCP(setupMCPForTest(t, defaultTracer)))
|
||||
}
|
||||
bridgeServer := newBridgeTestServer(t, ctx, upstream.URL, opts...)
|
||||
bridgeServer := newBridgeTestServer(ctx, t, upstream.URL, opts...)
|
||||
|
||||
var (
|
||||
reqBody = fix.Request()
|
||||
@@ -1842,7 +1874,9 @@ func TestChatCompletionsParallelToolCallsDisabled(t *testing.T) {
|
||||
reqBody, err = sjson.SetBytes(reqBody, "stream", streaming)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIChatCompletions, reqBody)
|
||||
resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIChatCompletions, reqBody)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
_, err = io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -1872,13 +1906,13 @@ func TestThinkingAdaptiveIsPreserved(t *testing.T) {
|
||||
t.Run(fmt.Sprintf("streaming=%v", streaming), func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
// Create a mock server that captures the request body sent upstream.
|
||||
upstream := newMockUpstream(t, ctx, newFixtureResponse(fix))
|
||||
upstream := newMockUpstream(ctx, t, newFixtureResponse(fix))
|
||||
|
||||
bridgeServer := newBridgeTestServer(t, ctx, upstream.URL)
|
||||
bridgeServer := newBridgeTestServer(ctx, t, upstream.URL)
|
||||
|
||||
// Inject adaptive thinking into the fixture request.
|
||||
reqBody, err := sjson.SetBytes(fix.Request(), "thinking", map[string]string{"type": "adaptive"})
|
||||
@@ -1886,7 +1920,9 @@ func TestThinkingAdaptiveIsPreserved(t *testing.T) {
|
||||
reqBody, err = sjson.SetBytes(reqBody, "stream", streaming)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody)
|
||||
resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
_, err = io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
@@ -1935,11 +1971,11 @@ func TestEnvironmentDoNotLeak(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// NOTE: Cannot use t.Parallel() here because t.Setenv requires sequential execution.
|
||||
|
||||
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
fix := fixtures.Parse(t, tc.fixture)
|
||||
upstream := newMockUpstream(t, ctx, newFixtureResponse(fix))
|
||||
upstream := newMockUpstream(ctx, t, newFixtureResponse(fix))
|
||||
|
||||
// Set environment variables that the SDK would automatically read.
|
||||
// These should NOT leak into upstream requests.
|
||||
@@ -1947,9 +1983,11 @@ func TestEnvironmentDoNotLeak(t *testing.T) {
|
||||
t.Setenv(key, val)
|
||||
}
|
||||
|
||||
bridgeServer := newBridgeTestServer(t, ctx, upstream.URL)
|
||||
bridgeServer := newBridgeTestServer(ctx, t, upstream.URL)
|
||||
|
||||
resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, fix.Request())
|
||||
resp, err := bridgeServer.makeRequest(t, http.MethodPost, tc.path, fix.Request())
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
// Verify that environment values did not leak.
|
||||
@@ -2045,14 +2083,14 @@ func TestActorHeaders(t *testing.T) {
|
||||
t.Run(fmt.Sprintf("%s/streaming=%v/send-headers=%v", tc.name, tc.streaming, send), func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
fix := fixtures.Parse(t, tc.fixture)
|
||||
upstream := newMockUpstream(t, ctx, newFixtureResponse(fix))
|
||||
upstream := newMockUpstream(ctx, t, newFixtureResponse(fix))
|
||||
|
||||
metadataKey := "Username"
|
||||
bridgeServer := newBridgeTestServer(t, ctx, upstream.URL,
|
||||
bridgeServer := newBridgeTestServer(ctx, t, upstream.URL,
|
||||
withCustomProvider(tc.createProviderFn(upstream.URL, apiKey, send)),
|
||||
withActor(defaultActorID, recorder.Metadata{
|
||||
metadataKey: actorUsername,
|
||||
@@ -2063,7 +2101,9 @@ func TestActorHeaders(t *testing.T) {
|
||||
reqBody, err := sjson.SetBytes(fix.Request(), "stream", tc.streaming)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, reqBody)
|
||||
resp, err := bridgeServer.makeRequest(t, http.MethodPost, tc.path, reqBody)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
// Drain the body so streaming responses complete without
|
||||
// a "connection reset" error in the mock upstream.
|
||||
_, err = io.ReadAll(resp.Body)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package integrationtest
|
||||
package integrationtest //nolint:testpackage // tests unexported internals
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
@@ -28,11 +28,11 @@ const (
|
||||
)
|
||||
|
||||
func anthropicSuccessResponse(model string) string {
|
||||
return fmt.Sprintf(`{"id":"msg_01","type":"message","role":"assistant","content":[{"type":"text","text":"Hello!"}],"model":"%s","stop_reason":"end_turn","usage":{"input_tokens":10,"output_tokens":5}}`, model)
|
||||
return fmt.Sprintf(`{"id":"msg_01","type":"message","role":"assistant","content":[{"type":"text","text":"Hello!"}],"model":%q,"stop_reason":"end_turn","usage":{"input_tokens":10,"output_tokens":5}}`, model)
|
||||
}
|
||||
|
||||
func openAISuccessResponse(model string) string {
|
||||
return fmt.Sprintf(`{"id":"chatcmpl-123","object":"chat.completion","created":1677652288,"model":"%s","choices":[{"index":0,"message":{"role":"assistant","content":"Hello!"},"finish_reason":"stop"}],"usage":{"prompt_tokens":9,"completion_tokens":12,"total_tokens":21}}`, model)
|
||||
return fmt.Sprintf(`{"id":"chatcmpl-123","object":"chat.completion","created":1677652288,"model":%q,"choices":[{"index":0,"message":{"role":"assistant","content":"Hello!"},"finish_reason":"stop"}],"usage":{"prompt_tokens":9,"completion_tokens":12,"total_tokens":21}}`, model)
|
||||
}
|
||||
|
||||
// TestCircuitBreaker_FullRecoveryCycle tests the complete circuit breaker lifecycle:
|
||||
@@ -130,31 +130,35 @@ func TestCircuitBreaker_FullRecoveryCycle(t *testing.T) {
|
||||
}
|
||||
|
||||
ctx := t.Context()
|
||||
bridgeServer := newBridgeTestServer(t, ctx, mockUpstream.URL,
|
||||
bridgeServer := newBridgeTestServer(ctx, t, mockUpstream.URL,
|
||||
withCustomProvider(tc.createProvider(mockUpstream.URL, cbConfig)),
|
||||
withMetrics(m),
|
||||
withActor("test-user-id", nil),
|
||||
)
|
||||
|
||||
doRequest := func() *http.Response {
|
||||
resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, []byte(tc.requestBody), tc.headers)
|
||||
_, err := io.ReadAll(resp.Body)
|
||||
doRequest := func() int {
|
||||
resp, err := bridgeServer.makeRequest(t, http.MethodPost, tc.path, []byte(tc.requestBody), tc.headers)
|
||||
require.NoError(t, err)
|
||||
return resp
|
||||
_, err = io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, resp.Body.Close())
|
||||
return resp.StatusCode
|
||||
}
|
||||
|
||||
// Phase 1: Trip the circuit breaker
|
||||
// First FailureThreshold requests hit upstream, get 429
|
||||
for i := uint32(0); i < cbConfig.FailureThreshold; i++ {
|
||||
resp := doRequest()
|
||||
assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode)
|
||||
status := doRequest()
|
||||
assert.Equal(t, http.StatusTooManyRequests, status)
|
||||
}
|
||||
//nolint:gosec // G115: test constant, no overflow risk
|
||||
assert.Equal(t, int32(cbConfig.FailureThreshold), upstreamCalls.Load())
|
||||
|
||||
// Phase 2: Verify circuit is open
|
||||
// Request should be blocked by circuit breaker (no upstream call)
|
||||
resp := doRequest()
|
||||
assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode)
|
||||
status := doRequest()
|
||||
assert.Equal(t, http.StatusServiceUnavailable, status)
|
||||
//nolint:gosec // G115: test constant, no overflow risk
|
||||
assert.Equal(t, int32(cbConfig.FailureThreshold), upstreamCalls.Load(), "No new upstream call when circuit is open")
|
||||
|
||||
// Verify metrics show circuit is open
|
||||
@@ -175,8 +179,8 @@ func TestCircuitBreaker_FullRecoveryCycle(t *testing.T) {
|
||||
|
||||
// Phase 4: Recovery - request in half-open state should succeed and close circuit
|
||||
upstreamCallsBefore := upstreamCalls.Load()
|
||||
resp = doRequest()
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode, "Request should succeed in half-open state")
|
||||
status = doRequest()
|
||||
assert.Equal(t, http.StatusOK, status, "Request should succeed in half-open state")
|
||||
assert.Equal(t, upstreamCallsBefore+1, upstreamCalls.Load(), "Request should reach upstream in half-open state")
|
||||
|
||||
// Verify circuit is now closed
|
||||
@@ -186,8 +190,8 @@ func TestCircuitBreaker_FullRecoveryCycle(t *testing.T) {
|
||||
// Phase 5: Verify circuit is fully functional again
|
||||
// Multiple requests should all succeed and reach upstream
|
||||
for i := 0; i < 3; i++ {
|
||||
resp = doRequest()
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode, "Request should succeed after circuit closes")
|
||||
status = doRequest()
|
||||
assert.Equal(t, http.StatusOK, status, "Request should succeed after circuit closes")
|
||||
}
|
||||
|
||||
// All requests should have reached upstream
|
||||
@@ -283,28 +287,30 @@ func TestCircuitBreaker_HalfOpenFailure(t *testing.T) {
|
||||
}
|
||||
|
||||
ctx := t.Context()
|
||||
bridgeServer := newBridgeTestServer(t, ctx, mockUpstream.URL,
|
||||
bridgeServer := newBridgeTestServer(ctx, t, mockUpstream.URL,
|
||||
withCustomProvider(tc.createProvider(mockUpstream.URL, cbConfig)),
|
||||
withMetrics(m),
|
||||
withActor("test-user-id", nil),
|
||||
)
|
||||
|
||||
doRequest := func() *http.Response {
|
||||
resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, []byte(tc.requestBody), tc.headers)
|
||||
_, err := io.ReadAll(resp.Body)
|
||||
doRequest := func() int {
|
||||
resp, err := bridgeServer.makeRequest(t, http.MethodPost, tc.path, []byte(tc.requestBody), tc.headers)
|
||||
require.NoError(t, err)
|
||||
return resp
|
||||
_, err = io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, resp.Body.Close())
|
||||
return resp.StatusCode
|
||||
}
|
||||
|
||||
// Phase 1: Trip the circuit
|
||||
for i := uint32(0); i < cbConfig.FailureThreshold; i++ {
|
||||
resp := doRequest()
|
||||
assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode)
|
||||
status := doRequest()
|
||||
assert.Equal(t, http.StatusTooManyRequests, status)
|
||||
}
|
||||
|
||||
// Verify circuit is open
|
||||
resp := doRequest()
|
||||
assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode)
|
||||
status := doRequest()
|
||||
assert.Equal(t, http.StatusServiceUnavailable, status)
|
||||
|
||||
trips := promtest.ToFloat64(m.CircuitBreakerTrips.WithLabelValues(tc.expectProvider, tc.expectEndpoint, tc.expectModel))
|
||||
assert.Equal(t, 1.0, trips, "CircuitBreakerTrips should be 1")
|
||||
@@ -314,13 +320,13 @@ func TestCircuitBreaker_HalfOpenFailure(t *testing.T) {
|
||||
|
||||
// Phase 3: Request in half-open state fails, circuit should re-open
|
||||
upstreamCallsBefore := upstreamCalls.Load()
|
||||
resp = doRequest()
|
||||
assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode, "Request should fail in half-open state")
|
||||
status = doRequest()
|
||||
assert.Equal(t, http.StatusTooManyRequests, status, "Request should fail in half-open state")
|
||||
assert.Equal(t, upstreamCallsBefore+1, upstreamCalls.Load(), "Request should reach upstream in half-open state")
|
||||
|
||||
// Circuit should be open again - next request should be rejected immediately
|
||||
resp = doRequest()
|
||||
assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode, "Circuit should be open again after half-open failure")
|
||||
status = doRequest()
|
||||
assert.Equal(t, http.StatusServiceUnavailable, status, "Circuit should be open again after half-open failure")
|
||||
assert.Equal(t, upstreamCallsBefore+1, upstreamCalls.Load(), "Request should NOT reach upstream when circuit re-opens")
|
||||
|
||||
// Verify metrics: trips should be 2 now (tripped twice)
|
||||
@@ -429,28 +435,30 @@ func TestCircuitBreaker_HalfOpenMaxRequests(t *testing.T) {
|
||||
}
|
||||
|
||||
ctx := t.Context()
|
||||
bridgeServer := newBridgeTestServer(t, ctx, mockUpstream.URL,
|
||||
bridgeServer := newBridgeTestServer(ctx, t, mockUpstream.URL,
|
||||
withCustomProvider(tc.createProvider(mockUpstream.URL, cbConfig)),
|
||||
withMetrics(m),
|
||||
withActor("test-user-id", nil),
|
||||
)
|
||||
|
||||
doRequest := func() *http.Response {
|
||||
resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, []byte(tc.requestBody), tc.headers)
|
||||
_, err := io.ReadAll(resp.Body)
|
||||
doRequest := func() int {
|
||||
resp, err := bridgeServer.makeRequest(t, http.MethodPost, tc.path, []byte(tc.requestBody), tc.headers)
|
||||
require.NoError(t, err)
|
||||
return resp
|
||||
_, err = io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, resp.Body.Close())
|
||||
return resp.StatusCode
|
||||
}
|
||||
|
||||
// Phase 1: Trip the circuit
|
||||
for i := uint32(0); i < cbConfig.FailureThreshold; i++ {
|
||||
resp := doRequest()
|
||||
assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode)
|
||||
status := doRequest()
|
||||
assert.Equal(t, http.StatusTooManyRequests, status)
|
||||
}
|
||||
|
||||
// Verify circuit is open
|
||||
resp := doRequest()
|
||||
assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode)
|
||||
status := doRequest()
|
||||
assert.Equal(t, http.StatusServiceUnavailable, status)
|
||||
|
||||
// Phase 2: Wait for half-open state and switch upstream to success
|
||||
time.Sleep(cbConfig.Timeout + 10*time.Millisecond)
|
||||
@@ -466,8 +474,8 @@ func TestCircuitBreaker_HalfOpenMaxRequests(t *testing.T) {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
resp := doRequest()
|
||||
responses <- resp.StatusCode
|
||||
status := doRequest()
|
||||
responses <- status
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -544,7 +552,7 @@ func TestCircuitBreaker_PerModelIsolation(t *testing.T) {
|
||||
MaxRequests: 1,
|
||||
}
|
||||
ctx := t.Context()
|
||||
bridgeServer := newBridgeTestServer(t, ctx, mockUpstream.URL,
|
||||
bridgeServer := newBridgeTestServer(ctx, t, mockUpstream.URL,
|
||||
withCustomProvider(provider.NewAnthropic(config.Anthropic{
|
||||
BaseURL: mockUpstream.URL,
|
||||
Key: "test-key",
|
||||
@@ -554,27 +562,31 @@ func TestCircuitBreaker_PerModelIsolation(t *testing.T) {
|
||||
withActor("test-user-id", nil),
|
||||
)
|
||||
|
||||
doRequest := func(model string) *http.Response {
|
||||
body := fmt.Sprintf(`{"model":"%s","max_tokens":1024,"messages":[{"role":"user","content":"hi"}]}`, model)
|
||||
resp := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, []byte(body), http.Header{
|
||||
doRequest := func(model string) int {
|
||||
body := fmt.Sprintf(`{"model":%q,"max_tokens":1024,"messages":[{"role":"user","content":"hi"}]}`, model)
|
||||
resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, []byte(body), http.Header{
|
||||
"x-api-key": {"test"},
|
||||
"anthropic-version": {"2023-06-01"},
|
||||
})
|
||||
_, err := io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
return resp
|
||||
_, err = io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, resp.Body.Close())
|
||||
return resp.StatusCode
|
||||
}
|
||||
|
||||
// Phase 1: Trip the circuit for sonnet model
|
||||
for i := uint32(0); i < cbConfig.FailureThreshold; i++ {
|
||||
resp := doRequest("claude-sonnet-4-20250514")
|
||||
assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode)
|
||||
status := doRequest("claude-sonnet-4-20250514")
|
||||
assert.Equal(t, http.StatusTooManyRequests, status)
|
||||
}
|
||||
//nolint:gosec // G115: test constant, no overflow risk
|
||||
assert.Equal(t, int32(cbConfig.FailureThreshold), sonnetCalls.Load())
|
||||
|
||||
// Verify sonnet circuit is open
|
||||
resp := doRequest("claude-sonnet-4-20250514")
|
||||
assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode, "Sonnet circuit should be open")
|
||||
status := doRequest("claude-sonnet-4-20250514")
|
||||
assert.Equal(t, http.StatusServiceUnavailable, status, "Sonnet circuit should be open")
|
||||
//nolint:gosec // G115: test constant, no overflow risk
|
||||
assert.Equal(t, int32(cbConfig.FailureThreshold), sonnetCalls.Load(), "No new sonnet calls when circuit is open")
|
||||
|
||||
// Verify sonnet metrics show circuit is open
|
||||
@@ -585,14 +597,14 @@ func TestCircuitBreaker_PerModelIsolation(t *testing.T) {
|
||||
assert.Equal(t, 1.0, sonnetState, "Sonnet CircuitBreakerState should be 1 (open)")
|
||||
|
||||
// Phase 2: Haiku model should still work (independent circuit)
|
||||
resp = doRequest("claude-3-5-haiku-20241022")
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode, "Haiku should succeed while sonnet circuit is open")
|
||||
status = doRequest("claude-3-5-haiku-20241022")
|
||||
assert.Equal(t, http.StatusOK, status, "Haiku should succeed while sonnet circuit is open")
|
||||
assert.Equal(t, int32(1), haikuCalls.Load(), "Haiku call should reach upstream")
|
||||
|
||||
// Make multiple haiku requests - all should succeed
|
||||
for i := 0; i < 3; i++ {
|
||||
resp = doRequest("claude-3-5-haiku-20241022")
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode, "Haiku should continue to succeed")
|
||||
status = doRequest("claude-3-5-haiku-20241022")
|
||||
assert.Equal(t, http.StatusOK, status, "Haiku should continue to succeed")
|
||||
}
|
||||
assert.Equal(t, int32(4), haikuCalls.Load(), "All haiku calls should reach upstream")
|
||||
|
||||
@@ -607,8 +619,8 @@ func TestCircuitBreaker_PerModelIsolation(t *testing.T) {
|
||||
time.Sleep(cbConfig.Timeout + 10*time.Millisecond)
|
||||
sonnetShouldFail.Store(false)
|
||||
|
||||
resp = doRequest("claude-sonnet-4-20250514")
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode, "Sonnet should recover after timeout")
|
||||
status = doRequest("claude-sonnet-4-20250514")
|
||||
assert.Equal(t, http.StatusOK, status, "Sonnet should recover after timeout")
|
||||
|
||||
// Verify sonnet circuit is now closed
|
||||
sonnetState = promtest.ToFloat64(m.CircuitBreakerState.WithLabelValues(config.ProviderAnthropic, "/v1/messages", "claude-sonnet-4-20250514"))
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package integrationtest
|
||||
package integrationtest //nolint:testpackage // tests unexported internals
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
promtest "github.com/prometheus/client_golang/prometheus/testutil"
|
||||
@@ -17,6 +16,7 @@ import (
|
||||
"github.com/coder/coder/v2/aibridge"
|
||||
"github.com/coder/coder/v2/aibridge/config"
|
||||
"github.com/coder/coder/v2/aibridge/fixtures"
|
||||
"github.com/coder/coder/v2/aibridge/internal/testutil"
|
||||
"github.com/coder/coder/v2/aibridge/metrics"
|
||||
)
|
||||
|
||||
@@ -104,7 +104,7 @@ func TestMetrics_Interception(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "oai_responses_blocking_error",
|
||||
fixture: fixtures.OaiResponsesBlockingHttpErr,
|
||||
fixture: fixtures.OaiResponsesBlockingHTTPErr,
|
||||
path: pathOpenAIResponses,
|
||||
headers: http.Header{"User-Agent": []string{"codex/1.0.0"}},
|
||||
expectStatus: metrics.InterceptionCountStatusFailed,
|
||||
@@ -127,7 +127,7 @@ func TestMetrics_Interception(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "oai_responses_streaming_error",
|
||||
fixture: fixtures.OaiResponsesStreamingHttpErr,
|
||||
fixture: fixtures.OaiResponsesStreamingHTTPErr,
|
||||
path: pathOpenAIResponses,
|
||||
headers: http.Header{"Originator": []string{"roo-code"}},
|
||||
expectStatus: metrics.InterceptionCountStatusFailed,
|
||||
@@ -143,20 +143,22 @@ func TestMetrics_Interception(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
fix := fixtures.Parse(t, tc.fixture)
|
||||
upstream := newMockUpstream(t, ctx, newFixtureResponse(fix))
|
||||
upstream := newMockUpstream(ctx, t, newFixtureResponse(fix))
|
||||
upstream.AllowOverflow = tc.allowOverflow
|
||||
|
||||
m := aibridge.NewMetrics(prometheus.NewRegistry())
|
||||
bridgeServer := newBridgeTestServer(t, ctx, upstream.URL,
|
||||
bridgeServer := newBridgeTestServer(ctx, t, upstream.URL,
|
||||
withMetrics(m),
|
||||
)
|
||||
|
||||
resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, fix.Request(), tc.headers)
|
||||
_, err := io.ReadAll(resp.Body)
|
||||
resp, err := bridgeServer.makeRequest(t, http.MethodPost, tc.path, fix.Request(), tc.headers)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
_, err = io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
|
||||
count := promtest.ToFloat64(m.InterceptionCount.WithLabelValues(
|
||||
@@ -173,7 +175,7 @@ func TestMetrics_InterceptionsInflight(t *testing.T) {
|
||||
|
||||
fix := fixtures.Parse(t, fixtures.AntSimple)
|
||||
|
||||
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
blockCh := make(chan struct{})
|
||||
@@ -185,7 +187,7 @@ func TestMetrics_InterceptionsInflight(t *testing.T) {
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
m := aibridge.NewMetrics(prometheus.NewRegistry())
|
||||
bridgeServer := newBridgeTestServer(t, ctx, srv.URL,
|
||||
bridgeServer := newBridgeTestServer(ctx, t, srv.URL,
|
||||
withMetrics(m),
|
||||
)
|
||||
|
||||
@@ -208,7 +210,7 @@ func TestMetrics_InterceptionsInflight(t *testing.T) {
|
||||
return promtest.ToFloat64(
|
||||
m.InterceptionsInflight.WithLabelValues(config.ProviderAnthropic, "claude-sonnet-4-0", "/v1/messages"),
|
||||
) == 1
|
||||
}, time.Second*10, time.Millisecond*50)
|
||||
}, testutil.WaitMedium, testutil.IntervalFast)
|
||||
|
||||
// Unblock request, await completion.
|
||||
close(blockCh)
|
||||
@@ -223,7 +225,7 @@ func TestMetrics_InterceptionsInflight(t *testing.T) {
|
||||
return promtest.ToFloat64(
|
||||
m.InterceptionsInflight.WithLabelValues(config.ProviderAnthropic, "claude-sonnet-4-0", "/v1/messages"),
|
||||
) == 0
|
||||
}, time.Second*10, time.Millisecond*50)
|
||||
}, testutil.WaitMedium, testutil.IntervalFast)
|
||||
}
|
||||
|
||||
func TestMetrics_PassthroughCount(t *testing.T) {
|
||||
@@ -233,11 +235,13 @@ func TestMetrics_PassthroughCount(t *testing.T) {
|
||||
t.Cleanup(upstream.Close)
|
||||
|
||||
m := aibridge.NewMetrics(prometheus.NewRegistry())
|
||||
bridgeServer := newBridgeTestServer(t, t.Context(), upstream.URL,
|
||||
bridgeServer := newBridgeTestServer(t.Context(), t, upstream.URL,
|
||||
withMetrics(m),
|
||||
)
|
||||
|
||||
resp := bridgeServer.makeRequest(t, http.MethodGet, "/openai/v1/models", nil)
|
||||
resp, err := bridgeServer.makeRequest(t, http.MethodGet, "/openai/v1/models", nil)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
count := promtest.ToFloat64(m.PassthroughCount.WithLabelValues(
|
||||
@@ -248,20 +252,22 @@ func TestMetrics_PassthroughCount(t *testing.T) {
|
||||
func TestMetrics_PromptCount(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
fix := fixtures.Parse(t, fixtures.OaiChatSimple)
|
||||
upstream := newMockUpstream(t, ctx, newFixtureResponse(fix))
|
||||
upstream := newMockUpstream(ctx, t, newFixtureResponse(fix))
|
||||
|
||||
m := aibridge.NewMetrics(prometheus.NewRegistry())
|
||||
bridgeServer := newBridgeTestServer(t, ctx, upstream.URL,
|
||||
bridgeServer := newBridgeTestServer(ctx, t, upstream.URL,
|
||||
withMetrics(m),
|
||||
)
|
||||
|
||||
resp := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIChatCompletions, fix.Request(), http.Header{"User-Agent": []string{"claude-code/1.0.0"}})
|
||||
resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIChatCompletions, fix.Request(), http.Header{"User-Agent": []string{"claude-code/1.0.0"}})
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
_, err := io.ReadAll(resp.Body)
|
||||
_, err = io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
|
||||
prompts := promtest.ToFloat64(m.PromptCount.WithLabelValues(
|
||||
@@ -336,14 +342,14 @@ func TestMetrics_TokenUseCount(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
fix := fixtures.Parse(t, tc.fixture)
|
||||
upstream := newMockUpstream(t, ctx, newFixtureResponse(fix))
|
||||
upstream := newMockUpstream(ctx, t, newFixtureResponse(fix))
|
||||
|
||||
m := aibridge.NewMetrics(prometheus.NewRegistry())
|
||||
bridgeServer := newBridgeTestServer(t, ctx, upstream.URL,
|
||||
bridgeServer := newBridgeTestServer(ctx, t, upstream.URL,
|
||||
withMetrics(m),
|
||||
)
|
||||
|
||||
@@ -353,7 +359,9 @@ func TestMetrics_TokenUseCount(t *testing.T) {
|
||||
reqBody, err = sjson.SetBytes(reqBody, "stream", true)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
resp := bridgeServer.makeRequest(t, http.MethodPost, tc.reqPath, reqBody, nil)
|
||||
resp, err := bridgeServer.makeRequest(t, http.MethodPost, tc.reqPath, reqBody, nil)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
_, _ = io.ReadAll(resp.Body)
|
||||
|
||||
@@ -361,7 +369,7 @@ func TestMetrics_TokenUseCount(t *testing.T) {
|
||||
require.Eventually(t, func() bool {
|
||||
return promtest.ToFloat64(m.TokenUseCount.WithLabelValues(
|
||||
tc.expectProvider, tc.expectModel, "input", defaultActorID, string(aibridge.ClientUnknown))) > 0
|
||||
}, time.Second*10, time.Millisecond*50)
|
||||
}, testutil.WaitMedium, testutil.IntervalFast)
|
||||
|
||||
for label, expected := range tc.expectedLabels {
|
||||
require.Equal(t, expected, promtest.ToFloat64(m.TokenUseCount.WithLabelValues(
|
||||
@@ -375,20 +383,22 @@ func TestMetrics_TokenUseCount(t *testing.T) {
|
||||
func TestMetrics_NonInjectedToolUseCount(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
fix := fixtures.Parse(t, fixtures.OaiChatSingleBuiltinTool)
|
||||
upstream := newMockUpstream(t, ctx, newFixtureResponse(fix))
|
||||
upstream := newMockUpstream(ctx, t, newFixtureResponse(fix))
|
||||
|
||||
m := aibridge.NewMetrics(prometheus.NewRegistry())
|
||||
bridgeServer := newBridgeTestServer(t, ctx, upstream.URL,
|
||||
bridgeServer := newBridgeTestServer(ctx, t, upstream.URL,
|
||||
withMetrics(m),
|
||||
)
|
||||
|
||||
resp := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIChatCompletions, fix.Request())
|
||||
resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIChatCompletions, fix.Request())
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
_, err := io.ReadAll(resp.Body)
|
||||
_, err = io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
|
||||
count := promtest.ToFloat64(m.NonInjectedToolUseCount.WithLabelValues(
|
||||
@@ -399,32 +409,34 @@ func TestMetrics_NonInjectedToolUseCount(t *testing.T) {
|
||||
func TestMetrics_InjectedToolUseCount(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
// First request returns the tool invocation, the second returns the mocked response to the tool result.
|
||||
fix := fixtures.Parse(t, fixtures.AntSingleInjectedTool)
|
||||
upstream := newMockUpstream(t, ctx, newFixtureResponse(fix), newFixtureToolResponse(fix))
|
||||
upstream := newMockUpstream(ctx, t, newFixtureResponse(fix), newFixtureToolResponse(fix))
|
||||
|
||||
m := aibridge.NewMetrics(prometheus.NewRegistry())
|
||||
|
||||
// Setup mocked MCP server & tools.
|
||||
mockMCP := setupMCPForTest(t, defaultTracer)
|
||||
|
||||
bridgeServer := newBridgeTestServer(t, ctx, upstream.URL,
|
||||
bridgeServer := newBridgeTestServer(ctx, t, upstream.URL,
|
||||
withMetrics(m),
|
||||
withMCP(mockMCP),
|
||||
)
|
||||
|
||||
resp := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, fix.Request())
|
||||
resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, fix.Request())
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
_, err := io.ReadAll(resp.Body)
|
||||
_, err = io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Wait until full roundtrip has completed.
|
||||
require.Eventually(t, func() bool {
|
||||
return upstream.Calls.Load() == 2
|
||||
}, time.Second*10, time.Millisecond*50)
|
||||
}, testutil.WaitMedium, testutil.IntervalFast)
|
||||
|
||||
recorder := bridgeServer.Recorder
|
||||
require.Len(t, recorder.ToolUsages(), 1)
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mark3labs/mcp-go/client/transport"
|
||||
mcplib "github.com/mark3labs/mcp-go/mcp"
|
||||
@@ -19,6 +18,7 @@ import (
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/aibridge/internal/testutil"
|
||||
"github.com/coder/coder/v2/aibridge/mcp"
|
||||
)
|
||||
|
||||
@@ -68,12 +68,12 @@ func setupMCPForTestWithName(t *testing.T, name string, tracer trace.Tracer) *mo
|
||||
|
||||
mgr := mcp.NewServerProxyManager(map[string]mcp.ServerProxier{proxy.Name(): proxy}, tracer)
|
||||
t.Cleanup(func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
|
||||
defer cancel()
|
||||
require.NoError(t, mgr.Shutdown(ctx))
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong)
|
||||
t.Cleanup(cancel)
|
||||
require.NoError(t, mgr.Init(ctx))
|
||||
require.NotEmpty(t, mgr.ListTools(), "mock MCP server should expose tools after init")
|
||||
@@ -141,7 +141,7 @@ func createMockMCPSrv(t *testing.T) (http.Handler, *callAccumulator) {
|
||||
tool := mcplib.NewTool(name,
|
||||
mcplib.WithDescription(fmt.Sprintf("Mock of the %s tool", name)),
|
||||
)
|
||||
s.AddTool(tool, func(ctx context.Context, request mcplib.CallToolRequest) (*mcplib.CallToolResult, error) {
|
||||
s.AddTool(tool, func(_ context.Context, request mcplib.CallToolRequest) (*mcplib.CallToolResult, error) {
|
||||
acc.addCall(request.Params.Name, request.Params.Arguments)
|
||||
if errMsg, ok := acc.getToolError(request.Params.Name); ok {
|
||||
return nil, xerrors.New(errMsg)
|
||||
|
||||
@@ -111,9 +111,9 @@ func (ms *mockUpstream) receivedRequests() []receivedRequest {
|
||||
// The test fails if the number of requests doesn't match the number of
|
||||
// responses (when AllowOverflow is not set, default).
|
||||
//
|
||||
// srv := newMockUpstream(t, ctx, newFixtureResponse(fix)) // simple
|
||||
// srv := newMockUpstream(t, ctx, newFixtureResponse(fix), newFixtureToolResponse(fix)) // multi-turn
|
||||
func newMockUpstream(t *testing.T, ctx context.Context, responses ...upstreamResponse) *mockUpstream {
|
||||
// srv := newMockUpstream(ctx, t, newFixtureResponse(fix)) // simple
|
||||
// srv := newMockUpstream(ctx, t, newFixtureResponse(fix), newFixtureToolResponse(fix)) // multi-turn
|
||||
func newMockUpstream(ctx context.Context, t *testing.T, responses ...upstreamResponse) *mockUpstream {
|
||||
t.Helper()
|
||||
require.NotEmpty(t, responses, "at least one upstreamResponse required")
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package integrationtest
|
||||
package integrationtest //nolint:testpackage // tests unexported internals
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -22,6 +22,7 @@ import (
|
||||
"github.com/coder/coder/v2/aibridge"
|
||||
"github.com/coder/coder/v2/aibridge/config"
|
||||
"github.com/coder/coder/v2/aibridge/fixtures"
|
||||
"github.com/coder/coder/v2/aibridge/internal/testutil"
|
||||
"github.com/coder/coder/v2/aibridge/provider"
|
||||
"github.com/coder/coder/v2/aibridge/recorder"
|
||||
"github.com/coder/coder/v2/aibridge/utils"
|
||||
@@ -335,15 +336,17 @@ func TestResponsesOutputMatchesUpstream(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
fix := fixtures.Parse(t, tc.fixture)
|
||||
upstream := newMockUpstream(t, ctx, newFixtureResponse(fix))
|
||||
upstream := newMockUpstream(ctx, t, newFixtureResponse(fix))
|
||||
|
||||
bridgeServer := newBridgeTestServer(t, ctx, upstream.URL)
|
||||
bridgeServer := newBridgeTestServer(ctx, t, upstream.URL)
|
||||
|
||||
resp := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIResponses, fix.Request(), http.Header{"User-Agent": {tc.userAgent}})
|
||||
resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIResponses, fix.Request(), http.Header{"User-Agent": {tc.userAgent}})
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
got, err := io.ReadAll(resp.Body)
|
||||
|
||||
@@ -416,7 +419,7 @@ func TestResponsesBackgroundModeForbidden(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
// request with Background mode should be rejected before it reaches upstream
|
||||
@@ -426,11 +429,13 @@ func TestResponsesBackgroundModeForbidden(t *testing.T) {
|
||||
}))
|
||||
t.Cleanup(upstream.Close)
|
||||
|
||||
bridgeServer := newBridgeTestServer(t, ctx, upstream.URL)
|
||||
bridgeServer := newBridgeTestServer(ctx, t, upstream.URL)
|
||||
|
||||
// Create a request with background mode enabled
|
||||
reqBytes := responsesRequestBytes(t, tc.streaming, keyVal{"background", true})
|
||||
resp := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIResponses, reqBytes)
|
||||
resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIResponses, reqBytes)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
require.Equal(t, "application/json", resp.Header.Get("Content-Type"))
|
||||
require.Equal(t, http.StatusNotImplemented, resp.StatusCode)
|
||||
@@ -547,17 +552,17 @@ func TestResponsesParallelToolsOverwritten(t *testing.T) {
|
||||
t.Run(fmt.Sprintf("%s/streaming=%v", tc.name, streaming), func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
fix := fixtures.Parse(t, tc.fixture[i])
|
||||
upstream := newMockUpstream(t, ctx, newFixtureResponse(fix))
|
||||
upstream := newMockUpstream(ctx, t, newFixtureResponse(fix))
|
||||
|
||||
var opts []bridgeOption
|
||||
if tc.withInjectedTools {
|
||||
opts = append(opts, withMCP(setupMCPForTest(t, defaultTracer)))
|
||||
}
|
||||
bridgeServer := newBridgeTestServer(t, ctx, upstream.URL, opts...)
|
||||
bridgeServer := newBridgeTestServer(ctx, t, upstream.URL, opts...)
|
||||
|
||||
var (
|
||||
reqBody = fix.Request()
|
||||
@@ -568,7 +573,9 @@ func TestResponsesParallelToolsOverwritten(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
resp := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIResponses, reqBody)
|
||||
resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIResponses, reqBody)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
_, err = io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -631,14 +638,16 @@ func TestClientAndConnectionError(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
// tc.addr may be an intentionally invalid URL; use withCustomProvider.
|
||||
bridgeServer := newBridgeTestServer(t, ctx, tc.addr, withCustomProvider(provider.NewOpenAI(openAICfg(tc.addr, apiKey))))
|
||||
bridgeServer := newBridgeTestServer(ctx, t, tc.addr, withCustomProvider(provider.NewOpenAI(openAICfg(tc.addr, apiKey))))
|
||||
|
||||
reqBytes := responsesRequestBytes(t, tc.streaming)
|
||||
resp := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIResponses, reqBytes)
|
||||
resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIResponses, reqBytes)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
require.Equal(t, "application/json", resp.Header.Get("Content-Type"))
|
||||
require.Equal(t, http.StatusInternalServerError, resp.StatusCode)
|
||||
@@ -701,7 +710,7 @@ func TestUpstreamError(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -712,10 +721,12 @@ func TestUpstreamError(t *testing.T) {
|
||||
}))
|
||||
t.Cleanup(upstream.Close)
|
||||
|
||||
bridgeServer := newBridgeTestServer(t, ctx, upstream.URL)
|
||||
bridgeServer := newBridgeTestServer(ctx, t, upstream.URL)
|
||||
|
||||
reqBytes := responsesRequestBytes(t, tc.streaming)
|
||||
resp := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIResponses, reqBytes)
|
||||
resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIResponses, reqBytes)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
require.Equal(t, tc.statusCode, resp.StatusCode)
|
||||
require.Equal(t, tc.contentType, resp.Header.Get("Content-Type"))
|
||||
@@ -880,13 +891,13 @@ func TestResponsesInjectedTool(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
// Setup mock server for multi-turn interaction.
|
||||
// First request → tool call response, second → tool response.
|
||||
fix := fixtures.Parse(t, tc.fixture)
|
||||
upstream := newMockUpstream(t, ctx, newFixtureResponse(fix), newFixtureToolResponse(fix))
|
||||
upstream := newMockUpstream(ctx, t, newFixtureResponse(fix), newFixtureToolResponse(fix))
|
||||
|
||||
// Setup MCP server proxies (with mock tools).
|
||||
mockMCP := setupMCPForTest(t, defaultTracer)
|
||||
@@ -894,9 +905,11 @@ func TestResponsesInjectedTool(t *testing.T) {
|
||||
mockMCP.setToolError(tc.mcpToolName, tc.expectToolError)
|
||||
}
|
||||
|
||||
bridgeServer := newBridgeTestServer(t, ctx, upstream.URL, withMCP(mockMCP))
|
||||
bridgeServer := newBridgeTestServer(ctx, t, upstream.URL, withMCP(mockMCP))
|
||||
|
||||
resp := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIResponses, fix.Request())
|
||||
resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIResponses, fix.Request())
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
@@ -905,7 +918,7 @@ func TestResponsesInjectedTool(t *testing.T) {
|
||||
// Wait for both requests to be made (inner agentic loop).
|
||||
require.Eventually(t, func() bool {
|
||||
return upstream.Calls.Load() == 2
|
||||
}, time.Second*10, time.Millisecond*50)
|
||||
}, testutil.WaitMedium, testutil.IntervalFast)
|
||||
|
||||
// Verify the injected tool was invoked via MCP.
|
||||
invocations := mockMCP.getCallsByTool(tc.mcpToolName)
|
||||
@@ -1025,18 +1038,20 @@ func TestResponsesModelThoughts(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
fix := fixtures.Parse(t, tc.fixture)
|
||||
upstream := newMockUpstream(t, ctx, newFixtureResponse(fix))
|
||||
upstream := newMockUpstream(ctx, t, newFixtureResponse(fix))
|
||||
|
||||
bridgeServer := newBridgeTestServer(t, ctx, upstream.URL)
|
||||
bridgeServer := newBridgeTestServer(ctx, t, upstream.URL)
|
||||
|
||||
resp := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIResponses, fix.Request())
|
||||
resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIResponses, fix.Request())
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
_, err := io.ReadAll(resp.Body)
|
||||
_, err = io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
|
||||
bridgeServer.Recorder.VerifyModelThoughtsRecorded(t, tc.expectedThoughts)
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/sjson"
|
||||
@@ -63,11 +62,13 @@ type bridgeTestServer struct {
|
||||
|
||||
// makeRequest builds and executes an HTTP request against this server.
|
||||
// Optional headers are applied after the default Content-Type.
|
||||
func (s *bridgeTestServer) makeRequest(t *testing.T, method string, path string, body []byte, header ...http.Header) *http.Response {
|
||||
func (s *bridgeTestServer) makeRequest(t *testing.T, method string, path string, body []byte, header ...http.Header) (*http.Response, error) {
|
||||
t.Helper()
|
||||
|
||||
req, err := http.NewRequestWithContext(t.Context(), method, s.URL+path, bytes.NewReader(body))
|
||||
require.NoError(t, err)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
for _, h := range header {
|
||||
for k, vals := range h {
|
||||
@@ -76,10 +77,7 @@ func (s *bridgeTestServer) makeRequest(t *testing.T, method string, path string,
|
||||
}
|
||||
}
|
||||
}
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { _ = resp.Body.Close() })
|
||||
return resp
|
||||
return http.DefaultClient.Do(req)
|
||||
}
|
||||
|
||||
type bridgeOption func(*bridgeConfig)
|
||||
@@ -133,8 +131,8 @@ func withActor(id string, md recorder.Metadata) bridgeOption {
|
||||
// - defaultTracer (unless withTracer)
|
||||
// - defaultActorID (unless withActor)
|
||||
func newBridgeTestServer(
|
||||
t *testing.T,
|
||||
ctx context.Context,
|
||||
t *testing.T,
|
||||
upstreamURL string,
|
||||
opts ...bridgeOption,
|
||||
) *bridgeTestServer {
|
||||
@@ -209,7 +207,7 @@ func setupInjectedToolTest(
|
||||
) (*bridgeTestServer, *mockMCP, *http.Response) {
|
||||
t.Helper()
|
||||
|
||||
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
fix := fixtures.Parse(t, fixture)
|
||||
@@ -220,7 +218,7 @@ func setupInjectedToolTest(
|
||||
firstResp := newFixtureResponse(fix)
|
||||
toolResp := newFixtureToolResponse(fix)
|
||||
toolResp.OnRequest = toolRequestValidatorFn
|
||||
upstream := newMockUpstream(t, ctx, firstResp, toolResp)
|
||||
upstream := newMockUpstream(ctx, t, firstResp, toolResp)
|
||||
|
||||
mockMCP := setupMCPForTest(t, tracer)
|
||||
|
||||
@@ -230,19 +228,20 @@ func setupInjectedToolTest(
|
||||
withActor(defaultActorID, nil),
|
||||
}
|
||||
allOpts = append(allOpts, opts...)
|
||||
bridgeServer := newBridgeTestServer(t, ctx, upstream.URL, allOpts...)
|
||||
bridgeServer := newBridgeTestServer(ctx, t, upstream.URL, allOpts...)
|
||||
|
||||
// Add the stream param to the request.
|
||||
reqBody, err := sjson.SetBytes(fix.Request(), "stream", streaming)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp := bridgeServer.makeRequest(t, http.MethodPost, path, reqBody)
|
||||
resp, err := bridgeServer.makeRequest(t, http.MethodPost, path, reqBody)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
// Wait both requests (initial + tool call result)
|
||||
require.Eventually(t, func() bool {
|
||||
return upstream.Calls.Load() == 2
|
||||
}, time.Second*10, time.Millisecond*50)
|
||||
}, testutil.WaitMedium, testutil.IntervalFast)
|
||||
|
||||
return bridgeServer, mockMCP, resp
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package integrationtest
|
||||
package integrationtest //nolint:testpackage // tests unexported internals
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -20,6 +19,7 @@ import (
|
||||
|
||||
"github.com/coder/coder/v2/aibridge/config"
|
||||
"github.com/coder/coder/v2/aibridge/fixtures"
|
||||
"github.com/coder/coder/v2/aibridge/internal/testutil"
|
||||
"github.com/coder/coder/v2/aibridge/tracing"
|
||||
)
|
||||
|
||||
@@ -43,6 +43,8 @@ func setupTracer(t *testing.T) (*tracetest.SpanRecorder, oteltrace.Tracer) {
|
||||
}
|
||||
|
||||
func TestTraceAnthropic(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
expectNonStreaming := []expectTrace{
|
||||
{"Intercept", 1, codes.Unset},
|
||||
{"Intercept.CreateInterceptor", 1, codes.Unset},
|
||||
@@ -137,13 +139,15 @@ func TestTraceAnthropic(t *testing.T) {
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
sr, tracer := setupTracer(t)
|
||||
|
||||
fix := fixtures.Parse(t, tc.fixture)
|
||||
upstream := newMockUpstream(t, ctx, newFixtureResponse(fix))
|
||||
upstream := newMockUpstream(ctx, t, newFixtureResponse(fix))
|
||||
|
||||
opts := []bridgeOption{
|
||||
withTracer(tracer),
|
||||
@@ -151,11 +155,13 @@ func TestTraceAnthropic(t *testing.T) {
|
||||
if tc.bedrock {
|
||||
opts = append(opts, withProvider(providerBedrock))
|
||||
}
|
||||
bridgeServer := newBridgeTestServer(t, ctx, upstream.URL, opts...)
|
||||
bridgeServer := newBridgeTestServer(ctx, t, upstream.URL, opts...)
|
||||
|
||||
reqBody, err := sjson.SetBytes(fix.Request(), "stream", tc.streaming)
|
||||
require.NoError(t, err)
|
||||
resp := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody)
|
||||
resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
bridgeServer.Close()
|
||||
|
||||
@@ -189,6 +195,8 @@ func TestTraceAnthropic(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestTraceAnthropicErr(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
expectNonStream := []expectTrace{
|
||||
{"Intercept", 1, codes.Error},
|
||||
{"Intercept.CreateInterceptor", 1, codes.Unset},
|
||||
@@ -247,13 +255,15 @@ func TestTraceAnthropicErr(t *testing.T) {
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
sr, tracer := setupTracer(t)
|
||||
|
||||
fix := fixtures.Parse(t, tc.fixture)
|
||||
upstream := newMockUpstream(t, ctx, newFixtureResponse(fix))
|
||||
upstream := newMockUpstream(ctx, t, newFixtureResponse(fix))
|
||||
|
||||
opts := []bridgeOption{
|
||||
withTracer(tracer),
|
||||
@@ -261,11 +271,13 @@ func TestTraceAnthropicErr(t *testing.T) {
|
||||
if tc.bedrock {
|
||||
opts = append(opts, withProvider(providerBedrock))
|
||||
}
|
||||
bridgeServer := newBridgeTestServer(t, ctx, upstream.URL, opts...)
|
||||
bridgeServer := newBridgeTestServer(ctx, t, upstream.URL, opts...)
|
||||
|
||||
reqBody, err := sjson.SetBytes(fix.Request(), "stream", tc.streaming)
|
||||
require.NoError(t, err)
|
||||
resp := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody)
|
||||
resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, reqBody)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
if tc.streaming {
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
} else {
|
||||
@@ -385,10 +397,11 @@ func TestInjectedToolsTrace(t *testing.T) {
|
||||
validatorFn = openaiChatToolResultValidator(t)
|
||||
}
|
||||
|
||||
bridgeServer, mockMCP, _ := setupInjectedToolTest(
|
||||
bridgeServer, mockMCP, resp := setupInjectedToolTest(
|
||||
t, tc.fixture, tc.streaming, tracer,
|
||||
tc.path, validatorFn, tc.opts...,
|
||||
)
|
||||
defer resp.Body.Close()
|
||||
|
||||
require.Len(t, bridgeServer.Recorder.RecordedInterceptions(), 1)
|
||||
intcID := bridgeServer.Recorder.RecordedInterceptions()[0].ID
|
||||
@@ -417,6 +430,8 @@ func TestInjectedToolsTrace(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestTraceOpenAI(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
fixture []byte
|
||||
@@ -529,20 +544,24 @@ func TestTraceOpenAI(t *testing.T) {
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
sr, tracer := setupTracer(t)
|
||||
|
||||
fix := fixtures.Parse(t, tc.fixture)
|
||||
upstream := newMockUpstream(t, ctx, newFixtureResponse(fix))
|
||||
bridgeServer := newBridgeTestServer(t, ctx, upstream.URL,
|
||||
upstream := newMockUpstream(ctx, t, newFixtureResponse(fix))
|
||||
bridgeServer := newBridgeTestServer(ctx, t, upstream.URL,
|
||||
withTracer(tracer),
|
||||
)
|
||||
|
||||
reqBody, err := sjson.SetBytes(fix.Request(), "stream", tc.streaming)
|
||||
require.NoError(t, err)
|
||||
resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, reqBody)
|
||||
resp, err := bridgeServer.makeRequest(t, http.MethodPost, tc.path, reqBody)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
bridgeServer.Close()
|
||||
|
||||
@@ -569,6 +588,8 @@ func TestTraceOpenAI(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestTraceOpenAIErr(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
fixture []byte
|
||||
@@ -647,7 +668,7 @@ func TestTraceOpenAIErr(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "trace_openai_responses_streaming_http_error",
|
||||
fixture: fixtures.OaiResponsesStreamingHttpErr,
|
||||
fixture: fixtures.OaiResponsesStreamingHTTPErr,
|
||||
streaming: true,
|
||||
allowOverflow: true, // 429 error causes retries
|
||||
|
||||
@@ -664,7 +685,7 @@ func TestTraceOpenAIErr(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "trace_openai_responses_blocking_http_error",
|
||||
fixture: fixtures.OaiResponsesBlockingHttpErr,
|
||||
fixture: fixtures.OaiResponsesBlockingHTTPErr,
|
||||
streaming: false,
|
||||
|
||||
path: pathOpenAIResponses,
|
||||
@@ -682,22 +703,26 @@ func TestTraceOpenAIErr(t *testing.T) {
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
sr, tracer := setupTracer(t)
|
||||
|
||||
fix := fixtures.Parse(t, tc.fixture)
|
||||
|
||||
mockAPI := newMockUpstream(t, ctx, newFixtureResponse(fix))
|
||||
mockAPI := newMockUpstream(ctx, t, newFixtureResponse(fix))
|
||||
mockAPI.AllowOverflow = tc.allowOverflow
|
||||
bridgeServer := newBridgeTestServer(t, ctx, mockAPI.URL,
|
||||
bridgeServer := newBridgeTestServer(ctx, t, mockAPI.URL,
|
||||
withTracer(tracer),
|
||||
)
|
||||
|
||||
reqBody, err := sjson.SetBytes(fix.Request(), "stream", tc.streaming)
|
||||
require.NoError(t, err)
|
||||
resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, reqBody)
|
||||
resp, err := bridgeServer.makeRequest(t, http.MethodPost, tc.path, reqBody)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
require.Equal(t, tc.expectCode, resp.StatusCode)
|
||||
bridgeServer.Close()
|
||||
@@ -729,15 +754,17 @@ func TestTracePassthrough(t *testing.T) {
|
||||
|
||||
fix := fixtures.Parse(t, fixtures.OaiChatFallthrough)
|
||||
|
||||
upstream := newMockUpstream(t, t.Context(), newFixtureResponse(fix))
|
||||
upstream := newMockUpstream(t.Context(), t, newFixtureResponse(fix))
|
||||
|
||||
sr, tracer := setupTracer(t)
|
||||
|
||||
bridgeServer := newBridgeTestServer(t, t.Context(), upstream.URL,
|
||||
bridgeServer := newBridgeTestServer(t.Context(), t, upstream.URL,
|
||||
withTracer(tracer),
|
||||
)
|
||||
|
||||
resp := bridgeServer.makeRequest(t, http.MethodGet, "/openai/v1/models", nil)
|
||||
resp, err := bridgeServer.makeRequest(t, http.MethodGet, "/openai/v1/models", nil)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
bridgeServer.Close()
|
||||
|
||||
@@ -755,6 +782,8 @@ func TestTracePassthrough(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestNewServerProxyManagerTraces(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
sr, tracer := setupTracer(t)
|
||||
|
||||
serverName := "serverName"
|
||||
|
||||
@@ -26,14 +26,14 @@ type MockRecorder struct {
|
||||
interceptionsEnd map[string]*recorder.InterceptionRecordEnded
|
||||
}
|
||||
|
||||
func (m *MockRecorder) RecordInterception(ctx context.Context, req *recorder.InterceptionRecord) error {
|
||||
func (m *MockRecorder) RecordInterception(_ context.Context, req *recorder.InterceptionRecord) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.interceptions = append(m.interceptions, req)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockRecorder) RecordInterceptionEnded(ctx context.Context, req *recorder.InterceptionRecordEnded) error {
|
||||
func (m *MockRecorder) RecordInterceptionEnded(_ context.Context, req *recorder.InterceptionRecordEnded) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if m.interceptionsEnd == nil {
|
||||
@@ -46,28 +46,28 @@ func (m *MockRecorder) RecordInterceptionEnded(ctx context.Context, req *recorde
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockRecorder) RecordPromptUsage(ctx context.Context, req *recorder.PromptUsageRecord) error {
|
||||
func (m *MockRecorder) RecordPromptUsage(_ context.Context, req *recorder.PromptUsageRecord) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.userPrompts = append(m.userPrompts, req)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockRecorder) RecordTokenUsage(ctx context.Context, req *recorder.TokenUsageRecord) error {
|
||||
func (m *MockRecorder) RecordTokenUsage(_ context.Context, req *recorder.TokenUsageRecord) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.tokenUsages = append(m.tokenUsages, req)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockRecorder) RecordToolUsage(ctx context.Context, req *recorder.ToolUsageRecord) error {
|
||||
func (m *MockRecorder) RecordToolUsage(_ context.Context, req *recorder.ToolUsageRecord) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.toolUsages = append(m.toolUsages, req)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockRecorder) RecordModelThought(ctx context.Context, req *recorder.ModelThoughtRecord) error {
|
||||
func (m *MockRecorder) RecordModelThought(_ context.Context, req *recorder.ModelThoughtRecord) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.modelThoughts = append(m.modelThoughts, req)
|
||||
|
||||
@@ -11,26 +11,26 @@ import (
|
||||
)
|
||||
|
||||
type MockProvider struct {
|
||||
Name_ string
|
||||
NameStr string
|
||||
URL string
|
||||
Bridged []string
|
||||
Passthrough []string
|
||||
InterceptorFunc func(w http.ResponseWriter, r *http.Request, tracer trace.Tracer) (intercept.Interceptor, error)
|
||||
}
|
||||
|
||||
func (m *MockProvider) Type() string { return m.Name_ }
|
||||
func (m *MockProvider) Name() string { return m.Name_ }
|
||||
func (m *MockProvider) BaseURL() string { return m.URL }
|
||||
func (m *MockProvider) RoutePrefix() string { return fmt.Sprintf("/%s", m.Name_) }
|
||||
func (m *MockProvider) BridgedRoutes() []string { return m.Bridged }
|
||||
func (m *MockProvider) PassthroughRoutes() []string { return m.Passthrough }
|
||||
func (m *MockProvider) AuthHeader() string { return "Authorization" }
|
||||
func (m *MockProvider) InjectAuthHeader(h *http.Header) {}
|
||||
func (m *MockProvider) CircuitBreakerConfig() *config.CircuitBreaker { return nil }
|
||||
func (m *MockProvider) APIDumpDir() string { return "" }
|
||||
func (m *MockProvider) Type() string { return m.NameStr }
|
||||
func (m *MockProvider) Name() string { return m.NameStr }
|
||||
func (m *MockProvider) BaseURL() string { return m.URL }
|
||||
func (m *MockProvider) RoutePrefix() string { return fmt.Sprintf("/%s", m.NameStr) }
|
||||
func (m *MockProvider) BridgedRoutes() []string { return m.Bridged }
|
||||
func (m *MockProvider) PassthroughRoutes() []string { return m.Passthrough }
|
||||
func (*MockProvider) AuthHeader() string { return "Authorization" }
|
||||
func (*MockProvider) InjectAuthHeader(_ *http.Header) {}
|
||||
func (*MockProvider) CircuitBreakerConfig() *config.CircuitBreaker { return nil }
|
||||
func (*MockProvider) APIDumpDir() string { return "" }
|
||||
func (m *MockProvider) CreateInterceptor(w http.ResponseWriter, r *http.Request, tracer trace.Tracer) (intercept.Interceptor, error) {
|
||||
if m.InterceptorFunc != nil {
|
||||
return m.InterceptorFunc(w, r, tracer)
|
||||
}
|
||||
return nil, nil
|
||||
return nil, nil //nolint:nilnil // mock: no interceptor configured is not an error
|
||||
}
|
||||
|
||||
@@ -0,0 +1,21 @@
|
||||
package testutil
|
||||
|
||||
import "time"
|
||||
|
||||
// Shared test timeout and interval constants.
|
||||
// Using named constants avoids magic numbers and makes timeout policy
|
||||
// easy to adjust across the entire test suite.
|
||||
const (
|
||||
// WaitLong is the default timeout for test operations that may take a while
|
||||
// (e.g. integration tests with HTTP round-trips).
|
||||
WaitLong = 30 * time.Second
|
||||
|
||||
// WaitMedium is a timeout for moderately slow operations.
|
||||
WaitMedium = 10 * time.Second
|
||||
|
||||
// WaitShort is a timeout for operations expected to complete quickly.
|
||||
WaitShort = 5 * time.Second
|
||||
|
||||
// IntervalFast is a short polling interval for require.Eventually and similar.
|
||||
IntervalFast = 50 * time.Millisecond
|
||||
)
|
||||
@@ -9,6 +9,8 @@ import (
|
||||
)
|
||||
|
||||
func TestGetClientInfo(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
info := mcp.GetClientInfo()
|
||||
|
||||
assert.Equal(t, "coder/aibridge", info.Name)
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
mcplib "github.com/mark3labs/mcp-go/mcp"
|
||||
"github.com/mark3labs/mcp-go/server"
|
||||
@@ -19,6 +18,7 @@ import (
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/aibridge/internal/testutil"
|
||||
"github.com/coder/coder/v2/aibridge/mcp"
|
||||
)
|
||||
|
||||
@@ -299,7 +299,7 @@ func TestToolInjectionOrder(t *testing.T) {
|
||||
|
||||
// Setup.
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
// Given: a MCP mock server offering a set of tools.
|
||||
|
||||
@@ -87,7 +87,7 @@ func (p *StreamableHTTPServerProxy) Init(ctx context.Context) (outErr error) {
|
||||
return xerrors.Errorf("MCP version negotiation failed; requested %q, accepts %q, received %q", version, strings.Join(mcp.ValidProtocolVersions, ","), result.ProtocolVersion)
|
||||
}
|
||||
|
||||
p.logger.Debug(ctx, "MCP client initialized", slog.F("name", result.ServerInfo.Name), slog.F("server_version", result.ServerInfo.Version))
|
||||
p.logger.Debug(ctx, "mcp client initialized", slog.F("name", result.ServerInfo.Name), slog.F("server_version", result.ServerInfo.Version))
|
||||
|
||||
tools, err := p.fetchTools(ctx)
|
||||
if err != nil {
|
||||
@@ -161,7 +161,7 @@ func (p *StreamableHTTPServerProxy) fetchTools(ctx context.Context) (_ map[strin
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (p *StreamableHTTPServerProxy) Shutdown(ctx context.Context) error {
|
||||
func (p *StreamableHTTPServerProxy) Shutdown(_ context.Context) error {
|
||||
if p.client == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
+12
-11
@@ -59,15 +59,15 @@ func (t *Tool) Call(ctx context.Context, input any, tracer trace.Tracer) (_ *mcp
|
||||
ctx, span := tracer.Start(ctx, "Intercept.ProcessRequest.ToolCall", trace.WithAttributes(spanAttrs...))
|
||||
defer tracing.EndSpanErr(span, &outErr)
|
||||
|
||||
inputJson, err := json.Marshal(input)
|
||||
inputJSON, err := json.Marshal(input)
|
||||
if err != nil {
|
||||
t.Logger.Warn(ctx, "failed to marshal tool input, will be omitted from span attrs", slog.Error(err))
|
||||
} else {
|
||||
strJson := string(inputJson)
|
||||
if len(strJson) > maxSpanInputAttrLen {
|
||||
strJson = strJson[:maxSpanInputAttrLen]
|
||||
strJSON := string(inputJSON)
|
||||
if len(strJSON) > maxSpanInputAttrLen {
|
||||
strJSON = strJSON[:maxSpanInputAttrLen]
|
||||
}
|
||||
span.SetAttributes(attribute.String(tracing.MCPInput, strJson))
|
||||
span.SetAttributes(attribute.String(tracing.MCPInput, strJSON))
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
@@ -88,7 +88,7 @@ func (t *Tool) Call(ctx context.Context, input any, tracer trace.Tracer) (_ *mcp
|
||||
logFn(ctx, "injected tool invoked",
|
||||
slog.F("name", t.Name),
|
||||
slog.F("server", t.ServerName),
|
||||
slog.F("input", inputJson),
|
||||
slog.F("input", inputJSON),
|
||||
slog.F("duration_sec", time.Since(start).Seconds()),
|
||||
slog.Error(outErr),
|
||||
)
|
||||
@@ -106,12 +106,13 @@ func (t *Tool) Call(ctx context.Context, input any, tracer trace.Tracer) (_ *mcp
|
||||
// - https://community.openai.com/t/function-call-description-max-length/529902
|
||||
// - https://github.com/anthropics/claude-code/issues/2326
|
||||
func EncodeToolID(server, tool string) string {
|
||||
// strings.Builder writes to in-memory storage and never return errors.
|
||||
var sb strings.Builder
|
||||
sb.WriteString(injectedToolPrefix)
|
||||
sb.WriteString(injectedToolDelimiter)
|
||||
sb.WriteString(server)
|
||||
sb.WriteString(injectedToolDelimiter)
|
||||
sb.WriteString(tool)
|
||||
_, _ = sb.WriteString(injectedToolPrefix)
|
||||
_, _ = sb.WriteString(injectedToolDelimiter)
|
||||
_, _ = sb.WriteString(server)
|
||||
_, _ = sb.WriteString(injectedToolDelimiter)
|
||||
_, _ = sb.WriteString(tool)
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
package mcpmock
|
||||
|
||||
//go:generate mockgen -destination ./mcpmock.go -package mcpmock github.com/coder/coder/v2/aibridge/mcp ServerProxier
|
||||
//go:generate mockgen -destination ./mcpmock.go -package mcpmock github.com/coder/aibridge/mcp ServerProxier
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/coder/coder/v2/aibridge/mcp (interfaces: ServerProxier)
|
||||
// Source: github.com/coder/aibridge/mcp (interfaces: ServerProxier)
|
||||
//
|
||||
// Generated by this command:
|
||||
//
|
||||
// mockgen -destination ./mcpmock.go -package mcpmock github.com/coder/coder/v2/aibridge/mcp ServerProxier
|
||||
// mockgen -destination ./mcpmock.go -package mcpmock github.com/coder/aibridge/mcp ServerProxier
|
||||
//
|
||||
|
||||
// Package mcpmock is a generated GoMock package.
|
||||
|
||||
@@ -5,7 +5,7 @@ import (
|
||||
"github.com/prometheus/client_golang/prometheus/promauto"
|
||||
)
|
||||
|
||||
var baseLabels []string = []string{"provider", "model"}
|
||||
var baseLabels = []string{"provider", "model"}
|
||||
|
||||
const (
|
||||
InterceptionCountStatusFailed = "failed"
|
||||
|
||||
@@ -21,10 +21,10 @@ import (
|
||||
|
||||
// newPassthroughRouter returns a simple reverse-proxy implementation which will be used when a route is not handled specifically
|
||||
// by a [intercept.Provider].
|
||||
func newPassthroughRouter(provider provider.Provider, logger slog.Logger, m *metrics.Metrics, tracer trace.Tracer) http.HandlerFunc {
|
||||
func newPassthroughRouter(prov provider.Provider, logger slog.Logger, m *metrics.Metrics, tracer trace.Tracer) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if m != nil {
|
||||
m.PassthroughCount.WithLabelValues(provider.Name(), r.URL.Path, r.Method).Add(1)
|
||||
m.PassthroughCount.WithLabelValues(prov.Name(), r.URL.Path, r.Method).Add(1)
|
||||
}
|
||||
|
||||
ctx, span := tracer.Start(r.Context(), "Passthrough", trace.WithAttributes(
|
||||
@@ -33,7 +33,7 @@ func newPassthroughRouter(provider provider.Provider, logger slog.Logger, m *met
|
||||
))
|
||||
defer span.End()
|
||||
|
||||
upURL, err := url.Parse(provider.BaseURL())
|
||||
upURL, err := url.Parse(prov.BaseURL())
|
||||
if err != nil {
|
||||
logger.Warn(ctx, "failed to parse provider base URL", slog.Error(err))
|
||||
http.Error(w, "request error", http.StatusBadGateway)
|
||||
@@ -44,7 +44,7 @@ func newPassthroughRouter(provider provider.Provider, logger slog.Logger, m *met
|
||||
// Append the request path to the upstream base path.
|
||||
reqPath, err := url.JoinPath(upURL.Path, r.URL.Path)
|
||||
if err != nil {
|
||||
logger.Warn(ctx, "failed to join upstream path", slog.Error(err), slog.F("upstreamPath", upURL.Path), slog.F("requestPath", r.URL.Path))
|
||||
logger.Warn(ctx, "failed to join upstream path", slog.Error(err), slog.F("upstream_path", upURL.Path), slog.F("request_path", r.URL.Path))
|
||||
http.Error(w, "failed to join upstream path", http.StatusInternalServerError)
|
||||
span.SetStatus(codes.Error, "failed to join upstream path: "+err.Error())
|
||||
return
|
||||
@@ -96,7 +96,7 @@ func newPassthroughRouter(provider provider.Provider, logger slog.Logger, m *met
|
||||
}
|
||||
|
||||
// Inject provider auth.
|
||||
provider.InjectAuthHeader(&req.Header)
|
||||
prov.InjectAuthHeader(&req.Header)
|
||||
},
|
||||
ErrorHandler: func(rw http.ResponseWriter, req *http.Request, e error) {
|
||||
logger.Warn(req.Context(), "reverse proxy error", slog.Error(e), slog.F("path", req.URL.Path))
|
||||
@@ -113,7 +113,7 @@ func newPassthroughRouter(provider provider.Provider, logger slog.Logger, m *met
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
}
|
||||
proxy.Transport = apidump.NewPassthroughMiddleware(t, provider.APIDumpDir(), provider.Name(), logger, quartz.NewReal())
|
||||
proxy.Transport = apidump.NewPassthroughMiddleware(t, prov.APIDumpDir(), prov.Name(), logger, quartz.NewReal())
|
||||
|
||||
proxy.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package aibridge
|
||||
package aibridge //nolint:testpackage // tests unexported newPassthroughRouter
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
@@ -73,7 +73,7 @@ func NewAnthropic(cfg config.Anthropic, bedrockCfg *config.AWSBedrock) *Anthropi
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Anthropic) Type() string {
|
||||
func (*Anthropic) Type() string {
|
||||
return config.ProviderAnthropic
|
||||
}
|
||||
|
||||
@@ -85,11 +85,11 @@ func (p *Anthropic) RoutePrefix() string {
|
||||
return fmt.Sprintf("/%s", p.Name())
|
||||
}
|
||||
|
||||
func (p *Anthropic) BridgedRoutes() []string {
|
||||
func (*Anthropic) BridgedRoutes() []string {
|
||||
return []string{routeMessages}
|
||||
}
|
||||
|
||||
func (p *Anthropic) PassthroughRoutes() []string {
|
||||
func (*Anthropic) PassthroughRoutes() []string {
|
||||
return []string{
|
||||
"/v1/models",
|
||||
"/v1/models/", // See https://pkg.go.dev/net/http#hdr-Trailing_slash_redirection-ServeMux.
|
||||
@@ -98,77 +98,76 @@ func (p *Anthropic) PassthroughRoutes() []string {
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Anthropic) CreateInterceptor(w http.ResponseWriter, r *http.Request, tracer trace.Tracer) (_ intercept.Interceptor, outErr error) {
|
||||
func (p *Anthropic) CreateInterceptor(_ http.ResponseWriter, r *http.Request, tracer trace.Tracer) (_ intercept.Interceptor, outErr error) {
|
||||
id := uuid.New()
|
||||
_, span := tracer.Start(r.Context(), "Intercept.CreateInterceptor")
|
||||
defer tracing.EndSpanErr(span, &outErr)
|
||||
|
||||
path := strings.TrimPrefix(r.URL.Path, p.RoutePrefix())
|
||||
switch path {
|
||||
case routeMessages:
|
||||
payload, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("read body: %w", err)
|
||||
}
|
||||
|
||||
reqPayload, err := messages.NewMessagesRequestPayload(payload)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("unmarshal request body: %w", err)
|
||||
}
|
||||
|
||||
cfg := p.cfg
|
||||
cfg.ExtraHeaders = extractAnthropicHeaders(r)
|
||||
|
||||
// At this point the request contains only LLM provider headers.
|
||||
// Any Coder-specific authentication has already been stripped.
|
||||
//
|
||||
// In centralized mode neither Authorization nor X-Api-Key is
|
||||
// present, so cfg keeps the centralized key unchanged.
|
||||
//
|
||||
// In BYOK mode the user's LLM credentials survive intact.
|
||||
// If X-Api-Key is present the user has a personal API key;
|
||||
// overwrite the centralized key with it. If Authorization is
|
||||
// present the user authenticated directly with provider;
|
||||
// set BYOKBearerToken and clear the centralized key.
|
||||
// When both are present, X-Api-Key takes priority to match
|
||||
// claude-code behavior.
|
||||
credKind := intercept.CredentialKindCentralized
|
||||
credSecret := cfg.Key
|
||||
authHeaderName := p.AuthHeader()
|
||||
if apiKey := r.Header.Get("X-Api-Key"); apiKey != "" {
|
||||
cfg.Key = apiKey
|
||||
authHeaderName = "X-Api-Key"
|
||||
credKind = intercept.CredentialKindBYOK
|
||||
credSecret = apiKey
|
||||
} else if token := utils.ExtractBearerToken(r.Header.Get("Authorization")); token != "" {
|
||||
cfg.BYOKBearerToken = token
|
||||
cfg.Key = ""
|
||||
authHeaderName = "Authorization"
|
||||
credKind = intercept.CredentialKindBYOK
|
||||
credSecret = token
|
||||
}
|
||||
|
||||
cred := intercept.NewCredentialInfo(credKind, credSecret)
|
||||
|
||||
var interceptor intercept.Interceptor
|
||||
if reqPayload.Stream() {
|
||||
interceptor = messages.NewStreamingInterceptor(id, reqPayload, p.Name(), cfg, p.bedrockCfg, r.Header, authHeaderName, tracer, cred)
|
||||
} else {
|
||||
interceptor = messages.NewBlockingInterceptor(id, reqPayload, p.Name(), cfg, p.bedrockCfg, r.Header, authHeaderName, tracer, cred)
|
||||
}
|
||||
span.SetAttributes(interceptor.TraceAttributes(r)...)
|
||||
return interceptor, nil
|
||||
if path != routeMessages {
|
||||
span.SetStatus(codes.Error, "unknown route: "+r.URL.Path)
|
||||
return nil, ErrUnknownRoute
|
||||
}
|
||||
|
||||
span.SetStatus(codes.Error, "unknown route: "+r.URL.Path)
|
||||
return nil, UnknownRoute
|
||||
payload, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("read body: %w", err)
|
||||
}
|
||||
|
||||
reqPayload, err := messages.NewRequestPayload(payload)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("unmarshal request body: %w", err)
|
||||
}
|
||||
|
||||
cfg := p.cfg
|
||||
cfg.ExtraHeaders = extractAnthropicHeaders(r)
|
||||
|
||||
// At this point the request contains only LLM provider headers.
|
||||
// Any Coder-specific authentication has already been stripped.
|
||||
//
|
||||
// In centralized mode neither Authorization nor X-Api-Key is
|
||||
// present, so cfg keeps the centralized key unchanged.
|
||||
//
|
||||
// In BYOK mode the user's LLM credentials survive intact.
|
||||
// If X-Api-Key is present the user has a personal API key;
|
||||
// overwrite the centralized key with it. If Authorization is
|
||||
// present the user authenticated directly with provider;
|
||||
// set BYOKBearerToken and clear the centralized key.
|
||||
// When both are present, X-Api-Key takes priority to match
|
||||
// claude-code behavior.
|
||||
credKind := intercept.CredentialKindCentralized
|
||||
credSecret := cfg.Key
|
||||
authHeaderName := p.AuthHeader()
|
||||
if apiKey := r.Header.Get("X-Api-Key"); apiKey != "" {
|
||||
cfg.Key = apiKey
|
||||
authHeaderName = "X-Api-Key"
|
||||
credKind = intercept.CredentialKindBYOK
|
||||
credSecret = apiKey
|
||||
} else if token := utils.ExtractBearerToken(r.Header.Get("Authorization")); token != "" {
|
||||
cfg.BYOKBearerToken = token
|
||||
cfg.Key = ""
|
||||
authHeaderName = "Authorization"
|
||||
credKind = intercept.CredentialKindBYOK
|
||||
credSecret = token
|
||||
}
|
||||
|
||||
cred := intercept.NewCredentialInfo(credKind, credSecret)
|
||||
|
||||
var interceptor intercept.Interceptor
|
||||
if reqPayload.Stream() {
|
||||
interceptor = messages.NewStreamingInterceptor(id, reqPayload, p.Name(), cfg, p.bedrockCfg, r.Header, authHeaderName, tracer, cred)
|
||||
} else {
|
||||
interceptor = messages.NewBlockingInterceptor(id, reqPayload, p.Name(), cfg, p.bedrockCfg, r.Header, authHeaderName, tracer, cred)
|
||||
}
|
||||
span.SetAttributes(interceptor.TraceAttributes(r)...)
|
||||
return interceptor, nil
|
||||
}
|
||||
|
||||
func (p *Anthropic) BaseURL() string {
|
||||
return p.cfg.BaseURL
|
||||
}
|
||||
|
||||
func (p *Anthropic) AuthHeader() string {
|
||||
func (*Anthropic) AuthHeader() string {
|
||||
return "X-Api-Key"
|
||||
}
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package provider
|
||||
package provider //nolint:testpackage // tests unexported internals
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
@@ -146,7 +146,7 @@ func TestAnthropic_CreateInterceptor(t *testing.T) {
|
||||
assert.Empty(t, receivedHeaders.Get("Authorization"), "client Authorization header must not reach upstream")
|
||||
})
|
||||
|
||||
t.Run("UnknownRoute", func(t *testing.T) {
|
||||
t.Run("ErrUnknownRoute", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
body := `{"model": "claude-opus-4-5", "max_tokens": 1024, "messages": [{"role": "user", "content": "hello"}]}`
|
||||
@@ -155,7 +155,7 @@ func TestAnthropic_CreateInterceptor(t *testing.T) {
|
||||
|
||||
interceptor, err := provider.CreateInterceptor(w, req, testTracer)
|
||||
|
||||
require.ErrorIs(t, err, UnknownRoute)
|
||||
require.ErrorIs(t, err, ErrUnknownRoute)
|
||||
require.Nil(t, interceptor)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -72,7 +72,7 @@ func NewCopilot(cfg config.Copilot) *Copilot {
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Copilot) Type() string {
|
||||
func (*Copilot) Type() string {
|
||||
return config.ProviderCopilot
|
||||
}
|
||||
|
||||
@@ -88,14 +88,14 @@ func (p *Copilot) RoutePrefix() string {
|
||||
return fmt.Sprintf("/%s", p.Name())
|
||||
}
|
||||
|
||||
func (p *Copilot) BridgedRoutes() []string {
|
||||
func (*Copilot) BridgedRoutes() []string {
|
||||
return []string{
|
||||
routeCopilotChatCompletions,
|
||||
routeCopilotResponses,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Copilot) PassthroughRoutes() []string {
|
||||
func (*Copilot) PassthroughRoutes() []string {
|
||||
return []string{
|
||||
"/models",
|
||||
"/models/",
|
||||
@@ -105,7 +105,7 @@ func (p *Copilot) PassthroughRoutes() []string {
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Copilot) AuthHeader() string {
|
||||
func (*Copilot) AuthHeader() string {
|
||||
return "Authorization"
|
||||
}
|
||||
|
||||
@@ -113,7 +113,7 @@ func (p *Copilot) AuthHeader() string {
|
||||
// Copilot uses per-user tokens passed in the original Authorization header,
|
||||
// rather than a global key configured at the provider level.
|
||||
// The original Authorization header flows through untouched from the client.
|
||||
func (p *Copilot) InjectAuthHeader(_ *http.Header) {}
|
||||
func (*Copilot) InjectAuthHeader(_ *http.Header) {}
|
||||
|
||||
func (p *Copilot) CircuitBreakerConfig() *config.CircuitBreaker {
|
||||
return p.circuitBreaker
|
||||
@@ -170,7 +170,7 @@ func (p *Copilot) CreateInterceptor(_ http.ResponseWriter, r *http.Request, trac
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("read body: %w", err)
|
||||
}
|
||||
reqPayload, err := responses.NewResponsesRequestPayload(payload)
|
||||
reqPayload, err := responses.NewRequestPayload(payload)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("unmarshal request body: %w", err)
|
||||
}
|
||||
@@ -183,7 +183,7 @@ func (p *Copilot) CreateInterceptor(_ http.ResponseWriter, r *http.Request, trac
|
||||
|
||||
default:
|
||||
span.SetStatus(codes.Error, "unknown route: "+r.URL.Path)
|
||||
return nil, UnknownRoute
|
||||
return nil, ErrUnknownRoute
|
||||
}
|
||||
|
||||
span.SetAttributes(interceptor.TraceAttributes(r)...)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package provider
|
||||
package provider //nolint:testpackage // tests unexported internals
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
@@ -300,7 +300,7 @@ func TestCopilot_CreateInterceptor(t *testing.T) {
|
||||
assert.Empty(t, receivedHeaders.Get("X-Api-Key"), "X-Api-Key must not be set upstream")
|
||||
})
|
||||
|
||||
t.Run("UnknownRoute", func(t *testing.T) {
|
||||
t.Run("ErrUnknownRoute", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
body := `{"model": "gpt-4.1", "messages": [{"role": "user", "content": "hello"}]}`
|
||||
@@ -310,7 +310,7 @@ func TestCopilot_CreateInterceptor(t *testing.T) {
|
||||
|
||||
interceptor, err := provider.CreateInterceptor(w, req, testTracer)
|
||||
|
||||
require.ErrorIs(t, err, UnknownRoute)
|
||||
require.ErrorIs(t, err, ErrUnknownRoute)
|
||||
require.Nil(t, interceptor)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -61,7 +61,7 @@ func NewOpenAI(cfg config.OpenAI) *OpenAI {
|
||||
}
|
||||
}
|
||||
|
||||
func (p *OpenAI) Type() string {
|
||||
func (*OpenAI) Type() string {
|
||||
return config.ProviderOpenAI
|
||||
}
|
||||
|
||||
@@ -75,7 +75,7 @@ func (p *OpenAI) RoutePrefix() string {
|
||||
return fmt.Sprintf("/%s/v1", p.Name())
|
||||
}
|
||||
|
||||
func (p *OpenAI) BridgedRoutes() []string {
|
||||
func (*OpenAI) BridgedRoutes() []string {
|
||||
return []string{
|
||||
routeChatCompletions,
|
||||
routeResponses,
|
||||
@@ -86,7 +86,7 @@ func (p *OpenAI) BridgedRoutes() []string {
|
||||
// but must be passed through to the upstream.
|
||||
// The /v1/completions legacy API is deprecated and will not be passed through.
|
||||
// See https://platform.openai.com/docs/api-reference/completions.
|
||||
func (p *OpenAI) PassthroughRoutes() []string {
|
||||
func (*OpenAI) PassthroughRoutes() []string {
|
||||
return []string{
|
||||
// See https://pkg.go.dev/net/http#hdr-Trailing_slash_redirection-ServeMux.
|
||||
// but without non trailing slash route requests to `/v1/conversations` are going to catch all
|
||||
@@ -98,7 +98,7 @@ func (p *OpenAI) PassthroughRoutes() []string {
|
||||
}
|
||||
}
|
||||
|
||||
func (p *OpenAI) CreateInterceptor(w http.ResponseWriter, r *http.Request, tracer trace.Tracer) (_ intercept.Interceptor, outErr error) {
|
||||
func (p *OpenAI) CreateInterceptor(_ http.ResponseWriter, r *http.Request, tracer trace.Tracer) (_ intercept.Interceptor, outErr error) {
|
||||
id := uuid.New()
|
||||
|
||||
_, span := tracer.Start(r.Context(), "Intercept.CreateInterceptor")
|
||||
@@ -141,7 +141,7 @@ func (p *OpenAI) CreateInterceptor(w http.ResponseWriter, r *http.Request, trace
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("read body: %w", err)
|
||||
}
|
||||
reqPayload, err := responses.NewResponsesRequestPayload(payload)
|
||||
reqPayload, err := responses.NewRequestPayload(payload)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("unmarshal request body: %w", err)
|
||||
}
|
||||
@@ -153,7 +153,7 @@ func (p *OpenAI) CreateInterceptor(w http.ResponseWriter, r *http.Request, trace
|
||||
|
||||
default:
|
||||
span.SetStatus(codes.Error, "unknown route: "+r.URL.Path)
|
||||
return nil, UnknownRoute
|
||||
return nil, ErrUnknownRoute
|
||||
}
|
||||
span.SetAttributes(interceptor.TraceAttributes(r)...)
|
||||
return interceptor, nil
|
||||
@@ -163,7 +163,7 @@ func (p *OpenAI) BaseURL() string {
|
||||
return p.cfg.BaseURL
|
||||
}
|
||||
|
||||
func (p *OpenAI) AuthHeader() string {
|
||||
func (*OpenAI) AuthHeader() string {
|
||||
return "Authorization"
|
||||
}
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package provider
|
||||
package provider //nolint:testpackage // tests unexported internals
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
"github.com/coder/coder/v2/aibridge/intercept"
|
||||
)
|
||||
|
||||
var UnknownRoute = xerrors.New("unknown route")
|
||||
var ErrUnknownRoute = xerrors.New("unknown route")
|
||||
|
||||
// Provider defines routes (bridged and passed through) for given provider.
|
||||
// Bridged routes are processed by dedicated interceptors.
|
||||
|
||||
@@ -14,19 +14,19 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
_ Recorder = &RecorderWrapper{}
|
||||
_ Recorder = &WrappedRecorder{}
|
||||
_ Recorder = &AsyncRecorder{}
|
||||
)
|
||||
|
||||
// RecorderWrapper is a convenience struct which implements RecorderClient and resolves a client before calling each method.
|
||||
// WrappedRecorder is a convenience struct which implements RecorderClient and resolves a client before calling each method.
|
||||
// It also sets the start/creation time of each record.
|
||||
type RecorderWrapper struct {
|
||||
type WrappedRecorder struct {
|
||||
logger slog.Logger
|
||||
tracer trace.Tracer
|
||||
clientFn func() (Recorder, error)
|
||||
}
|
||||
|
||||
func (r *RecorderWrapper) RecordInterception(ctx context.Context, req *InterceptionRecord) (outErr error) {
|
||||
func (r *WrappedRecorder) RecordInterception(ctx context.Context, req *InterceptionRecord) (outErr error) {
|
||||
ctx, span := r.tracer.Start(ctx, "Intercept.RecordInterception", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...))
|
||||
defer tracing.EndSpanErr(span, &outErr)
|
||||
|
||||
@@ -44,7 +44,7 @@ func (r *RecorderWrapper) RecordInterception(ctx context.Context, req *Intercept
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *RecorderWrapper) RecordInterceptionEnded(ctx context.Context, req *InterceptionRecordEnded) (outErr error) {
|
||||
func (r *WrappedRecorder) RecordInterceptionEnded(ctx context.Context, req *InterceptionRecordEnded) (outErr error) {
|
||||
ctx, span := r.tracer.Start(ctx, "Intercept.RecordInterceptionEnded", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...))
|
||||
defer tracing.EndSpanErr(span, &outErr)
|
||||
|
||||
@@ -62,7 +62,7 @@ func (r *RecorderWrapper) RecordInterceptionEnded(ctx context.Context, req *Inte
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *RecorderWrapper) RecordPromptUsage(ctx context.Context, req *PromptUsageRecord) (outErr error) {
|
||||
func (r *WrappedRecorder) RecordPromptUsage(ctx context.Context, req *PromptUsageRecord) (outErr error) {
|
||||
ctx, span := r.tracer.Start(ctx, "Intercept.RecordPromptUsage", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...))
|
||||
defer tracing.EndSpanErr(span, &outErr)
|
||||
|
||||
@@ -80,7 +80,7 @@ func (r *RecorderWrapper) RecordPromptUsage(ctx context.Context, req *PromptUsag
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *RecorderWrapper) RecordTokenUsage(ctx context.Context, req *TokenUsageRecord) (outErr error) {
|
||||
func (r *WrappedRecorder) RecordTokenUsage(ctx context.Context, req *TokenUsageRecord) (outErr error) {
|
||||
ctx, span := r.tracer.Start(ctx, "Intercept.RecordTokenUsage", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...))
|
||||
defer tracing.EndSpanErr(span, &outErr)
|
||||
|
||||
@@ -98,7 +98,7 @@ func (r *RecorderWrapper) RecordTokenUsage(ctx context.Context, req *TokenUsageR
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *RecorderWrapper) RecordToolUsage(ctx context.Context, req *ToolUsageRecord) (outErr error) {
|
||||
func (r *WrappedRecorder) RecordToolUsage(ctx context.Context, req *ToolUsageRecord) (outErr error) {
|
||||
ctx, span := r.tracer.Start(ctx, "Intercept.RecordToolUsage", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...))
|
||||
defer tracing.EndSpanErr(span, &outErr)
|
||||
|
||||
@@ -116,7 +116,7 @@ func (r *RecorderWrapper) RecordToolUsage(ctx context.Context, req *ToolUsageRec
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *RecorderWrapper) RecordModelThought(ctx context.Context, req *ModelThoughtRecord) (outErr error) {
|
||||
func (r *WrappedRecorder) RecordModelThought(ctx context.Context, req *ModelThoughtRecord) (outErr error) {
|
||||
ctx, span := r.tracer.Start(ctx, "Intercept.RecordModelThought", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...))
|
||||
defer tracing.EndSpanErr(span, &outErr)
|
||||
|
||||
@@ -134,8 +134,8 @@ func (r *RecorderWrapper) RecordModelThought(ctx context.Context, req *ModelThou
|
||||
return err
|
||||
}
|
||||
|
||||
func NewRecorder(logger slog.Logger, tracer trace.Tracer, clientFn func() (Recorder, error)) *RecorderWrapper {
|
||||
return &RecorderWrapper{
|
||||
func NewWrappedRecorder(logger slog.Logger, tracer trace.Tracer, clientFn func() (Recorder, error)) *WrappedRecorder {
|
||||
return &WrappedRecorder{
|
||||
logger: logger,
|
||||
tracer: tracer,
|
||||
clientFn: clientFn,
|
||||
@@ -185,7 +185,7 @@ func (a *AsyncRecorder) WithClient(client string) {
|
||||
|
||||
// RecordInterception must NOT be called asynchronously.
|
||||
// If an interception cannot be recorded, the whole request should fail.
|
||||
func (a *AsyncRecorder) RecordInterception(ctx context.Context, req *InterceptionRecord) error {
|
||||
func (*AsyncRecorder) RecordInterception(context.Context, *InterceptionRecord) error {
|
||||
panic("RecordInterception must not be called asynchronously")
|
||||
}
|
||||
|
||||
|
||||
+2
-2
@@ -14,10 +14,10 @@ import (
|
||||
|
||||
var claudeCodePattern = regexp.MustCompile(`_session_(.+)$`) // Legacy format: save compilation on each call.
|
||||
|
||||
// guessSessionID attempts to retrieve a session ID which may have been sent by
|
||||
// GuessSessionID attempts to retrieve a session ID which may have been sent by
|
||||
// the client. We only attempt to retrieve sessions using methods recognized for
|
||||
// the given client.
|
||||
func guessSessionID(client Client, r *http.Request) *string {
|
||||
func GuessSessionID(client Client, r *http.Request) *string {
|
||||
switch client {
|
||||
case ClientClaudeCode:
|
||||
// Prefer the dedicated header (added in Claude Code v2.1.86+).
|
||||
|
||||
+39
-38
@@ -1,4 +1,4 @@
|
||||
package aibridge
|
||||
package aibridge_test
|
||||
|
||||
import (
|
||||
"io"
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/aibridge"
|
||||
"github.com/coder/coder/v2/aibridge/utils"
|
||||
)
|
||||
|
||||
@@ -16,7 +17,7 @@ func TestGuessSessionID(t *testing.T) {
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
client Client
|
||||
client aibridge.Client
|
||||
body string
|
||||
headers map[string]string
|
||||
sessionID *string
|
||||
@@ -24,177 +25,177 @@ func TestGuessSessionID(t *testing.T) {
|
||||
// Claude Code.
|
||||
{
|
||||
name: "claude_code_header_takes_precedence",
|
||||
client: ClientClaudeCode,
|
||||
client: aibridge.ClientClaudeCode,
|
||||
headers: map[string]string{"X-Claude-Code-Session-Id": "header-session-id"},
|
||||
body: `{"metadata":{"user_id":"user_abc123_account_456_session_body-session-id"}}`,
|
||||
sessionID: utils.PtrTo("header-session-id"),
|
||||
},
|
||||
{
|
||||
name: "claude_code_header_only",
|
||||
client: ClientClaudeCode,
|
||||
client: aibridge.ClientClaudeCode,
|
||||
headers: map[string]string{"X-Claude-Code-Session-Id": "aabb-ccdd"},
|
||||
body: `{"model":"claude-3"}`,
|
||||
sessionID: utils.PtrTo("aabb-ccdd"),
|
||||
},
|
||||
{
|
||||
name: "claude_code_empty_header_falls_back_to_body",
|
||||
client: ClientClaudeCode,
|
||||
client: aibridge.ClientClaudeCode,
|
||||
headers: map[string]string{"X-Claude-Code-Session-Id": ""},
|
||||
body: `{"metadata":{"user_id":"user_abc123_account_456_session_f47ac10b-58cc-4372-a567-0e02b2c3d479"}}`,
|
||||
sessionID: utils.PtrTo("f47ac10b-58cc-4372-a567-0e02b2c3d479"),
|
||||
},
|
||||
{
|
||||
name: "claude_code_whitespace_header_falls_back_to_body",
|
||||
client: ClientClaudeCode,
|
||||
client: aibridge.ClientClaudeCode,
|
||||
headers: map[string]string{"X-Claude-Code-Session-Id": " "},
|
||||
body: `{"metadata":{"user_id":"user_abc123_account_456_session_f47ac10b-58cc-4372-a567-0e02b2c3d479"}}`,
|
||||
sessionID: utils.PtrTo("f47ac10b-58cc-4372-a567-0e02b2c3d479"),
|
||||
},
|
||||
{
|
||||
name: "claude_code_with_valid_session",
|
||||
client: ClientClaudeCode,
|
||||
client: aibridge.ClientClaudeCode,
|
||||
body: `{"metadata":{"user_id":"user_abc123_account_456_session_f47ac10b-58cc-4372-a567-0e02b2c3d479"}}`,
|
||||
sessionID: utils.PtrTo("f47ac10b-58cc-4372-a567-0e02b2c3d479"),
|
||||
},
|
||||
{
|
||||
name: "claude_code_with_valid_session_new_format",
|
||||
client: ClientClaudeCode,
|
||||
client: aibridge.ClientClaudeCode,
|
||||
body: `{"metadata":{"user_id":"{\"device_id\":\"45aa15c8c244ea2582f8144dde91a50ec3815851f6f648abef4ee15b173cc927\",\"account_uuid\":\"\",\"session_id\":\"54c1eb09-bc4c-4d2f-98eb-6d2ab2d5e2fe\"}"}}`,
|
||||
sessionID: utils.PtrTo("54c1eb09-bc4c-4d2f-98eb-6d2ab2d5e2fe"),
|
||||
},
|
||||
{
|
||||
name: "claude_code_new_format_empty_session_id",
|
||||
client: ClientClaudeCode,
|
||||
client: aibridge.ClientClaudeCode,
|
||||
body: `{"metadata":{"user_id":"{\"device_id\":\"abc\",\"account_uuid\":\"\",\"session_id\":\"\"}"}}`,
|
||||
},
|
||||
{
|
||||
name: "claude_code_new_format_no_session_id_field",
|
||||
client: ClientClaudeCode,
|
||||
client: aibridge.ClientClaudeCode,
|
||||
body: `{"metadata":{"user_id":"{\"device_id\":\"abc\",\"account_uuid\":\"\"}"}}`,
|
||||
},
|
||||
{
|
||||
name: "claude_code_missing_metadata",
|
||||
client: ClientClaudeCode,
|
||||
client: aibridge.ClientClaudeCode,
|
||||
body: `{"model":"claude-3"}`,
|
||||
},
|
||||
{
|
||||
name: "claude_code_missing_user_id",
|
||||
client: ClientClaudeCode,
|
||||
client: aibridge.ClientClaudeCode,
|
||||
body: `{"metadata":{}}`,
|
||||
},
|
||||
{
|
||||
name: "claude_code_user_id_without_session",
|
||||
client: ClientClaudeCode,
|
||||
client: aibridge.ClientClaudeCode,
|
||||
body: `{"metadata":{"user_id":"user_abc123_account_456"}}`,
|
||||
},
|
||||
{
|
||||
name: "claude_code_empty_body",
|
||||
client: ClientClaudeCode,
|
||||
client: aibridge.ClientClaudeCode,
|
||||
body: ``,
|
||||
},
|
||||
{
|
||||
name: "claude_code_invalid_json",
|
||||
client: ClientClaudeCode,
|
||||
client: aibridge.ClientClaudeCode,
|
||||
body: `not json at all`,
|
||||
},
|
||||
// Codex.
|
||||
{
|
||||
name: "codex_with_session_header",
|
||||
client: ClientCodex,
|
||||
client: aibridge.ClientCodex,
|
||||
headers: map[string]string{"session_id": "codex-session-123"},
|
||||
sessionID: utils.PtrTo("codex-session-123"),
|
||||
},
|
||||
{
|
||||
name: "codex_with_whitespace_in_header",
|
||||
client: ClientCodex,
|
||||
client: aibridge.ClientCodex,
|
||||
headers: map[string]string{"session_id": " codex-session-123 "},
|
||||
sessionID: utils.PtrTo("codex-session-123"),
|
||||
},
|
||||
{
|
||||
name: "codex_without_session_header",
|
||||
client: ClientCodex,
|
||||
client: aibridge.ClientCodex,
|
||||
},
|
||||
// Other clients shouldn't use others' logic.
|
||||
{
|
||||
name: "unknown_client_returns_empty",
|
||||
client: ClientUnknown,
|
||||
client: aibridge.ClientUnknown,
|
||||
body: `{"metadata":{"user_id":"user_abc_account_456_session_some-id"}}`,
|
||||
},
|
||||
{
|
||||
name: "zed_returns_empty",
|
||||
client: ClientZed,
|
||||
client: aibridge.ClientZed,
|
||||
headers: map[string]string{"session_id": "zed-session"},
|
||||
body: `{"metadata":{"user_id":"user_abc_account_456_session_some-id"}}`,
|
||||
},
|
||||
// Mux.
|
||||
{
|
||||
name: "mux_with_workspace_header",
|
||||
client: ClientMux,
|
||||
client: aibridge.ClientMux,
|
||||
headers: map[string]string{"X-Mux-Workspace-Id": "ws-abc-123"},
|
||||
sessionID: utils.PtrTo("ws-abc-123"),
|
||||
},
|
||||
{
|
||||
name: "mux_without_workspace_header",
|
||||
client: ClientMux,
|
||||
client: aibridge.ClientMux,
|
||||
},
|
||||
// Copilot VS Code.
|
||||
{
|
||||
name: "copilot_vsc_with_interaction_id",
|
||||
client: ClientCopilotVSC,
|
||||
client: aibridge.ClientCopilotVSC,
|
||||
headers: map[string]string{"x-interaction-id": "interaction-xyz"},
|
||||
sessionID: utils.PtrTo("interaction-xyz"),
|
||||
},
|
||||
{
|
||||
name: "copilot_vsc_without_interaction_id",
|
||||
client: ClientCopilotVSC,
|
||||
client: aibridge.ClientCopilotVSC,
|
||||
},
|
||||
// Copilot CLI.
|
||||
{
|
||||
name: "copilot_cli_with_session_header",
|
||||
client: ClientCopilotCLI,
|
||||
client: aibridge.ClientCopilotCLI,
|
||||
headers: map[string]string{"X-Client-Session-Id": "cli-sess-456"},
|
||||
sessionID: utils.PtrTo("cli-sess-456"),
|
||||
},
|
||||
{
|
||||
name: "copilot_cli_without_session_header",
|
||||
client: ClientCopilotCLI,
|
||||
client: aibridge.ClientCopilotCLI,
|
||||
},
|
||||
// Kilo.
|
||||
{
|
||||
name: "kilo_with_task_id",
|
||||
client: ClientKilo,
|
||||
client: aibridge.ClientKilo,
|
||||
headers: map[string]string{"X-KILOCODE-TASKID": "task-789"},
|
||||
sessionID: utils.PtrTo("task-789"),
|
||||
},
|
||||
{
|
||||
name: "kilo_without_task_id",
|
||||
client: ClientKilo,
|
||||
client: aibridge.ClientKilo,
|
||||
},
|
||||
// Coder Agents.
|
||||
{
|
||||
name: "coder_agents_with_chat_id",
|
||||
client: ClientCoderAgents,
|
||||
client: aibridge.ClientCoderAgents,
|
||||
headers: map[string]string{"X-Coder-Chat-Id": "a1b2c3d4-e5f6-7890-abcd-ef1234567890"},
|
||||
sessionID: utils.PtrTo("a1b2c3d4-e5f6-7890-abcd-ef1234567890"),
|
||||
},
|
||||
{
|
||||
name: "coder_agents_without_chat_id",
|
||||
client: ClientCoderAgents,
|
||||
client: aibridge.ClientCoderAgents,
|
||||
},
|
||||
// Roo.
|
||||
{
|
||||
name: "roo_returns_empty",
|
||||
client: ClientRoo,
|
||||
client: aibridge.ClientRoo,
|
||||
},
|
||||
// Cursor.
|
||||
{
|
||||
name: "cursor_returns_empty",
|
||||
client: ClientCursor,
|
||||
client: aibridge.ClientCursor,
|
||||
},
|
||||
// Other cases.
|
||||
{
|
||||
name: "empty session ID value",
|
||||
client: ClientKilo,
|
||||
client: aibridge.ClientKilo,
|
||||
headers: map[string]string{"X-KILOCODE-TASKID": " "},
|
||||
sessionID: nil,
|
||||
},
|
||||
@@ -205,14 +206,14 @@ func TestGuessSessionID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
body := tc.body
|
||||
req, err := http.NewRequest(http.MethodPost, "http://localhost", strings.NewReader(body))
|
||||
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, "http://localhost", strings.NewReader(body))
|
||||
require.NoError(t, err)
|
||||
|
||||
for key, value := range tc.headers {
|
||||
req.Header.Set(key, value)
|
||||
}
|
||||
|
||||
got := guessSessionID(tc.client, req)
|
||||
got := aibridge.GuessSessionID(tc.client, req)
|
||||
require.Equal(t, tc.sessionID, got)
|
||||
|
||||
// Verify the body was restored and can be read again.
|
||||
@@ -226,16 +227,16 @@ func TestGuessSessionID(t *testing.T) {
|
||||
func TestUnreadableBody(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
req, err := http.NewRequest(http.MethodPost, "http://localhost", &errReader{})
|
||||
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, "http://localhost", &errReader{})
|
||||
require.NoError(t, err)
|
||||
|
||||
got := guessSessionID(ClientClaudeCode, req)
|
||||
got := aibridge.GuessSessionID(aibridge.ClientClaudeCode, req)
|
||||
require.Nil(t, got)
|
||||
}
|
||||
|
||||
// errReader is an io.Reader that always returns an error.
|
||||
type errReader struct{}
|
||||
|
||||
func (e *errReader) Read([]byte) (int, error) {
|
||||
func (*errReader) Read([]byte) (int, error) {
|
||||
return 0, io.ErrUnexpectedEOF
|
||||
}
|
||||
|
||||
@@ -18,11 +18,15 @@ func TestConcurrentGroup(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("no goroutines", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cg := utils.NewConcurrentGroup()
|
||||
require.NoError(t, cg.Wait())
|
||||
})
|
||||
|
||||
t.Run("multiple goroutines, all ok", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cg := utils.NewConcurrentGroup()
|
||||
cg.Go(func() error {
|
||||
return nil
|
||||
@@ -34,6 +38,8 @@ func TestConcurrentGroup(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("multiple goroutines, one err", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cg := utils.NewConcurrentGroup()
|
||||
oops := xerrors.New("oops")
|
||||
cg.Go(func() error {
|
||||
@@ -46,6 +52,8 @@ func TestConcurrentGroup(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("multiple goroutines, multiple errs", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cg := utils.NewConcurrentGroup()
|
||||
oops := xerrors.New("oops")
|
||||
eek := xerrors.New("eek")
|
||||
|
||||
Reference in New Issue
Block a user