Compare commits
9 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| a58828fb39 | |||
| 67103a6e61 | |||
| a46f13b321 | |||
| 5a2756c38f | |||
| 6cec0956dd | |||
| d3e5e8b1bb | |||
| 5af92abb28 | |||
| 3acab4a7bc | |||
| 5b4eef620f |
@@ -0,0 +1,101 @@
|
||||
package aibridged
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/aibridge"
|
||||
)
|
||||
|
||||
var _ aibridge.Provider = &AmpProvider{}
|
||||
|
||||
const (
|
||||
ProviderAmp = "amp"
|
||||
ampRouteMessages = "/amp/v1/messages"
|
||||
)
|
||||
|
||||
type (
|
||||
AmpConfig = aibridge.ProviderConfig
|
||||
)
|
||||
|
||||
type AmpProvider struct {
|
||||
cfg AmpConfig
|
||||
}
|
||||
|
||||
func NewAmpProvider(cfg AmpConfig) *AmpProvider {
|
||||
if cfg.BaseURL == "" {
|
||||
cfg.BaseURL = "https://ampcode.com/api/provider/anthropic"
|
||||
}
|
||||
return &AmpProvider{cfg: cfg}
|
||||
}
|
||||
|
||||
func (p *AmpProvider) Name() string {
|
||||
return ProviderAmp
|
||||
}
|
||||
|
||||
func (p *AmpProvider) BaseURL() string {
|
||||
return p.cfg.BaseURL
|
||||
}
|
||||
|
||||
// BridgedRoutes returns routes that will be intercepted.
|
||||
func (p *AmpProvider) BridgedRoutes() []string {
|
||||
return []string{ampRouteMessages}
|
||||
}
|
||||
|
||||
// PassthroughRoutes returns routes that are proxied directly.
|
||||
func (p *AmpProvider) PassthroughRoutes() []string {
|
||||
return []string{
|
||||
"/v1/models",
|
||||
"/v1/models/",
|
||||
"/v1/messages/count_tokens",
|
||||
}
|
||||
}
|
||||
|
||||
func (p *AmpProvider) AuthHeader() string {
|
||||
return "X-Api-Key"
|
||||
}
|
||||
|
||||
// InjectAuthHeader Amp already makes the request with X-Api-Key containing the authenticated user's API key
|
||||
// One key per user instead of a global key.
|
||||
func (p *AmpProvider) InjectAuthHeader(h *http.Header) {}
|
||||
|
||||
// CreateInterceptor creates an interceptor for the request.
|
||||
// Reuses Anthropic's interceptor since Amp uses the same API format.
|
||||
func (p *AmpProvider) CreateInterceptor(w http.ResponseWriter, r *http.Request, tracer trace.Tracer) (aibridge.Interceptor, error) {
|
||||
// Capture the API key from the incoming request
|
||||
apiKey := r.Header.Get("X-Api-Key")
|
||||
|
||||
payload, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("read body: %w", err)
|
||||
}
|
||||
|
||||
id := uuid.New()
|
||||
|
||||
switch r.URL.Path {
|
||||
case ampRouteMessages:
|
||||
var req aibridge.MessageNewParamsWrapper
|
||||
if err := json.Unmarshal(payload, &req); err != nil {
|
||||
return nil, xerrors.Errorf("failed to unmarshal request: %w", err)
|
||||
}
|
||||
|
||||
// Reuse Anthropic interceptors as Amp uses the same API format
|
||||
ampCfg := aibridge.AnthropicConfig{
|
||||
BaseURL: p.cfg.BaseURL,
|
||||
Key: apiKey,
|
||||
}
|
||||
|
||||
if req.Stream {
|
||||
return aibridge.NewAnthropicMessagesStreamingInterception(id, &req, ampCfg, nil, tracer), nil
|
||||
}
|
||||
|
||||
return aibridge.NewAnthropicMessagesBlockingInterception(id, &req, ampCfg, nil, tracer), nil
|
||||
}
|
||||
|
||||
return nil, aibridge.UnknownRoute
|
||||
}
|
||||
@@ -0,0 +1,259 @@
|
||||
# AI Proxy Certificate Setup
|
||||
|
||||
This document describes how to set up MITM certificates for the AI proxy, including cross-signing for proxy chaining scenarios.
|
||||
|
||||
## Overview
|
||||
|
||||
The AI proxy uses MITM (Man-in-the-Middle) to intercept HTTPS traffic to AI providers. When chaining through an upstream SSL-bumping proxy (like Squid), both proxies need coordinated certificate trust.
|
||||
|
||||
## Certificate Hierarchy
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────┐
|
||||
│ Upstream Proxy Root CA │
|
||||
│ (e.g., Squid Root CA) │
|
||||
│ Self-signed │
|
||||
└──────────────┬──────────────────────┘
|
||||
│ signs
|
||||
▼
|
||||
┌─────────────────────────────────────┐
|
||||
│ Downstream Proxy CA (intermediate) │
|
||||
│ (AI Proxy's MITM CA) │
|
||||
│ Cross-signed by upstream │
|
||||
└──────────────┬──────────────────────┘
|
||||
│ signs
|
||||
▼
|
||||
┌─────────────────────────────────────┐
|
||||
│ Leaf certificates │
|
||||
│ (Generated per-site for MITM) │
|
||||
└─────────────────────────────────────┘
|
||||
```
|
||||
|
||||
## Creating a New CA Key Pair
|
||||
|
||||
### 1. Generate a new private key
|
||||
|
||||
```sh
|
||||
openssl genrsa -out mitm.key 2048
|
||||
chmod 400 mitm.key
|
||||
```
|
||||
|
||||
### 2. Create a self-signed CA certificate
|
||||
|
||||
```sh
|
||||
openssl req -new -x509 -days 365 \
|
||||
-key mitm.key \
|
||||
-out mitm.crt \
|
||||
-subj "/CN=AI Proxy CA"
|
||||
```
|
||||
|
||||
## Cross-Signing Against an Existing CA
|
||||
|
||||
Cross-signing allows your CA to be trusted by clients that already trust another root CA. This is essential for proxy chaining where an upstream proxy does SSL bumping.
|
||||
|
||||
### 1. Create a Certificate Signing Request (CSR) from your key
|
||||
|
||||
```sh
|
||||
openssl req -new \
|
||||
-key mitm.key \
|
||||
-out mitm.csr \
|
||||
-subj "/CN=AI Proxy CA"
|
||||
```
|
||||
|
||||
### 2. Create an extensions file for CA certificates
|
||||
|
||||
```sh
|
||||
cat > ca_extensions.cnf << 'EOF'
|
||||
basicConstraints=CA:TRUE
|
||||
keyUsage=keyCertSign,cRLSign
|
||||
EOF
|
||||
```
|
||||
|
||||
### 3. Sign the CSR with the upstream CA
|
||||
|
||||
```sh
|
||||
openssl x509 -req \
|
||||
-in mitm.csr \
|
||||
-CA upstream-ca.crt \
|
||||
-CAkey upstream-ca.key \
|
||||
-CAcreateserial \
|
||||
-out mitm-cross-signed.crt \
|
||||
-days 365 \
|
||||
-extfile ca_extensions.cnf
|
||||
```
|
||||
|
||||
### 4. Create a certificate chain file
|
||||
|
||||
The chain file should contain your certificate followed by the upstream CA:
|
||||
|
||||
```sh
|
||||
cat mitm-cross-signed.crt upstream-ca.crt > mitm-chain.crt
|
||||
```
|
||||
|
||||
### 5. Verify the chain
|
||||
|
||||
```sh
|
||||
# Check the signing relationship
|
||||
openssl x509 -in mitm-cross-signed.crt -noout -subject -issuer
|
||||
|
||||
# Verify the chain is valid
|
||||
openssl verify -CAfile upstream-ca.crt mitm-cross-signed.crt
|
||||
```
|
||||
|
||||
Expected output:
|
||||
```
|
||||
subject=CN = AI Proxy CA
|
||||
issuer=CN = Upstream Root CA
|
||||
mitm-cross-signed.crt: OK
|
||||
```
|
||||
|
||||
## Proxy Chaining Architecture
|
||||
|
||||
### Request Flow
|
||||
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
participant Client
|
||||
participant AIProxy as AI Proxy<br/>(Downstream)
|
||||
participant Squid as Squid<br/>(Upstream)
|
||||
participant Coder as Coder Server<br/>(aibridge)
|
||||
participant Anthropic as Anthropic API
|
||||
|
||||
Note over Client,Anthropic: Phase 1: Tunnel Establishment
|
||||
|
||||
Client->>AIProxy: CONNECT api.anthropic.com:443<br/>Proxy-Authorization: Basic <coder-token>
|
||||
AIProxy->>Squid: CONNECT api.anthropic.com:443
|
||||
Squid->>Anthropic: TCP Connection
|
||||
Anthropic-->>Squid: Connected
|
||||
Squid-->>AIProxy: 200 Connection Established
|
||||
AIProxy-->>Client: 200 Connection Established
|
||||
|
||||
Note over Client,Anthropic: Phase 2: TLS Handshakes (MITM)
|
||||
|
||||
Client->>AIProxy: TLS ClientHello
|
||||
AIProxy->>AIProxy: Generate cert for api.anthropic.com<br/>signed by AI Proxy CA
|
||||
AIProxy-->>Client: TLS ServerHello<br/>(AI Proxy's cert)
|
||||
Client->>AIProxy: TLS Finished
|
||||
|
||||
AIProxy->>Squid: TLS ClientHello
|
||||
Squid->>Squid: Generate cert for api.anthropic.com<br/>signed by Squid CA
|
||||
Squid-->>AIProxy: TLS ServerHello<br/>(Squid's cert)
|
||||
Note over AIProxy: Validates Squid's cert<br/>using UpstreamProxyCACert
|
||||
AIProxy->>Squid: TLS Finished
|
||||
|
||||
Squid->>Anthropic: TLS ClientHello
|
||||
Anthropic-->>Squid: TLS ServerHello<br/>(Real cert)
|
||||
Squid->>Anthropic: TLS Finished
|
||||
|
||||
Note over Client,Anthropic: Phase 3: Request Interception & Routing
|
||||
|
||||
Client->>AIProxy: POST /v1/messages<br/>(to api.anthropic.com)
|
||||
AIProxy->>AIProxy: Decrypt request<br/>Extract coder-token<br/>Rewrite URL to aibridge
|
||||
|
||||
AIProxy->>Squid: POST /api/v2/aibridge/anthropic/v1/messages<br/>(to Coder server)<br/>Authorization: Bearer <coder-token>
|
||||
Squid->>Squid: Decrypt, log, re-encrypt
|
||||
Squid->>Coder: POST /api/v2/aibridge/anthropic/v1/messages
|
||||
|
||||
Note over Coder: Validate token<br/>Record usage<br/>Forward to provider
|
||||
|
||||
Coder->>Anthropic: POST /v1/messages<br/>(with Anthropic API key)
|
||||
Anthropic-->>Coder: Response (streaming or JSON)
|
||||
Coder-->>Squid: Response
|
||||
Squid-->>AIProxy: Response
|
||||
AIProxy-->>Client: Response
|
||||
```
|
||||
|
||||
### Trust Relationships
|
||||
|
||||
```mermaid
|
||||
flowchart TB
|
||||
subgraph Client["Client Machine"]
|
||||
ClientTrust["Trusts: Squid Root CA<br/>(squid-ca.crt)"]
|
||||
end
|
||||
|
||||
subgraph AIProxy["AI Proxy (Downstream)"]
|
||||
AIProxyCA["AI Proxy CA<br/>- Signs certs for clients<br/>- Cross-signed by Squid CA"]
|
||||
AIProxyTrust["Trusts: Squid Root CA<br/>(via UpstreamProxyCACert)"]
|
||||
end
|
||||
|
||||
subgraph Squid["Squid Proxy (Upstream)"]
|
||||
SquidCA["Squid Root CA<br/>- Signs certs for AI Proxy<br/>- Self-signed root"]
|
||||
end
|
||||
|
||||
subgraph Internet["Internet"]
|
||||
RealCA["Public CAs<br/>(DigiCert, Let's Encrypt, etc.)"]
|
||||
Target["Target Servers"]
|
||||
end
|
||||
|
||||
ClientTrust -->|validates chain via| SquidCA
|
||||
AIProxyCA -->|signed by| SquidCA
|
||||
AIProxyTrust -->|validates| SquidCA
|
||||
Squid -->|trusts| RealCA
|
||||
RealCA -->|signs| Target
|
||||
```
|
||||
|
||||
### Certificate Files Reference
|
||||
|
||||
| ID | File | Purpose | Used by |
|
||||
|----|--------------------------------------------|---------------------------------------------------|------------------------------------------------|
|
||||
| A | Upstream Root CA key (`squid-ca.key`) | Signs intermediate CA; Signs fake certs for Squid | Upstream Proxy (Squid) |
|
||||
| B | Upstream Root CA cert (`squid-ca.crt`) | Trust anchor for entire chain | AI Proxy (upstream trust); Client (root trust) |
|
||||
| C | AI Proxy CA key (`mitm.key`) | Signs fake certificates for client connections | AI Proxy |
|
||||
| D | AI Proxy CA cert (`mitm-cross-signed.crt`) | Intermediate CA, signed by upstream root | AI Proxy (part of chain served to clients) |
|
||||
| E | Certificate chain (`mitm-chain.crt`) | D + B combined, full chain | AI Proxy (loads as CA cert file) |
|
||||
|
||||
> **Note**: Clients only need to trust `squid-ca.crt` (B). If this is already in the system trust store (e.g., corporate proxy CA), no additional client configuration is needed.
|
||||
|
||||
### Certificate Signing (Setup Time)
|
||||
|
||||
How the cross-signed certificate chain is created:
|
||||
|
||||
```mermaid
|
||||
flowchart TD
|
||||
A["A: squid-ca.key"] -->|signs| D["D: mitm-cross-signed.crt"]
|
||||
C["C: mitm.key"] -->|generates CSR| D
|
||||
D -->|concatenate| E["E: mitm-chain.crt"]
|
||||
B["B: squid-ca.crt"] -->|concatenate| E
|
||||
```
|
||||
|
||||
### Certificate Trust (Runtime)
|
||||
|
||||
```mermaid
|
||||
flowchart LR
|
||||
Client -->|trusts| B["B: squid-ca.crt"]
|
||||
AIProxy["AI Proxy"] -->|trusts| B
|
||||
Squid -->|trusts| PublicCAs["Public CAs"]
|
||||
```
|
||||
|
||||
> Clients validate the chain: leaf cert → D (intermediate) → B (root). Only B needs to be trusted.
|
||||
|
||||
### Certificate Usage (Runtime)
|
||||
|
||||
Which keys sign fake certificates:
|
||||
|
||||
```mermaid
|
||||
flowchart LR
|
||||
C["C: mitm.key"] -->|signs fake certs| AIProxy["AI Proxy"]
|
||||
A["A: squid-ca.key"] -->|signs fake certs| Squid
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Verifying Cross-Signing
|
||||
|
||||
Check that your certificate shows a different issuer than subject:
|
||||
|
||||
```sh
|
||||
openssl x509 -in mitm-cross-signed.crt -noout -subject -issuer
|
||||
```
|
||||
|
||||
If both are the same, the certificate is self-signed, not cross-signed.
|
||||
|
||||
### Testing the Chain
|
||||
|
||||
```sh
|
||||
# Test with curl through the proxy
|
||||
curl -x http://localhost:8888 \
|
||||
--cacert /path/to/ai-proxy-ca.crt \
|
||||
https://api.anthropic.com/v1/messages
|
||||
```
|
||||
@@ -0,0 +1,452 @@
|
||||
package aiproxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/elazarl/goproxy"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog"
|
||||
)
|
||||
|
||||
// goproxyLogger adapts slog.Logger to goproxy's Logger interface.
|
||||
type goproxyLogger struct {
|
||||
ctx context.Context
|
||||
logger slog.Logger
|
||||
}
|
||||
|
||||
func (l *goproxyLogger) Printf(format string, v ...any) {
|
||||
// goproxy's format includes "[%03d] " session prefix and trailing newline.
|
||||
// We strip the newline since slog adds its own.
|
||||
msg := fmt.Sprintf(format, v...)
|
||||
msg = strings.TrimSuffix(msg, "\n")
|
||||
l.logger.Debug(l.ctx, msg)
|
||||
}
|
||||
|
||||
// certCache implements goproxy.CertStorage to cache generated leaf certificates.
|
||||
type certCache struct {
|
||||
mu sync.RWMutex
|
||||
certs map[string]*tls.Certificate
|
||||
}
|
||||
|
||||
func (c *certCache) Fetch(hostname string, gen func() (*tls.Certificate, error)) (*tls.Certificate, error) {
|
||||
c.mu.RLock()
|
||||
cert, ok := c.certs[hostname]
|
||||
c.mu.RUnlock()
|
||||
if ok {
|
||||
return cert, nil
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
// Double-check after acquiring write lock
|
||||
if cert, ok := c.certs[hostname]; ok {
|
||||
return cert, nil
|
||||
}
|
||||
|
||||
cert, err := gen()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.certs[hostname] = cert
|
||||
return cert, nil
|
||||
}
|
||||
|
||||
type Server struct {
|
||||
ctx context.Context
|
||||
logger slog.Logger
|
||||
proxy *goproxy.ProxyHttpServer
|
||||
httpServer *http.Server
|
||||
coderAccessURL *url.URL
|
||||
}
|
||||
|
||||
type Options struct {
|
||||
ListenAddr string
|
||||
CertFile string
|
||||
KeyFile string
|
||||
CoderAccessURL string
|
||||
// UpstreamProxy is the URL of an upstream HTTP proxy to chain requests through.
|
||||
// If empty, requests are made directly to targets.
|
||||
// Format: http://[user:pass@]host:port or https://[user:pass@]host:port
|
||||
UpstreamProxy string
|
||||
// UpstreamProxyCACert is the PEM-encoded CA certificate to trust for the upstream
|
||||
// proxy's TLS interception. Required when chaining through an SSL-bumping proxy.
|
||||
UpstreamProxyCACert []byte
|
||||
}
|
||||
|
||||
func New(ctx context.Context, logger slog.Logger, opts Options) (*Server, error) {
|
||||
logger.Info(ctx, "initializing AI proxy server")
|
||||
|
||||
// Load CA certificate for MITM
|
||||
if err := loadMitmCertificate(opts.CertFile, opts.KeyFile); err != nil {
|
||||
return nil, xerrors.Errorf("failed to load MITM certificate: %w", err)
|
||||
}
|
||||
|
||||
// Parse coderAccessURL once at startup - invalid URL is a fatal config error
|
||||
coderAccessURL, err := url.Parse(opts.CoderAccessURL)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("invalid CoderAccessURL %q: %w", opts.CoderAccessURL, err)
|
||||
}
|
||||
|
||||
srv := &Server{
|
||||
ctx: ctx,
|
||||
logger: logger,
|
||||
coderAccessURL: coderAccessURL,
|
||||
}
|
||||
|
||||
proxy := goproxy.NewProxyHttpServer()
|
||||
proxy.Verbose = true
|
||||
// proxy.Logger = &goproxyLogger{ctx: ctx, logger: logger.Named("goproxy")}
|
||||
proxy.CertStore = &certCache{certs: make(map[string]*tls.Certificate)}
|
||||
|
||||
// Configure upstream proxy for chaining if specified
|
||||
if opts.UpstreamProxy != "" {
|
||||
upstreamURL, err := url.Parse(opts.UpstreamProxy)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("invalid UpstreamProxy URL %q: %w", opts.UpstreamProxy, err)
|
||||
}
|
||||
logger.Info(ctx, "configuring upstream proxy", slog.F("upstream", upstreamURL.Host))
|
||||
|
||||
tlsConfig := &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
}
|
||||
|
||||
// Add upstream proxy CA to trusted roots if provided
|
||||
if len(opts.UpstreamProxyCACert) > 0 {
|
||||
rootCAs, err := x509.SystemCertPool()
|
||||
if err != nil {
|
||||
rootCAs = x509.NewCertPool()
|
||||
}
|
||||
if !rootCAs.AppendCertsFromPEM(opts.UpstreamProxyCACert) {
|
||||
return nil, xerrors.Errorf("failed to parse upstream proxy CA certificate")
|
||||
}
|
||||
tlsConfig.RootCAs = rootCAs
|
||||
logger.Info(ctx, "configured upstream proxy CA certificate")
|
||||
}
|
||||
|
||||
// Configure HTTP transport to use upstream proxy
|
||||
proxy.Tr = &http.Transport{
|
||||
Proxy: http.ProxyURL(upstreamURL),
|
||||
TLSClientConfig: tlsConfig,
|
||||
}
|
||||
|
||||
// Configure CONNECT requests to go through upstream proxy
|
||||
proxy.ConnectDial = proxy.NewConnectDialToProxy(opts.UpstreamProxy)
|
||||
}
|
||||
|
||||
// Custom MITM handler that extracts auth and rejects unauthenticated requests.
|
||||
// The token is stored in ctx.UserData which goproxy propagates to subsequent
|
||||
// request contexts for decrypted requests within this MITM session.
|
||||
mitmWithAuth := goproxy.FuncHttpsHandler(func(host string, ctx *goproxy.ProxyCtx) (*goproxy.ConnectAction, string) {
|
||||
// Restrict CONNECT to standard HTTP/HTTPS ports to prevent request smuggling.
|
||||
// Attackers could use non-standard ports to bypass security controls.
|
||||
port := extractPort(host)
|
||||
if port != "443" && port != "80" {
|
||||
srv.logger.Warn(srv.ctx, "rejecting connect to non-standard port",
|
||||
slog.F("host", host),
|
||||
slog.F("port", port),
|
||||
)
|
||||
return goproxy.RejectConnect, host
|
||||
}
|
||||
|
||||
proxyAuth := ctx.Req.Header.Get("Proxy-Authorization")
|
||||
coderToken := extractCoderTokenFromProxyAuth(proxyAuth)
|
||||
|
||||
// Reject unauthenticated or invalid auth requests - proxy is a protected service
|
||||
if coderToken == "" {
|
||||
hasAuth := proxyAuth != ""
|
||||
srv.logger.Warn(srv.ctx, "rejecting connect request",
|
||||
slog.F("host", host),
|
||||
slog.F("reason", map[bool]string{true: "invalid_auth", false: "missing_auth"}[hasAuth]),
|
||||
)
|
||||
return goproxy.RejectConnect, host
|
||||
}
|
||||
|
||||
// Store token in UserData - goproxy copies this to subsequent request contexts
|
||||
ctx.UserData = coderToken
|
||||
|
||||
return goproxy.MitmConnect, host
|
||||
})
|
||||
|
||||
// Apply MITM only to allowlisted AI provider hosts
|
||||
proxy.OnRequest(goproxy.ReqHostIs(
|
||||
"api.anthropic.com:443",
|
||||
"api.openai.com:443",
|
||||
"ampcode.com:443",
|
||||
)).HandleConnect(mitmWithAuth)
|
||||
|
||||
// Request handler for decrypted HTTPS traffic
|
||||
proxy.OnRequest().DoFunc(srv.requestHandler)
|
||||
|
||||
// Response handler
|
||||
proxy.OnResponse().DoFunc(srv.responseHandler)
|
||||
|
||||
srv.proxy = proxy
|
||||
|
||||
// Start HTTP server in background
|
||||
srv.httpServer = &http.Server{
|
||||
Addr: opts.ListenAddr,
|
||||
Handler: proxy,
|
||||
ReadHeaderTimeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
go func() {
|
||||
logger.Info(ctx, "starting AI proxy", slog.F("addr", opts.ListenAddr))
|
||||
if err := srv.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
logger.Error(ctx, "proxy server error", slog.Error(err))
|
||||
}
|
||||
}()
|
||||
|
||||
return srv, nil
|
||||
}
|
||||
|
||||
// loadMitmCertificate loads the CA certificate for MITM into goproxy.
|
||||
func loadMitmCertificate(certFile, keyFile string) error {
|
||||
tlsCert, err := tls.LoadX509KeyPair(certFile, keyFile)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("failed to load x509 keypair: %w", err)
|
||||
}
|
||||
|
||||
x509Cert, err := x509.ParseCertificate(tlsCert.Certificate[0])
|
||||
if err != nil {
|
||||
return xerrors.Errorf("failed to parse certificate: %w", err)
|
||||
}
|
||||
|
||||
goproxy.GoproxyCa = tls.Certificate{
|
||||
Certificate: tlsCert.Certificate,
|
||||
PrivateKey: tlsCert.PrivateKey,
|
||||
Leaf: x509Cert,
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// extractCoderTokenFromProxyAuth extracts the Coder session token from the
|
||||
// Proxy-Authorization header. The token is expected to be in the password
|
||||
// field of basic auth: "Basic base64(ignored:token)"
|
||||
// Returns empty string if no valid token is found.
|
||||
func extractCoderTokenFromProxyAuth(proxyAuth string) string {
|
||||
if proxyAuth == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Expected format: "Basic base64(username:password)"
|
||||
// Auth scheme is case-insensitive per HTTP spec
|
||||
parts := strings.Fields(proxyAuth)
|
||||
if len(parts) != 2 || !strings.EqualFold(parts[0], "Basic") {
|
||||
return ""
|
||||
}
|
||||
|
||||
decoded, err := base64.StdEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Format: "username:password" - we use password as the Coder token
|
||||
credentials := strings.SplitN(string(decoded), ":", 2)
|
||||
if len(credentials) != 2 {
|
||||
return ""
|
||||
}
|
||||
|
||||
return credentials[1]
|
||||
}
|
||||
|
||||
// extractPort extracts the port from a host:port string.
|
||||
// Returns "443" as the default if no port is specified (standard HTTPS).
|
||||
func extractPort(host string) string {
|
||||
if i := strings.LastIndex(host, ":"); i != -1 {
|
||||
return host[i+1:]
|
||||
}
|
||||
return "443" // Default HTTPS port
|
||||
}
|
||||
|
||||
// canonicalHost strips the port from a host:port string and lowercases it.
|
||||
func canonicalHost(h string) string {
|
||||
if i := strings.IndexByte(h, ':'); i != -1 {
|
||||
h = h[:i]
|
||||
}
|
||||
return strings.ToLower(h)
|
||||
}
|
||||
|
||||
// providerFromHost maps the request host to the aibridge provider name.
|
||||
// All requests through the proxy for known AI providers are routed through aibridge.
|
||||
// Unknown hosts return empty string and are passed through directly without aibridge.
|
||||
// Uses exact host matching consistent with the MITM allowlist.
|
||||
func providerFromHost(host string) string {
|
||||
h := canonicalHost(host)
|
||||
switch h {
|
||||
case "api.anthropic.com":
|
||||
return "anthropic"
|
||||
case "api.openai.com":
|
||||
return "openai"
|
||||
case "ampcode.com":
|
||||
return "amp"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// mapPathForProvider converts the original request path to the aibridge path format.
|
||||
// Returns empty string if the path should not be routed through aibridge.
|
||||
func mapPathForProvider(provider, originalPath string) string {
|
||||
switch provider {
|
||||
case "amp":
|
||||
// Only intercept AI provider routes
|
||||
// Original: /api/provider/anthropic/v1/messages
|
||||
// aibridge expects: /amp/v1/messages
|
||||
const ampPrefix = "/api/provider/anthropic"
|
||||
if strings.HasPrefix(originalPath, ampPrefix) {
|
||||
return "/amp" + strings.TrimPrefix(originalPath, ampPrefix)
|
||||
}
|
||||
// Other Amp routes (e.g., /api/internal) should not go through aibridge
|
||||
return ""
|
||||
case "anthropic":
|
||||
return "/anthropic" + originalPath
|
||||
case "openai":
|
||||
return "/openai" + originalPath
|
||||
default:
|
||||
return "/" + provider + originalPath
|
||||
}
|
||||
}
|
||||
|
||||
// requestHandler intercepts HTTP requests after MITM decryption.
|
||||
// LLM requests are rewritten to aibridge, with the Coder session token
|
||||
// (from ctx.UserData, set during CONNECT) injected as "Authorization: Bearer <token>".
|
||||
func (srv *Server) requestHandler(req *http.Request, ctx *goproxy.ProxyCtx) (*http.Request, *http.Response) {
|
||||
// Get token from UserData (set during CONNECT, propagated by goproxy)
|
||||
// Presence of token indicates request was decrypted via MITM
|
||||
coderToken, _ := ctx.UserData.(string)
|
||||
decrypted := coderToken != ""
|
||||
|
||||
// Check if this request is for a supported AI provider.
|
||||
provider := providerFromHost(req.Host)
|
||||
if provider == "" {
|
||||
srv.logger.Debug(srv.ctx, "passthrough request to unknown host",
|
||||
slog.F("host", req.Host),
|
||||
slog.F("method", req.Method),
|
||||
slog.F("path", req.URL.Path),
|
||||
slog.F("decrypted", decrypted),
|
||||
)
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// Map the original path to the aibridge path
|
||||
originalPath := req.URL.Path
|
||||
aibridgePath := mapPathForProvider(provider, originalPath)
|
||||
|
||||
// If path doesn't map to an aibridge route, pass through directly
|
||||
if aibridgePath == "" {
|
||||
srv.logger.Debug(srv.ctx, "passthrough request to non-aibridge path",
|
||||
slog.F("host", req.Host),
|
||||
slog.F("method", req.Method),
|
||||
slog.F("path", originalPath),
|
||||
slog.F("provider", provider),
|
||||
slog.F("decrypted", decrypted),
|
||||
)
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// Reject unauthenticated requests
|
||||
if coderToken == "" {
|
||||
srv.logger.Warn(srv.ctx, "rejecting unauthenticated request",
|
||||
slog.F("host", req.Host),
|
||||
slog.F("path", originalPath),
|
||||
slog.F("decrypted", decrypted),
|
||||
)
|
||||
resp := goproxy.NewResponse(req, goproxy.ContentTypeText, http.StatusProxyAuthRequired, "Proxy authentication required")
|
||||
resp.Header.Set("Proxy-Authenticate", `Basic realm="Coder AI proxy"`)
|
||||
return req, resp
|
||||
}
|
||||
|
||||
// Rewrite URL to point to aibridged (shallow copy of pre-parsed URL)
|
||||
newURL := *srv.coderAccessURL
|
||||
newURL.Path = "/api/v2/aibridge" + aibridgePath
|
||||
newURL.RawQuery = req.URL.RawQuery
|
||||
|
||||
req.URL = &newURL
|
||||
req.Host = newURL.Host
|
||||
|
||||
// Set Authorization header for coder's aibridge authentication
|
||||
req.Header.Set("Authorization", "Bearer "+coderToken)
|
||||
|
||||
srv.logger.Info(srv.ctx, "proxying decrypted request to aibridge",
|
||||
slog.F("method", req.Method),
|
||||
slog.F("provider", provider),
|
||||
slog.F("path", originalPath),
|
||||
)
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// responseHandler handles responses from upstream.
|
||||
func (srv *Server) responseHandler(resp *http.Response, ctx *goproxy.ProxyCtx) *http.Response {
|
||||
// Check for proxy errors (connection failures, TLS errors, etc.)
|
||||
if ctx.Error != nil {
|
||||
srv.logger.Error(srv.ctx, "upstream request failed",
|
||||
slog.F("error", ctx.Error.Error()),
|
||||
slog.F("url", ctx.Req.URL.String()),
|
||||
slog.F("method", ctx.Req.Method),
|
||||
)
|
||||
return resp
|
||||
}
|
||||
|
||||
req := ctx.Req
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
contentEncoding := resp.Header.Get("Content-Encoding")
|
||||
|
||||
// Skip logging for compressed or streaming responses to preserve streaming semantics
|
||||
if contentEncoding == "" && strings.Contains(contentType, "text") &&
|
||||
!strings.HasPrefix(contentType, "text/event-stream") {
|
||||
// Read the response body
|
||||
var bodyBytes []byte
|
||||
if resp.Body != nil {
|
||||
bodyBytes, _ = io.ReadAll(resp.Body)
|
||||
resp.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
}
|
||||
|
||||
if resp.StatusCode == http.StatusOK || resp.StatusCode == http.StatusAccepted {
|
||||
srv.logger.Info(srv.ctx, "received response",
|
||||
slog.F("url", req.URL.String()),
|
||||
slog.F("method", req.Method),
|
||||
slog.F("path", req.URL.Path),
|
||||
slog.F("response_status", resp.StatusCode),
|
||||
slog.F("response_body", string(bodyBytes)),
|
||||
)
|
||||
} else {
|
||||
srv.logger.Warn(srv.ctx, "received response",
|
||||
slog.F("url", req.URL.String()),
|
||||
slog.F("method", req.Method),
|
||||
slog.F("path", req.URL.Path),
|
||||
slog.F("response_status", resp.StatusCode),
|
||||
slog.F("response_body", string(bodyBytes)),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return resp
|
||||
}
|
||||
|
||||
// Close shuts down the proxy server.
|
||||
// Note: existing MITM'd connections may persist briefly after shutdown due to
|
||||
// goproxy's hijack-based design - Shutdown only manages connections that net/http
|
||||
// is aware of.
|
||||
func (srv *Server) Close() error {
|
||||
if srv.httpServer == nil {
|
||||
return nil
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
return srv.httpServer.Shutdown(ctx)
|
||||
}
|
||||
@@ -18,7 +18,7 @@ import (
|
||||
|
||||
func newAIBridgeDaemon(coderAPI *coderd.API) (*aibridged.Server, error) {
|
||||
ctx := context.Background()
|
||||
coderAPI.Logger.Debug(ctx, "starting in-memory aibridge daemon")
|
||||
coderAPI.Logger.Info(ctx, "starting in-memory aibridge daemon")
|
||||
|
||||
logger := coderAPI.Logger.Named("aibridged")
|
||||
|
||||
@@ -32,6 +32,10 @@ func newAIBridgeDaemon(coderAPI *coderd.API) (*aibridged.Server, error) {
|
||||
BaseURL: coderAPI.DeploymentValues.AI.BridgeConfig.Anthropic.BaseURL.String(),
|
||||
Key: coderAPI.DeploymentValues.AI.BridgeConfig.Anthropic.Key.String(),
|
||||
}, getBedrockConfig(coderAPI.DeploymentValues.AI.BridgeConfig.Bedrock)),
|
||||
// TODO(ssncferreira): add provider to aibridge project
|
||||
aibridged.NewAmpProvider(aibridged.AmpConfig{
|
||||
BaseURL: "https://ampcode.com/api/provider/anthropic",
|
||||
}),
|
||||
}
|
||||
|
||||
reg := prometheus.WrapRegistererWithPrefix("coder_aibridged_", coderAPI.PrometheusRegistry)
|
||||
|
||||
@@ -0,0 +1,46 @@
|
||||
//go:build !slim
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/coder/coder/v2/enterprise/aiproxy"
|
||||
"github.com/coder/coder/v2/enterprise/coderd"
|
||||
)
|
||||
|
||||
func newAIProxy(coderAPI *coderd.API) (*aiproxy.Server, error) {
|
||||
ctx := context.Background()
|
||||
coderAPI.Logger.Info(ctx, "starting in-memory AI proxy")
|
||||
|
||||
logger := coderAPI.Logger.Named("aiproxy")
|
||||
|
||||
// Load upstream proxy CA certificate if specified
|
||||
var upstreamCACert []byte
|
||||
if caPath := os.Getenv("CODER_AI_PROXY_UPSTREAM_CA"); caPath != "" {
|
||||
var err error
|
||||
upstreamCACert, err = os.ReadFile(caPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
logger.Info(ctx, "loaded upstream proxy CA certificate", "path", caPath)
|
||||
}
|
||||
|
||||
// TODO: Make these configurable via deployment values
|
||||
// For now, expect certs in current working directory
|
||||
srv, err := aiproxy.New(ctx, logger, aiproxy.Options{
|
||||
ListenAddr: ":8888",
|
||||
CertFile: filepath.Join(".", "mitm.crt"), // This should be set to mitm-cross-signed.crt if CODER_AI_PROXY_UPSTREAM is set.
|
||||
KeyFile: filepath.Join(".", "mitm.key"),
|
||||
CoderAccessURL: coderAPI.AccessURL.String(),
|
||||
UpstreamProxy: os.Getenv("CODER_AI_PROXY_UPSTREAM"),
|
||||
UpstreamProxyCACert: upstreamCACert,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return srv, nil
|
||||
}
|
||||
@@ -14,9 +14,11 @@ import (
|
||||
"tailscale.com/derp"
|
||||
"tailscale.com/types/key"
|
||||
|
||||
agplcoderd "github.com/coder/coder/v2/coderd"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/cryptorand"
|
||||
"github.com/coder/coder/v2/enterprise/aibridged"
|
||||
"github.com/coder/coder/v2/enterprise/aiproxy"
|
||||
"github.com/coder/coder/v2/enterprise/audit"
|
||||
"github.com/coder/coder/v2/enterprise/audit/backends"
|
||||
"github.com/coder/coder/v2/enterprise/coderd"
|
||||
@@ -27,8 +29,6 @@ import (
|
||||
"github.com/coder/coder/v2/tailnet"
|
||||
"github.com/coder/quartz"
|
||||
"github.com/coder/serpent"
|
||||
|
||||
agplcoderd "github.com/coder/coder/v2/coderd"
|
||||
)
|
||||
|
||||
func (r *RootCmd) Server(_ func()) *serpent.Command {
|
||||
@@ -165,6 +165,18 @@ func (r *RootCmd) Server(_ func()) *serpent.Command {
|
||||
closers.Add(aibridgeDaemon)
|
||||
}
|
||||
|
||||
// In-memory AI proxy
|
||||
var aiProxyServer *aiproxy.Server
|
||||
// TODO: add options.DeploymentValues.AI.BridgeConfig.Enabled
|
||||
aiProxyServer, err = newAIProxy(api)
|
||||
if err != nil {
|
||||
return nil, nil, xerrors.Errorf("create aiproxy: %w", err)
|
||||
}
|
||||
|
||||
closers.Add(aiProxyServer)
|
||||
|
||||
_ = aiProxyServer
|
||||
|
||||
return api.AGPL, closers, nil
|
||||
})
|
||||
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
package coderd
|
||||
|
||||
import (
|
||||
"encoding/pem"
|
||||
"net/http"
|
||||
|
||||
"github.com/elazarl/goproxy"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
// @Summary Get AI proxy CA certificate
|
||||
// @ID get-ai-proxy-ca-cert
|
||||
// @Security CoderSessionToken
|
||||
// @Produce application/x-pem-file
|
||||
// @Tags Enterprise
|
||||
// @Success 200 {file} binary "PEM-encoded CA certificate"
|
||||
// @Router /aiproxy/ca-cert [get]
|
||||
func (api *API) aiproxyCACert(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
ca := goproxy.GoproxyCa
|
||||
if len(ca.Certificate) == 0 {
|
||||
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
|
||||
Message: "AI proxy CA certificate not configured",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
pemBlock := &pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: ca.Certificate[0],
|
||||
}
|
||||
|
||||
rw.Header().Set("Content-Type", "application/x-pem-file")
|
||||
rw.Header().Set("Content-Disposition", "attachment; filename=aiproxy-ca.crt")
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
|
||||
if err := pem.Encode(rw, pemBlock); err != nil {
|
||||
api.Logger.Error(ctx, "failed to encode CA certificate")
|
||||
}
|
||||
}
|
||||
@@ -236,6 +236,11 @@ func New(ctx context.Context, options *Options) (_ *API, err error) {
|
||||
r.Route("/aibridge", aibridgeHandler(api, apiKeyMiddleware))
|
||||
})
|
||||
|
||||
api.AGPL.APIHandler.Group(func(r chi.Router) {
|
||||
r.Use(apiKeyMiddleware)
|
||||
r.Get("/aiproxy/ca-cert", api.aiproxyCACert)
|
||||
})
|
||||
|
||||
api.AGPL.APIHandler.Group(func(r chi.Router) {
|
||||
r.Get("/entitlements", api.serveEntitlements)
|
||||
// /regions overrides the AGPL /regions endpoint
|
||||
|
||||
@@ -482,6 +482,7 @@ require (
|
||||
github.com/coder/preview v1.0.4
|
||||
github.com/danieljoos/wincred v1.2.3
|
||||
github.com/dgraph-io/ristretto/v2 v2.3.0
|
||||
github.com/elazarl/goproxy v1.7.2
|
||||
github.com/fsnotify/fsnotify v1.9.0
|
||||
github.com/go-git/go-git/v5 v5.16.2
|
||||
github.com/icholy/replace v0.6.0
|
||||
@@ -560,3 +561,5 @@ require (
|
||||
gopkg.in/warnings.v0 v0.1.2 // indirect
|
||||
k8s.io/utils v0.0.0-20241210054802-24370beab758 // indirect
|
||||
)
|
||||
|
||||
replace github.com/coder/aibridge => github.com/coder/aibridge v0.3.1-0.20251205180200-daa2e2422f44
|
||||
|
||||
@@ -919,8 +919,8 @@ github.com/cncf/xds/go v0.0.0-20251022180443-0feb69152e9f h1:Y8xYupdHxryycyPlc9Y
|
||||
github.com/cncf/xds/go v0.0.0-20251022180443-0feb69152e9f/go.mod h1:HlzOvOjVBOfTGSRXRyY0OiCS/3J1akRGQQpRO/7zyF4=
|
||||
github.com/coder/agentapi-sdk-go v0.0.0-20250505131810-560d1d88d225 h1:tRIViZ5JRmzdOEo5wUWngaGEFBG8OaE1o2GIHN5ujJ8=
|
||||
github.com/coder/agentapi-sdk-go v0.0.0-20250505131810-560d1d88d225/go.mod h1:rNLVpYgEVeu1Zk29K64z6Od8RBP9DwqCu9OfCzh8MR4=
|
||||
github.com/coder/aibridge v0.3.0 h1:z5coky9A5uXOr+zjgmsynal8PVYBMmxE9u1vcIzs4t8=
|
||||
github.com/coder/aibridge v0.3.0/go.mod h1:ENnl6VhU8Qot5OuVYqs7V4vXII11oKBWgWKrgIJbRAs=
|
||||
github.com/coder/aibridge v0.3.1-0.20251205180200-daa2e2422f44 h1:5zfvl4AyQV6FIyL8PLwWC/fNborVpgaAoOw582LLmxk=
|
||||
github.com/coder/aibridge v0.3.1-0.20251205180200-daa2e2422f44/go.mod h1:ENnl6VhU8Qot5OuVYqs7V4vXII11oKBWgWKrgIJbRAs=
|
||||
github.com/coder/aisdk-go v0.0.9 h1:Vzo/k2qwVGLTR10ESDeP2Ecek1SdPfZlEjtTfMveiVo=
|
||||
github.com/coder/aisdk-go v0.0.9/go.mod h1:KF6/Vkono0FJJOtWtveh5j7yfNrSctVTpwgweYWSp5M=
|
||||
github.com/coder/boundary v1.0.1-0.20250925154134-55a44f2a7945 h1:hDUf02kTX8EGR3+5B+v5KdYvORs4YNfDPci0zCs+pC0=
|
||||
@@ -1041,6 +1041,8 @@ github.com/elastic/go-sysinfo v1.15.1 h1:zBmTnFEXxIQ3iwcQuk7MzaUotmKRp3OabbbWM8T
|
||||
github.com/elastic/go-sysinfo v1.15.1/go.mod h1:jPSuTgXG+dhhh0GKIyI2Cso+w5lPJ5PvVqKlL8LV/Hk=
|
||||
github.com/elastic/go-windows v1.0.0 h1:qLURgZFkkrYyTTkvYpsZIgf83AUsdIHfvlJaqaZ7aSY=
|
||||
github.com/elastic/go-windows v1.0.0/go.mod h1:TsU0Nrp7/y3+VwE82FoZF8gC/XFg/Elz6CcloAxnPgU=
|
||||
github.com/elazarl/goproxy v1.7.2 h1:Y2o6urb7Eule09PjlhQRGNsqRfPmYI3KKQLFpCAV3+o=
|
||||
github.com/elazarl/goproxy v1.7.2/go.mod h1:82vkLNir0ALaW14Rc399OTTjyNREgmdL2cVoIbS6XaE=
|
||||
github.com/emersion/go-sasl v0.0.0-20200509203442-7bfe0ed36a21 h1:OJyUGMJTzHTd1XQp98QTaHernxMYzRaOasRir9hUlFQ=
|
||||
github.com/emersion/go-sasl v0.0.0-20200509203442-7bfe0ed36a21/go.mod h1:iL2twTeMvZnrg54ZoPDNfJaJaqy0xIQFuBdrLsmspwQ=
|
||||
github.com/emersion/go-smtp v0.21.2 h1:OLDgvZKuofk4em9fT5tFG5j4jE1/hXnX75UMvcrL4AA=
|
||||
|
||||
Reference in New Issue
Block a user