feat(coderd/x/chatd/chatdebug): add recorder, transport, and redaction

Change-Id: Ibbc67a85ba78201c0778ccb5c8675b15e90b1cdf
Signed-off-by: Thomas Kosiewski <tk@coder.com>
This commit is contained in:
Thomas Kosiewski
2026-04-08 22:31:29 +00:00
parent a47f4fec56
commit f7e0d8de31
10 changed files with 2213 additions and 183 deletions
@@ -387,9 +387,9 @@ func TestDebugModel_StreamCompletedAfterFinish(t *testing.T) {
// The handle starts with zero status; simulate what the wrapper does
// when consumer breaks after finish.
h.finish(ctx2, StatusCompleted, nil, nil, nil, nil)
h.mu.Lock()
// finish uses sync.Once, so after it returns the status is safely
// set and readable in this single-goroutine test.
require.Equal(t, StatusCompleted, h.status)
h.mu.Unlock()
}
// TestDebugModel_StreamInterruptedBeforeFinish verifies that when a consumer
+277
View File
@@ -0,0 +1,277 @@
package chatdebug
import (
"context"
"sync"
"sync/atomic"
"time"
"charm.land/fantasy"
"github.com/google/uuid"
"cdr.dev/slog/v3"
)
// RecorderOptions identifies the chat/model context for debug recording.
type RecorderOptions struct {
ChatID uuid.UUID
OwnerID uuid.UUID
Provider string
Model string
}
// WrapModel returns model unchanged when debug recording is disabled, or a
// debug wrapper when a service is available.
func WrapModel(
model fantasy.LanguageModel,
svc *Service,
opts RecorderOptions,
) fantasy.LanguageModel {
if model == nil {
panic("chatdebug: nil LanguageModel")
}
if svc == nil {
return model
}
return &debugModel{inner: model, svc: svc, opts: opts}
}
type attemptSink struct {
mu sync.Mutex
attempts []Attempt
attemptCounter atomic.Int32
}
func (s *attemptSink) nextAttemptNumber() int {
if s == nil {
panic("chatdebug: nil attemptSink")
}
return int(s.attemptCounter.Add(1))
}
func (s *attemptSink) record(a Attempt) {
s.mu.Lock()
defer s.mu.Unlock()
s.attempts = append(s.attempts, a)
}
func (s *attemptSink) snapshot() []Attempt {
s.mu.Lock()
defer s.mu.Unlock()
attempts := make([]Attempt, len(s.attempts))
copy(attempts, s.attempts)
return attempts
}
type attemptSinkKey struct{}
func withAttemptSink(ctx context.Context, sink *attemptSink) context.Context {
if sink == nil {
panic("chatdebug: nil attemptSink")
}
return context.WithValue(ctx, attemptSinkKey{}, sink)
}
func attemptSinkFromContext(ctx context.Context) *attemptSink {
sink, _ := ctx.Value(attemptSinkKey{}).(*attemptSink)
return sink
}
var stepCounters sync.Map // map[uuid.UUID]*atomic.Int32
func nextStepNumber(runID uuid.UUID) int32 {
val, _ := stepCounters.LoadOrStore(runID, &atomic.Int32{})
counter, ok := val.(*atomic.Int32)
if !ok {
panic("chatdebug: invalid step counter type")
}
return counter.Add(1)
}
// CleanupStepCounter removes per-run step counter and reference count
// state. This is used by tests and later stacked branches that have a
// real run lifecycle.
func CleanupStepCounter(runID uuid.UUID) {
stepCounters.Delete(runID)
runRefCounts.Delete(runID)
}
const stepFinalizeTimeout = 5 * time.Second
func stepFinalizeContext(ctx context.Context) (context.Context, context.CancelFunc) {
if ctx == nil {
panic("chatdebug: nil context")
}
return context.WithTimeout(context.WithoutCancel(ctx), stepFinalizeTimeout)
}
func syncStepCounter(runID uuid.UUID, stepNumber int32) {
val, _ := stepCounters.LoadOrStore(runID, &atomic.Int32{})
counter, ok := val.(*atomic.Int32)
if !ok {
panic("chatdebug: invalid step counter type")
}
for {
current := counter.Load()
if current >= stepNumber {
return
}
if counter.CompareAndSwap(current, stepNumber) {
return
}
}
}
type stepHandle struct {
stepCtx *StepContext
sink *attemptSink
svc *Service
opts RecorderOptions
once sync.Once
mu sync.Mutex
status Status
response any
usage any
err any
metadata any
}
// beginStep validates preconditions, creates a debug step, and returns a
// handle plus an enriched context carrying StepContext and attemptSink.
// Returns (nil, original ctx) when debug recording should be skipped.
func beginStep(
ctx context.Context,
svc *Service,
opts RecorderOptions,
op Operation,
normalizedReq any,
) (*stepHandle, context.Context) {
if svc == nil {
return nil, ctx
}
rc, ok := RunFromContext(ctx)
if !ok || rc.RunID == uuid.Nil {
return nil, ctx
}
chatID := opts.ChatID
if chatID == uuid.Nil {
chatID = rc.ChatID
}
if !svc.IsEnabled(ctx, chatID, opts.OwnerID) {
return nil, ctx
}
holder, reuseStep := reuseHolderFromContext(ctx)
if reuseStep {
holder.mu.Lock()
defer holder.mu.Unlock()
// Only reuse the cached handle if it belongs to the same run.
// A different RunContext means a new logical run, so we must
// create a fresh step to avoid cross-run attribution.
if holder.handle != nil && holder.handle.stepCtx.RunID == rc.RunID {
enriched := ContextWithStep(ctx, holder.handle.stepCtx)
enriched = withAttemptSink(enriched, holder.handle.sink)
return holder.handle, enriched
}
}
stepNum := nextStepNumber(rc.RunID)
step, err := svc.CreateStep(ctx, CreateStepParams{
RunID: rc.RunID,
ChatID: chatID,
StepNumber: stepNum,
Operation: op,
Status: StatusInProgress,
HistoryTipMessageID: rc.HistoryTipMessageID,
NormalizedRequest: normalizedReq,
})
if err != nil {
svc.log.Warn(ctx, "failed to create chat debug step",
slog.Error(err),
slog.F("chat_id", chatID),
slog.F("run_id", rc.RunID),
slog.F("operation", op),
)
return nil, ctx
}
syncStepCounter(rc.RunID, step.StepNumber)
actualStepNumber := step.StepNumber
if actualStepNumber == 0 {
actualStepNumber = stepNum
}
sc := &StepContext{
StepID: step.ID,
RunID: rc.RunID,
ChatID: chatID,
StepNumber: actualStepNumber,
Operation: op,
HistoryTipMessageID: rc.HistoryTipMessageID,
}
handle := &stepHandle{stepCtx: sc, sink: &attemptSink{}, svc: svc, opts: opts}
enriched := ContextWithStep(ctx, handle.stepCtx)
enriched = withAttemptSink(enriched, handle.sink)
if reuseStep {
holder.handle = handle
}
return handle, enriched
}
// finish updates the debug step with final status and data.
// sync.Once prevents data races when concurrent callers (e.g.
// retried stream wrappers sharing a reuse handle) both attempt
// to finalize the same step. Only the first finish call takes
// effect.
func (h *stepHandle) finish(
ctx context.Context,
status Status,
response any,
usage any,
errPayload any,
metadata any,
) {
if h == nil || h.stepCtx == nil {
return
}
h.once.Do(func() {
h.mu.Lock()
h.status = status
h.response = response
h.usage = usage
h.err = errPayload
h.metadata = metadata
h.mu.Unlock()
if h.svc == nil {
return
}
updateCtx, cancel := stepFinalizeContext(ctx)
defer cancel()
if _, updateErr := h.svc.UpdateStep(updateCtx, UpdateStepParams{
ID: h.stepCtx.StepID,
ChatID: h.stepCtx.ChatID,
Status: status,
NormalizedResponse: response,
Usage: usage,
Attempts: h.sink.snapshot(),
Error: errPayload,
Metadata: metadata,
FinishedAt: time.Now(),
}); updateErr != nil {
h.svc.log.Warn(updateCtx, "failed to finalize chat debug step",
slog.Error(updateErr),
slog.F("step_id", h.stepCtx.StepID),
slog.F("chat_id", h.stepCtx.ChatID),
slog.F("status", status),
)
}
})
}
+174
View File
@@ -0,0 +1,174 @@
package chatdebug //nolint:testpackage // Uses unexported recorder helpers.
import (
"context"
"sort"
"sync"
"testing"
"charm.land/fantasy"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/x/chatd/chattest"
)
func TestAttemptSink_ThreadSafe(t *testing.T) {
t.Parallel()
const n = 256
sink := &attemptSink{}
var wg sync.WaitGroup
wg.Add(n)
for i := range n {
go func() {
defer wg.Done()
sink.record(Attempt{Number: i + 1, ResponseStatus: 200 + i})
}()
}
wg.Wait()
attempts := sink.snapshot()
require.Len(t, attempts, n)
numbers := make([]int, 0, n)
statuses := make([]int, 0, n)
for _, attempt := range attempts {
numbers = append(numbers, attempt.Number)
statuses = append(statuses, attempt.ResponseStatus)
}
sort.Ints(numbers)
sort.Ints(statuses)
for i := range n {
require.Equal(t, i+1, numbers[i])
require.Equal(t, 200+i, statuses[i])
}
}
func TestAttemptSinkContext(t *testing.T) {
t.Parallel()
ctx := context.Background()
require.Nil(t, attemptSinkFromContext(ctx))
sink := &attemptSink{}
ctx = withAttemptSink(ctx, sink)
require.Same(t, sink, attemptSinkFromContext(ctx))
}
func TestWrapModel_NilModel(t *testing.T) {
t.Parallel()
require.Panics(t, func() {
WrapModel(nil, &Service{}, RecorderOptions{})
})
}
func TestWrapModel_NilService(t *testing.T) {
t.Parallel()
model := &chattest.FakeModel{ProviderName: "provider", ModelName: "model"}
wrapped := WrapModel(model, nil, RecorderOptions{})
require.Same(t, model, wrapped)
}
func TestNextStepNumber_Concurrent(t *testing.T) {
t.Parallel()
const n = 256
runID := uuid.New()
results := make([]int, n)
var wg sync.WaitGroup
wg.Add(n)
for i := range n {
go func() {
defer wg.Done()
results[i] = int(nextStepNumber(runID))
}()
}
wg.Wait()
sort.Ints(results)
for i := range n {
require.Equal(t, i+1, results[i])
}
}
func TestStepFinalizeContext_StripsCancellation(t *testing.T) {
t.Parallel()
baseCtx, cancelBase := context.WithCancel(context.Background())
cancelBase()
require.ErrorIs(t, baseCtx.Err(), context.Canceled)
finalizeCtx, cancelFinalize := stepFinalizeContext(baseCtx)
defer cancelFinalize()
require.NoError(t, finalizeCtx.Err())
_, hasDeadline := finalizeCtx.Deadline()
require.True(t, hasDeadline)
}
func TestSyncStepCounter_AdvancesCounter(t *testing.T) {
t.Parallel()
runID := uuid.New()
t.Cleanup(func() { CleanupStepCounter(runID) })
syncStepCounter(runID, 7)
require.Equal(t, int32(8), nextStepNumber(runID))
}
func TestStepHandleFinish_NilHandle(t *testing.T) {
t.Parallel()
var handle *stepHandle
handle.finish(context.Background(), StatusCompleted, nil, nil, nil, nil)
}
func TestBeginStep_NilService(t *testing.T) {
t.Parallel()
ctx := context.Background()
handle, enriched := beginStep(ctx, nil, RecorderOptions{}, OperationGenerate, nil)
require.Nil(t, handle)
require.Nil(t, attemptSinkFromContext(enriched))
_, ok := StepFromContext(enriched)
require.False(t, ok)
}
func TestBeginStep_FallsBackToRunChatID(t *testing.T) {
t.Parallel()
runID := uuid.New()
runChatID := uuid.New()
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: runChatID})
handle, enriched := beginStep(ctx, &Service{}, RecorderOptions{}, OperationGenerate, nil)
require.NotNil(t, handle)
require.Equal(t, runChatID, handle.stepCtx.ChatID)
stepCtx, ok := StepFromContext(enriched)
require.True(t, ok)
require.Equal(t, runChatID, stepCtx.ChatID)
}
func TestWrapModel_ReturnsDebugModel(t *testing.T) {
t.Parallel()
model := &chattest.FakeModel{ProviderName: "provider", ModelName: "model"}
wrapped := WrapModel(model, &Service{}, RecorderOptions{})
require.NotSame(t, model, wrapped)
require.IsType(t, &debugModel{}, wrapped)
require.Implements(t, (*fantasy.LanguageModel)(nil), wrapped)
require.Equal(t, model.Provider(), wrapped.Provider())
require.Equal(t, model.Model(), wrapped.Model())
}
+227
View File
@@ -0,0 +1,227 @@
package chatdebug
import (
"bytes"
"encoding/json"
"errors"
"io"
"net/http"
"strings"
"golang.org/x/xerrors"
)
// RedactedValue replaces sensitive values in debug payloads.
const RedactedValue = "[REDACTED]"
var sensitiveHeaderNames = map[string]struct{}{
"authorization": {},
"x-api-key": {},
"api-key": {},
"proxy-authorization": {},
"cookie": {},
"set-cookie": {},
}
// sensitiveJSONKeyFragments triggers redaction for JSON keys containing
// these substrings. Notably, "token" is intentionally absent because it
// false-positively redacts LLM token-usage fields (input_tokens,
// output_tokens, prompt_tokens, completion_tokens, reasoning_tokens,
// cache_creation_input_tokens, cache_read_input_tokens, etc.). Auth-
// related token fields are caught by the exact-match set below.
var sensitiveJSONKeyFragments = []string{
"secret",
"password",
"authorization",
"credential",
}
// sensitiveJSONKeyExact matches auth-related token/key field names
// without false-positiving on LLM usage counters. Includes both
// snake_case originals and their camelCase-lowered equivalents
// (e.g. "accessToken" → "accesstoken") so that providers using
// either convention are caught.
var sensitiveJSONKeyExact = map[string]struct{}{
"token": {},
"access_token": {},
"accesstoken": {},
"refresh_token": {},
"refreshtoken": {},
"id_token": {},
"idtoken": {},
"api_token": {},
"apitoken": {},
"api_key": {},
"apikey": {},
"api-key": {},
"x-api-key": {},
"auth_token": {},
"authtoken": {},
"bearer_token": {},
"bearertoken": {},
"session_token": {},
"sessiontoken": {},
"security_token": {},
"securitytoken": {},
"private_key": {},
"privatekey": {},
"signing_key": {},
"signingkey": {},
"secret_key": {},
"secretkey": {},
}
// RedactHeaders returns a flattened copy of h with sensitive values redacted.
func RedactHeaders(h http.Header) map[string]string {
if h == nil {
return nil
}
redacted := make(map[string]string, len(h))
for name, values := range h {
if isSensitiveName(name) {
redacted[name] = RedactedValue
continue
}
redacted[name] = strings.Join(values, ", ")
}
return redacted
}
// RedactJSONSecrets redacts sensitive JSON values by key name. When
// the input is not valid JSON (truncated body, HTML error page, etc.)
// the raw bytes are replaced entirely with a diagnostic placeholder
// to avoid leaking credentials from malformed payloads.
func RedactJSONSecrets(data []byte) []byte {
if len(data) == 0 {
return data
}
decoder := json.NewDecoder(bytes.NewReader(data))
decoder.UseNumber()
var value any
if err := decoder.Decode(&value); err != nil {
// Cannot parse: replace entirely to prevent credential leaks
// from non-JSON error responses (HTML pages, partial bodies).
return []byte(`{"error":"chatdebug: body is not valid JSON, redacted for safety"}`)
}
if err := consumeJSONEOF(decoder); err != nil {
return []byte(`{"error":"chatdebug: body contains extra JSON values, redacted for safety"}`)
}
redacted, changed := redactJSONValue(value)
if !changed {
return data
}
encoded, err := json.Marshal(redacted)
if err != nil {
return data
}
return encoded
}
func consumeJSONEOF(decoder *json.Decoder) error {
var extra any
err := decoder.Decode(&extra)
if errors.Is(err, io.EOF) {
return nil
}
if err == nil {
return xerrors.New("chatdebug: extra JSON values")
}
return err
}
var safeRateLimitHeaderNames = map[string]struct{}{
"anthropic-ratelimit-requests-limit": {},
"anthropic-ratelimit-requests-remaining": {},
"anthropic-ratelimit-requests-reset": {},
"anthropic-ratelimit-tokens-limit": {},
"anthropic-ratelimit-tokens-remaining": {},
"anthropic-ratelimit-tokens-reset": {},
"x-ratelimit-limit-requests": {},
"x-ratelimit-limit-tokens": {},
"x-ratelimit-remaining-requests": {},
"x-ratelimit-remaining-tokens": {},
"x-ratelimit-reset-requests": {},
"x-ratelimit-reset-tokens": {},
}
// isSensitiveName reports whether a name (header or query parameter)
// looks like a credential-carrying key. Exact-match headers are
// checked first, then the rate-limit allowlist, then substring
// patterns for API keys and auth tokens.
func isSensitiveName(name string) bool {
lowerName := strings.ToLower(name)
if _, ok := sensitiveHeaderNames[lowerName]; ok {
return true
}
if _, ok := safeRateLimitHeaderNames[lowerName]; ok {
return false
}
if strings.Contains(lowerName, "api-key") ||
strings.Contains(lowerName, "api_key") ||
strings.Contains(lowerName, "apikey") {
return true
}
// Catch any header containing "token" (e.g. Token, X-Token,
// X-Auth-Token). Safe rate-limit headers like
// x-ratelimit-remaining-tokens are already allowlisted above
// and will not reach this point.
if strings.Contains(lowerName, "token") {
return true
}
return strings.Contains(lowerName, "secret") ||
strings.Contains(lowerName, "bearer")
}
func isSensitiveJSONKey(key string) bool {
lowerKey := strings.ToLower(key)
if _, ok := sensitiveJSONKeyExact[lowerKey]; ok {
return true
}
for _, fragment := range sensitiveJSONKeyFragments {
if strings.Contains(lowerKey, fragment) {
return true
}
}
return false
}
func redactJSONValue(value any) (any, bool) {
switch typed := value.(type) {
case map[string]any:
changed := false
for key, child := range typed {
if isSensitiveJSONKey(key) {
if current, ok := child.(string); ok && current == RedactedValue {
continue
}
typed[key] = RedactedValue
changed = true
continue
}
redactedChild, childChanged := redactJSONValue(child)
if childChanged {
typed[key] = redactedChild
changed = true
}
}
return typed, changed
case []any:
changed := false
for i, child := range typed {
redactedChild, childChanged := redactJSONValue(child)
if childChanged {
typed[i] = redactedChild
changed = true
}
}
return typed, changed
default:
return value, false
}
}
+277
View File
@@ -0,0 +1,277 @@
package chatdebug_test
import (
"net/http"
"testing"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/x/chatd/chatdebug"
)
func TestRedactHeaders(t *testing.T) {
t.Parallel()
t.Run("nil input", func(t *testing.T) {
t.Parallel()
require.Nil(t, chatdebug.RedactHeaders(nil))
})
t.Run("empty header", func(t *testing.T) {
t.Parallel()
redacted := chatdebug.RedactHeaders(http.Header{})
require.NotNil(t, redacted)
require.Empty(t, redacted)
})
t.Run("authorization redacted and others preserved", func(t *testing.T) {
t.Parallel()
headers := http.Header{
"Authorization": {"Bearer secret-token"},
"Accept": {"application/json"},
}
redacted := chatdebug.RedactHeaders(headers)
require.Equal(t, chatdebug.RedactedValue, redacted["Authorization"])
require.Equal(t, "application/json", redacted["Accept"])
})
t.Run("multi-value headers are flattened", func(t *testing.T) {
t.Parallel()
headers := http.Header{
"Accept": {"application/json", "text/plain"},
}
redacted := chatdebug.RedactHeaders(headers)
require.Equal(t, "application/json, text/plain", redacted["Accept"])
})
t.Run("header name matching is case insensitive", func(t *testing.T) {
t.Parallel()
lowerAuthorization := "authorization"
upperAuthorization := "AUTHORIZATION"
headers := http.Header{
lowerAuthorization: {"lower"},
upperAuthorization: {"upper"},
}
redacted := chatdebug.RedactHeaders(headers)
require.Equal(t, chatdebug.RedactedValue, redacted[lowerAuthorization])
require.Equal(t, chatdebug.RedactedValue, redacted[upperAuthorization])
})
t.Run("token and secret substrings are redacted", func(t *testing.T) {
t.Parallel()
traceHeader := "X-Trace-ID"
headers := http.Header{
"X-Auth-Token": {"abc"},
"X-Custom-Secret": {"def"},
"X-Bearer": {"ghi"},
traceHeader: {"trace"},
}
redacted := chatdebug.RedactHeaders(headers)
require.Equal(t, chatdebug.RedactedValue, redacted["X-Auth-Token"])
require.Equal(t, chatdebug.RedactedValue, redacted["X-Custom-Secret"])
require.Equal(t, chatdebug.RedactedValue, redacted["X-Bearer"])
require.Equal(t, "trace", redacted[traceHeader])
})
t.Run("known safe rate limit headers containing token are not redacted", func(t *testing.T) {
t.Parallel()
headers := http.Header{
"Anthropic-Ratelimit-Tokens-Limit": {"1000000"},
"Anthropic-Ratelimit-Tokens-Remaining": {"999000"},
"Anthropic-Ratelimit-Tokens-Reset": {"2026-03-31T08:55:26Z"},
"X-RateLimit-Limit-Tokens": {"120000"},
"X-RateLimit-Remaining-Tokens": {"119500"},
"X-RateLimit-Reset-Tokens": {"12ms"},
}
redacted := chatdebug.RedactHeaders(headers)
require.Equal(t, "1000000", redacted["Anthropic-Ratelimit-Tokens-Limit"])
require.Equal(t, "999000", redacted["Anthropic-Ratelimit-Tokens-Remaining"])
require.Equal(t, "2026-03-31T08:55:26Z", redacted["Anthropic-Ratelimit-Tokens-Reset"])
require.Equal(t, "120000", redacted["X-RateLimit-Limit-Tokens"])
require.Equal(t, "119500", redacted["X-RateLimit-Remaining-Tokens"])
require.Equal(t, "12ms", redacted["X-RateLimit-Reset-Tokens"])
})
t.Run("non-standard headers with api-key pattern are redacted", func(t *testing.T) {
t.Parallel()
headers := http.Header{
"X-Custom-Api-Key": {"secret-key"},
"X-Custom-Secret": {"secret-val"},
"X-Custom-Session-Token": {"session-id"},
}
redacted := chatdebug.RedactHeaders(headers)
require.Equal(t, chatdebug.RedactedValue, redacted["X-Custom-Api-Key"])
require.Equal(t, chatdebug.RedactedValue, redacted["X-Custom-Secret"])
require.Equal(t, chatdebug.RedactedValue, redacted["X-Custom-Session-Token"])
})
t.Run("rate limit headers with token in name are preserved", func(t *testing.T) {
t.Parallel()
// Rate-limit headers containing "token" should NOT be redacted
// because they carry usage/limit counts, not credentials.
headers := http.Header{
"X-Ratelimit-Limit-Tokens": {"1000000"},
"X-Ratelimit-Remaining-Tokens": {"999000"},
}
redacted := chatdebug.RedactHeaders(headers)
require.Equal(t, "1000000", redacted["X-Ratelimit-Limit-Tokens"])
require.Equal(t, "999000", redacted["X-Ratelimit-Remaining-Tokens"])
})
t.Run("original header is not modified", func(t *testing.T) {
t.Parallel()
headers := http.Header{
"Authorization": {"Bearer keep-me"},
"X-Test": {"value"},
}
redacted := chatdebug.RedactHeaders(headers)
redacted["X-Test"] = "changed"
require.Equal(t, []string{"Bearer keep-me"}, headers["Authorization"])
require.Equal(t, []string{"value"}, headers["X-Test"])
require.Equal(t, chatdebug.RedactedValue, redacted["Authorization"])
})
t.Run("api-key header variants are redacted", func(t *testing.T) {
t.Parallel()
headers := http.Header{
"X-Goog-Api-Key": {"secret"},
"X-Api_Key": {"other-secret"},
"X-Safe": {"ok"},
}
redacted := chatdebug.RedactHeaders(headers)
require.Equal(t, chatdebug.RedactedValue, redacted["X-Goog-Api-Key"])
require.Equal(t, chatdebug.RedactedValue, redacted["X-Api_Key"])
require.Equal(t, "ok", redacted["X-Safe"])
})
t.Run("plain token headers are redacted", func(t *testing.T) {
t.Parallel()
// Headers like "Token" or "X-Token" should be redacted
// even without auth/session/access qualifiers.
headers := http.Header{
"Token": {"my-secret-token"},
"X-Token": {"another-secret"},
"X-Safe": {"ok"},
}
redacted := chatdebug.RedactHeaders(headers)
require.Equal(t, chatdebug.RedactedValue, redacted["Token"])
require.Equal(t, chatdebug.RedactedValue, redacted["X-Token"])
require.Equal(t, "ok", redacted["X-Safe"])
})
}
func TestRedactJSONSecrets(t *testing.T) {
t.Parallel()
t.Run("redacts top level secret fields", func(t *testing.T) {
t.Parallel()
input := []byte(`{"api_key":"abc","token":"def","password":"ghi","safe":"ok"}`)
redacted := chatdebug.RedactJSONSecrets(input)
require.JSONEq(t, `{"api_key":"[REDACTED]","token":"[REDACTED]","password":"[REDACTED]","safe":"ok"}`, string(redacted))
})
t.Run("redacts security_token exact key", func(t *testing.T) {
t.Parallel()
input := []byte(`{"security_token":"s3cret","securityToken":"tok","safe":"ok"}`)
redacted := chatdebug.RedactJSONSecrets(input)
require.JSONEq(t, `{"security_token":"[REDACTED]","securityToken":"[REDACTED]","safe":"ok"}`, string(redacted))
})
t.Run("preserves LLM token usage fields", func(t *testing.T) {
t.Parallel()
input := []byte(`{"input_tokens":100,"output_tokens":50,"prompt_tokens":80,"completion_tokens":20,"reasoning_tokens":10,"cache_creation_input_tokens":5,"cache_read_input_tokens":3,"total_tokens":150,"max_tokens":4096,"max_output_tokens":2048}`)
redacted := chatdebug.RedactJSONSecrets(input)
// All usage/limit fields should be preserved, not redacted.
require.Equal(t, input, redacted)
})
t.Run("redacts nested objects", func(t *testing.T) {
t.Parallel()
input := []byte(`{"outer":{"nested_secret":"abc","safe":1},"keep":true}`)
redacted := chatdebug.RedactJSONSecrets(input)
require.JSONEq(t, `{"outer":{"nested_secret":"[REDACTED]","safe":1},"keep":true}`, string(redacted))
})
t.Run("redacts arrays of objects", func(t *testing.T) {
t.Parallel()
input := []byte(`[{"token":"abc"},{"value":1,"credentials":{"access_key":"def"}}]`)
redacted := chatdebug.RedactJSONSecrets(input)
require.JSONEq(t, `[{"token":"[REDACTED]"},{"value":1,"credentials":"[REDACTED]"}]`, string(redacted))
})
t.Run("concatenated JSON is replaced with diagnostic", func(t *testing.T) {
t.Parallel()
input := []byte(`{"token":"abc"}{"safe":"ok"}`)
result := chatdebug.RedactJSONSecrets(input)
require.Contains(t, string(result), "extra JSON values")
})
t.Run("non JSON input is replaced with diagnostic", func(t *testing.T) {
t.Parallel()
input := []byte("not json")
result := chatdebug.RedactJSONSecrets(input)
require.Contains(t, string(result), "not valid JSON")
})
t.Run("empty input is unchanged", func(t *testing.T) {
t.Parallel()
input := []byte{}
require.Equal(t, input, chatdebug.RedactJSONSecrets(input))
})
t.Run("JSON without sensitive keys is unchanged", func(t *testing.T) {
t.Parallel()
input := []byte(`{"safe":"ok","nested":{"value":1}}`)
require.Equal(t, input, chatdebug.RedactJSONSecrets(input))
})
t.Run("key matching is case insensitive", func(t *testing.T) {
t.Parallel()
input := []byte(`{"API_KEY":"abc","Token":"def","PASSWORD":"ghi"}`)
redacted := chatdebug.RedactJSONSecrets(input)
require.JSONEq(t, `{"API_KEY":"[REDACTED]","Token":"[REDACTED]","PASSWORD":"[REDACTED]"}`, string(redacted))
})
t.Run("camelCase token field names are redacted", func(t *testing.T) {
t.Parallel()
// Providers may use camelCase (e.g. accessToken, refreshToken).
// These should be redacted even though they don't match the
// snake_case originals exactly.
input := []byte(`{"accessToken":"abc","refreshToken":"def","authToken":"ghi","input_tokens":100,"output_tokens":50}`)
redacted := chatdebug.RedactJSONSecrets(input)
require.JSONEq(t, `{"accessToken":"[REDACTED]","refreshToken":"[REDACTED]","authToken":"[REDACTED]","input_tokens":100,"output_tokens":50}`, string(redacted))
})
}
@@ -0,0 +1,74 @@
package chatdebug //nolint:testpackage // Uses unexported recorder helpers.
import (
"context"
"testing"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/testutil"
)
func TestBeginStepReuseStep(t *testing.T) {
t.Parallel()
t.Run("reuses handle under ReuseStep", func(t *testing.T) {
t.Parallel()
chatID := uuid.New()
ownerID := uuid.New()
runID := uuid.New()
t.Cleanup(func() { CleanupStepCounter(runID) })
svc := NewService(nil, testutil.Logger(t), nil)
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})
ctx = ReuseStep(ctx)
opts := RecorderOptions{ChatID: chatID, OwnerID: ownerID}
firstHandle, firstEnriched := beginStep(ctx, svc, opts, OperationStream, nil)
secondHandle, secondEnriched := beginStep(ctx, svc, opts, OperationStream, nil)
require.NotNil(t, firstHandle)
require.Same(t, firstHandle, secondHandle)
require.Same(t, firstHandle.stepCtx, secondHandle.stepCtx)
require.Same(t, firstHandle.sink, secondHandle.sink)
require.Equal(t, runID, firstHandle.stepCtx.RunID)
require.Equal(t, chatID, firstHandle.stepCtx.ChatID)
require.Equal(t, int32(1), firstHandle.stepCtx.StepNumber)
require.Equal(t, OperationStream, firstHandle.stepCtx.Operation)
require.NotEqual(t, uuid.Nil, firstHandle.stepCtx.StepID)
firstStepCtx, ok := StepFromContext(firstEnriched)
require.True(t, ok)
secondStepCtx, ok := StepFromContext(secondEnriched)
require.True(t, ok)
require.Same(t, firstStepCtx, secondStepCtx)
require.Same(t, firstHandle.stepCtx, firstStepCtx)
require.Same(t, attemptSinkFromContext(firstEnriched), attemptSinkFromContext(secondEnriched))
})
t.Run("creates new handles without ReuseStep", func(t *testing.T) {
t.Parallel()
chatID := uuid.New()
ownerID := uuid.New()
runID := uuid.New()
t.Cleanup(func() { CleanupStepCounter(runID) })
svc := NewService(nil, testutil.Logger(t), nil)
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})
opts := RecorderOptions{ChatID: chatID, OwnerID: ownerID}
firstHandle, _ := beginStep(ctx, svc, opts, OperationStream, nil)
secondHandle, _ := beginStep(ctx, svc, opts, OperationStream, nil)
require.NotNil(t, firstHandle)
require.NotNil(t, secondHandle)
require.NotSame(t, firstHandle, secondHandle)
require.NotSame(t, firstHandle.sink, secondHandle.sink)
require.Equal(t, int32(1), firstHandle.stepCtx.StepNumber)
require.Equal(t, int32(2), secondHandle.stepCtx.StepNumber)
require.NotEqual(t, firstHandle.stepCtx.StepID, secondHandle.stepCtx.StepID)
})
}
+63 -148
View File
@@ -6,6 +6,7 @@ import (
"strings"
"sync"
"sync/atomic"
"time"
"unicode/utf8"
"github.com/google/uuid"
@@ -15,43 +16,76 @@ import (
"github.com/coder/coder/v2/coderd/database/pubsub"
)
// This branch-02 compatibility shim forward-declares recorder, service,
// and summary symbols that land in later stacked branches. Delete this
// file once recorder.go, service.go, and summary.go are available here.
// RecorderOptions identifies the chat/model context for debug recording.
type RecorderOptions struct {
ChatID uuid.UUID
OwnerID uuid.UUID
Provider string
Model string
}
// This compatibility shim forward-declares service and summary symbols
// that land in later stacked branches. Delete this file once service.go
// and summary.go are available here.
// Service is a placeholder for the later chat debug persistence service.
type Service struct{}
// NewService constructs the branch-02 placeholder chat debug service.
func NewService(_ database.Store, _ slog.Logger, _ pubsub.Pubsub) *Service {
return &Service{}
type Service struct {
log slog.Logger
}
type attemptSink struct{}
type attemptSinkKey struct{}
func withAttemptSink(ctx context.Context, sink *attemptSink) context.Context {
if sink == nil {
panic("chatdebug: nil attemptSink")
}
return context.WithValue(ctx, attemptSinkKey{}, sink)
// CreateStepParams identifies the data recorded when a debug step starts.
type CreateStepParams struct {
RunID uuid.UUID
ChatID uuid.UUID
StepNumber int32
Operation Operation
Status Status
HistoryTipMessageID int64
NormalizedRequest any
}
func attemptSinkFromContext(ctx context.Context) *attemptSink {
sink, _ := ctx.Value(attemptSinkKey{}).(*attemptSink)
return sink
// UpdateStepParams identifies the data recorded when a debug step finishes.
type UpdateStepParams struct {
ID uuid.UUID
ChatID uuid.UUID
Status Status
NormalizedResponse any
Usage any
Attempts []Attempt
Error any
Metadata any
FinishedAt time.Time
}
var stepCounters sync.Map // map[uuid.UUID]*atomic.Int32
// NewService constructs the placeholder chat debug service.
func NewService(_ database.Store, log slog.Logger, _ pubsub.Pubsub) *Service {
return &Service{log: log}
}
// IsEnabled reports whether debug recording is enabled for a chat owner.
func (*Service) IsEnabled(context.Context, uuid.UUID, uuid.UUID) bool {
return true
}
// CreateStep synthesizes a debug step so recorder tests can exercise the
// wrapper without requiring the later persistence service implementation.
func (*Service) CreateStep(
_ context.Context,
params CreateStepParams,
) (database.ChatDebugStep, error) {
return database.ChatDebugStep{
ID: uuid.New(),
RunID: params.RunID,
ChatID: params.ChatID,
StepNumber: params.StepNumber,
Operation: string(params.Operation),
Status: string(params.Status),
}, nil
}
// UpdateStep accepts final step state once recording completes.
func (*Service) UpdateStep(
_ context.Context,
params UpdateStepParams,
) (database.ChatDebugStep, error) {
return database.ChatDebugStep{
ID: params.ID,
ChatID: params.ChatID,
Status: string(params.Status),
}, nil
}
// runRefCounts tracks how many live RunContext instances reference each
// RunID. Cleanup of shared state (step counters) is deferred until the
@@ -78,125 +112,6 @@ func releaseRunRef(runID uuid.UUID) {
}
}
func nextStepNumber(runID uuid.UUID) int32 {
val, _ := stepCounters.LoadOrStore(runID, &atomic.Int32{})
counter, ok := val.(*atomic.Int32)
if !ok {
panic("chatdebug: invalid step counter type")
}
return counter.Add(1)
}
// CleanupStepCounter removes per-run step counter and reference count
// state. This is used by tests and later stacked branches that have a
// real run lifecycle.
func CleanupStepCounter(runID uuid.UUID) {
stepCounters.Delete(runID)
runRefCounts.Delete(runID)
}
type stepHandle struct {
stepCtx *StepContext
sink *attemptSink
mu sync.Mutex
status Status
response any
usage any
err any
metadata any
}
func beginStep(
ctx context.Context,
svc *Service,
opts RecorderOptions,
op Operation,
_ any,
) (*stepHandle, context.Context) {
if svc == nil {
return nil, ctx
}
rc, ok := RunFromContext(ctx)
if !ok || rc.RunID == uuid.Nil {
return nil, ctx
}
if holder, reuseStep := reuseHolderFromContext(ctx); reuseStep {
holder.mu.Lock()
defer holder.mu.Unlock()
// Only reuse the cached handle if it belongs to the same run.
// A different RunContext means a new logical run, so we must
// create a fresh step to avoid cross-run attribution.
if holder.handle != nil && holder.handle.stepCtx.RunID == rc.RunID {
enriched := ContextWithStep(ctx, holder.handle.stepCtx)
enriched = withAttemptSink(enriched, holder.handle.sink)
return holder.handle, enriched
}
handle, enriched := newStepHandle(ctx, rc, opts, op)
holder.handle = handle
return handle, enriched
}
return newStepHandle(ctx, rc, opts, op)
}
func newStepHandle(
ctx context.Context,
rc *RunContext,
opts RecorderOptions,
op Operation,
) (*stepHandle, context.Context) {
if rc == nil || rc.RunID == uuid.Nil {
return nil, ctx
}
chatID := opts.ChatID
if chatID == uuid.Nil {
chatID = rc.ChatID
}
handle := &stepHandle{
stepCtx: &StepContext{
StepID: uuid.New(),
RunID: rc.RunID,
ChatID: chatID,
StepNumber: nextStepNumber(rc.RunID),
Operation: op,
HistoryTipMessageID: rc.HistoryTipMessageID,
},
sink: &attemptSink{},
}
enriched := ContextWithStep(ctx, handle.stepCtx)
enriched = withAttemptSink(enriched, handle.sink)
return handle, enriched
}
func (h *stepHandle) finish(
_ context.Context,
status Status,
response any,
usage any,
err any,
metadata any,
) {
if h == nil || h.stepCtx == nil {
return
}
// Guard with a mutex so concurrent callers (e.g. retried stream
// wrappers sharing a reused handle) don't race. Unlike sync.Once,
// later retries are allowed to overwrite earlier failure results so
// the step reflects the final outcome.
h.mu.Lock()
defer h.mu.Unlock()
h.status = status
h.response = response
h.usage = usage
h.err = err
h.metadata = metadata
}
// whitespaceRun matches one or more consecutive whitespace characters.
var whitespaceRun = regexp.MustCompile(`\s+`)
@@ -2,7 +2,6 @@ package chatdebug
import (
"context"
"net/http"
"testing"
"unicode/utf8"
@@ -19,15 +18,6 @@ func TestBeginStep_SkipsNilRunID(t *testing.T) {
require.Equal(t, ctx, enriched)
}
func TestNewStepHandle_SkipsNilRunID(t *testing.T) {
t.Parallel()
ctx := context.Background()
handle, enriched := newStepHandle(ctx, &RunContext{ChatID: uuid.New()}, RecorderOptions{ChatID: uuid.New()}, OperationGenerate)
require.Nil(t, handle)
require.Equal(t, ctx, enriched)
}
func TestTruncateLabel(t *testing.T) {
t.Parallel()
@@ -50,7 +40,6 @@ func TestTruncateLabel(t *testing.T) {
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
got := TruncateLabel(tc.input, tc.maxLen)
@@ -66,25 +55,3 @@ func maxInt(a, b int) int {
}
return b
}
// RedactedValue replaces sensitive values in debug payloads.
const RedactedValue = "[REDACTED]"
// RecordingTransport is the branch-02 placeholder HTTP recording transport.
type RecordingTransport struct {
Base http.RoundTripper
}
var _ http.RoundTripper = (*RecordingTransport)(nil)
func (t *RecordingTransport) RoundTrip(req *http.Request) (*http.Response, error) {
if req == nil {
panic("chatdebug: nil request")
}
base := t.Base
if base == nil {
base = http.DefaultTransport
}
return base.RoundTrip(req)
}
+382
View File
@@ -0,0 +1,382 @@
package chatdebug
import (
"bytes"
"encoding/json"
"errors"
"io"
"mime"
"net/http"
"net/url"
"strings"
"sync"
"time"
)
// attemptStatusCompleted is the status recorded when a response body
// is fully read without transport-level errors.
const attemptStatusCompleted = "completed"
// attemptStatusFailed is the status recorded when a transport error
// or body read error occurs.
const attemptStatusFailed = "failed"
// maxRecordedRequestBodyBytes caps in-memory request capture when GetBody
// is available.
const maxRecordedRequestBodyBytes = 50_000
// maxRecordedResponseBodyBytes caps in-memory response capture.
const maxRecordedResponseBodyBytes = 50_000
// RecordingTransport captures HTTP request/response data for debug steps.
// When the request context carries an attemptSink, it records each round
// trip. Otherwise it delegates directly.
type RecordingTransport struct {
// Base is the underlying transport. nil defaults to http.DefaultTransport.
Base http.RoundTripper
}
var _ http.RoundTripper = (*RecordingTransport)(nil)
func (t *RecordingTransport) RoundTrip(req *http.Request) (*http.Response, error) {
if req == nil {
panic("chatdebug: nil request")
}
base := t.Base
if base == nil {
base = http.DefaultTransport
}
sink := attemptSinkFromContext(req.Context())
if sink == nil {
return base.RoundTrip(req)
}
requestHeaders := RedactHeaders(req.Header)
// Capture method and URL/path from the request.
method := req.Method
reqURL := ""
reqPath := ""
if req.URL != nil {
reqURL = redactURL(req.URL)
reqPath = req.URL.Path
}
requestBody, err := captureRequestBody(req)
if err != nil {
return nil, err
}
attemptNumber := sink.nextAttemptNumber()
startedAt := time.Now()
resp, err := base.RoundTrip(req)
finishedAt := time.Now()
durationMs := finishedAt.Sub(startedAt).Milliseconds()
if err != nil {
sink.record(Attempt{
Number: attemptNumber,
Status: attemptStatusFailed,
Method: method,
URL: reqURL,
Path: reqPath,
StartedAt: startedAt.UTC().Format(time.RFC3339Nano),
FinishedAt: finishedAt.UTC().Format(time.RFC3339Nano),
RequestHeaders: requestHeaders,
RequestBody: requestBody,
Error: err.Error(),
DurationMs: durationMs,
})
return nil, err
}
respHeaders := RedactHeaders(resp.Header)
resp.Body = &recordingBody{
inner: resp.Body,
sink: sink,
startedAt: startedAt,
contentLength: resp.ContentLength,
base: Attempt{
Number: attemptNumber,
Method: method,
URL: reqURL,
Path: reqPath,
RequestHeaders: requestHeaders,
RequestBody: requestBody,
ResponseStatus: resp.StatusCode,
ResponseHeaders: respHeaders,
DurationMs: durationMs,
},
}
return resp, nil
}
func redactURL(u *url.URL) string {
if u == nil {
return ""
}
clone := *u
clone.User = nil
q := clone.Query()
for key, values := range q {
if isSensitiveName(key) || isSensitiveJSONKey(key) {
for i := range values {
values[i] = RedactedValue
}
q[key] = values
}
}
clone.RawQuery = q.Encode()
return clone.String()
}
func captureRequestBody(req *http.Request) ([]byte, error) {
if req == nil || req.Body == nil {
return nil, nil
}
if req.GetBody != nil {
clone, err := req.GetBody()
if err == nil {
defer clone.Close()
limited, err := io.ReadAll(io.LimitReader(clone, maxRecordedRequestBodyBytes+1))
if err == nil {
if len(limited) > maxRecordedRequestBodyBytes {
return []byte("[TRUNCATED]"), nil
}
return RedactJSONSecrets(limited), nil
}
}
}
// Without GetBody we cannot safely capture the request body without
// fully consuming a potentially large or streaming body before the
// request is sent. Skip capture in that case to keep debug logging
// lightweight and non-invasive.
return nil, nil
}
type recordingBody struct {
inner io.ReadCloser
contentLength int64
sink *attemptSink
base Attempt
startedAt time.Time
mu sync.Mutex
buf bytes.Buffer
truncated bool
sawEOF bool
bytesRead int64
recordOnce sync.Once
closeOnce sync.Once
}
func (r *recordingBody) Read(p []byte) (int, error) {
n, err := r.inner.Read(p)
r.mu.Lock()
r.bytesRead += int64(n)
if n > 0 && !r.truncated {
remaining := maxRecordedResponseBodyBytes - r.buf.Len()
if remaining > 0 {
toWrite := n
if toWrite > remaining {
toWrite = remaining
r.truncated = true
}
_, _ = r.buf.Write(p[:toWrite])
} else {
r.truncated = true
}
}
if errors.Is(err, io.EOF) {
r.sawEOF = true
}
r.mu.Unlock()
if err != nil {
r.record(err)
}
return n, err
}
func (r *recordingBody) Close() error {
r.mu.Lock()
sawEOF := r.sawEOF
bytesRead := r.bytesRead
contentLength := r.contentLength
truncated := r.truncated
responseBody := append([]byte(nil), r.buf.Bytes()...)
r.mu.Unlock()
contentType := r.base.ResponseHeaders["Content-Type"]
shouldDrainUnknownLengthJSON := contentLength < 0 &&
!sawEOF &&
bytesRead > 0 &&
!truncated &&
isCompleteUnknownLengthJSONBody(contentType, responseBody)
// Always close the inner reader first so that stalled chunked
// bodies cannot block drainToEOF indefinitely. Once inner is
// closed, reads return immediately with an error or EOF.
var closeErr error
r.closeOnce.Do(func() {
closeErr = r.inner.Close()
})
if closeErr != nil {
r.record(closeErr)
return closeErr
}
// Drain remaining bytes that may already be buffered inside the
// HTTP transport after close. Because inner is closed, this
// finishes immediately rather than blocking on the network.
if shouldDrainUnknownLengthJSON {
// Best-effort drain; ignore errors since inner is closed.
_ = r.drainToEOF()
}
r.mu.Lock()
sawEOF = r.sawEOF
bytesRead = r.bytesRead
contentLength = r.contentLength
truncated = r.truncated
responseBody = append([]byte(nil), r.buf.Bytes()...)
r.mu.Unlock()
switch {
// Only check JSON completeness when the recording buffer is
// not truncated. A truncated buffer is an incomplete prefix
// of the body, so the completeness check would false-positive.
case sawEOF && !truncated && contentLength < 0 && isJSONLikeContentType(contentType) && !isCompleteUnknownLengthJSONBody(contentType, responseBody):
r.record(io.ErrUnexpectedEOF)
case sawEOF:
r.record(io.EOF)
case responseHasNoBody(r.base.Method, r.base.ResponseStatus):
r.record(nil)
case contentLength >= 0 && bytesRead >= contentLength:
r.record(nil)
case contentLength < 0 && !truncated && isCompleteUnknownLengthJSONBody(contentType, responseBody):
r.record(nil)
default:
r.record(io.ErrUnexpectedEOF)
}
return nil
}
func responseHasNoBody(method string, statusCode int) bool {
if method == http.MethodHead {
return true
}
return statusCode == http.StatusNoContent ||
statusCode == http.StatusNotModified ||
(statusCode >= 100 && statusCode < 200)
}
func isJSONLikeContentType(contentType string) bool {
mediaType, _, err := mime.ParseMediaType(contentType)
if err != nil {
mediaType = strings.TrimSpace(strings.Split(contentType, ";")[0])
}
return mediaType == "application/json" || strings.HasSuffix(mediaType, "+json")
}
// maxDrainBytes caps how many trailing bytes drainToEOF will consume.
// This prevents Close() from blocking indefinitely on a misbehaving
// or extremely large chunked body.
const maxDrainBytes = 64 * 1024 // 64 KB
func (r *recordingBody) drainToEOF() error {
buf := make([]byte, 4*1024)
var drained int64
for {
n, err := r.inner.Read(buf)
r.mu.Lock()
r.bytesRead += int64(n)
drained += int64(n)
if n > 0 && !r.truncated {
remaining := maxRecordedResponseBodyBytes - r.buf.Len()
if remaining > 0 {
toWrite := n
if toWrite > remaining {
toWrite = remaining
r.truncated = true
}
_, _ = r.buf.Write(buf[:toWrite])
} else {
r.truncated = true
}
}
if errors.Is(err, io.EOF) {
r.sawEOF = true
}
r.mu.Unlock()
if err != nil {
if errors.Is(err, io.EOF) {
return nil
}
return err
}
// Safety valve: stop draining after maxDrainBytes to prevent
// Close() from blocking indefinitely on a chunked body.
if drained >= maxDrainBytes {
return io.ErrUnexpectedEOF
}
}
}
func isCompleteUnknownLengthJSONBody(contentType string, body []byte) bool {
if !isJSONLikeContentType(contentType) {
return false
}
trimmed := bytes.TrimSpace(body)
if len(trimmed) == 0 {
return false
}
decoder := json.NewDecoder(bytes.NewReader(trimmed))
var value any
if err := decoder.Decode(&value); err != nil {
return false
}
var extra any
return errors.Is(decoder.Decode(&extra), io.EOF)
}
func (r *recordingBody) record(err error) {
r.recordOnce.Do(func() {
finishedAt := time.Now()
r.mu.Lock()
truncated := r.truncated
responseBody := append([]byte(nil), r.buf.Bytes()...)
base := r.base
startedAt := r.startedAt
r.mu.Unlock()
if truncated {
base.ResponseBody = []byte("[TRUNCATED]")
} else {
base.ResponseBody = RedactJSONSecrets(responseBody)
}
base.StartedAt = startedAt.UTC().Format(time.RFC3339Nano)
base.FinishedAt = finishedAt.UTC().Format(time.RFC3339Nano)
// Recompute duration to include body read time.
base.DurationMs = finishedAt.Sub(startedAt).Milliseconds()
if err != nil && !errors.Is(err, io.EOF) {
base.Error = err.Error()
base.Status = attemptStatusFailed
} else {
base.Status = attemptStatusCompleted
}
r.sink.record(base)
})
}
+737
View File
@@ -0,0 +1,737 @@
package chatdebug //nolint:testpackage // Uses unexported recorder helpers.
import (
"context"
"encoding/json"
"errors"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"
)
func newTestSinkContext(t *testing.T) (context.Context, *attemptSink) {
t.Helper()
sink := &attemptSink{}
return withAttemptSink(context.Background(), sink), sink
}
type roundTripFunc func(*http.Request) (*http.Response, error)
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return f(req)
}
type scriptedReadCloser struct {
chunks [][]byte
index int
offset int // byte offset within current chunk
}
func (r *scriptedReadCloser) Read(p []byte) (int, error) {
if r.index >= len(r.chunks) {
return 0, io.EOF
}
chunk := r.chunks[r.index]
remaining := chunk[r.offset:]
n := copy(p, remaining)
r.offset += n
if r.offset >= len(chunk) {
r.index++
r.offset = 0
}
return n, nil
}
func (*scriptedReadCloser) Close() error {
return nil
}
func TestRecordingTransport_NoSink(t *testing.T) {
t.Parallel()
gotMethod := make(chan string, 1)
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
gotMethod <- req.Method
_, _ = rw.Write([]byte("ok"))
}))
defer server.Close()
client := &http.Client{
Transport: &RecordingTransport{Base: server.Client().Transport},
}
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, server.URL, nil)
require.NoError(t, err)
resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
require.Equal(t, "ok", string(body))
require.Equal(t, http.MethodGet, <-gotMethod)
}
func TestRecordingTransport_CaptureRequest(t *testing.T) {
t.Parallel()
const requestBody = `{"message":"hello","api_key":"super-secret"}`
type receivedRequest struct {
authorization string
body []byte
}
gotRequest := make(chan receivedRequest, 1)
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
body, err := io.ReadAll(req.Body)
require.NoError(t, err)
gotRequest <- receivedRequest{
authorization: req.Header.Get("Authorization"),
body: body,
}
_, _ = rw.Write([]byte(`{"ok":true}`))
}))
defer server.Close()
ctx, sink := newTestSinkContext(t)
client := &http.Client{
Transport: &RecordingTransport{Base: server.Client().Transport},
}
req, err := http.NewRequestWithContext(
ctx,
http.MethodPost,
server.URL,
strings.NewReader(requestBody),
)
require.NoError(t, err)
req.Header.Set("Authorization", "Bearer top-secret")
req.Header.Set("Content-Type", "application/json")
resp, err := client.Do(req)
require.NoError(t, err)
_, err = io.ReadAll(resp.Body)
require.NoError(t, err)
require.NoError(t, resp.Body.Close())
attempts := sink.snapshot()
require.Len(t, attempts, 1)
require.Equal(t, 1, attempts[0].Number)
require.Equal(t, RedactedValue, attempts[0].RequestHeaders["Authorization"])
require.Equal(t, "application/json", attempts[0].RequestHeaders["Content-Type"])
require.JSONEq(t, `{"message":"hello","api_key":"[REDACTED]"}`, string(attempts[0].RequestBody))
received := <-gotRequest
require.JSONEq(t, requestBody, string(received.body))
require.Equal(t, "Bearer top-secret", received.authorization)
}
func TestRecordingTransport_RedactsSensitiveQueryParameters(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
_, _ = rw.Write([]byte(`ok`))
}))
defer server.Close()
ctx, sink := newTestSinkContext(t)
client := &http.Client{Transport: &RecordingTransport{Base: server.Client().Transport}}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, server.URL+`?api_key=secret&safe=ok`, nil)
require.NoError(t, err)
resp, err := client.Do(req)
require.NoError(t, err)
_, err = io.ReadAll(resp.Body)
require.NoError(t, err)
require.NoError(t, resp.Body.Close())
attempts := sink.snapshot()
require.Len(t, attempts, 1)
require.Contains(t, attempts[0].URL, "api_key=%5BREDACTED%5D")
require.Contains(t, attempts[0].URL, "safe=ok")
}
func TestRecordingTransport_TruncatesLargeRequestBodies(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
_, _ = io.Copy(io.Discard, req.Body)
_, _ = rw.Write([]byte(`ok`))
}))
defer server.Close()
ctx, sink := newTestSinkContext(t)
client := &http.Client{Transport: &RecordingTransport{Base: server.Client().Transport}}
large := strings.Repeat("x", maxRecordedRequestBodyBytes+1024)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, server.URL, strings.NewReader(large))
require.NoError(t, err)
resp, err := client.Do(req)
require.NoError(t, err)
_, err = io.ReadAll(resp.Body)
require.NoError(t, err)
require.NoError(t, resp.Body.Close())
attempts := sink.snapshot()
require.Len(t, attempts, 1)
require.Equal(t, []byte("[TRUNCATED]"), attempts[0].RequestBody)
}
func TestRecordingTransport_StripsURLUserinfo(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
_, _ = rw.Write([]byte(`ok`))
}))
defer server.Close()
ctx, sink := newTestSinkContext(t)
client := &http.Client{Transport: &RecordingTransport{Base: server.Client().Transport}}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, strings.Replace(server.URL, "http://", "http://user:secret@", 1)+`?api_key=secret`, nil)
require.NoError(t, err)
resp, err := client.Do(req)
require.NoError(t, err)
_, err = io.ReadAll(resp.Body)
require.NoError(t, err)
require.NoError(t, resp.Body.Close())
attempts := sink.snapshot()
require.Len(t, attempts, 1)
require.NotContains(t, attempts[0].URL, "user:secret")
require.Contains(t, attempts[0].URL, "api_key=%5BREDACTED%5D")
}
func TestRecordingTransport_SkipsNonReplayableRequestBodyCapture(t *testing.T) {
t.Parallel()
const requestBody = `{"message":"hello"}`
gotRequest := make(chan []byte, 1)
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
body, err := io.ReadAll(req.Body)
require.NoError(t, err)
gotRequest <- body
_, _ = rw.Write([]byte(`ok`))
}))
defer server.Close()
ctx, sink := newTestSinkContext(t)
client := &http.Client{Transport: &RecordingTransport{Base: server.Client().Transport}}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, server.URL, io.NopCloser(strings.NewReader(requestBody)))
require.NoError(t, err)
req.GetBody = nil
resp, err := client.Do(req)
require.NoError(t, err)
_, err = io.ReadAll(resp.Body)
require.NoError(t, err)
require.NoError(t, resp.Body.Close())
require.JSONEq(t, requestBody, string(<-gotRequest))
attempts := sink.snapshot()
require.Len(t, attempts, 1)
require.Nil(t, attempts[0].RequestBody)
}
func TestRecordingTransport_CaptureResponse(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.Header().Set("X-API-Key", "response-secret")
rw.Header().Set("X-Trace-ID", "trace-123")
rw.WriteHeader(http.StatusCreated)
_, _ = rw.Write([]byte(`{"token":"response-secret","safe":"ok"}`))
}))
defer server.Close()
ctx, sink := newTestSinkContext(t)
client := &http.Client{
Transport: &RecordingTransport{Base: server.Client().Transport},
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, server.URL, nil)
require.NoError(t, err)
resp, err := client.Do(req)
require.NoError(t, err)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.NoError(t, resp.Body.Close())
require.JSONEq(t, `{"token":"response-secret","safe":"ok"}`, string(body))
attempts := sink.snapshot()
require.Len(t, attempts, 1)
require.Equal(t, http.StatusCreated, attempts[0].ResponseStatus)
require.Equal(t, RedactedValue, attempts[0].ResponseHeaders["X-Api-Key"])
require.Equal(t, "trace-123", attempts[0].ResponseHeaders["X-Trace-Id"])
require.JSONEq(t, `{"token":"[REDACTED]","safe":"ok"}`, string(attempts[0].ResponseBody))
}
func TestRecordingTransport_CaptureResponseOnEOFWithoutClose(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.Header().Set("Content-Type", "application/json")
rw.Header().Set("X-API-Key", "response-secret")
rw.WriteHeader(http.StatusAccepted)
_, _ = rw.Write([]byte(`{"token":"response-secret","safe":"ok"}`))
}))
defer server.Close()
ctx, sink := newTestSinkContext(t)
client := &http.Client{
Transport: &RecordingTransport{Base: server.Client().Transport},
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, server.URL, nil)
require.NoError(t, err)
resp, err := client.Do(req)
require.NoError(t, err)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.JSONEq(t, `{"token":"response-secret","safe":"ok"}`, string(body))
attempts := sink.snapshot()
require.Len(t, attempts, 1)
require.Equal(t, http.StatusAccepted, attempts[0].ResponseStatus)
require.Equal(t, "application/json", attempts[0].ResponseHeaders["Content-Type"])
require.Equal(t, RedactedValue, attempts[0].ResponseHeaders["X-Api-Key"])
require.JSONEq(t, `{"token":"[REDACTED]","safe":"ok"}`, string(attempts[0].ResponseBody))
require.NoError(t, resp.Body.Close())
}
func TestRecordingTransport_StreamingBody(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
flusher, ok := rw.(http.Flusher)
require.True(t, ok)
rw.Header().Set("Content-Type", "application/json")
_, _ = rw.Write([]byte(`{"safe":"stream",`))
flusher.Flush()
_, _ = rw.Write([]byte(`"token":"chunk-secret"}`))
flusher.Flush()
}))
defer server.Close()
ctx, sink := newTestSinkContext(t)
client := &http.Client{
Transport: &RecordingTransport{Base: server.Client().Transport},
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, server.URL, nil)
require.NoError(t, err)
resp, err := client.Do(req)
require.NoError(t, err)
buf := make([]byte, 5)
var body strings.Builder
for {
n, readErr := resp.Body.Read(buf)
if n > 0 {
_, writeErr := body.Write(buf[:n])
require.NoError(t, writeErr)
}
if errors.Is(readErr, io.EOF) {
break
}
require.NoError(t, readErr)
}
require.NoError(t, resp.Body.Close())
require.JSONEq(t, `{"safe":"stream","token":"chunk-secret"}`, body.String())
attempts := sink.snapshot()
require.Len(t, attempts, 1)
require.JSONEq(t, `{"safe":"stream","token":"[REDACTED]"}`, string(attempts[0].ResponseBody))
}
func TestRecordingTransport_CloseAfterDecoderConsumesContentLengthSucceeds(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.Header().Set("Content-Type", "application/json")
_, _ = rw.Write([]byte(`{"token":"response-secret","safe":"ok"}`))
}))
defer server.Close()
ctx, sink := newTestSinkContext(t)
client := &http.Client{Transport: &RecordingTransport{Base: server.Client().Transport}}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, server.URL, nil)
require.NoError(t, err)
resp, err := client.Do(req)
require.NoError(t, err)
var decoded map[string]string
require.NoError(t, json.NewDecoder(resp.Body).Decode(&decoded))
require.Equal(t, "ok", decoded["safe"])
require.NoError(t, resp.Body.Close())
attempts := sink.snapshot()
require.Len(t, attempts, 1)
require.Equal(t, attemptStatusCompleted, attempts[0].Status)
require.Empty(t, attempts[0].Error)
}
func TestRecordingTransport_CloseAfterDecoderConsumesUnknownLengthJSONSucceeds(t *testing.T) {
t.Parallel()
ctx, sink := newTestSinkContext(t)
client := &http.Client{
Transport: &RecordingTransport{
Base: roundTripFunc(func(req *http.Request) (*http.Response, error) {
return &http.Response{ //nolint:exhaustruct // Test response exercises unknown-length close semantics.
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: &scriptedReadCloser{chunks: [][]byte{[]byte(`{"token":"response-secret","safe":"ok"}`)}},
ContentLength: -1,
}, nil
}),
},
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://example.invalid", nil)
require.NoError(t, err)
resp, err := client.Do(req)
require.NoError(t, err)
var decoded map[string]string
require.NoError(t, json.NewDecoder(resp.Body).Decode(&decoded))
require.Equal(t, "ok", decoded["safe"])
require.NoError(t, resp.Body.Close())
attempts := sink.snapshot()
require.Len(t, attempts, 1)
require.Equal(t, attemptStatusCompleted, attempts[0].Status)
require.Empty(t, attempts[0].Error)
}
func TestRecordingTransport_CloseAfterDecoderConsumesUnknownLengthJSONWithTrailingDocumentMarksFailed(t *testing.T) {
t.Parallel()
ctx, sink := newTestSinkContext(t)
client := &http.Client{
Transport: &RecordingTransport{
Base: roundTripFunc(func(req *http.Request) (*http.Response, error) {
return &http.Response{ //nolint:exhaustruct // Test response exercises unknown-length close semantics.
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: &scriptedReadCloser{chunks: [][]byte{[]byte("{\"token\":\"response-secret\",\"safe\":\"ok\"}{\"token\":\"second\"}")}},
ContentLength: -1,
}, nil
}),
},
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://example.invalid", nil)
require.NoError(t, err)
resp, err := client.Do(req)
require.NoError(t, err)
var decoded map[string]string
require.NoError(t, json.NewDecoder(resp.Body).Decode(&decoded))
require.Equal(t, "ok", decoded["safe"])
require.NoError(t, resp.Body.Close())
attempts := sink.snapshot()
require.Len(t, attempts, 1)
require.Equal(t, attemptStatusFailed, attempts[0].Status)
require.Equal(t, io.ErrUnexpectedEOF.Error(), attempts[0].Error)
}
func TestRecordingTransport_CloseAfterDecoderConsumesUnknownLengthNDJSONMarksFailed(t *testing.T) {
t.Parallel()
ctx, sink := newTestSinkContext(t)
client := &http.Client{
Transport: &RecordingTransport{
Base: roundTripFunc(func(req *http.Request) (*http.Response, error) {
return &http.Response{ //nolint:exhaustruct // Test response exercises unknown-length close semantics.
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/x-ndjson"}},
Body: &scriptedReadCloser{chunks: [][]byte{[]byte("{\"token\":\"response-secret\",\"safe\":\"ok\"}\n{\"token\":\"second\"}\n")}},
ContentLength: -1,
}, nil
}),
},
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://example.invalid", nil)
require.NoError(t, err)
resp, err := client.Do(req)
require.NoError(t, err)
var decoded map[string]string
require.NoError(t, json.NewDecoder(resp.Body).Decode(&decoded))
require.Equal(t, "ok", decoded["safe"])
require.NoError(t, resp.Body.Close())
attempts := sink.snapshot()
require.Len(t, attempts, 1)
require.Equal(t, attemptStatusFailed, attempts[0].Status)
require.Equal(t, io.ErrUnexpectedEOF.Error(), attempts[0].Error)
}
func TestRecordingTransport_CloseAfterDecoderDrainsUnknownLengthSucceeds(t *testing.T) {
t.Parallel()
ctx, sink := newTestSinkContext(t)
client := &http.Client{
Transport: &RecordingTransport{
Base: roundTripFunc(func(req *http.Request) (*http.Response, error) {
return &http.Response{ //nolint:exhaustruct // Test response exercises unknown-length close semantics.
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: &scriptedReadCloser{chunks: [][]byte{[]byte(`{"token":"response-secret","safe":"ok"}`)}},
ContentLength: -1,
}, nil
}),
},
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://example.invalid", nil)
require.NoError(t, err)
resp, err := client.Do(req)
require.NoError(t, err)
var decoded map[string]string
require.NoError(t, json.NewDecoder(resp.Body).Decode(&decoded))
require.Equal(t, "ok", decoded["safe"])
_, err = io.Copy(io.Discard, resp.Body)
require.NoError(t, err)
require.NoError(t, resp.Body.Close())
attempts := sink.snapshot()
require.Len(t, attempts, 1)
require.Equal(t, attemptStatusCompleted, attempts[0].Status)
require.Empty(t, attempts[0].Error)
}
func TestRecordingTransport_CloseWithoutReadingHeadResponseSucceeds(t *testing.T) {
t.Parallel()
ctx, sink := newTestSinkContext(t)
client := &http.Client{
Transport: &RecordingTransport{
Base: roundTripFunc(func(req *http.Request) (*http.Response, error) {
return &http.Response{ //nolint:exhaustruct // Test response exercises no-body close semantics.
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: &scriptedReadCloser{chunks: [][]byte{[]byte(`{"ignored":true}`)}},
ContentLength: 13,
Request: req,
}, nil
}),
},
}
req, err := http.NewRequestWithContext(ctx, http.MethodHead, "http://example.invalid", nil)
require.NoError(t, err)
resp, err := client.Do(req)
require.NoError(t, err)
require.NoError(t, resp.Body.Close())
attempts := sink.snapshot()
require.Len(t, attempts, 1)
require.Equal(t, attemptStatusCompleted, attempts[0].Status)
require.Empty(t, attempts[0].Error)
}
func TestRecordingTransport_CloseWithoutReadingUnknownLengthMarksFailed(t *testing.T) {
t.Parallel()
ctx, sink := newTestSinkContext(t)
client := &http.Client{
Transport: &RecordingTransport{
Base: roundTripFunc(func(req *http.Request) (*http.Response, error) {
return &http.Response{ //nolint:exhaustruct // Test response exercises unknown-length close semantics.
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: &scriptedReadCloser{chunks: [][]byte{[]byte(`{"token":"response-secret","safe":"ok"}`)}},
ContentLength: -1,
}, nil
}),
},
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://example.invalid", nil)
require.NoError(t, err)
resp, err := client.Do(req)
require.NoError(t, err)
require.NoError(t, resp.Body.Close())
attempts := sink.snapshot()
require.Len(t, attempts, 1)
require.Equal(t, attemptStatusFailed, attempts[0].Status)
require.Equal(t, io.ErrUnexpectedEOF.Error(), attempts[0].Error)
}
func TestRecordingTransport_PrematureCloseUnknownLengthMarksFailed(t *testing.T) {
t.Parallel()
ctx, sink := newTestSinkContext(t)
client := &http.Client{
Transport: &RecordingTransport{
Base: roundTripFunc(func(req *http.Request) (*http.Response, error) {
return &http.Response{ //nolint:exhaustruct // Test response exercises unknown-length close semantics.
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: &scriptedReadCloser{chunks: [][]byte{[]byte(`{"token":"response-secret","safe":"ok"}`)}},
ContentLength: -1,
}, nil
}),
},
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://example.invalid", nil)
require.NoError(t, err)
resp, err := client.Do(req)
require.NoError(t, err)
buf := make([]byte, 5)
_, err = resp.Body.Read(buf)
require.NoError(t, err)
require.NoError(t, resp.Body.Close())
attempts := sink.snapshot()
require.Len(t, attempts, 1)
require.Equal(t, attemptStatusFailed, attempts[0].Status)
require.Equal(t, io.ErrUnexpectedEOF.Error(), attempts[0].Error)
}
func TestRecordingTransport_PrematureCloseMarksFailed(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
_, _ = rw.Write([]byte(`{"token":"response-secret","safe":"ok"}`))
}))
defer server.Close()
ctx, sink := newTestSinkContext(t)
client := &http.Client{Transport: &RecordingTransport{Base: server.Client().Transport}}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, server.URL, nil)
require.NoError(t, err)
resp, err := client.Do(req)
require.NoError(t, err)
buf := make([]byte, 5)
_, err = resp.Body.Read(buf)
require.NoError(t, err)
require.NoError(t, resp.Body.Close())
attempts := sink.snapshot()
require.Len(t, attempts, 1)
require.Equal(t, attemptStatusFailed, attempts[0].Status)
}
func TestRecordingTransport_TruncatesLargeResponses(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
_, _ = rw.Write([]byte(strings.Repeat("x", maxRecordedResponseBodyBytes+1024)))
}))
defer server.Close()
ctx, sink := newTestSinkContext(t)
client := &http.Client{Transport: &RecordingTransport{Base: server.Client().Transport}}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, server.URL, nil)
require.NoError(t, err)
resp, err := client.Do(req)
require.NoError(t, err)
_, err = io.ReadAll(resp.Body)
require.NoError(t, err)
require.NoError(t, resp.Body.Close())
attempts := sink.snapshot()
require.Len(t, attempts, 1)
require.Equal(t, []byte("[TRUNCATED]"), attempts[0].ResponseBody)
}
func TestRecordingTransport_TransportError(t *testing.T) {
t.Parallel()
ctx, sink := newTestSinkContext(t)
client := &http.Client{
Transport: &RecordingTransport{
Base: roundTripFunc(func(req *http.Request) (*http.Response, error) {
return nil, xerrors.New("transport exploded")
}),
},
}
req, err := http.NewRequestWithContext(
ctx,
http.MethodPost,
"http://example.invalid",
strings.NewReader(`{"password":"secret","safe":"ok"}`),
)
require.NoError(t, err)
req.Header.Set("Authorization", "Bearer top-secret")
resp, err := client.Do(req)
if resp != nil {
defer resp.Body.Close()
}
require.Nil(t, resp)
require.EqualError(t, err, "Post \"http://example.invalid\": transport exploded")
attempts := sink.snapshot()
require.Len(t, attempts, 1)
require.Equal(t, 1, attempts[0].Number)
require.Equal(t, RedactedValue, attempts[0].RequestHeaders["Authorization"])
require.JSONEq(t, `{"password":"[REDACTED]","safe":"ok"}`, string(attempts[0].RequestBody))
require.Zero(t, attempts[0].ResponseStatus)
require.Equal(t, "transport exploded", attempts[0].Error)
require.GreaterOrEqual(t, attempts[0].DurationMs, int64(0))
}
func TestRecordingTransport_NilBase(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
_, _ = rw.Write([]byte("ok"))
}))
defer server.Close()
client := &http.Client{Transport: &RecordingTransport{}}
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, server.URL, nil)
require.NoError(t, err)
resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "ok", string(body))
}