Compare commits

...

1 Commits

Author SHA1 Message Date
Yevhenii Shcherbina 7bfa79f75d feat: add allow-byok option for ai-gateway 2026-04-12 14:12:01 +00:00
6 changed files with 29 additions and 8 deletions
+11
View File
@@ -3801,6 +3801,16 @@ Write out the current server config as YAML to stdout.`,
Group: &deploymentGroupAIBridge,
YAML: "send_actor_headers",
},
{
Name: "AI Gateway Allow BYOK",
Description: "Allow users to bring their own LLM API keys or subscriptions. When disabled, only centralized key authentication is permitted.",
Flag: "ai-gateway-allow-byok",
Env: "CODER_AI_GATEWAY_ALLOW_BYOK",
Value: &c.AI.BridgeConfig.AllowBYOK,
Default: "true",
Group: &deploymentGroupAIBridge,
YAML: "allow_byok",
},
{
Name: "AI Bridge Circuit Breaker Enabled",
Description: "Enable the circuit breaker to protect against cascading failures from upstream AI provider rate limits (429, 503, 529 overloaded).",
@@ -4048,6 +4058,7 @@ type AIBridgeConfig struct {
RateLimit serpent.Int64 `json:"rate_limit" typescript:",notnull"`
StructuredLogging serpent.Bool `json:"structured_logging" typescript:",notnull"`
SendActorHeaders serpent.Bool `json:"send_actor_headers" typescript:",notnull"`
AllowBYOK serpent.Bool `json:"allow_byok" typescript:",notnull"`
// Circuit breaker protects against cascading failures from upstream AI
// provider rate limits (429, 503, 529 overloaded).
CircuitBreakerEnabled serpent.Bool `json:"circuit_breaker_enabled" typescript:",notnull"`
+4 -1
View File
@@ -48,9 +48,11 @@ type Server struct {
cancelFn func()
shutdownOnce sync.Once
allowBYOK bool
}
func New(ctx context.Context, pool Pooler, rpcDialer Dialer, logger slog.Logger, tracer trace.Tracer) (*Server, error) {
func New(ctx context.Context, pool Pooler, rpcDialer Dialer, logger slog.Logger, tracer trace.Tracer, allowBYOK bool) (*Server, error) {
if rpcDialer == nil {
return nil, xerrors.Errorf("nil rpcDialer given")
}
@@ -66,6 +68,7 @@ func New(ctx context.Context, pool Pooler, rpcDialer Dialer, logger slog.Logger,
initConnectionCh: make(chan struct{}),
requestBridgePool: pool,
allowBYOK: allowBYOK,
}
daemon.wg.Add(1)
@@ -190,7 +190,7 @@ func TestIntegration(t *testing.T) {
// Given: aibridged is started.
srv, err := aibridged.New(t.Context(), pool, func(ctx context.Context) (aibridged.DRPCClient, error) {
return aiBridgeClient, nil
}, logger, tracer)
}, logger, tracer, true)
require.NoError(t, err, "create new aibridged")
t.Cleanup(func() {
_ = srv.Shutdown(ctx)
@@ -393,7 +393,7 @@ func TestIntegrationWithMetrics(t *testing.T) {
// Given: aibridged is started.
srv, err := aibridged.New(ctx, pool, func(ctx context.Context) (aibridged.DRPCClient, error) {
return aiBridgeClient, nil
}, logger, testTracer)
}, logger, testTracer, true)
require.NoError(t, err, "create new aibridged")
t.Cleanup(func() {
_ = srv.Shutdown(ctx)
@@ -508,7 +508,7 @@ func TestIntegrationCircuitBreaker(t *testing.T) {
// Given: aibridged is started.
srv, err := aibridged.New(ctx, pool, func(ctx context.Context) (aibridged.DRPCClient, error) {
return aiBridgeClient, nil
}, logger, testTracer)
}, logger, testTracer, true)
require.NoError(t, err, "create new aibridged")
t.Cleanup(func() {
_ = srv.Shutdown(ctx)
+3 -3
View File
@@ -43,7 +43,7 @@ func newTestServer(t *testing.T) (*aibridged.Server, *mock.MockDRPCClient, *mock
pool,
func(ctx context.Context) (aibridged.DRPCClient, error) {
return client, nil
}, logger, testTracer)
}, logger, testTracer, true)
require.NoError(t, err, "create new aibridged")
t.Cleanup(func() {
srv.Shutdown(context.Background())
@@ -441,7 +441,7 @@ func TestServeHTTP_ActorHeaders(t *testing.T) {
// Given: aibridged is started.
srv, err := aibridged.New(t.Context(), pool, func(ctx context.Context) (aibridged.DRPCClient, error) {
return client, nil
}, logger, testTracer)
}, logger, testTracer, true)
require.NoError(t, err, "create new aibridged")
t.Cleanup(func() {
_ = srv.Shutdown(testutil.Context(t, testutil.WaitShort))
@@ -545,7 +545,7 @@ func TestRouting(t *testing.T) {
// Given: aibridged is started.
srv, err := aibridged.New(t.Context(), pool, func(ctx context.Context) (aibridged.DRPCClient, error) {
return client, nil
}, logger, testTracer)
}, logger, testTracer, true)
require.NoError(t, err, "create new aibridged")
t.Cleanup(func() {
_ = srv.Shutdown(testutil.Context(t, testutil.WaitShort))
+7
View File
@@ -56,6 +56,13 @@ func (s *Server) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
authMode = "byok"
}
if byok && !s.allowBYOK {
logger.Warn(ctx, "BYOK request rejected: not allowed by deployment configuration")
http.Error(rw, "Bring Your Own Key (BYOK) mode is not enabled. "+
"Contact your administrator to enable it with --ai-gateway-allow-byok.", http.StatusForbidden)
return
}
key := strings.TrimSpace(agplaibridge.ExtractAuthToken(r.Header))
if key == "" {
// Some clients (e.g. Claude) send a HEAD request
+1 -1
View File
@@ -86,7 +86,7 @@ func newAIBridgeDaemon(coderAPI *coderd.API) (*aibridged.Server, error) {
// Create daemon.
srv, err := aibridged.New(ctx, pool, func(dialCtx context.Context) (aibridged.DRPCClient, error) {
return coderAPI.CreateInMemoryAIBridgeServer(dialCtx)
}, logger, tracer)
}, logger, tracer, cfg.AllowBYOK.Value())
if err != nil {
return nil, xerrors.Errorf("start in-memory aibridge daemon: %w", err)
}