feat(coderd/x/chatd/chatdebug): add recorder, transport, and redaction
Change-Id: Ibbc67a85ba78201c0778ccb5c8675b15e90b1cdf Signed-off-by: Thomas Kosiewski <tk@coder.com>
This commit is contained in:
@@ -387,9 +387,9 @@ func TestDebugModel_StreamCompletedAfterFinish(t *testing.T) {
|
||||
// The handle starts with zero status; simulate what the wrapper does
|
||||
// when consumer breaks after finish.
|
||||
h.finish(ctx2, StatusCompleted, nil, nil, nil, nil)
|
||||
h.mu.Lock()
|
||||
// finish uses sync.Once, so after it returns the status is safely
|
||||
// set and readable in this single-goroutine test.
|
||||
require.Equal(t, StatusCompleted, h.status)
|
||||
h.mu.Unlock()
|
||||
}
|
||||
|
||||
// TestDebugModel_StreamInterruptedBeforeFinish verifies that when a consumer
|
||||
|
||||
@@ -0,0 +1,277 @@
|
||||
package chatdebug
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/google/uuid"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
)
|
||||
|
||||
// RecorderOptions identifies the chat/model context for debug recording.
|
||||
type RecorderOptions struct {
|
||||
ChatID uuid.UUID
|
||||
OwnerID uuid.UUID
|
||||
Provider string
|
||||
Model string
|
||||
}
|
||||
|
||||
// WrapModel returns model unchanged when debug recording is disabled, or a
|
||||
// debug wrapper when a service is available.
|
||||
func WrapModel(
|
||||
model fantasy.LanguageModel,
|
||||
svc *Service,
|
||||
opts RecorderOptions,
|
||||
) fantasy.LanguageModel {
|
||||
if model == nil {
|
||||
panic("chatdebug: nil LanguageModel")
|
||||
}
|
||||
if svc == nil {
|
||||
return model
|
||||
}
|
||||
return &debugModel{inner: model, svc: svc, opts: opts}
|
||||
}
|
||||
|
||||
type attemptSink struct {
|
||||
mu sync.Mutex
|
||||
attempts []Attempt
|
||||
attemptCounter atomic.Int32
|
||||
}
|
||||
|
||||
func (s *attemptSink) nextAttemptNumber() int {
|
||||
if s == nil {
|
||||
panic("chatdebug: nil attemptSink")
|
||||
}
|
||||
return int(s.attemptCounter.Add(1))
|
||||
}
|
||||
|
||||
func (s *attemptSink) record(a Attempt) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.attempts = append(s.attempts, a)
|
||||
}
|
||||
|
||||
func (s *attemptSink) snapshot() []Attempt {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
attempts := make([]Attempt, len(s.attempts))
|
||||
copy(attempts, s.attempts)
|
||||
return attempts
|
||||
}
|
||||
|
||||
type attemptSinkKey struct{}
|
||||
|
||||
func withAttemptSink(ctx context.Context, sink *attemptSink) context.Context {
|
||||
if sink == nil {
|
||||
panic("chatdebug: nil attemptSink")
|
||||
}
|
||||
return context.WithValue(ctx, attemptSinkKey{}, sink)
|
||||
}
|
||||
|
||||
func attemptSinkFromContext(ctx context.Context) *attemptSink {
|
||||
sink, _ := ctx.Value(attemptSinkKey{}).(*attemptSink)
|
||||
return sink
|
||||
}
|
||||
|
||||
var stepCounters sync.Map // map[uuid.UUID]*atomic.Int32
|
||||
|
||||
func nextStepNumber(runID uuid.UUID) int32 {
|
||||
val, _ := stepCounters.LoadOrStore(runID, &atomic.Int32{})
|
||||
counter, ok := val.(*atomic.Int32)
|
||||
if !ok {
|
||||
panic("chatdebug: invalid step counter type")
|
||||
}
|
||||
return counter.Add(1)
|
||||
}
|
||||
|
||||
// CleanupStepCounter removes per-run step counter and reference count
|
||||
// state. This is used by tests and later stacked branches that have a
|
||||
// real run lifecycle.
|
||||
func CleanupStepCounter(runID uuid.UUID) {
|
||||
stepCounters.Delete(runID)
|
||||
runRefCounts.Delete(runID)
|
||||
}
|
||||
|
||||
const stepFinalizeTimeout = 5 * time.Second
|
||||
|
||||
func stepFinalizeContext(ctx context.Context) (context.Context, context.CancelFunc) {
|
||||
if ctx == nil {
|
||||
panic("chatdebug: nil context")
|
||||
}
|
||||
return context.WithTimeout(context.WithoutCancel(ctx), stepFinalizeTimeout)
|
||||
}
|
||||
|
||||
func syncStepCounter(runID uuid.UUID, stepNumber int32) {
|
||||
val, _ := stepCounters.LoadOrStore(runID, &atomic.Int32{})
|
||||
counter, ok := val.(*atomic.Int32)
|
||||
if !ok {
|
||||
panic("chatdebug: invalid step counter type")
|
||||
}
|
||||
for {
|
||||
current := counter.Load()
|
||||
if current >= stepNumber {
|
||||
return
|
||||
}
|
||||
if counter.CompareAndSwap(current, stepNumber) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type stepHandle struct {
|
||||
stepCtx *StepContext
|
||||
sink *attemptSink
|
||||
svc *Service
|
||||
opts RecorderOptions
|
||||
once sync.Once
|
||||
mu sync.Mutex
|
||||
status Status
|
||||
response any
|
||||
usage any
|
||||
err any
|
||||
metadata any
|
||||
}
|
||||
|
||||
// beginStep validates preconditions, creates a debug step, and returns a
|
||||
// handle plus an enriched context carrying StepContext and attemptSink.
|
||||
// Returns (nil, original ctx) when debug recording should be skipped.
|
||||
func beginStep(
|
||||
ctx context.Context,
|
||||
svc *Service,
|
||||
opts RecorderOptions,
|
||||
op Operation,
|
||||
normalizedReq any,
|
||||
) (*stepHandle, context.Context) {
|
||||
if svc == nil {
|
||||
return nil, ctx
|
||||
}
|
||||
|
||||
rc, ok := RunFromContext(ctx)
|
||||
if !ok || rc.RunID == uuid.Nil {
|
||||
return nil, ctx
|
||||
}
|
||||
|
||||
chatID := opts.ChatID
|
||||
if chatID == uuid.Nil {
|
||||
chatID = rc.ChatID
|
||||
}
|
||||
if !svc.IsEnabled(ctx, chatID, opts.OwnerID) {
|
||||
return nil, ctx
|
||||
}
|
||||
|
||||
holder, reuseStep := reuseHolderFromContext(ctx)
|
||||
if reuseStep {
|
||||
holder.mu.Lock()
|
||||
defer holder.mu.Unlock()
|
||||
// Only reuse the cached handle if it belongs to the same run.
|
||||
// A different RunContext means a new logical run, so we must
|
||||
// create a fresh step to avoid cross-run attribution.
|
||||
if holder.handle != nil && holder.handle.stepCtx.RunID == rc.RunID {
|
||||
enriched := ContextWithStep(ctx, holder.handle.stepCtx)
|
||||
enriched = withAttemptSink(enriched, holder.handle.sink)
|
||||
return holder.handle, enriched
|
||||
}
|
||||
}
|
||||
|
||||
stepNum := nextStepNumber(rc.RunID)
|
||||
step, err := svc.CreateStep(ctx, CreateStepParams{
|
||||
RunID: rc.RunID,
|
||||
ChatID: chatID,
|
||||
StepNumber: stepNum,
|
||||
Operation: op,
|
||||
Status: StatusInProgress,
|
||||
HistoryTipMessageID: rc.HistoryTipMessageID,
|
||||
NormalizedRequest: normalizedReq,
|
||||
})
|
||||
if err != nil {
|
||||
svc.log.Warn(ctx, "failed to create chat debug step",
|
||||
slog.Error(err),
|
||||
slog.F("chat_id", chatID),
|
||||
slog.F("run_id", rc.RunID),
|
||||
slog.F("operation", op),
|
||||
)
|
||||
return nil, ctx
|
||||
}
|
||||
|
||||
syncStepCounter(rc.RunID, step.StepNumber)
|
||||
actualStepNumber := step.StepNumber
|
||||
if actualStepNumber == 0 {
|
||||
actualStepNumber = stepNum
|
||||
}
|
||||
|
||||
sc := &StepContext{
|
||||
StepID: step.ID,
|
||||
RunID: rc.RunID,
|
||||
ChatID: chatID,
|
||||
StepNumber: actualStepNumber,
|
||||
Operation: op,
|
||||
HistoryTipMessageID: rc.HistoryTipMessageID,
|
||||
}
|
||||
handle := &stepHandle{stepCtx: sc, sink: &attemptSink{}, svc: svc, opts: opts}
|
||||
enriched := ContextWithStep(ctx, handle.stepCtx)
|
||||
enriched = withAttemptSink(enriched, handle.sink)
|
||||
if reuseStep {
|
||||
holder.handle = handle
|
||||
}
|
||||
|
||||
return handle, enriched
|
||||
}
|
||||
|
||||
// finish updates the debug step with final status and data.
|
||||
// sync.Once prevents data races when concurrent callers (e.g.
|
||||
// retried stream wrappers sharing a reuse handle) both attempt
|
||||
// to finalize the same step. Only the first finish call takes
|
||||
// effect.
|
||||
func (h *stepHandle) finish(
|
||||
ctx context.Context,
|
||||
status Status,
|
||||
response any,
|
||||
usage any,
|
||||
errPayload any,
|
||||
metadata any,
|
||||
) {
|
||||
if h == nil || h.stepCtx == nil {
|
||||
return
|
||||
}
|
||||
|
||||
h.once.Do(func() {
|
||||
h.mu.Lock()
|
||||
h.status = status
|
||||
h.response = response
|
||||
h.usage = usage
|
||||
h.err = errPayload
|
||||
h.metadata = metadata
|
||||
h.mu.Unlock()
|
||||
if h.svc == nil {
|
||||
return
|
||||
}
|
||||
|
||||
updateCtx, cancel := stepFinalizeContext(ctx)
|
||||
defer cancel()
|
||||
|
||||
if _, updateErr := h.svc.UpdateStep(updateCtx, UpdateStepParams{
|
||||
ID: h.stepCtx.StepID,
|
||||
ChatID: h.stepCtx.ChatID,
|
||||
Status: status,
|
||||
NormalizedResponse: response,
|
||||
Usage: usage,
|
||||
Attempts: h.sink.snapshot(),
|
||||
Error: errPayload,
|
||||
Metadata: metadata,
|
||||
FinishedAt: time.Now(),
|
||||
}); updateErr != nil {
|
||||
h.svc.log.Warn(updateCtx, "failed to finalize chat debug step",
|
||||
slog.Error(updateErr),
|
||||
slog.F("step_id", h.stepCtx.StepID),
|
||||
slog.F("chat_id", h.stepCtx.ChatID),
|
||||
slog.F("status", status),
|
||||
)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,174 @@
|
||||
package chatdebug //nolint:testpackage // Uses unexported recorder helpers.
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sort"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chattest"
|
||||
)
|
||||
|
||||
func TestAttemptSink_ThreadSafe(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const n = 256
|
||||
|
||||
sink := &attemptSink{}
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(n)
|
||||
|
||||
for i := range n {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
sink.record(Attempt{Number: i + 1, ResponseStatus: 200 + i})
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
attempts := sink.snapshot()
|
||||
require.Len(t, attempts, n)
|
||||
|
||||
numbers := make([]int, 0, n)
|
||||
statuses := make([]int, 0, n)
|
||||
for _, attempt := range attempts {
|
||||
numbers = append(numbers, attempt.Number)
|
||||
statuses = append(statuses, attempt.ResponseStatus)
|
||||
}
|
||||
sort.Ints(numbers)
|
||||
sort.Ints(statuses)
|
||||
|
||||
for i := range n {
|
||||
require.Equal(t, i+1, numbers[i])
|
||||
require.Equal(t, 200+i, statuses[i])
|
||||
}
|
||||
}
|
||||
|
||||
func TestAttemptSinkContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
require.Nil(t, attemptSinkFromContext(ctx))
|
||||
|
||||
sink := &attemptSink{}
|
||||
ctx = withAttemptSink(ctx, sink)
|
||||
require.Same(t, sink, attemptSinkFromContext(ctx))
|
||||
}
|
||||
|
||||
func TestWrapModel_NilModel(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
require.Panics(t, func() {
|
||||
WrapModel(nil, &Service{}, RecorderOptions{})
|
||||
})
|
||||
}
|
||||
|
||||
func TestWrapModel_NilService(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
model := &chattest.FakeModel{ProviderName: "provider", ModelName: "model"}
|
||||
wrapped := WrapModel(model, nil, RecorderOptions{})
|
||||
require.Same(t, model, wrapped)
|
||||
}
|
||||
|
||||
func TestNextStepNumber_Concurrent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const n = 256
|
||||
|
||||
runID := uuid.New()
|
||||
results := make([]int, n)
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(n)
|
||||
|
||||
for i := range n {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
results[i] = int(nextStepNumber(runID))
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
sort.Ints(results)
|
||||
for i := range n {
|
||||
require.Equal(t, i+1, results[i])
|
||||
}
|
||||
}
|
||||
|
||||
func TestStepFinalizeContext_StripsCancellation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
baseCtx, cancelBase := context.WithCancel(context.Background())
|
||||
cancelBase()
|
||||
require.ErrorIs(t, baseCtx.Err(), context.Canceled)
|
||||
|
||||
finalizeCtx, cancelFinalize := stepFinalizeContext(baseCtx)
|
||||
defer cancelFinalize()
|
||||
|
||||
require.NoError(t, finalizeCtx.Err())
|
||||
_, hasDeadline := finalizeCtx.Deadline()
|
||||
require.True(t, hasDeadline)
|
||||
}
|
||||
|
||||
func TestSyncStepCounter_AdvancesCounter(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
runID := uuid.New()
|
||||
t.Cleanup(func() { CleanupStepCounter(runID) })
|
||||
|
||||
syncStepCounter(runID, 7)
|
||||
require.Equal(t, int32(8), nextStepNumber(runID))
|
||||
}
|
||||
|
||||
func TestStepHandleFinish_NilHandle(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var handle *stepHandle
|
||||
handle.finish(context.Background(), StatusCompleted, nil, nil, nil, nil)
|
||||
}
|
||||
|
||||
func TestBeginStep_NilService(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
handle, enriched := beginStep(ctx, nil, RecorderOptions{}, OperationGenerate, nil)
|
||||
require.Nil(t, handle)
|
||||
require.Nil(t, attemptSinkFromContext(enriched))
|
||||
_, ok := StepFromContext(enriched)
|
||||
require.False(t, ok)
|
||||
}
|
||||
|
||||
func TestBeginStep_FallsBackToRunChatID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
runID := uuid.New()
|
||||
runChatID := uuid.New()
|
||||
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: runChatID})
|
||||
|
||||
handle, enriched := beginStep(ctx, &Service{}, RecorderOptions{}, OperationGenerate, nil)
|
||||
require.NotNil(t, handle)
|
||||
require.Equal(t, runChatID, handle.stepCtx.ChatID)
|
||||
|
||||
stepCtx, ok := StepFromContext(enriched)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, runChatID, stepCtx.ChatID)
|
||||
}
|
||||
|
||||
func TestWrapModel_ReturnsDebugModel(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
model := &chattest.FakeModel{ProviderName: "provider", ModelName: "model"}
|
||||
wrapped := WrapModel(model, &Service{}, RecorderOptions{})
|
||||
|
||||
require.NotSame(t, model, wrapped)
|
||||
require.IsType(t, &debugModel{}, wrapped)
|
||||
require.Implements(t, (*fantasy.LanguageModel)(nil), wrapped)
|
||||
require.Equal(t, model.Provider(), wrapped.Provider())
|
||||
require.Equal(t, model.Model(), wrapped.Model())
|
||||
}
|
||||
@@ -0,0 +1,227 @@
|
||||
package chatdebug
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
// RedactedValue replaces sensitive values in debug payloads.
|
||||
const RedactedValue = "[REDACTED]"
|
||||
|
||||
var sensitiveHeaderNames = map[string]struct{}{
|
||||
"authorization": {},
|
||||
"x-api-key": {},
|
||||
"api-key": {},
|
||||
"proxy-authorization": {},
|
||||
"cookie": {},
|
||||
"set-cookie": {},
|
||||
}
|
||||
|
||||
// sensitiveJSONKeyFragments triggers redaction for JSON keys containing
|
||||
// these substrings. Notably, "token" is intentionally absent because it
|
||||
// false-positively redacts LLM token-usage fields (input_tokens,
|
||||
// output_tokens, prompt_tokens, completion_tokens, reasoning_tokens,
|
||||
// cache_creation_input_tokens, cache_read_input_tokens, etc.). Auth-
|
||||
// related token fields are caught by the exact-match set below.
|
||||
var sensitiveJSONKeyFragments = []string{
|
||||
"secret",
|
||||
"password",
|
||||
"authorization",
|
||||
"credential",
|
||||
}
|
||||
|
||||
// sensitiveJSONKeyExact matches auth-related token/key field names
|
||||
// without false-positiving on LLM usage counters. Includes both
|
||||
// snake_case originals and their camelCase-lowered equivalents
|
||||
// (e.g. "accessToken" → "accesstoken") so that providers using
|
||||
// either convention are caught.
|
||||
var sensitiveJSONKeyExact = map[string]struct{}{
|
||||
"token": {},
|
||||
"access_token": {},
|
||||
"accesstoken": {},
|
||||
"refresh_token": {},
|
||||
"refreshtoken": {},
|
||||
"id_token": {},
|
||||
"idtoken": {},
|
||||
"api_token": {},
|
||||
"apitoken": {},
|
||||
"api_key": {},
|
||||
"apikey": {},
|
||||
"api-key": {},
|
||||
"x-api-key": {},
|
||||
"auth_token": {},
|
||||
"authtoken": {},
|
||||
"bearer_token": {},
|
||||
"bearertoken": {},
|
||||
"session_token": {},
|
||||
"sessiontoken": {},
|
||||
"security_token": {},
|
||||
"securitytoken": {},
|
||||
"private_key": {},
|
||||
"privatekey": {},
|
||||
"signing_key": {},
|
||||
"signingkey": {},
|
||||
"secret_key": {},
|
||||
"secretkey": {},
|
||||
}
|
||||
|
||||
// RedactHeaders returns a flattened copy of h with sensitive values redacted.
|
||||
func RedactHeaders(h http.Header) map[string]string {
|
||||
if h == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
redacted := make(map[string]string, len(h))
|
||||
for name, values := range h {
|
||||
if isSensitiveName(name) {
|
||||
redacted[name] = RedactedValue
|
||||
continue
|
||||
}
|
||||
redacted[name] = strings.Join(values, ", ")
|
||||
}
|
||||
return redacted
|
||||
}
|
||||
|
||||
// RedactJSONSecrets redacts sensitive JSON values by key name. When
|
||||
// the input is not valid JSON (truncated body, HTML error page, etc.)
|
||||
// the raw bytes are replaced entirely with a diagnostic placeholder
|
||||
// to avoid leaking credentials from malformed payloads.
|
||||
func RedactJSONSecrets(data []byte) []byte {
|
||||
if len(data) == 0 {
|
||||
return data
|
||||
}
|
||||
|
||||
decoder := json.NewDecoder(bytes.NewReader(data))
|
||||
decoder.UseNumber()
|
||||
|
||||
var value any
|
||||
if err := decoder.Decode(&value); err != nil {
|
||||
// Cannot parse: replace entirely to prevent credential leaks
|
||||
// from non-JSON error responses (HTML pages, partial bodies).
|
||||
return []byte(`{"error":"chatdebug: body is not valid JSON, redacted for safety"}`)
|
||||
}
|
||||
if err := consumeJSONEOF(decoder); err != nil {
|
||||
return []byte(`{"error":"chatdebug: body contains extra JSON values, redacted for safety"}`)
|
||||
}
|
||||
|
||||
redacted, changed := redactJSONValue(value)
|
||||
if !changed {
|
||||
return data
|
||||
}
|
||||
|
||||
encoded, err := json.Marshal(redacted)
|
||||
if err != nil {
|
||||
return data
|
||||
}
|
||||
return encoded
|
||||
}
|
||||
|
||||
func consumeJSONEOF(decoder *json.Decoder) error {
|
||||
var extra any
|
||||
err := decoder.Decode(&extra)
|
||||
if errors.Is(err, io.EOF) {
|
||||
return nil
|
||||
}
|
||||
if err == nil {
|
||||
return xerrors.New("chatdebug: extra JSON values")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
var safeRateLimitHeaderNames = map[string]struct{}{
|
||||
"anthropic-ratelimit-requests-limit": {},
|
||||
"anthropic-ratelimit-requests-remaining": {},
|
||||
"anthropic-ratelimit-requests-reset": {},
|
||||
"anthropic-ratelimit-tokens-limit": {},
|
||||
"anthropic-ratelimit-tokens-remaining": {},
|
||||
"anthropic-ratelimit-tokens-reset": {},
|
||||
"x-ratelimit-limit-requests": {},
|
||||
"x-ratelimit-limit-tokens": {},
|
||||
"x-ratelimit-remaining-requests": {},
|
||||
"x-ratelimit-remaining-tokens": {},
|
||||
"x-ratelimit-reset-requests": {},
|
||||
"x-ratelimit-reset-tokens": {},
|
||||
}
|
||||
|
||||
// isSensitiveName reports whether a name (header or query parameter)
|
||||
// looks like a credential-carrying key. Exact-match headers are
|
||||
// checked first, then the rate-limit allowlist, then substring
|
||||
// patterns for API keys and auth tokens.
|
||||
func isSensitiveName(name string) bool {
|
||||
lowerName := strings.ToLower(name)
|
||||
if _, ok := sensitiveHeaderNames[lowerName]; ok {
|
||||
return true
|
||||
}
|
||||
if _, ok := safeRateLimitHeaderNames[lowerName]; ok {
|
||||
return false
|
||||
}
|
||||
if strings.Contains(lowerName, "api-key") ||
|
||||
strings.Contains(lowerName, "api_key") ||
|
||||
strings.Contains(lowerName, "apikey") {
|
||||
return true
|
||||
}
|
||||
// Catch any header containing "token" (e.g. Token, X-Token,
|
||||
// X-Auth-Token). Safe rate-limit headers like
|
||||
// x-ratelimit-remaining-tokens are already allowlisted above
|
||||
// and will not reach this point.
|
||||
if strings.Contains(lowerName, "token") {
|
||||
return true
|
||||
}
|
||||
return strings.Contains(lowerName, "secret") ||
|
||||
strings.Contains(lowerName, "bearer")
|
||||
}
|
||||
|
||||
func isSensitiveJSONKey(key string) bool {
|
||||
lowerKey := strings.ToLower(key)
|
||||
if _, ok := sensitiveJSONKeyExact[lowerKey]; ok {
|
||||
return true
|
||||
}
|
||||
for _, fragment := range sensitiveJSONKeyFragments {
|
||||
if strings.Contains(lowerKey, fragment) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func redactJSONValue(value any) (any, bool) {
|
||||
switch typed := value.(type) {
|
||||
case map[string]any:
|
||||
changed := false
|
||||
for key, child := range typed {
|
||||
if isSensitiveJSONKey(key) {
|
||||
if current, ok := child.(string); ok && current == RedactedValue {
|
||||
continue
|
||||
}
|
||||
typed[key] = RedactedValue
|
||||
changed = true
|
||||
continue
|
||||
}
|
||||
|
||||
redactedChild, childChanged := redactJSONValue(child)
|
||||
if childChanged {
|
||||
typed[key] = redactedChild
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
return typed, changed
|
||||
case []any:
|
||||
changed := false
|
||||
for i, child := range typed {
|
||||
redactedChild, childChanged := redactJSONValue(child)
|
||||
if childChanged {
|
||||
typed[i] = redactedChild
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
return typed, changed
|
||||
default:
|
||||
return value, false
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,277 @@
|
||||
package chatdebug_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatdebug"
|
||||
)
|
||||
|
||||
func TestRedactHeaders(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("nil input", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
require.Nil(t, chatdebug.RedactHeaders(nil))
|
||||
})
|
||||
|
||||
t.Run("empty header", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
redacted := chatdebug.RedactHeaders(http.Header{})
|
||||
require.NotNil(t, redacted)
|
||||
require.Empty(t, redacted)
|
||||
})
|
||||
|
||||
t.Run("authorization redacted and others preserved", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
headers := http.Header{
|
||||
"Authorization": {"Bearer secret-token"},
|
||||
"Accept": {"application/json"},
|
||||
}
|
||||
|
||||
redacted := chatdebug.RedactHeaders(headers)
|
||||
require.Equal(t, chatdebug.RedactedValue, redacted["Authorization"])
|
||||
require.Equal(t, "application/json", redacted["Accept"])
|
||||
})
|
||||
|
||||
t.Run("multi-value headers are flattened", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
headers := http.Header{
|
||||
"Accept": {"application/json", "text/plain"},
|
||||
}
|
||||
|
||||
redacted := chatdebug.RedactHeaders(headers)
|
||||
require.Equal(t, "application/json, text/plain", redacted["Accept"])
|
||||
})
|
||||
|
||||
t.Run("header name matching is case insensitive", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
lowerAuthorization := "authorization"
|
||||
upperAuthorization := "AUTHORIZATION"
|
||||
headers := http.Header{
|
||||
lowerAuthorization: {"lower"},
|
||||
upperAuthorization: {"upper"},
|
||||
}
|
||||
|
||||
redacted := chatdebug.RedactHeaders(headers)
|
||||
require.Equal(t, chatdebug.RedactedValue, redacted[lowerAuthorization])
|
||||
require.Equal(t, chatdebug.RedactedValue, redacted[upperAuthorization])
|
||||
})
|
||||
|
||||
t.Run("token and secret substrings are redacted", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
traceHeader := "X-Trace-ID"
|
||||
headers := http.Header{
|
||||
"X-Auth-Token": {"abc"},
|
||||
"X-Custom-Secret": {"def"},
|
||||
"X-Bearer": {"ghi"},
|
||||
traceHeader: {"trace"},
|
||||
}
|
||||
|
||||
redacted := chatdebug.RedactHeaders(headers)
|
||||
require.Equal(t, chatdebug.RedactedValue, redacted["X-Auth-Token"])
|
||||
require.Equal(t, chatdebug.RedactedValue, redacted["X-Custom-Secret"])
|
||||
require.Equal(t, chatdebug.RedactedValue, redacted["X-Bearer"])
|
||||
require.Equal(t, "trace", redacted[traceHeader])
|
||||
})
|
||||
|
||||
t.Run("known safe rate limit headers containing token are not redacted", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
headers := http.Header{
|
||||
"Anthropic-Ratelimit-Tokens-Limit": {"1000000"},
|
||||
"Anthropic-Ratelimit-Tokens-Remaining": {"999000"},
|
||||
"Anthropic-Ratelimit-Tokens-Reset": {"2026-03-31T08:55:26Z"},
|
||||
"X-RateLimit-Limit-Tokens": {"120000"},
|
||||
"X-RateLimit-Remaining-Tokens": {"119500"},
|
||||
"X-RateLimit-Reset-Tokens": {"12ms"},
|
||||
}
|
||||
|
||||
redacted := chatdebug.RedactHeaders(headers)
|
||||
require.Equal(t, "1000000", redacted["Anthropic-Ratelimit-Tokens-Limit"])
|
||||
require.Equal(t, "999000", redacted["Anthropic-Ratelimit-Tokens-Remaining"])
|
||||
require.Equal(t, "2026-03-31T08:55:26Z", redacted["Anthropic-Ratelimit-Tokens-Reset"])
|
||||
require.Equal(t, "120000", redacted["X-RateLimit-Limit-Tokens"])
|
||||
require.Equal(t, "119500", redacted["X-RateLimit-Remaining-Tokens"])
|
||||
require.Equal(t, "12ms", redacted["X-RateLimit-Reset-Tokens"])
|
||||
})
|
||||
|
||||
t.Run("non-standard headers with api-key pattern are redacted", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
headers := http.Header{
|
||||
"X-Custom-Api-Key": {"secret-key"},
|
||||
"X-Custom-Secret": {"secret-val"},
|
||||
"X-Custom-Session-Token": {"session-id"},
|
||||
}
|
||||
|
||||
redacted := chatdebug.RedactHeaders(headers)
|
||||
require.Equal(t, chatdebug.RedactedValue, redacted["X-Custom-Api-Key"])
|
||||
require.Equal(t, chatdebug.RedactedValue, redacted["X-Custom-Secret"])
|
||||
require.Equal(t, chatdebug.RedactedValue, redacted["X-Custom-Session-Token"])
|
||||
})
|
||||
|
||||
t.Run("rate limit headers with token in name are preserved", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Rate-limit headers containing "token" should NOT be redacted
|
||||
// because they carry usage/limit counts, not credentials.
|
||||
headers := http.Header{
|
||||
"X-Ratelimit-Limit-Tokens": {"1000000"},
|
||||
"X-Ratelimit-Remaining-Tokens": {"999000"},
|
||||
}
|
||||
|
||||
redacted := chatdebug.RedactHeaders(headers)
|
||||
require.Equal(t, "1000000", redacted["X-Ratelimit-Limit-Tokens"])
|
||||
require.Equal(t, "999000", redacted["X-Ratelimit-Remaining-Tokens"])
|
||||
})
|
||||
|
||||
t.Run("original header is not modified", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
headers := http.Header{
|
||||
"Authorization": {"Bearer keep-me"},
|
||||
"X-Test": {"value"},
|
||||
}
|
||||
|
||||
redacted := chatdebug.RedactHeaders(headers)
|
||||
redacted["X-Test"] = "changed"
|
||||
|
||||
require.Equal(t, []string{"Bearer keep-me"}, headers["Authorization"])
|
||||
require.Equal(t, []string{"value"}, headers["X-Test"])
|
||||
require.Equal(t, chatdebug.RedactedValue, redacted["Authorization"])
|
||||
})
|
||||
t.Run("api-key header variants are redacted", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
headers := http.Header{
|
||||
"X-Goog-Api-Key": {"secret"},
|
||||
"X-Api_Key": {"other-secret"},
|
||||
"X-Safe": {"ok"},
|
||||
}
|
||||
|
||||
redacted := chatdebug.RedactHeaders(headers)
|
||||
require.Equal(t, chatdebug.RedactedValue, redacted["X-Goog-Api-Key"])
|
||||
require.Equal(t, chatdebug.RedactedValue, redacted["X-Api_Key"])
|
||||
require.Equal(t, "ok", redacted["X-Safe"])
|
||||
})
|
||||
|
||||
t.Run("plain token headers are redacted", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Headers like "Token" or "X-Token" should be redacted
|
||||
// even without auth/session/access qualifiers.
|
||||
headers := http.Header{
|
||||
"Token": {"my-secret-token"},
|
||||
"X-Token": {"another-secret"},
|
||||
"X-Safe": {"ok"},
|
||||
}
|
||||
|
||||
redacted := chatdebug.RedactHeaders(headers)
|
||||
require.Equal(t, chatdebug.RedactedValue, redacted["Token"])
|
||||
require.Equal(t, chatdebug.RedactedValue, redacted["X-Token"])
|
||||
require.Equal(t, "ok", redacted["X-Safe"])
|
||||
})
|
||||
}
|
||||
|
||||
func TestRedactJSONSecrets(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("redacts top level secret fields", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
input := []byte(`{"api_key":"abc","token":"def","password":"ghi","safe":"ok"}`)
|
||||
redacted := chatdebug.RedactJSONSecrets(input)
|
||||
require.JSONEq(t, `{"api_key":"[REDACTED]","token":"[REDACTED]","password":"[REDACTED]","safe":"ok"}`, string(redacted))
|
||||
})
|
||||
|
||||
t.Run("redacts security_token exact key", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
input := []byte(`{"security_token":"s3cret","securityToken":"tok","safe":"ok"}`)
|
||||
redacted := chatdebug.RedactJSONSecrets(input)
|
||||
require.JSONEq(t, `{"security_token":"[REDACTED]","securityToken":"[REDACTED]","safe":"ok"}`, string(redacted))
|
||||
})
|
||||
|
||||
t.Run("preserves LLM token usage fields", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
input := []byte(`{"input_tokens":100,"output_tokens":50,"prompt_tokens":80,"completion_tokens":20,"reasoning_tokens":10,"cache_creation_input_tokens":5,"cache_read_input_tokens":3,"total_tokens":150,"max_tokens":4096,"max_output_tokens":2048}`)
|
||||
redacted := chatdebug.RedactJSONSecrets(input)
|
||||
// All usage/limit fields should be preserved, not redacted.
|
||||
require.Equal(t, input, redacted)
|
||||
})
|
||||
|
||||
t.Run("redacts nested objects", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
input := []byte(`{"outer":{"nested_secret":"abc","safe":1},"keep":true}`)
|
||||
redacted := chatdebug.RedactJSONSecrets(input)
|
||||
require.JSONEq(t, `{"outer":{"nested_secret":"[REDACTED]","safe":1},"keep":true}`, string(redacted))
|
||||
})
|
||||
|
||||
t.Run("redacts arrays of objects", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
input := []byte(`[{"token":"abc"},{"value":1,"credentials":{"access_key":"def"}}]`)
|
||||
redacted := chatdebug.RedactJSONSecrets(input)
|
||||
require.JSONEq(t, `[{"token":"[REDACTED]"},{"value":1,"credentials":"[REDACTED]"}]`, string(redacted))
|
||||
})
|
||||
|
||||
t.Run("concatenated JSON is replaced with diagnostic", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
input := []byte(`{"token":"abc"}{"safe":"ok"}`)
|
||||
result := chatdebug.RedactJSONSecrets(input)
|
||||
require.Contains(t, string(result), "extra JSON values")
|
||||
})
|
||||
|
||||
t.Run("non JSON input is replaced with diagnostic", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
input := []byte("not json")
|
||||
result := chatdebug.RedactJSONSecrets(input)
|
||||
require.Contains(t, string(result), "not valid JSON")
|
||||
})
|
||||
|
||||
t.Run("empty input is unchanged", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
input := []byte{}
|
||||
require.Equal(t, input, chatdebug.RedactJSONSecrets(input))
|
||||
})
|
||||
|
||||
t.Run("JSON without sensitive keys is unchanged", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
input := []byte(`{"safe":"ok","nested":{"value":1}}`)
|
||||
require.Equal(t, input, chatdebug.RedactJSONSecrets(input))
|
||||
})
|
||||
|
||||
t.Run("key matching is case insensitive", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
input := []byte(`{"API_KEY":"abc","Token":"def","PASSWORD":"ghi"}`)
|
||||
redacted := chatdebug.RedactJSONSecrets(input)
|
||||
require.JSONEq(t, `{"API_KEY":"[REDACTED]","Token":"[REDACTED]","PASSWORD":"[REDACTED]"}`, string(redacted))
|
||||
})
|
||||
|
||||
t.Run("camelCase token field names are redacted", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Providers may use camelCase (e.g. accessToken, refreshToken).
|
||||
// These should be redacted even though they don't match the
|
||||
// snake_case originals exactly.
|
||||
input := []byte(`{"accessToken":"abc","refreshToken":"def","authToken":"ghi","input_tokens":100,"output_tokens":50}`)
|
||||
redacted := chatdebug.RedactJSONSecrets(input)
|
||||
require.JSONEq(t, `{"accessToken":"[REDACTED]","refreshToken":"[REDACTED]","authToken":"[REDACTED]","input_tokens":100,"output_tokens":50}`, string(redacted))
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,74 @@
|
||||
package chatdebug //nolint:testpackage // Uses unexported recorder helpers.
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestBeginStepReuseStep(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("reuses handle under ReuseStep", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
chatID := uuid.New()
|
||||
ownerID := uuid.New()
|
||||
runID := uuid.New()
|
||||
t.Cleanup(func() { CleanupStepCounter(runID) })
|
||||
|
||||
svc := NewService(nil, testutil.Logger(t), nil)
|
||||
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})
|
||||
ctx = ReuseStep(ctx)
|
||||
opts := RecorderOptions{ChatID: chatID, OwnerID: ownerID}
|
||||
|
||||
firstHandle, firstEnriched := beginStep(ctx, svc, opts, OperationStream, nil)
|
||||
secondHandle, secondEnriched := beginStep(ctx, svc, opts, OperationStream, nil)
|
||||
|
||||
require.NotNil(t, firstHandle)
|
||||
require.Same(t, firstHandle, secondHandle)
|
||||
require.Same(t, firstHandle.stepCtx, secondHandle.stepCtx)
|
||||
require.Same(t, firstHandle.sink, secondHandle.sink)
|
||||
require.Equal(t, runID, firstHandle.stepCtx.RunID)
|
||||
require.Equal(t, chatID, firstHandle.stepCtx.ChatID)
|
||||
require.Equal(t, int32(1), firstHandle.stepCtx.StepNumber)
|
||||
require.Equal(t, OperationStream, firstHandle.stepCtx.Operation)
|
||||
require.NotEqual(t, uuid.Nil, firstHandle.stepCtx.StepID)
|
||||
|
||||
firstStepCtx, ok := StepFromContext(firstEnriched)
|
||||
require.True(t, ok)
|
||||
secondStepCtx, ok := StepFromContext(secondEnriched)
|
||||
require.True(t, ok)
|
||||
require.Same(t, firstStepCtx, secondStepCtx)
|
||||
require.Same(t, firstHandle.stepCtx, firstStepCtx)
|
||||
require.Same(t, attemptSinkFromContext(firstEnriched), attemptSinkFromContext(secondEnriched))
|
||||
})
|
||||
|
||||
t.Run("creates new handles without ReuseStep", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
chatID := uuid.New()
|
||||
ownerID := uuid.New()
|
||||
runID := uuid.New()
|
||||
t.Cleanup(func() { CleanupStepCounter(runID) })
|
||||
|
||||
svc := NewService(nil, testutil.Logger(t), nil)
|
||||
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})
|
||||
opts := RecorderOptions{ChatID: chatID, OwnerID: ownerID}
|
||||
|
||||
firstHandle, _ := beginStep(ctx, svc, opts, OperationStream, nil)
|
||||
secondHandle, _ := beginStep(ctx, svc, opts, OperationStream, nil)
|
||||
|
||||
require.NotNil(t, firstHandle)
|
||||
require.NotNil(t, secondHandle)
|
||||
require.NotSame(t, firstHandle, secondHandle)
|
||||
require.NotSame(t, firstHandle.sink, secondHandle.sink)
|
||||
require.Equal(t, int32(1), firstHandle.stepCtx.StepNumber)
|
||||
require.Equal(t, int32(2), secondHandle.stepCtx.StepNumber)
|
||||
require.NotEqual(t, firstHandle.stepCtx.StepID, secondHandle.stepCtx.StepID)
|
||||
})
|
||||
}
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -15,43 +16,76 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/database/pubsub"
|
||||
)
|
||||
|
||||
// This branch-02 compatibility shim forward-declares recorder, service,
|
||||
// and summary symbols that land in later stacked branches. Delete this
|
||||
// file once recorder.go, service.go, and summary.go are available here.
|
||||
|
||||
// RecorderOptions identifies the chat/model context for debug recording.
|
||||
type RecorderOptions struct {
|
||||
ChatID uuid.UUID
|
||||
OwnerID uuid.UUID
|
||||
Provider string
|
||||
Model string
|
||||
}
|
||||
// This compatibility shim forward-declares service and summary symbols
|
||||
// that land in later stacked branches. Delete this file once service.go
|
||||
// and summary.go are available here.
|
||||
|
||||
// Service is a placeholder for the later chat debug persistence service.
|
||||
type Service struct{}
|
||||
|
||||
// NewService constructs the branch-02 placeholder chat debug service.
|
||||
func NewService(_ database.Store, _ slog.Logger, _ pubsub.Pubsub) *Service {
|
||||
return &Service{}
|
||||
type Service struct {
|
||||
log slog.Logger
|
||||
}
|
||||
|
||||
type attemptSink struct{}
|
||||
|
||||
type attemptSinkKey struct{}
|
||||
|
||||
func withAttemptSink(ctx context.Context, sink *attemptSink) context.Context {
|
||||
if sink == nil {
|
||||
panic("chatdebug: nil attemptSink")
|
||||
}
|
||||
return context.WithValue(ctx, attemptSinkKey{}, sink)
|
||||
// CreateStepParams identifies the data recorded when a debug step starts.
|
||||
type CreateStepParams struct {
|
||||
RunID uuid.UUID
|
||||
ChatID uuid.UUID
|
||||
StepNumber int32
|
||||
Operation Operation
|
||||
Status Status
|
||||
HistoryTipMessageID int64
|
||||
NormalizedRequest any
|
||||
}
|
||||
|
||||
func attemptSinkFromContext(ctx context.Context) *attemptSink {
|
||||
sink, _ := ctx.Value(attemptSinkKey{}).(*attemptSink)
|
||||
return sink
|
||||
// UpdateStepParams identifies the data recorded when a debug step finishes.
|
||||
type UpdateStepParams struct {
|
||||
ID uuid.UUID
|
||||
ChatID uuid.UUID
|
||||
Status Status
|
||||
NormalizedResponse any
|
||||
Usage any
|
||||
Attempts []Attempt
|
||||
Error any
|
||||
Metadata any
|
||||
FinishedAt time.Time
|
||||
}
|
||||
|
||||
var stepCounters sync.Map // map[uuid.UUID]*atomic.Int32
|
||||
// NewService constructs the placeholder chat debug service.
|
||||
func NewService(_ database.Store, log slog.Logger, _ pubsub.Pubsub) *Service {
|
||||
return &Service{log: log}
|
||||
}
|
||||
|
||||
// IsEnabled reports whether debug recording is enabled for a chat owner.
|
||||
func (*Service) IsEnabled(context.Context, uuid.UUID, uuid.UUID) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// CreateStep synthesizes a debug step so recorder tests can exercise the
|
||||
// wrapper without requiring the later persistence service implementation.
|
||||
func (*Service) CreateStep(
|
||||
_ context.Context,
|
||||
params CreateStepParams,
|
||||
) (database.ChatDebugStep, error) {
|
||||
return database.ChatDebugStep{
|
||||
ID: uuid.New(),
|
||||
RunID: params.RunID,
|
||||
ChatID: params.ChatID,
|
||||
StepNumber: params.StepNumber,
|
||||
Operation: string(params.Operation),
|
||||
Status: string(params.Status),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// UpdateStep accepts final step state once recording completes.
|
||||
func (*Service) UpdateStep(
|
||||
_ context.Context,
|
||||
params UpdateStepParams,
|
||||
) (database.ChatDebugStep, error) {
|
||||
return database.ChatDebugStep{
|
||||
ID: params.ID,
|
||||
ChatID: params.ChatID,
|
||||
Status: string(params.Status),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// runRefCounts tracks how many live RunContext instances reference each
|
||||
// RunID. Cleanup of shared state (step counters) is deferred until the
|
||||
@@ -78,125 +112,6 @@ func releaseRunRef(runID uuid.UUID) {
|
||||
}
|
||||
}
|
||||
|
||||
func nextStepNumber(runID uuid.UUID) int32 {
|
||||
val, _ := stepCounters.LoadOrStore(runID, &atomic.Int32{})
|
||||
counter, ok := val.(*atomic.Int32)
|
||||
if !ok {
|
||||
panic("chatdebug: invalid step counter type")
|
||||
}
|
||||
return counter.Add(1)
|
||||
}
|
||||
|
||||
// CleanupStepCounter removes per-run step counter and reference count
|
||||
// state. This is used by tests and later stacked branches that have a
|
||||
// real run lifecycle.
|
||||
func CleanupStepCounter(runID uuid.UUID) {
|
||||
stepCounters.Delete(runID)
|
||||
runRefCounts.Delete(runID)
|
||||
}
|
||||
|
||||
type stepHandle struct {
|
||||
stepCtx *StepContext
|
||||
sink *attemptSink
|
||||
mu sync.Mutex
|
||||
status Status
|
||||
response any
|
||||
usage any
|
||||
err any
|
||||
metadata any
|
||||
}
|
||||
|
||||
func beginStep(
|
||||
ctx context.Context,
|
||||
svc *Service,
|
||||
opts RecorderOptions,
|
||||
op Operation,
|
||||
_ any,
|
||||
) (*stepHandle, context.Context) {
|
||||
if svc == nil {
|
||||
return nil, ctx
|
||||
}
|
||||
|
||||
rc, ok := RunFromContext(ctx)
|
||||
if !ok || rc.RunID == uuid.Nil {
|
||||
return nil, ctx
|
||||
}
|
||||
|
||||
if holder, reuseStep := reuseHolderFromContext(ctx); reuseStep {
|
||||
holder.mu.Lock()
|
||||
defer holder.mu.Unlock()
|
||||
// Only reuse the cached handle if it belongs to the same run.
|
||||
// A different RunContext means a new logical run, so we must
|
||||
// create a fresh step to avoid cross-run attribution.
|
||||
if holder.handle != nil && holder.handle.stepCtx.RunID == rc.RunID {
|
||||
enriched := ContextWithStep(ctx, holder.handle.stepCtx)
|
||||
enriched = withAttemptSink(enriched, holder.handle.sink)
|
||||
return holder.handle, enriched
|
||||
}
|
||||
|
||||
handle, enriched := newStepHandle(ctx, rc, opts, op)
|
||||
holder.handle = handle
|
||||
return handle, enriched
|
||||
}
|
||||
|
||||
return newStepHandle(ctx, rc, opts, op)
|
||||
}
|
||||
|
||||
func newStepHandle(
|
||||
ctx context.Context,
|
||||
rc *RunContext,
|
||||
opts RecorderOptions,
|
||||
op Operation,
|
||||
) (*stepHandle, context.Context) {
|
||||
if rc == nil || rc.RunID == uuid.Nil {
|
||||
return nil, ctx
|
||||
}
|
||||
|
||||
chatID := opts.ChatID
|
||||
if chatID == uuid.Nil {
|
||||
chatID = rc.ChatID
|
||||
}
|
||||
|
||||
handle := &stepHandle{
|
||||
stepCtx: &StepContext{
|
||||
StepID: uuid.New(),
|
||||
RunID: rc.RunID,
|
||||
ChatID: chatID,
|
||||
StepNumber: nextStepNumber(rc.RunID),
|
||||
Operation: op,
|
||||
HistoryTipMessageID: rc.HistoryTipMessageID,
|
||||
},
|
||||
sink: &attemptSink{},
|
||||
}
|
||||
enriched := ContextWithStep(ctx, handle.stepCtx)
|
||||
enriched = withAttemptSink(enriched, handle.sink)
|
||||
return handle, enriched
|
||||
}
|
||||
|
||||
func (h *stepHandle) finish(
|
||||
_ context.Context,
|
||||
status Status,
|
||||
response any,
|
||||
usage any,
|
||||
err any,
|
||||
metadata any,
|
||||
) {
|
||||
if h == nil || h.stepCtx == nil {
|
||||
return
|
||||
}
|
||||
// Guard with a mutex so concurrent callers (e.g. retried stream
|
||||
// wrappers sharing a reused handle) don't race. Unlike sync.Once,
|
||||
// later retries are allowed to overwrite earlier failure results so
|
||||
// the step reflects the final outcome.
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
h.status = status
|
||||
h.response = response
|
||||
h.usage = usage
|
||||
h.err = err
|
||||
h.metadata = metadata
|
||||
}
|
||||
|
||||
// whitespaceRun matches one or more consecutive whitespace characters.
|
||||
var whitespaceRun = regexp.MustCompile(`\s+`)
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@ package chatdebug
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"testing"
|
||||
"unicode/utf8"
|
||||
|
||||
@@ -19,15 +18,6 @@ func TestBeginStep_SkipsNilRunID(t *testing.T) {
|
||||
require.Equal(t, ctx, enriched)
|
||||
}
|
||||
|
||||
func TestNewStepHandle_SkipsNilRunID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
handle, enriched := newStepHandle(ctx, &RunContext{ChatID: uuid.New()}, RecorderOptions{ChatID: uuid.New()}, OperationGenerate)
|
||||
require.Nil(t, handle)
|
||||
require.Equal(t, ctx, enriched)
|
||||
}
|
||||
|
||||
func TestTruncateLabel(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -50,7 +40,6 @@ func TestTruncateLabel(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := TruncateLabel(tc.input, tc.maxLen)
|
||||
@@ -66,25 +55,3 @@ func maxInt(a, b int) int {
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// RedactedValue replaces sensitive values in debug payloads.
|
||||
const RedactedValue = "[REDACTED]"
|
||||
|
||||
// RecordingTransport is the branch-02 placeholder HTTP recording transport.
|
||||
type RecordingTransport struct {
|
||||
Base http.RoundTripper
|
||||
}
|
||||
|
||||
var _ http.RoundTripper = (*RecordingTransport)(nil)
|
||||
|
||||
func (t *RecordingTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
if req == nil {
|
||||
panic("chatdebug: nil request")
|
||||
}
|
||||
|
||||
base := t.Base
|
||||
if base == nil {
|
||||
base = http.DefaultTransport
|
||||
}
|
||||
return base.RoundTrip(req)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,382 @@
|
||||
package chatdebug
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"mime"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// attemptStatusCompleted is the status recorded when a response body
|
||||
// is fully read without transport-level errors.
|
||||
const attemptStatusCompleted = "completed"
|
||||
|
||||
// attemptStatusFailed is the status recorded when a transport error
|
||||
// or body read error occurs.
|
||||
const attemptStatusFailed = "failed"
|
||||
|
||||
// maxRecordedRequestBodyBytes caps in-memory request capture when GetBody
|
||||
// is available.
|
||||
const maxRecordedRequestBodyBytes = 50_000
|
||||
|
||||
// maxRecordedResponseBodyBytes caps in-memory response capture.
|
||||
const maxRecordedResponseBodyBytes = 50_000
|
||||
|
||||
// RecordingTransport captures HTTP request/response data for debug steps.
|
||||
// When the request context carries an attemptSink, it records each round
|
||||
// trip. Otherwise it delegates directly.
|
||||
type RecordingTransport struct {
|
||||
// Base is the underlying transport. nil defaults to http.DefaultTransport.
|
||||
Base http.RoundTripper
|
||||
}
|
||||
|
||||
var _ http.RoundTripper = (*RecordingTransport)(nil)
|
||||
|
||||
func (t *RecordingTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
if req == nil {
|
||||
panic("chatdebug: nil request")
|
||||
}
|
||||
|
||||
base := t.Base
|
||||
if base == nil {
|
||||
base = http.DefaultTransport
|
||||
}
|
||||
|
||||
sink := attemptSinkFromContext(req.Context())
|
||||
if sink == nil {
|
||||
return base.RoundTrip(req)
|
||||
}
|
||||
|
||||
requestHeaders := RedactHeaders(req.Header)
|
||||
|
||||
// Capture method and URL/path from the request.
|
||||
method := req.Method
|
||||
reqURL := ""
|
||||
reqPath := ""
|
||||
if req.URL != nil {
|
||||
reqURL = redactURL(req.URL)
|
||||
reqPath = req.URL.Path
|
||||
}
|
||||
|
||||
requestBody, err := captureRequestBody(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
attemptNumber := sink.nextAttemptNumber()
|
||||
|
||||
startedAt := time.Now()
|
||||
resp, err := base.RoundTrip(req)
|
||||
finishedAt := time.Now()
|
||||
durationMs := finishedAt.Sub(startedAt).Milliseconds()
|
||||
if err != nil {
|
||||
sink.record(Attempt{
|
||||
Number: attemptNumber,
|
||||
Status: attemptStatusFailed,
|
||||
Method: method,
|
||||
URL: reqURL,
|
||||
Path: reqPath,
|
||||
StartedAt: startedAt.UTC().Format(time.RFC3339Nano),
|
||||
FinishedAt: finishedAt.UTC().Format(time.RFC3339Nano),
|
||||
RequestHeaders: requestHeaders,
|
||||
RequestBody: requestBody,
|
||||
Error: err.Error(),
|
||||
DurationMs: durationMs,
|
||||
})
|
||||
return nil, err
|
||||
}
|
||||
|
||||
respHeaders := RedactHeaders(resp.Header)
|
||||
resp.Body = &recordingBody{
|
||||
inner: resp.Body,
|
||||
sink: sink,
|
||||
startedAt: startedAt,
|
||||
contentLength: resp.ContentLength,
|
||||
base: Attempt{
|
||||
Number: attemptNumber,
|
||||
Method: method,
|
||||
URL: reqURL,
|
||||
Path: reqPath,
|
||||
RequestHeaders: requestHeaders,
|
||||
RequestBody: requestBody,
|
||||
ResponseStatus: resp.StatusCode,
|
||||
ResponseHeaders: respHeaders,
|
||||
DurationMs: durationMs,
|
||||
},
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func redactURL(u *url.URL) string {
|
||||
if u == nil {
|
||||
return ""
|
||||
}
|
||||
clone := *u
|
||||
clone.User = nil
|
||||
q := clone.Query()
|
||||
for key, values := range q {
|
||||
if isSensitiveName(key) || isSensitiveJSONKey(key) {
|
||||
for i := range values {
|
||||
values[i] = RedactedValue
|
||||
}
|
||||
q[key] = values
|
||||
}
|
||||
}
|
||||
clone.RawQuery = q.Encode()
|
||||
return clone.String()
|
||||
}
|
||||
|
||||
func captureRequestBody(req *http.Request) ([]byte, error) {
|
||||
if req == nil || req.Body == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if req.GetBody != nil {
|
||||
clone, err := req.GetBody()
|
||||
if err == nil {
|
||||
defer clone.Close()
|
||||
limited, err := io.ReadAll(io.LimitReader(clone, maxRecordedRequestBodyBytes+1))
|
||||
if err == nil {
|
||||
if len(limited) > maxRecordedRequestBodyBytes {
|
||||
return []byte("[TRUNCATED]"), nil
|
||||
}
|
||||
return RedactJSONSecrets(limited), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Without GetBody we cannot safely capture the request body without
|
||||
// fully consuming a potentially large or streaming body before the
|
||||
// request is sent. Skip capture in that case to keep debug logging
|
||||
// lightweight and non-invasive.
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
type recordingBody struct {
|
||||
inner io.ReadCloser
|
||||
contentLength int64
|
||||
sink *attemptSink
|
||||
base Attempt
|
||||
startedAt time.Time
|
||||
|
||||
mu sync.Mutex
|
||||
buf bytes.Buffer
|
||||
truncated bool
|
||||
sawEOF bool
|
||||
bytesRead int64
|
||||
|
||||
recordOnce sync.Once
|
||||
closeOnce sync.Once
|
||||
}
|
||||
|
||||
func (r *recordingBody) Read(p []byte) (int, error) {
|
||||
n, err := r.inner.Read(p)
|
||||
|
||||
r.mu.Lock()
|
||||
r.bytesRead += int64(n)
|
||||
if n > 0 && !r.truncated {
|
||||
remaining := maxRecordedResponseBodyBytes - r.buf.Len()
|
||||
if remaining > 0 {
|
||||
toWrite := n
|
||||
if toWrite > remaining {
|
||||
toWrite = remaining
|
||||
r.truncated = true
|
||||
}
|
||||
_, _ = r.buf.Write(p[:toWrite])
|
||||
} else {
|
||||
r.truncated = true
|
||||
}
|
||||
}
|
||||
if errors.Is(err, io.EOF) {
|
||||
r.sawEOF = true
|
||||
}
|
||||
r.mu.Unlock()
|
||||
|
||||
if err != nil {
|
||||
r.record(err)
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (r *recordingBody) Close() error {
|
||||
r.mu.Lock()
|
||||
sawEOF := r.sawEOF
|
||||
bytesRead := r.bytesRead
|
||||
contentLength := r.contentLength
|
||||
truncated := r.truncated
|
||||
responseBody := append([]byte(nil), r.buf.Bytes()...)
|
||||
r.mu.Unlock()
|
||||
|
||||
contentType := r.base.ResponseHeaders["Content-Type"]
|
||||
shouldDrainUnknownLengthJSON := contentLength < 0 &&
|
||||
!sawEOF &&
|
||||
bytesRead > 0 &&
|
||||
!truncated &&
|
||||
isCompleteUnknownLengthJSONBody(contentType, responseBody)
|
||||
|
||||
// Always close the inner reader first so that stalled chunked
|
||||
// bodies cannot block drainToEOF indefinitely. Once inner is
|
||||
// closed, reads return immediately with an error or EOF.
|
||||
var closeErr error
|
||||
r.closeOnce.Do(func() {
|
||||
closeErr = r.inner.Close()
|
||||
})
|
||||
if closeErr != nil {
|
||||
r.record(closeErr)
|
||||
return closeErr
|
||||
}
|
||||
|
||||
// Drain remaining bytes that may already be buffered inside the
|
||||
// HTTP transport after close. Because inner is closed, this
|
||||
// finishes immediately rather than blocking on the network.
|
||||
if shouldDrainUnknownLengthJSON {
|
||||
// Best-effort drain; ignore errors since inner is closed.
|
||||
_ = r.drainToEOF()
|
||||
}
|
||||
|
||||
r.mu.Lock()
|
||||
sawEOF = r.sawEOF
|
||||
bytesRead = r.bytesRead
|
||||
contentLength = r.contentLength
|
||||
truncated = r.truncated
|
||||
responseBody = append([]byte(nil), r.buf.Bytes()...)
|
||||
r.mu.Unlock()
|
||||
|
||||
switch {
|
||||
// Only check JSON completeness when the recording buffer is
|
||||
// not truncated. A truncated buffer is an incomplete prefix
|
||||
// of the body, so the completeness check would false-positive.
|
||||
case sawEOF && !truncated && contentLength < 0 && isJSONLikeContentType(contentType) && !isCompleteUnknownLengthJSONBody(contentType, responseBody):
|
||||
r.record(io.ErrUnexpectedEOF)
|
||||
case sawEOF:
|
||||
r.record(io.EOF)
|
||||
case responseHasNoBody(r.base.Method, r.base.ResponseStatus):
|
||||
r.record(nil)
|
||||
case contentLength >= 0 && bytesRead >= contentLength:
|
||||
r.record(nil)
|
||||
case contentLength < 0 && !truncated && isCompleteUnknownLengthJSONBody(contentType, responseBody):
|
||||
r.record(nil)
|
||||
default:
|
||||
r.record(io.ErrUnexpectedEOF)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func responseHasNoBody(method string, statusCode int) bool {
|
||||
if method == http.MethodHead {
|
||||
return true
|
||||
}
|
||||
return statusCode == http.StatusNoContent ||
|
||||
statusCode == http.StatusNotModified ||
|
||||
(statusCode >= 100 && statusCode < 200)
|
||||
}
|
||||
|
||||
func isJSONLikeContentType(contentType string) bool {
|
||||
mediaType, _, err := mime.ParseMediaType(contentType)
|
||||
if err != nil {
|
||||
mediaType = strings.TrimSpace(strings.Split(contentType, ";")[0])
|
||||
}
|
||||
return mediaType == "application/json" || strings.HasSuffix(mediaType, "+json")
|
||||
}
|
||||
|
||||
// maxDrainBytes caps how many trailing bytes drainToEOF will consume.
|
||||
// This prevents Close() from blocking indefinitely on a misbehaving
|
||||
// or extremely large chunked body.
|
||||
const maxDrainBytes = 64 * 1024 // 64 KB
|
||||
|
||||
func (r *recordingBody) drainToEOF() error {
|
||||
buf := make([]byte, 4*1024)
|
||||
var drained int64
|
||||
for {
|
||||
n, err := r.inner.Read(buf)
|
||||
|
||||
r.mu.Lock()
|
||||
r.bytesRead += int64(n)
|
||||
drained += int64(n)
|
||||
if n > 0 && !r.truncated {
|
||||
remaining := maxRecordedResponseBodyBytes - r.buf.Len()
|
||||
if remaining > 0 {
|
||||
toWrite := n
|
||||
if toWrite > remaining {
|
||||
toWrite = remaining
|
||||
r.truncated = true
|
||||
}
|
||||
_, _ = r.buf.Write(buf[:toWrite])
|
||||
} else {
|
||||
r.truncated = true
|
||||
}
|
||||
}
|
||||
if errors.Is(err, io.EOF) {
|
||||
r.sawEOF = true
|
||||
}
|
||||
r.mu.Unlock()
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Safety valve: stop draining after maxDrainBytes to prevent
|
||||
// Close() from blocking indefinitely on a chunked body.
|
||||
if drained >= maxDrainBytes {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func isCompleteUnknownLengthJSONBody(contentType string, body []byte) bool {
|
||||
if !isJSONLikeContentType(contentType) {
|
||||
return false
|
||||
}
|
||||
|
||||
trimmed := bytes.TrimSpace(body)
|
||||
if len(trimmed) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
decoder := json.NewDecoder(bytes.NewReader(trimmed))
|
||||
var value any
|
||||
if err := decoder.Decode(&value); err != nil {
|
||||
return false
|
||||
}
|
||||
var extra any
|
||||
return errors.Is(decoder.Decode(&extra), io.EOF)
|
||||
}
|
||||
|
||||
func (r *recordingBody) record(err error) {
|
||||
r.recordOnce.Do(func() {
|
||||
finishedAt := time.Now()
|
||||
|
||||
r.mu.Lock()
|
||||
truncated := r.truncated
|
||||
responseBody := append([]byte(nil), r.buf.Bytes()...)
|
||||
base := r.base
|
||||
startedAt := r.startedAt
|
||||
r.mu.Unlock()
|
||||
|
||||
if truncated {
|
||||
base.ResponseBody = []byte("[TRUNCATED]")
|
||||
} else {
|
||||
base.ResponseBody = RedactJSONSecrets(responseBody)
|
||||
}
|
||||
base.StartedAt = startedAt.UTC().Format(time.RFC3339Nano)
|
||||
base.FinishedAt = finishedAt.UTC().Format(time.RFC3339Nano)
|
||||
// Recompute duration to include body read time.
|
||||
base.DurationMs = finishedAt.Sub(startedAt).Milliseconds()
|
||||
if err != nil && !errors.Is(err, io.EOF) {
|
||||
base.Error = err.Error()
|
||||
base.Status = attemptStatusFailed
|
||||
} else {
|
||||
base.Status = attemptStatusCompleted
|
||||
}
|
||||
r.sink.record(base)
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,737 @@
|
||||
package chatdebug //nolint:testpackage // Uses unexported recorder helpers.
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
func newTestSinkContext(t *testing.T) (context.Context, *attemptSink) {
|
||||
t.Helper()
|
||||
|
||||
sink := &attemptSink{}
|
||||
return withAttemptSink(context.Background(), sink), sink
|
||||
}
|
||||
|
||||
type roundTripFunc func(*http.Request) (*http.Response, error)
|
||||
|
||||
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return f(req)
|
||||
}
|
||||
|
||||
type scriptedReadCloser struct {
|
||||
chunks [][]byte
|
||||
index int
|
||||
offset int // byte offset within current chunk
|
||||
}
|
||||
|
||||
func (r *scriptedReadCloser) Read(p []byte) (int, error) {
|
||||
if r.index >= len(r.chunks) {
|
||||
return 0, io.EOF
|
||||
}
|
||||
chunk := r.chunks[r.index]
|
||||
remaining := chunk[r.offset:]
|
||||
n := copy(p, remaining)
|
||||
r.offset += n
|
||||
if r.offset >= len(chunk) {
|
||||
r.index++
|
||||
r.offset = 0
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (*scriptedReadCloser) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestRecordingTransport_NoSink(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
gotMethod := make(chan string, 1)
|
||||
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
gotMethod <- req.Method
|
||||
_, _ = rw.Write([]byte("ok"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := &http.Client{
|
||||
Transport: &RecordingTransport{Base: server.Client().Transport},
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, server.URL, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
require.Equal(t, "ok", string(body))
|
||||
require.Equal(t, http.MethodGet, <-gotMethod)
|
||||
}
|
||||
|
||||
func TestRecordingTransport_CaptureRequest(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const requestBody = `{"message":"hello","api_key":"super-secret"}`
|
||||
|
||||
type receivedRequest struct {
|
||||
authorization string
|
||||
body []byte
|
||||
}
|
||||
gotRequest := make(chan receivedRequest, 1)
|
||||
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
body, err := io.ReadAll(req.Body)
|
||||
require.NoError(t, err)
|
||||
gotRequest <- receivedRequest{
|
||||
authorization: req.Header.Get("Authorization"),
|
||||
body: body,
|
||||
}
|
||||
_, _ = rw.Write([]byte(`{"ok":true}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
ctx, sink := newTestSinkContext(t)
|
||||
client := &http.Client{
|
||||
Transport: &RecordingTransport{Base: server.Client().Transport},
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(
|
||||
ctx,
|
||||
http.MethodPost,
|
||||
server.URL,
|
||||
strings.NewReader(requestBody),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
req.Header.Set("Authorization", "Bearer top-secret")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
_, err = io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, resp.Body.Close())
|
||||
|
||||
attempts := sink.snapshot()
|
||||
require.Len(t, attempts, 1)
|
||||
require.Equal(t, 1, attempts[0].Number)
|
||||
require.Equal(t, RedactedValue, attempts[0].RequestHeaders["Authorization"])
|
||||
require.Equal(t, "application/json", attempts[0].RequestHeaders["Content-Type"])
|
||||
require.JSONEq(t, `{"message":"hello","api_key":"[REDACTED]"}`, string(attempts[0].RequestBody))
|
||||
|
||||
received := <-gotRequest
|
||||
require.JSONEq(t, requestBody, string(received.body))
|
||||
require.Equal(t, "Bearer top-secret", received.authorization)
|
||||
}
|
||||
|
||||
func TestRecordingTransport_RedactsSensitiveQueryParameters(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
_, _ = rw.Write([]byte(`ok`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
ctx, sink := newTestSinkContext(t)
|
||||
client := &http.Client{Transport: &RecordingTransport{Base: server.Client().Transport}}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, server.URL+`?api_key=secret&safe=ok`, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
_, err = io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, resp.Body.Close())
|
||||
|
||||
attempts := sink.snapshot()
|
||||
require.Len(t, attempts, 1)
|
||||
require.Contains(t, attempts[0].URL, "api_key=%5BREDACTED%5D")
|
||||
require.Contains(t, attempts[0].URL, "safe=ok")
|
||||
}
|
||||
|
||||
func TestRecordingTransport_TruncatesLargeRequestBodies(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
_, _ = io.Copy(io.Discard, req.Body)
|
||||
_, _ = rw.Write([]byte(`ok`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
ctx, sink := newTestSinkContext(t)
|
||||
client := &http.Client{Transport: &RecordingTransport{Base: server.Client().Transport}}
|
||||
|
||||
large := strings.Repeat("x", maxRecordedRequestBodyBytes+1024)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, server.URL, strings.NewReader(large))
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
_, err = io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, resp.Body.Close())
|
||||
|
||||
attempts := sink.snapshot()
|
||||
require.Len(t, attempts, 1)
|
||||
require.Equal(t, []byte("[TRUNCATED]"), attempts[0].RequestBody)
|
||||
}
|
||||
|
||||
func TestRecordingTransport_StripsURLUserinfo(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
_, _ = rw.Write([]byte(`ok`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
ctx, sink := newTestSinkContext(t)
|
||||
client := &http.Client{Transport: &RecordingTransport{Base: server.Client().Transport}}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, strings.Replace(server.URL, "http://", "http://user:secret@", 1)+`?api_key=secret`, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
_, err = io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, resp.Body.Close())
|
||||
|
||||
attempts := sink.snapshot()
|
||||
require.Len(t, attempts, 1)
|
||||
require.NotContains(t, attempts[0].URL, "user:secret")
|
||||
require.Contains(t, attempts[0].URL, "api_key=%5BREDACTED%5D")
|
||||
}
|
||||
|
||||
func TestRecordingTransport_SkipsNonReplayableRequestBodyCapture(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const requestBody = `{"message":"hello"}`
|
||||
gotRequest := make(chan []byte, 1)
|
||||
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
body, err := io.ReadAll(req.Body)
|
||||
require.NoError(t, err)
|
||||
gotRequest <- body
|
||||
_, _ = rw.Write([]byte(`ok`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
ctx, sink := newTestSinkContext(t)
|
||||
client := &http.Client{Transport: &RecordingTransport{Base: server.Client().Transport}}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, server.URL, io.NopCloser(strings.NewReader(requestBody)))
|
||||
require.NoError(t, err)
|
||||
req.GetBody = nil
|
||||
|
||||
resp, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
_, err = io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, resp.Body.Close())
|
||||
|
||||
require.JSONEq(t, requestBody, string(<-gotRequest))
|
||||
attempts := sink.snapshot()
|
||||
require.Len(t, attempts, 1)
|
||||
require.Nil(t, attempts[0].RequestBody)
|
||||
}
|
||||
|
||||
func TestRecordingTransport_CaptureResponse(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
rw.Header().Set("X-API-Key", "response-secret")
|
||||
rw.Header().Set("X-Trace-ID", "trace-123")
|
||||
rw.WriteHeader(http.StatusCreated)
|
||||
_, _ = rw.Write([]byte(`{"token":"response-secret","safe":"ok"}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
ctx, sink := newTestSinkContext(t)
|
||||
client := &http.Client{
|
||||
Transport: &RecordingTransport{Base: server.Client().Transport},
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, server.URL, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, resp.Body.Close())
|
||||
require.JSONEq(t, `{"token":"response-secret","safe":"ok"}`, string(body))
|
||||
|
||||
attempts := sink.snapshot()
|
||||
require.Len(t, attempts, 1)
|
||||
require.Equal(t, http.StatusCreated, attempts[0].ResponseStatus)
|
||||
require.Equal(t, RedactedValue, attempts[0].ResponseHeaders["X-Api-Key"])
|
||||
require.Equal(t, "trace-123", attempts[0].ResponseHeaders["X-Trace-Id"])
|
||||
require.JSONEq(t, `{"token":"[REDACTED]","safe":"ok"}`, string(attempts[0].ResponseBody))
|
||||
}
|
||||
|
||||
func TestRecordingTransport_CaptureResponseOnEOFWithoutClose(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
rw.Header().Set("Content-Type", "application/json")
|
||||
rw.Header().Set("X-API-Key", "response-secret")
|
||||
rw.WriteHeader(http.StatusAccepted)
|
||||
_, _ = rw.Write([]byte(`{"token":"response-secret","safe":"ok"}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
ctx, sink := newTestSinkContext(t)
|
||||
client := &http.Client{
|
||||
Transport: &RecordingTransport{Base: server.Client().Transport},
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, server.URL, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
require.JSONEq(t, `{"token":"response-secret","safe":"ok"}`, string(body))
|
||||
|
||||
attempts := sink.snapshot()
|
||||
require.Len(t, attempts, 1)
|
||||
require.Equal(t, http.StatusAccepted, attempts[0].ResponseStatus)
|
||||
require.Equal(t, "application/json", attempts[0].ResponseHeaders["Content-Type"])
|
||||
require.Equal(t, RedactedValue, attempts[0].ResponseHeaders["X-Api-Key"])
|
||||
require.JSONEq(t, `{"token":"[REDACTED]","safe":"ok"}`, string(attempts[0].ResponseBody))
|
||||
require.NoError(t, resp.Body.Close())
|
||||
}
|
||||
|
||||
func TestRecordingTransport_StreamingBody(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
flusher, ok := rw.(http.Flusher)
|
||||
require.True(t, ok)
|
||||
|
||||
rw.Header().Set("Content-Type", "application/json")
|
||||
_, _ = rw.Write([]byte(`{"safe":"stream",`))
|
||||
flusher.Flush()
|
||||
_, _ = rw.Write([]byte(`"token":"chunk-secret"}`))
|
||||
flusher.Flush()
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
ctx, sink := newTestSinkContext(t)
|
||||
client := &http.Client{
|
||||
Transport: &RecordingTransport{Base: server.Client().Transport},
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, server.URL, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
buf := make([]byte, 5)
|
||||
var body strings.Builder
|
||||
for {
|
||||
n, readErr := resp.Body.Read(buf)
|
||||
if n > 0 {
|
||||
_, writeErr := body.Write(buf[:n])
|
||||
require.NoError(t, writeErr)
|
||||
}
|
||||
if errors.Is(readErr, io.EOF) {
|
||||
break
|
||||
}
|
||||
require.NoError(t, readErr)
|
||||
}
|
||||
require.NoError(t, resp.Body.Close())
|
||||
require.JSONEq(t, `{"safe":"stream","token":"chunk-secret"}`, body.String())
|
||||
|
||||
attempts := sink.snapshot()
|
||||
require.Len(t, attempts, 1)
|
||||
require.JSONEq(t, `{"safe":"stream","token":"[REDACTED]"}`, string(attempts[0].ResponseBody))
|
||||
}
|
||||
|
||||
func TestRecordingTransport_CloseAfterDecoderConsumesContentLengthSucceeds(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
rw.Header().Set("Content-Type", "application/json")
|
||||
_, _ = rw.Write([]byte(`{"token":"response-secret","safe":"ok"}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
ctx, sink := newTestSinkContext(t)
|
||||
client := &http.Client{Transport: &RecordingTransport{Base: server.Client().Transport}}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, server.URL, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
var decoded map[string]string
|
||||
require.NoError(t, json.NewDecoder(resp.Body).Decode(&decoded))
|
||||
require.Equal(t, "ok", decoded["safe"])
|
||||
require.NoError(t, resp.Body.Close())
|
||||
|
||||
attempts := sink.snapshot()
|
||||
require.Len(t, attempts, 1)
|
||||
require.Equal(t, attemptStatusCompleted, attempts[0].Status)
|
||||
require.Empty(t, attempts[0].Error)
|
||||
}
|
||||
|
||||
func TestRecordingTransport_CloseAfterDecoderConsumesUnknownLengthJSONSucceeds(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, sink := newTestSinkContext(t)
|
||||
client := &http.Client{
|
||||
Transport: &RecordingTransport{
|
||||
Base: roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||
return &http.Response{ //nolint:exhaustruct // Test response exercises unknown-length close semantics.
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: &scriptedReadCloser{chunks: [][]byte{[]byte(`{"token":"response-secret","safe":"ok"}`)}},
|
||||
ContentLength: -1,
|
||||
}, nil
|
||||
}),
|
||||
},
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://example.invalid", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
var decoded map[string]string
|
||||
require.NoError(t, json.NewDecoder(resp.Body).Decode(&decoded))
|
||||
require.Equal(t, "ok", decoded["safe"])
|
||||
require.NoError(t, resp.Body.Close())
|
||||
|
||||
attempts := sink.snapshot()
|
||||
require.Len(t, attempts, 1)
|
||||
require.Equal(t, attemptStatusCompleted, attempts[0].Status)
|
||||
require.Empty(t, attempts[0].Error)
|
||||
}
|
||||
|
||||
func TestRecordingTransport_CloseAfterDecoderConsumesUnknownLengthJSONWithTrailingDocumentMarksFailed(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, sink := newTestSinkContext(t)
|
||||
client := &http.Client{
|
||||
Transport: &RecordingTransport{
|
||||
Base: roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||
return &http.Response{ //nolint:exhaustruct // Test response exercises unknown-length close semantics.
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: &scriptedReadCloser{chunks: [][]byte{[]byte("{\"token\":\"response-secret\",\"safe\":\"ok\"}{\"token\":\"second\"}")}},
|
||||
ContentLength: -1,
|
||||
}, nil
|
||||
}),
|
||||
},
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://example.invalid", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
var decoded map[string]string
|
||||
require.NoError(t, json.NewDecoder(resp.Body).Decode(&decoded))
|
||||
require.Equal(t, "ok", decoded["safe"])
|
||||
require.NoError(t, resp.Body.Close())
|
||||
|
||||
attempts := sink.snapshot()
|
||||
require.Len(t, attempts, 1)
|
||||
require.Equal(t, attemptStatusFailed, attempts[0].Status)
|
||||
require.Equal(t, io.ErrUnexpectedEOF.Error(), attempts[0].Error)
|
||||
}
|
||||
|
||||
func TestRecordingTransport_CloseAfterDecoderConsumesUnknownLengthNDJSONMarksFailed(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, sink := newTestSinkContext(t)
|
||||
client := &http.Client{
|
||||
Transport: &RecordingTransport{
|
||||
Base: roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||
return &http.Response{ //nolint:exhaustruct // Test response exercises unknown-length close semantics.
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"application/x-ndjson"}},
|
||||
Body: &scriptedReadCloser{chunks: [][]byte{[]byte("{\"token\":\"response-secret\",\"safe\":\"ok\"}\n{\"token\":\"second\"}\n")}},
|
||||
ContentLength: -1,
|
||||
}, nil
|
||||
}),
|
||||
},
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://example.invalid", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
var decoded map[string]string
|
||||
require.NoError(t, json.NewDecoder(resp.Body).Decode(&decoded))
|
||||
require.Equal(t, "ok", decoded["safe"])
|
||||
require.NoError(t, resp.Body.Close())
|
||||
|
||||
attempts := sink.snapshot()
|
||||
require.Len(t, attempts, 1)
|
||||
require.Equal(t, attemptStatusFailed, attempts[0].Status)
|
||||
require.Equal(t, io.ErrUnexpectedEOF.Error(), attempts[0].Error)
|
||||
}
|
||||
|
||||
func TestRecordingTransport_CloseAfterDecoderDrainsUnknownLengthSucceeds(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, sink := newTestSinkContext(t)
|
||||
client := &http.Client{
|
||||
Transport: &RecordingTransport{
|
||||
Base: roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||
return &http.Response{ //nolint:exhaustruct // Test response exercises unknown-length close semantics.
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: &scriptedReadCloser{chunks: [][]byte{[]byte(`{"token":"response-secret","safe":"ok"}`)}},
|
||||
ContentLength: -1,
|
||||
}, nil
|
||||
}),
|
||||
},
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://example.invalid", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
var decoded map[string]string
|
||||
require.NoError(t, json.NewDecoder(resp.Body).Decode(&decoded))
|
||||
require.Equal(t, "ok", decoded["safe"])
|
||||
_, err = io.Copy(io.Discard, resp.Body)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, resp.Body.Close())
|
||||
|
||||
attempts := sink.snapshot()
|
||||
require.Len(t, attempts, 1)
|
||||
require.Equal(t, attemptStatusCompleted, attempts[0].Status)
|
||||
require.Empty(t, attempts[0].Error)
|
||||
}
|
||||
|
||||
func TestRecordingTransport_CloseWithoutReadingHeadResponseSucceeds(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, sink := newTestSinkContext(t)
|
||||
client := &http.Client{
|
||||
Transport: &RecordingTransport{
|
||||
Base: roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||
return &http.Response{ //nolint:exhaustruct // Test response exercises no-body close semantics.
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: &scriptedReadCloser{chunks: [][]byte{[]byte(`{"ignored":true}`)}},
|
||||
ContentLength: 13,
|
||||
Request: req,
|
||||
}, nil
|
||||
}),
|
||||
},
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodHead, "http://example.invalid", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, resp.Body.Close())
|
||||
|
||||
attempts := sink.snapshot()
|
||||
require.Len(t, attempts, 1)
|
||||
require.Equal(t, attemptStatusCompleted, attempts[0].Status)
|
||||
require.Empty(t, attempts[0].Error)
|
||||
}
|
||||
|
||||
func TestRecordingTransport_CloseWithoutReadingUnknownLengthMarksFailed(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, sink := newTestSinkContext(t)
|
||||
client := &http.Client{
|
||||
Transport: &RecordingTransport{
|
||||
Base: roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||
return &http.Response{ //nolint:exhaustruct // Test response exercises unknown-length close semantics.
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: &scriptedReadCloser{chunks: [][]byte{[]byte(`{"token":"response-secret","safe":"ok"}`)}},
|
||||
ContentLength: -1,
|
||||
}, nil
|
||||
}),
|
||||
},
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://example.invalid", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, resp.Body.Close())
|
||||
|
||||
attempts := sink.snapshot()
|
||||
require.Len(t, attempts, 1)
|
||||
require.Equal(t, attemptStatusFailed, attempts[0].Status)
|
||||
require.Equal(t, io.ErrUnexpectedEOF.Error(), attempts[0].Error)
|
||||
}
|
||||
|
||||
func TestRecordingTransport_PrematureCloseUnknownLengthMarksFailed(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, sink := newTestSinkContext(t)
|
||||
client := &http.Client{
|
||||
Transport: &RecordingTransport{
|
||||
Base: roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||
return &http.Response{ //nolint:exhaustruct // Test response exercises unknown-length close semantics.
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: &scriptedReadCloser{chunks: [][]byte{[]byte(`{"token":"response-secret","safe":"ok"}`)}},
|
||||
ContentLength: -1,
|
||||
}, nil
|
||||
}),
|
||||
},
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://example.invalid", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
buf := make([]byte, 5)
|
||||
_, err = resp.Body.Read(buf)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, resp.Body.Close())
|
||||
|
||||
attempts := sink.snapshot()
|
||||
require.Len(t, attempts, 1)
|
||||
require.Equal(t, attemptStatusFailed, attempts[0].Status)
|
||||
require.Equal(t, io.ErrUnexpectedEOF.Error(), attempts[0].Error)
|
||||
}
|
||||
|
||||
func TestRecordingTransport_PrematureCloseMarksFailed(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
_, _ = rw.Write([]byte(`{"token":"response-secret","safe":"ok"}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
ctx, sink := newTestSinkContext(t)
|
||||
client := &http.Client{Transport: &RecordingTransport{Base: server.Client().Transport}}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, server.URL, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
buf := make([]byte, 5)
|
||||
_, err = resp.Body.Read(buf)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, resp.Body.Close())
|
||||
|
||||
attempts := sink.snapshot()
|
||||
require.Len(t, attempts, 1)
|
||||
require.Equal(t, attemptStatusFailed, attempts[0].Status)
|
||||
}
|
||||
|
||||
func TestRecordingTransport_TruncatesLargeResponses(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
_, _ = rw.Write([]byte(strings.Repeat("x", maxRecordedResponseBodyBytes+1024)))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
ctx, sink := newTestSinkContext(t)
|
||||
client := &http.Client{Transport: &RecordingTransport{Base: server.Client().Transport}}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, server.URL, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
_, err = io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, resp.Body.Close())
|
||||
|
||||
attempts := sink.snapshot()
|
||||
require.Len(t, attempts, 1)
|
||||
require.Equal(t, []byte("[TRUNCATED]"), attempts[0].ResponseBody)
|
||||
}
|
||||
|
||||
func TestRecordingTransport_TransportError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, sink := newTestSinkContext(t)
|
||||
client := &http.Client{
|
||||
Transport: &RecordingTransport{
|
||||
Base: roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||
return nil, xerrors.New("transport exploded")
|
||||
}),
|
||||
},
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(
|
||||
ctx,
|
||||
http.MethodPost,
|
||||
"http://example.invalid",
|
||||
strings.NewReader(`{"password":"secret","safe":"ok"}`),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
req.Header.Set("Authorization", "Bearer top-secret")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if resp != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
require.Nil(t, resp)
|
||||
require.EqualError(t, err, "Post \"http://example.invalid\": transport exploded")
|
||||
|
||||
attempts := sink.snapshot()
|
||||
require.Len(t, attempts, 1)
|
||||
require.Equal(t, 1, attempts[0].Number)
|
||||
require.Equal(t, RedactedValue, attempts[0].RequestHeaders["Authorization"])
|
||||
require.JSONEq(t, `{"password":"[REDACTED]","safe":"ok"}`, string(attempts[0].RequestBody))
|
||||
require.Zero(t, attempts[0].ResponseStatus)
|
||||
require.Equal(t, "transport exploded", attempts[0].Error)
|
||||
require.GreaterOrEqual(t, attempts[0].DurationMs, int64(0))
|
||||
}
|
||||
|
||||
func TestRecordingTransport_NilBase(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
_, _ = rw.Write([]byte("ok"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := &http.Client{Transport: &RecordingTransport{}}
|
||||
|
||||
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, server.URL, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "ok", string(body))
|
||||
}
|
||||
Reference in New Issue
Block a user