fix(codersdk): propagate HTTPClient to websocket.Dial for TLS relay (#22642)
## Problem In multi-replica Coder deployments, the chat relay WebSocket between replicas fails with HTTP 401 (or TLS handshake errors). The subscriber replica cannot relay `message_part` events from the worker replica. **Root cause:** `codersdk.Client.Dial()` does not pass `c.HTTPClient` to `websocket.DialOptions.HTTPClient`. The websocket library (`github.com/coder/websocket`) falls back to `http.DefaultClient`, which lacks the mesh TLS configuration needed for inter-replica communication. The relay code in `enterprise/coderd/chatd/chatd.go` correctly sets `sdkClient.HTTPClient = cfg.ReplicaHTTPClient` (which has mesh TLS certs), but that client was never used for the actual WebSocket handshake. ## Fix One-line fix in `codersdk/client.go`: propagate `c.HTTPClient` to `opts.HTTPClient` when the caller hasn't already set one. ## Test Added `TestChatStreamRelay/RelayWithTLSAndCookieAuth` which: - Sets up two replicas with TLS certificates (simulating mesh TLS in production) - Authenticates via cookies (simulating browser WebSocket behavior) - Verifies message_part events relay across replicas over TLS This test times out without the fix because the WebSocket handshake fails with `x509: certificate signed by unknown authority` (http.DefaultClient rejects self-signed certs). ## Related Follow-up to #22635 which fixed the `redirectToAccessURL` middleware bypassing 307 redirects for relay requests. That fix changed the error from HTTP 200 to HTTP 401, exposing this deeper issue.
This commit is contained in:
@@ -368,6 +368,13 @@ func (c *Client) Dial(ctx context.Context, path string, opts *websocket.DialOpti
|
||||
if opts == nil {
|
||||
opts = &websocket.DialOptions{}
|
||||
}
|
||||
// Propagate the client's HTTP client to the websocket dialer
|
||||
// so that custom TLS configurations (e.g. mesh TLS between
|
||||
// replicas) are used for the handshake request. Without this,
|
||||
// the websocket library falls back to http.DefaultClient.
|
||||
if opts.HTTPClient == nil {
|
||||
opts.HTTPClient = c.HTTPClient
|
||||
}
|
||||
c.SessionTokenProvider.SetDialOption(opts)
|
||||
|
||||
conn, resp, err := websocket.Dial(ctx, u.String(), opts)
|
||||
|
||||
@@ -2,6 +2,9 @@ package coderd_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net/http"
|
||||
"net/http/cookiejar"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
@@ -166,6 +169,204 @@ func TestChatStreamRelay(t *testing.T) {
|
||||
close(streamingChunks)
|
||||
})
|
||||
|
||||
// This test verifies that the relay WebSocket dial works when replicas
|
||||
// use TLS (mesh certificates) and the original request authenticates
|
||||
// via cookies only (as browsers do for WebSocket upgrades, since
|
||||
// browsers cannot set custom headers on WebSocket connections).
|
||||
//
|
||||
// The bug: codersdk.Client.Dial() does not propagate c.HTTPClient to
|
||||
// websocket.DialOptions.HTTPClient, so the websocket library falls
|
||||
// back to http.DefaultClient. With TLS between replicas,
|
||||
// http.DefaultClient lacks the required TLS config, causing a 401
|
||||
// (or TLS handshake failure) when the relay subscriber replica
|
||||
// dials the worker replica.
|
||||
t.Run("RelayWithTLSAndCookieAuth", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
certificates := []tls.Certificate{testutil.GenerateTLSCertificate(t, "localhost")}
|
||||
db, pubsub := dbtestutil.NewDB(t)
|
||||
firstClient, _ := coderdenttest.New(t, &coderdenttest.Options{
|
||||
Options: &coderdtest.Options{
|
||||
Database: db,
|
||||
Pubsub: pubsub,
|
||||
TLSCertificates: certificates,
|
||||
},
|
||||
LicenseOptions: &coderdenttest.LicenseOptions{
|
||||
Features: license.Features{
|
||||
codersdk.FeatureHighAvailability: 1,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
secondClient, _ := coderdenttest.New(t, &coderdenttest.Options{
|
||||
Options: &coderdtest.Options{
|
||||
Database: db,
|
||||
Pubsub: pubsub,
|
||||
TLSCertificates: certificates,
|
||||
},
|
||||
DontAddLicense: true,
|
||||
DontAddFirstUser: true,
|
||||
})
|
||||
|
||||
// Authenticate the second client using cookies only, simulating
|
||||
// browser WebSocket behavior. Browsers cannot set custom
|
||||
// headers (like Coder-Session-Token) on WebSocket upgrades;
|
||||
// they rely on cookies for authentication.
|
||||
//
|
||||
// We intentionally do NOT call secondClient.SetSessionToken()
|
||||
// because that would set the Coder-Session-Token header,
|
||||
// which masks the bug.
|
||||
//nolint:gocritic // Test uses owner client session token for cookie-based auth.
|
||||
sessionToken := firstClient.SessionToken()
|
||||
// Set session token via cookie on the second client's HTTP
|
||||
// jar so that HTTP requests authenticate, but the WebSocket
|
||||
// relay between replicas only gets cookie-based auth forwarded.
|
||||
cookieJar := secondClient.HTTPClient.Jar
|
||||
if cookieJar == nil {
|
||||
var jarErr error
|
||||
cookieJar, jarErr = cookiejar.New(nil)
|
||||
require.NoError(t, jarErr)
|
||||
secondClient.HTTPClient.Jar = cookieJar
|
||||
}
|
||||
cookieJar.SetCookies(secondClient.URL, []*http.Cookie{{
|
||||
Name: codersdk.SessionTokenCookie,
|
||||
Value: sessionToken,
|
||||
}})
|
||||
|
||||
// Also set the session token header so regular API calls work
|
||||
// (e.g. Replicas(), CreateChatProvider()). The relay code
|
||||
// extracts credentials from the original request's headers,
|
||||
// which includes Cookie but the Coder-Session-Token header
|
||||
// won't be present on browser WebSocket requests.
|
||||
secondClient.SetSessionToken(sessionToken)
|
||||
|
||||
// Verify we have two replicas.
|
||||
replicas, err := secondClient.Replicas(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, replicas, 2)
|
||||
firstReplicaID := replicaIDForClientURL(t, firstClient.URL, replicas)
|
||||
secondReplicaID := replicaIDForClientURL(t, secondClient.URL, replicas)
|
||||
|
||||
streamingChunks := make(chan chattest.OpenAIChunk, 8)
|
||||
chatStreamStarted := make(chan struct{}, 1)
|
||||
openai := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
||||
if req.Stream {
|
||||
select {
|
||||
case chatStreamStarted <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
return chattest.OpenAIResponse{StreamingChunks: streamingChunks}
|
||||
}
|
||||
return chattest.OpenAINonStreamingResponse("ok")
|
||||
})
|
||||
|
||||
//nolint:gocritic // Test uses owner client to configure chat providers.
|
||||
provider, err := firstClient.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
|
||||
Provider: "openai",
|
||||
DisplayName: "OpenAI",
|
||||
APIKey: "test",
|
||||
BaseURL: openai,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
model, err := firstClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
|
||||
Provider: provider.Provider,
|
||||
Model: "gpt-4",
|
||||
DisplayName: "GPT-4",
|
||||
ContextLimit: &[]int64{1000}[0],
|
||||
CompressionThreshold: &[]int32{70}[0],
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a chat on the first replica.
|
||||
chat, err := firstClient.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
Content: []codersdk.ChatInputPart{{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
Text: "Test chat for TLS relay",
|
||||
}},
|
||||
ModelConfigID: &model.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, codersdk.ChatStatusPending, chat.Status)
|
||||
|
||||
var runningChat database.Chat
|
||||
require.Eventually(t, func() bool {
|
||||
current, getErr := db.GetChatByID(ctx, chat.ID)
|
||||
if getErr != nil {
|
||||
return false
|
||||
}
|
||||
if current.Status != database.ChatStatusRunning || !current.WorkerID.Valid {
|
||||
return false
|
||||
}
|
||||
runningChat = current
|
||||
return true
|
||||
}, testutil.WaitLong, testutil.IntervalFast)
|
||||
|
||||
var localClient *codersdk.Client
|
||||
var relayClient *codersdk.Client
|
||||
switch runningChat.WorkerID.UUID {
|
||||
case firstReplicaID:
|
||||
localClient = firstClient
|
||||
relayClient = secondClient
|
||||
case secondReplicaID:
|
||||
localClient = secondClient
|
||||
relayClient = firstClient
|
||||
default:
|
||||
require.FailNowf(
|
||||
t,
|
||||
"worker replica was not recognized",
|
||||
"worker %s was not one of %s or %s",
|
||||
runningChat.WorkerID.UUID,
|
||||
firstReplicaID,
|
||||
secondReplicaID,
|
||||
)
|
||||
}
|
||||
|
||||
// Subscribe on the worker replica to start the stream.
|
||||
firstEvents, firstStream, err := localClient.StreamChat(ctx, chat.ID, nil)
|
||||
require.NoError(t, err)
|
||||
defer firstStream.Close()
|
||||
|
||||
select {
|
||||
case <-chatStreamStarted:
|
||||
case <-ctx.Done():
|
||||
require.FailNowf(
|
||||
t,
|
||||
"timed out waiting for OpenAI stream request",
|
||||
"chat stream request did not start before context deadline: %v",
|
||||
ctx.Err(),
|
||||
)
|
||||
}
|
||||
|
||||
// Send a chunk on the worker.
|
||||
firstChunkText := "tls-relay-part-one"
|
||||
streamingChunks <- chattest.OpenAITextChunks(firstChunkText)[0]
|
||||
firstEvent := waitForStreamTextPart(ctx, t, firstEvents, firstChunkText)
|
||||
require.Equal(t, "assistant", firstEvent.MessagePart.Role)
|
||||
|
||||
// Subscribe from the non-worker replica. This triggers the
|
||||
// relay dial to the worker over TLS. With the bug, this
|
||||
// fails because Dial() does not propagate HTTPClient (with
|
||||
// the TLS config) to the websocket library.
|
||||
secondEvents, secondStream, err := relayClient.StreamChat(ctx, chat.ID, nil)
|
||||
require.NoError(t, err)
|
||||
defer secondStream.Close()
|
||||
|
||||
// The relay should deliver the already-sent chunk as a
|
||||
// snapshot event.
|
||||
secondSnapshotEvent := waitForStreamTextPart(ctx, t, secondEvents, firstChunkText)
|
||||
require.Equal(t, "assistant", secondSnapshotEvent.MessagePart.Role)
|
||||
|
||||
// Send another chunk and verify it flows through the relay.
|
||||
secondChunkText := "tls-relay-part-two"
|
||||
streamingChunks <- chattest.OpenAITextChunks(secondChunkText)[0]
|
||||
waitForStreamTextPart(ctx, t, firstEvents, secondChunkText)
|
||||
waitForStreamTextPart(ctx, t, secondEvents, secondChunkText)
|
||||
|
||||
close(streamingChunks)
|
||||
})
|
||||
|
||||
t.Run("RelaySnapshotIncludesBufferedParts", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
Reference in New Issue
Block a user