Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| ee23cff369 |
@@ -0,0 +1,102 @@
|
||||
package aibridged
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/coder/aibridge"
|
||||
)
|
||||
|
||||
var _ aibridge.Provider = &AmpProvider{}
|
||||
|
||||
const (
|
||||
ProviderAmp = "amp"
|
||||
ampRouteMessages = "/amp/v1/messages" // How aibridge identifies this route
|
||||
)
|
||||
|
||||
type AmpConfig struct {
|
||||
BaseURL string
|
||||
Key string
|
||||
}
|
||||
|
||||
type AmpProvider struct {
|
||||
cfg AmpConfig
|
||||
}
|
||||
|
||||
func NewAmpProvider(cfg AmpConfig) *AmpProvider {
|
||||
if cfg.BaseURL == "" {
|
||||
cfg.BaseURL = "https://ampcode.com/api/provider/anthropic"
|
||||
}
|
||||
return &AmpProvider{cfg: cfg}
|
||||
}
|
||||
|
||||
func (p *AmpProvider) Name() string {
|
||||
return ProviderAmp
|
||||
}
|
||||
|
||||
func (p *AmpProvider) BaseURL() string {
|
||||
return p.cfg.BaseURL
|
||||
}
|
||||
|
||||
// BridgedRoutes returns routes that will be intercepted.
|
||||
func (p *AmpProvider) BridgedRoutes() []string {
|
||||
return []string{ampRouteMessages}
|
||||
}
|
||||
|
||||
// PassthroughRoutes returns routes that are proxied directly.
|
||||
// TODO(ssncferreira): should these include internal routes to amp?
|
||||
func (p *AmpProvider) PassthroughRoutes() []string {
|
||||
return []string{
|
||||
"/v1/models",
|
||||
"/v1/models/",
|
||||
"/v1/messages/count_tokens",
|
||||
}
|
||||
}
|
||||
|
||||
func (p *AmpProvider) AuthHeader() string {
|
||||
return "X-Api-Key"
|
||||
}
|
||||
|
||||
func (p *AmpProvider) InjectAuthHeader(h *http.Header) {
|
||||
if h == nil || p.cfg.Key == "" {
|
||||
return
|
||||
}
|
||||
h.Set(p.AuthHeader(), p.cfg.Key)
|
||||
}
|
||||
|
||||
// CreateInterceptor creates an interceptor for the request.
|
||||
// Reuses Anthropic's interceptor since Amp uses the same API format.
|
||||
func (p *AmpProvider) CreateInterceptor(w http.ResponseWriter, r *http.Request) (aibridge.Interceptor, error) {
|
||||
payload, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read body: %w", err)
|
||||
}
|
||||
|
||||
id := uuid.New()
|
||||
|
||||
switch r.URL.Path {
|
||||
case ampRouteMessages:
|
||||
var req aibridge.MessageNewParamsWrapper
|
||||
if err := json.Unmarshal(payload, &req); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal request: %w", err)
|
||||
}
|
||||
|
||||
// Reuse Anthropic interceptors as Amp uses the same API format
|
||||
ampCfg := aibridge.AnthropicConfig{
|
||||
BaseURL: p.cfg.BaseURL,
|
||||
Key: p.cfg.Key,
|
||||
}
|
||||
|
||||
if req.Stream {
|
||||
return aibridge.NewAnthropicMessagesStreamingInterception(id, &req, ampCfg, nil), nil
|
||||
}
|
||||
|
||||
return aibridge.NewAnthropicMessagesBlockingInterception(id, &req, ampCfg, nil), nil
|
||||
}
|
||||
|
||||
return nil, aibridge.UnknownRoute
|
||||
}
|
||||
@@ -0,0 +1,212 @@
|
||||
package aiproxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/AdguardTeam/gomitmproxy"
|
||||
"github.com/AdguardTeam/gomitmproxy/mitm"
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
ctx context.Context
|
||||
logger slog.Logger
|
||||
proxy *gomitmproxy.Proxy
|
||||
coderAccessURL string
|
||||
}
|
||||
|
||||
type Options struct {
|
||||
ListenAddr string
|
||||
CertFile string
|
||||
KeyFile string
|
||||
CoderAccessURL string
|
||||
}
|
||||
|
||||
func New(ctx context.Context, logger slog.Logger, opts Options) (*Server, error) {
|
||||
logger.Info(ctx, "initializing AI proxy server")
|
||||
|
||||
mitmConfig, err := createMitmConfig(opts.CertFile, opts.KeyFile)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("failed to create TLS proxy config: %w", err)
|
||||
}
|
||||
|
||||
addr, err := net.ResolveTCPAddr("tcp", opts.ListenAddr)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("listen-address invalid: %w", err)
|
||||
}
|
||||
|
||||
srv := &Server{
|
||||
ctx: ctx,
|
||||
logger: logger,
|
||||
coderAccessURL: opts.CoderAccessURL,
|
||||
}
|
||||
|
||||
logger.Info(ctx, "starting AI proxy", slog.F("addr", addr.String()))
|
||||
proxy := gomitmproxy.NewProxy(gomitmproxy.Config{
|
||||
ListenAddr: addr,
|
||||
OnRequest: srv.requestHandler,
|
||||
OnResponse: srv.responseHandler,
|
||||
MITMConfig: mitmConfig,
|
||||
})
|
||||
|
||||
if err := proxy.Start(); err != nil {
|
||||
return nil, xerrors.Errorf("failed to start proxy: %w", err)
|
||||
}
|
||||
|
||||
srv.proxy = proxy
|
||||
return srv, nil
|
||||
}
|
||||
|
||||
// createMitmConfig creates the TLS MITM configuration.
|
||||
func createMitmConfig(certFile, keyFile string) (*mitm.Config, error) {
|
||||
tlsCert, err := tls.LoadX509KeyPair(certFile, keyFile)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("failed to read x509 keypair: %w", err)
|
||||
}
|
||||
privateKey := tlsCert.PrivateKey.(*rsa.PrivateKey)
|
||||
|
||||
x509c, err := x509.ParseCertificate(tlsCert.Certificate[0])
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("failed to parse cert: %w", err)
|
||||
}
|
||||
|
||||
mitmConfig, err := mitm.NewConfig(x509c, privateKey, nil)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("failed to create MITM config: %w", err)
|
||||
}
|
||||
|
||||
mitmConfig.SetValidity(time.Hour * 24 * 365) // 1 year validity
|
||||
mitmConfig.SetOrganization("coder aiproxy")
|
||||
return mitmConfig, nil
|
||||
}
|
||||
|
||||
// providerFromHost maps the request host to the aibridge provider name.
|
||||
// All requests through the proxy for known AI providers are routed through aibridge.
|
||||
// Unknown hosts return empty string and are passed through directly without aibridge
|
||||
// (TODO (ssncferreira): is this correct?)
|
||||
func providerFromHost(host string) string {
|
||||
switch {
|
||||
case strings.Contains(host, "anthropic.com"):
|
||||
return "anthropic"
|
||||
case strings.Contains(host, "openai.com"):
|
||||
return "openai"
|
||||
case strings.Contains(host, "ampcode.com"):
|
||||
return "amp"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// mapPathForProvider converts the original request path to the aibridge path format.
|
||||
// Returns empty string if the path should not be routed through aibridge.
|
||||
func mapPathForProvider(provider, originalPath string) string {
|
||||
switch provider {
|
||||
case "amp":
|
||||
// Only intercept AI provider routes
|
||||
// Original: /api/provider/anthropic/v1/messages
|
||||
// aibridge expects: /amp/v1/messages
|
||||
const ampPrefix = "/api/provider/anthropic"
|
||||
if strings.HasPrefix(originalPath, ampPrefix) {
|
||||
return "/amp" + strings.TrimPrefix(originalPath, ampPrefix)
|
||||
}
|
||||
// Other Amp routes (e.g., /api/internal) should not go through aibridge
|
||||
// TODO(ssncferreira): should these routes be added to the provider's PassthroughRoutes?
|
||||
return ""
|
||||
case "anthropic":
|
||||
return "/anthropic" + originalPath
|
||||
case "openai":
|
||||
return "/openai" + originalPath
|
||||
default:
|
||||
return "/" + provider + originalPath
|
||||
}
|
||||
}
|
||||
|
||||
// requestHandler handles incoming requests.
|
||||
func (srv *Server) requestHandler(session *gomitmproxy.Session) (*http.Request, *http.Response) {
|
||||
req := session.Request()
|
||||
|
||||
if req.Method == http.MethodConnect {
|
||||
return req, nil
|
||||
}
|
||||
|
||||
srv.logger.Info(srv.ctx, "received request",
|
||||
slog.F("url", req.URL.String()),
|
||||
slog.F("method", req.Method),
|
||||
slog.F("host", req.Host),
|
||||
)
|
||||
|
||||
// Check if this request is for a supported AI provider.
|
||||
// All requests through the proxy for known AI providers are routed through aibridge.
|
||||
// Unknown hosts return empty string and are passed through directly without aibridge.
|
||||
provider := providerFromHost(req.Host)
|
||||
if provider == "" {
|
||||
srv.logger.Info(srv.ctx, "unknown provider, passthrough",
|
||||
slog.F("host", req.Host),
|
||||
)
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// Map the original path to the aibridge path
|
||||
originalPath := req.URL.Path
|
||||
aibridgePath := mapPathForProvider(provider, originalPath)
|
||||
|
||||
// If path doesn't map to an aibridge route, pass through directly
|
||||
if aibridgePath == "" {
|
||||
srv.logger.Info(srv.ctx, "path not handled by aibridge, passthrough",
|
||||
slog.F("host", req.Host),
|
||||
slog.F("path", originalPath),
|
||||
)
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// Rewrite URL to point to aibridged
|
||||
newURL, err := url.Parse(srv.coderAccessURL)
|
||||
if err != nil {
|
||||
srv.logger.Error(srv.ctx, "failed to parse coder access URL", slog.Error(err))
|
||||
return req, nil
|
||||
}
|
||||
|
||||
newURL.Path = "/api/v2/aibridge" + aibridgePath
|
||||
newURL.RawQuery = req.URL.RawQuery
|
||||
|
||||
srv.logger.Info(srv.ctx, "rewriting request to aibridged",
|
||||
slog.F("original_url", req.URL.String()),
|
||||
slog.F("new_url", newURL.String()),
|
||||
slog.F("provider", provider),
|
||||
)
|
||||
|
||||
req.URL = newURL
|
||||
req.Host = newURL.Host
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// responseHandler handles responses from upstream.
|
||||
// For now, just passes through.
|
||||
func (srv *Server) responseHandler(session *gomitmproxy.Session) *http.Response {
|
||||
req := session.Request()
|
||||
srv.logger.Info(srv.ctx, "received response",
|
||||
slog.F("url", req.URL.String()),
|
||||
slog.F("method", req.Method),
|
||||
)
|
||||
|
||||
return session.Response()
|
||||
}
|
||||
|
||||
// Close shuts down the proxy server.
|
||||
func (srv *Server) Close() error {
|
||||
if srv.proxy != nil {
|
||||
srv.proxy.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -4,6 +4,7 @@ package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
@@ -17,7 +18,7 @@ import (
|
||||
|
||||
func newAIBridgeDaemon(coderAPI *coderd.API) (*aibridged.Server, error) {
|
||||
ctx := context.Background()
|
||||
coderAPI.Logger.Debug(ctx, "starting in-memory aibridge daemon")
|
||||
coderAPI.Logger.Info(ctx, "starting in-memory aibridge daemon")
|
||||
|
||||
logger := coderAPI.Logger.Named("aibridged")
|
||||
|
||||
@@ -31,6 +32,11 @@ func newAIBridgeDaemon(coderAPI *coderd.API) (*aibridged.Server, error) {
|
||||
BaseURL: coderAPI.DeploymentValues.AI.BridgeConfig.Anthropic.BaseURL.String(),
|
||||
Key: coderAPI.DeploymentValues.AI.BridgeConfig.Anthropic.Key.String(),
|
||||
}, getBedrockConfig(coderAPI.DeploymentValues.AI.BridgeConfig.Bedrock)),
|
||||
// TODO(ssncferreira): add provider to aibridge project
|
||||
aibridged.NewAmpProvider(aibridged.AmpConfig{
|
||||
BaseURL: "https://ampcode.com/api/provider/anthropic",
|
||||
Key: os.Getenv("AMP_API_KEY"), // TODO: add via deployment values
|
||||
}),
|
||||
}
|
||||
|
||||
reg := prometheus.WrapRegistererWithPrefix("coder_aibridged_", coderAPI.PrometheusRegistry)
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
//go:build !slim
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/coder/coder/v2/enterprise/aiproxy"
|
||||
"github.com/coder/coder/v2/enterprise/coderd"
|
||||
)
|
||||
|
||||
func newAIProxy(coderAPI *coderd.API) (*aiproxy.Server, error) {
|
||||
ctx := context.Background()
|
||||
coderAPI.Logger.Debug(ctx, "starting in-memory AI proxy")
|
||||
|
||||
logger := coderAPI.Logger.Named("aiproxy")
|
||||
|
||||
// TODO: Make these configurable via deployment values
|
||||
// For now, expect certs in current working directory
|
||||
srv, err := aiproxy.New(ctx, logger, aiproxy.Options{
|
||||
ListenAddr: "127.0.0.1:8888",
|
||||
CertFile: filepath.Join(".", "mitm.crt"),
|
||||
KeyFile: filepath.Join(".", "mitm.key"),
|
||||
CoderAccessURL: coderAPI.AccessURL.String(),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return srv, nil
|
||||
}
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"io"
|
||||
"net/url"
|
||||
|
||||
"github.com/coder/coder/v2/enterprise/aiproxy"
|
||||
"golang.org/x/xerrors"
|
||||
"tailscale.com/derp"
|
||||
"tailscale.com/types/key"
|
||||
@@ -165,6 +166,18 @@ func (r *RootCmd) Server(_ func()) *serpent.Command {
|
||||
closers.Add(aibridgeDaemon)
|
||||
}
|
||||
|
||||
// In-memory AI proxy
|
||||
var aiProxyServer *aiproxy.Server
|
||||
// TODO: add options.DeploymentValues.AI.BridgeConfig.Enabled
|
||||
aiProxyServer, err = newAIProxy(api)
|
||||
if err != nil {
|
||||
return nil, nil, xerrors.Errorf("create aiproxy: %w", err)
|
||||
}
|
||||
|
||||
closers.Add(aiProxyServer)
|
||||
|
||||
_ = aiProxyServer
|
||||
|
||||
return api.AGPL, closers, nil
|
||||
})
|
||||
|
||||
|
||||
@@ -473,6 +473,7 @@ require (
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/AdguardTeam/gomitmproxy v0.2.1
|
||||
github.com/anthropics/anthropic-sdk-go v1.19.0
|
||||
github.com/brianvoe/gofakeit/v7 v7.9.0
|
||||
github.com/coder/agentapi-sdk-go v0.0.0-20250505131810-560d1d88d225
|
||||
@@ -496,6 +497,7 @@ require (
|
||||
cloud.google.com/go/monitoring v1.24.2 // indirect
|
||||
cloud.google.com/go/storage v1.55.0 // indirect
|
||||
git.sr.ht/~jackmordaunt/go-toast v1.1.2 // indirect
|
||||
github.com/AdguardTeam/golibs v0.4.0 // indirect
|
||||
github.com/DataDog/datadog-agent/comp/core/tagger/origindetection v0.64.2 // indirect
|
||||
github.com/DataDog/datadog-agent/pkg/version v0.64.2 // indirect
|
||||
github.com/DataDog/dd-trace-go/v2 v2.0.0 // indirect
|
||||
|
||||
@@ -624,6 +624,10 @@ gioui.org v0.0.0-20210308172011-57750fc8a0a6/go.mod h1:RSH6KIUZ0p2xy5zHDxgAM4zum
|
||||
git.sr.ht/~jackmordaunt/go-toast v1.1.2 h1:/yrfI55LRt1M7H1vkaw+NaH1+L1CDxrqDltwm5euVuE=
|
||||
git.sr.ht/~jackmordaunt/go-toast v1.1.2/go.mod h1:jA4OqHKTQ4AFBdwrSnwnskUIIS3HYzlJSgdzCKqfavo=
|
||||
git.sr.ht/~sbinet/gg v0.3.1/go.mod h1:KGYtlADtqsqANL9ueOFkWymvzUvLMQllU5Ixo+8v3pc=
|
||||
github.com/AdguardTeam/golibs v0.4.0 h1:4VX6LoOqFe9p9Gf55BeD8BvJD6M6RDYmgEiHrENE9KU=
|
||||
github.com/AdguardTeam/golibs v0.4.0/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4=
|
||||
github.com/AdguardTeam/gomitmproxy v0.2.1 h1:p9gr8Er1TYvf+7ic81Ax1sZ62UNCsMTZNbm7tC59S9o=
|
||||
github.com/AdguardTeam/gomitmproxy v0.2.1/go.mod h1:Qdv0Mktnzer5zpdpi5rAwixNJzW2FN91LjKJCkVbYGU=
|
||||
github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c h1:udKWzYgxTojEKWjV8V+WSxDXJ4NFATAsZjh8iIbsQIg=
|
||||
github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E=
|
||||
github.com/BurntSushi/locker v0.0.0-20171006230638-a6e239ea1c69 h1:+tu3HOoMXB7RXEINRVIpxJCT+KdYiI7LAEAUrOw3dIU=
|
||||
@@ -1636,6 +1640,7 @@ github.com/natefinch/atomic v1.0.1 h1:ZPYKxkqQOx3KZ+RsbnP/YsgvxWQPGxjC0oBt2AhwV0
|
||||
github.com/natefinch/atomic v1.0.1/go.mod h1:N/D/ELrljoqDyT3rZrsUmtsuzvHkeB/wWjHV22AZRbM=
|
||||
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ=
|
||||
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8=
|
||||
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno=
|
||||
github.com/niklasfasching/go-org v1.9.1 h1:/3s4uTPOF06pImGa2Yvlp24yKXZoTYM+nsIlMzfpg/0=
|
||||
github.com/niklasfasching/go-org v1.9.1/go.mod h1:ZAGFFkWvUQcpazmi/8nHqwvARpr1xpb+Es67oUGX/48=
|
||||
github.com/oasdiff/yaml v0.0.0-20250309154309-f31be36b4037 h1:G7ERwszslrBzRxj//JalHPu/3yz+De2J+4aLtSRlHiY=
|
||||
@@ -2763,6 +2768,7 @@ gopkg.in/DataDog/dd-trace-go.v1 v1.74.0 h1:wScziU1ff6Bnyr8MEyxATPSLJdnLxKz3p6RsA
|
||||
gopkg.in/DataDog/dd-trace-go.v1 v1.74.0/go.mod h1:ReNBsNfnsjVC7GsCe80zRcykL/n+nxvsNrg3NbjuleM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
gopkg.in/cheggaaa/pb.v1 v1.0.27/go.mod h1:V/YB90LKu/1FcN3WVnfiiE5oMCibMjukxqG/qStrOgw=
|
||||
|
||||
Reference in New Issue
Block a user