feat(coderd): add chat debug service and summary aggregation
Signed-off-by: Thomas Kosiewski <tk@coder.com>
This commit is contained in:
@@ -24328,12 +24328,10 @@ func (q *sqlQuerier) GetUserChatCustomPrompt(ctx context.Context, userID uuid.UU
|
||||
|
||||
const getUserChatDebugLoggingEnabled = `-- name: GetUserChatDebugLoggingEnabled :one
|
||||
SELECT
|
||||
COALESCE((
|
||||
SELECT value = 'true'
|
||||
FROM user_configs
|
||||
WHERE user_id = $1
|
||||
AND key = 'chat_debug_logging_enabled'
|
||||
), false) :: boolean AS debug_logging_enabled
|
||||
value = 'true' AS debug_logging_enabled
|
||||
FROM user_configs
|
||||
WHERE user_id = $1
|
||||
AND key = 'chat_debug_logging_enabled'
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) GetUserChatDebugLoggingEnabled(ctx context.Context, userID uuid.UUID) (bool, error) {
|
||||
|
||||
@@ -215,12 +215,10 @@ DELETE FROM user_configs WHERE user_id = @user_id AND key = @key;
|
||||
|
||||
-- name: GetUserChatDebugLoggingEnabled :one
|
||||
SELECT
|
||||
COALESCE((
|
||||
SELECT value = 'true'
|
||||
FROM user_configs
|
||||
WHERE user_id = @user_id
|
||||
AND key = 'chat_debug_logging_enabled'
|
||||
), false) :: boolean AS debug_logging_enabled;
|
||||
value = 'true' AS debug_logging_enabled
|
||||
FROM user_configs
|
||||
WHERE user_id = @user_id
|
||||
AND key = 'chat_debug_logging_enabled';
|
||||
|
||||
-- name: UpsertUserChatDebugLoggingEnabled :exec
|
||||
INSERT INTO user_configs (user_id, key, value)
|
||||
|
||||
@@ -32,9 +32,8 @@ func TestContextWithRun_CleansUpStepCounterAfterGC(t *testing.T) {
|
||||
t.Cleanup(func() { CleanupStepCounter(runID) })
|
||||
|
||||
func() {
|
||||
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})
|
||||
handle, _ := beginStep(ctx, &Service{}, RecorderOptions{ChatID: chatID}, OperationGenerate, nil)
|
||||
require.NotNil(t, handle)
|
||||
_ = ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})
|
||||
require.Equal(t, int32(1), nextStepNumber(runID))
|
||||
_, ok := stepCounters.Load(runID)
|
||||
require.True(t, ok)
|
||||
}()
|
||||
@@ -96,11 +95,9 @@ func TestContextWithRun_CleansUpStepCounterOnGCAfterCancel(t *testing.T) {
|
||||
// context cancellation, allowing GC to trigger the cleanup.
|
||||
func() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
ctx = ContextWithRun(ctx, &RunContext{RunID: runID, ChatID: chatID})
|
||||
ContextWithRun(ctx, &RunContext{RunID: runID, ChatID: chatID})
|
||||
|
||||
handle, _ := beginStep(ctx, &Service{}, RecorderOptions{ChatID: chatID}, OperationGenerate, nil)
|
||||
require.NotNil(t, handle)
|
||||
require.Equal(t, int32(1), handle.stepCtx.StepNumber)
|
||||
require.Equal(t, int32(1), nextStepNumber(runID))
|
||||
|
||||
_, ok := stepCounters.Load(runID)
|
||||
require.True(t, ok)
|
||||
@@ -117,8 +114,5 @@ func TestContextWithRun_CleansUpStepCounterOnGCAfterCancel(t *testing.T) {
|
||||
return !ok
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
|
||||
freshCtx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})
|
||||
freshHandle, _ := beginStep(freshCtx, &Service{}, RecorderOptions{ChatID: chatID}, OperationGenerate, nil)
|
||||
require.NotNil(t, freshHandle)
|
||||
require.Equal(t, int32(1), freshHandle.stepCtx.StepNumber)
|
||||
require.Equal(t, int32(1), nextStepNumber(runID))
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package chatdebug
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@@ -14,6 +15,7 @@ import (
|
||||
"go.uber.org/mock/gomock"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbmock"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chattest"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
@@ -23,6 +25,126 @@ type testError struct{ message string }
|
||||
|
||||
func (e *testError) Error() string { return e.message }
|
||||
|
||||
func expectDebugLoggingEnabled(
|
||||
t *testing.T,
|
||||
db *dbmock.MockStore,
|
||||
ownerID uuid.UUID,
|
||||
) {
|
||||
t.Helper()
|
||||
|
||||
db.EXPECT().GetChatDebugLoggingEnabled(gomock.Any()).Return(true, nil)
|
||||
db.EXPECT().GetUserChatDebugLoggingEnabled(gomock.Any(), ownerID).Return(true, nil)
|
||||
}
|
||||
|
||||
func expectCreateStepNumberWithRequestValidity(
|
||||
t *testing.T,
|
||||
db *dbmock.MockStore,
|
||||
runID uuid.UUID,
|
||||
chatID uuid.UUID,
|
||||
stepNumber int32,
|
||||
op Operation,
|
||||
normalizedRequestValid bool,
|
||||
) uuid.UUID {
|
||||
t.Helper()
|
||||
|
||||
stepID := uuid.New()
|
||||
db.EXPECT().
|
||||
InsertChatDebugStep(gomock.Any(), gomock.AssignableToTypeOf(database.InsertChatDebugStepParams{})).
|
||||
DoAndReturn(func(_ context.Context, params database.InsertChatDebugStepParams) (database.ChatDebugStep, error) {
|
||||
require.Equal(t, runID, params.RunID)
|
||||
require.Equal(t, chatID, params.ChatID)
|
||||
require.Equal(t, stepNumber, params.StepNumber)
|
||||
require.Equal(t, string(op), params.Operation)
|
||||
require.Equal(t, string(StatusInProgress), params.Status)
|
||||
require.Equal(t, normalizedRequestValid, params.NormalizedRequest.Valid)
|
||||
|
||||
return database.ChatDebugStep{
|
||||
ID: stepID,
|
||||
RunID: runID,
|
||||
ChatID: chatID,
|
||||
StepNumber: params.StepNumber,
|
||||
Operation: params.Operation,
|
||||
Status: params.Status,
|
||||
}, nil
|
||||
})
|
||||
|
||||
// CreateStep now touches the parent run's updated_at to prevent
|
||||
// premature stale finalization.
|
||||
db.EXPECT().
|
||||
UpdateChatDebugRun(gomock.Any(), gomock.AssignableToTypeOf(database.UpdateChatDebugRunParams{})).
|
||||
DoAndReturn(func(_ context.Context, params database.UpdateChatDebugRunParams) (database.ChatDebugRun, error) {
|
||||
require.Equal(t, runID, params.ID)
|
||||
require.Equal(t, chatID, params.ChatID)
|
||||
return database.ChatDebugRun{ID: runID, ChatID: chatID}, nil
|
||||
})
|
||||
|
||||
return stepID
|
||||
}
|
||||
|
||||
func expectCreateStepNumber(
|
||||
t *testing.T,
|
||||
db *dbmock.MockStore,
|
||||
runID uuid.UUID,
|
||||
chatID uuid.UUID,
|
||||
stepNumber int32,
|
||||
op Operation,
|
||||
) uuid.UUID {
|
||||
t.Helper()
|
||||
|
||||
return expectCreateStepNumberWithRequestValidity(
|
||||
t,
|
||||
db,
|
||||
runID,
|
||||
chatID,
|
||||
stepNumber,
|
||||
op,
|
||||
true,
|
||||
)
|
||||
}
|
||||
|
||||
func expectCreateStep(
|
||||
t *testing.T,
|
||||
db *dbmock.MockStore,
|
||||
runID uuid.UUID,
|
||||
chatID uuid.UUID,
|
||||
op Operation,
|
||||
) uuid.UUID {
|
||||
t.Helper()
|
||||
|
||||
return expectCreateStepNumber(t, db, runID, chatID, 1, op)
|
||||
}
|
||||
|
||||
func expectUpdateStep(
|
||||
t *testing.T,
|
||||
db *dbmock.MockStore,
|
||||
stepID uuid.UUID,
|
||||
chatID uuid.UUID,
|
||||
status Status,
|
||||
assertFn func(database.UpdateChatDebugStepParams),
|
||||
) {
|
||||
t.Helper()
|
||||
|
||||
db.EXPECT().
|
||||
UpdateChatDebugStep(gomock.Any(), gomock.AssignableToTypeOf(database.UpdateChatDebugStepParams{})).
|
||||
DoAndReturn(func(_ context.Context, params database.UpdateChatDebugStepParams) (database.ChatDebugStep, error) {
|
||||
require.Equal(t, stepID, params.ID)
|
||||
require.Equal(t, chatID, params.ChatID)
|
||||
require.True(t, params.Status.Valid)
|
||||
require.Equal(t, string(status), params.Status.String)
|
||||
require.True(t, params.FinishedAt.Valid)
|
||||
|
||||
if assertFn != nil {
|
||||
assertFn(params)
|
||||
}
|
||||
|
||||
return database.ChatDebugStep{
|
||||
ID: stepID,
|
||||
ChatID: chatID,
|
||||
Status: params.Status.String,
|
||||
}, nil
|
||||
})
|
||||
}
|
||||
|
||||
func TestDebugModel_Provider(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -48,6 +170,7 @@ func TestDebugModel_Disabled(t *testing.T) {
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
chatID := uuid.New()
|
||||
ownerID := uuid.New()
|
||||
|
||||
svc := NewService(db, testutil.Logger(t), nil)
|
||||
respWant := &fantasy.Response{FinishReason: fantasy.FinishReasonStop}
|
||||
inner := &chattest.FakeModel{
|
||||
@@ -97,6 +220,30 @@ func TestDebugModel_Generate(t *testing.T) {
|
||||
Warnings: []fantasy.CallWarning{{Message: "warning"}},
|
||||
}
|
||||
|
||||
expectDebugLoggingEnabled(t, db, ownerID)
|
||||
stepID := expectCreateStep(t, db, runID, chatID, OperationGenerate)
|
||||
expectUpdateStep(t, db, stepID, chatID, StatusCompleted, func(params database.UpdateChatDebugStepParams) {
|
||||
require.True(t, params.NormalizedResponse.Valid)
|
||||
require.True(t, params.Usage.Valid)
|
||||
require.True(t, params.Attempts.Valid)
|
||||
// Clean successes (no prior error) leave the error column
|
||||
// as SQL NULL rather than sending jsonClear.
|
||||
require.False(t, params.Error.Valid)
|
||||
require.False(t, params.Metadata.Valid)
|
||||
|
||||
// Verify actual JSON content so a broken tag or field
|
||||
// rename is caught rather than only checking .Valid.
|
||||
var usage fantasy.Usage
|
||||
require.NoError(t, json.Unmarshal(params.Usage.RawMessage, &usage))
|
||||
require.EqualValues(t, 10, usage.InputTokens)
|
||||
require.EqualValues(t, 4, usage.OutputTokens)
|
||||
require.EqualValues(t, 14, usage.TotalTokens)
|
||||
|
||||
var resp map[string]any
|
||||
require.NoError(t, json.Unmarshal(params.NormalizedResponse.RawMessage, &resp))
|
||||
require.Equal(t, "stop", resp["finish_reason"])
|
||||
})
|
||||
|
||||
svc := NewService(db, testutil.Logger(t), nil)
|
||||
inner := &chattest.FakeModel{
|
||||
GenerateFn: func(ctx context.Context, got fantasy.Call) (*fantasy.Response, error) {
|
||||
@@ -149,6 +296,20 @@ func TestDebugModel_GeneratePersistsAttemptsWithoutResponseClose(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
expectDebugLoggingEnabled(t, db, ownerID)
|
||||
stepID := expectCreateStep(t, db, runID, chatID, OperationGenerate)
|
||||
expectUpdateStep(t, db, stepID, chatID, StatusCompleted, func(params database.UpdateChatDebugStepParams) {
|
||||
require.True(t, params.Attempts.Valid)
|
||||
require.True(t, params.NormalizedResponse.Valid)
|
||||
require.True(t, params.Usage.Valid)
|
||||
|
||||
var attempts []Attempt
|
||||
require.NoError(t, json.Unmarshal(params.Attempts.RawMessage, &attempts))
|
||||
require.Len(t, attempts, 1)
|
||||
require.Equal(t, attemptStatusCompleted, attempts[0].Status)
|
||||
require.Equal(t, http.StatusCreated, attempts[0].ResponseStatus)
|
||||
})
|
||||
|
||||
svc := NewService(db, testutil.Logger(t), nil)
|
||||
inner := &chattest.FakeModel{
|
||||
GenerateFn: func(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) {
|
||||
@@ -197,6 +358,16 @@ func TestDebugModel_GenerateError(t *testing.T) {
|
||||
runID := uuid.New()
|
||||
wantErr := &testError{message: "boom"}
|
||||
|
||||
expectDebugLoggingEnabled(t, db, ownerID)
|
||||
stepID := expectCreateStep(t, db, runID, chatID, OperationGenerate)
|
||||
expectUpdateStep(t, db, stepID, chatID, StatusError, func(params database.UpdateChatDebugStepParams) {
|
||||
require.False(t, params.NormalizedResponse.Valid)
|
||||
require.False(t, params.Usage.Valid)
|
||||
require.True(t, params.Attempts.Valid)
|
||||
require.True(t, params.Error.Valid)
|
||||
require.False(t, params.Metadata.Valid)
|
||||
})
|
||||
|
||||
svc := NewService(db, testutil.Logger(t), nil)
|
||||
model := &debugModel{
|
||||
inner: &chattest.FakeModel{
|
||||
@@ -252,6 +423,16 @@ func TestDebugModel_Stream(t *testing.T) {
|
||||
{Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop, Usage: fantasy.Usage{InputTokens: 8, OutputTokens: 3, TotalTokens: 11}},
|
||||
}
|
||||
|
||||
expectDebugLoggingEnabled(t, db, ownerID)
|
||||
stepID := expectCreateStep(t, db, runID, chatID, OperationStream)
|
||||
expectUpdateStep(t, db, stepID, chatID, StatusError, func(params database.UpdateChatDebugStepParams) {
|
||||
require.True(t, params.NormalizedResponse.Valid)
|
||||
require.True(t, params.Usage.Valid)
|
||||
require.True(t, params.Attempts.Valid)
|
||||
require.True(t, params.Error.Valid)
|
||||
require.True(t, params.Metadata.Valid)
|
||||
})
|
||||
|
||||
svc := NewService(db, testutil.Logger(t), nil)
|
||||
model := &debugModel{
|
||||
inner: &chattest.FakeModel{
|
||||
@@ -299,6 +480,18 @@ func TestDebugModel_StreamObject(t *testing.T) {
|
||||
{Type: fantasy.ObjectStreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop, Usage: fantasy.Usage{InputTokens: 5, OutputTokens: 2, TotalTokens: 7}},
|
||||
}
|
||||
|
||||
expectDebugLoggingEnabled(t, db, ownerID)
|
||||
stepID := expectCreateStep(t, db, runID, chatID, OperationStream)
|
||||
expectUpdateStep(t, db, stepID, chatID, StatusCompleted, func(params database.UpdateChatDebugStepParams) {
|
||||
require.True(t, params.NormalizedResponse.Valid)
|
||||
require.True(t, params.Usage.Valid)
|
||||
require.True(t, params.Attempts.Valid)
|
||||
// Clean successes (no prior error) leave the error column
|
||||
// as SQL NULL rather than sending jsonClear.
|
||||
require.False(t, params.Error.Valid)
|
||||
require.True(t, params.Metadata.Valid)
|
||||
})
|
||||
|
||||
svc := NewService(db, testutil.Logger(t), nil)
|
||||
model := &debugModel{
|
||||
inner: &chattest.FakeModel{
|
||||
@@ -340,12 +533,20 @@ func TestDebugModel_StreamCompletedAfterFinish(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
chatID := uuid.New()
|
||||
ownerID := uuid.New()
|
||||
runID := uuid.New()
|
||||
parts := []fantasy.StreamPart{
|
||||
{Type: fantasy.StreamPartTypeTextDelta, Delta: "hello"},
|
||||
{Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop, Usage: fantasy.Usage{InputTokens: 5, OutputTokens: 1, TotalTokens: 6}},
|
||||
}
|
||||
|
||||
// The mock expectation for UpdateStep with StatusCompleted is the
|
||||
// assertion: if the wrapper chose StatusInterrupted instead, the
|
||||
// mock would reject the call.
|
||||
expectDebugLoggingEnabled(t, db, ownerID)
|
||||
stepID := expectCreateStep(t, db, runID, chatID, OperationStream)
|
||||
expectUpdateStep(t, db, stepID, chatID, StatusCompleted, nil)
|
||||
|
||||
svc := NewService(db, testutil.Logger(t), nil)
|
||||
model := &debugModel{
|
||||
inner: &chattest.FakeModel{
|
||||
@@ -354,7 +555,7 @@ func TestDebugModel_StreamCompletedAfterFinish(t *testing.T) {
|
||||
},
|
||||
},
|
||||
svc: svc,
|
||||
opts: RecorderOptions{ChatID: chatID, OwnerID: uuid.New()},
|
||||
opts: RecorderOptions{ChatID: chatID, OwnerID: ownerID},
|
||||
}
|
||||
t.Cleanup(func() { CleanupStepCounter(runID) })
|
||||
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})
|
||||
@@ -362,34 +563,14 @@ func TestDebugModel_StreamCompletedAfterFinish(t *testing.T) {
|
||||
seq, err := model.Stream(ctx, fantasy.Call{})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Consumer reads the finish part then breaks — this should still be
|
||||
// considered a completed stream, not interrupted.
|
||||
var handle *stepHandle
|
||||
// Consumer reads the finish part then breaks — this should still
|
||||
// be considered a completed stream, not interrupted.
|
||||
for part := range seq {
|
||||
if part.Type == fantasy.StreamPartTypeFinish {
|
||||
break
|
||||
}
|
||||
}
|
||||
// The step handle is on the model's last beginStep call; verify
|
||||
// status via the internal handle state by calling beginStep directly.
|
||||
// Since the model wrapper already finalized the handle, just verify
|
||||
// we consumed something. The real assertion is that the finalize
|
||||
// path chose StatusCompleted (tested via handle.status below).
|
||||
_ = handle // handle is not directly accessible, but we can verify via a fresh step
|
||||
|
||||
// Verify by running a second stream where we inspect the handle.
|
||||
runID2 := uuid.New()
|
||||
t.Cleanup(func() { CleanupStepCounter(runID2) })
|
||||
ctx2 := ContextWithRun(context.Background(), &RunContext{RunID: runID2, ChatID: chatID})
|
||||
|
||||
h, _ := beginStep(ctx2, svc, RecorderOptions{ChatID: chatID}, OperationStream, nil)
|
||||
require.NotNil(t, h)
|
||||
// The handle starts with zero status; simulate what the wrapper does
|
||||
// when consumer breaks after finish.
|
||||
h.finish(ctx2, StatusCompleted, nil, nil, nil, nil)
|
||||
// finish uses sync.Once, so after it returns the status is safely
|
||||
// set and readable in this single-goroutine test.
|
||||
require.Equal(t, StatusCompleted, h.status)
|
||||
// gomock verifies UpdateStep was called with StatusCompleted.
|
||||
}
|
||||
|
||||
// TestDebugModel_StreamInterruptedBeforeFinish verifies that when a consumer
|
||||
@@ -401,6 +582,7 @@ func TestDebugModel_StreamInterruptedBeforeFinish(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
chatID := uuid.New()
|
||||
ownerID := uuid.New()
|
||||
runID := uuid.New()
|
||||
parts := []fantasy.StreamPart{
|
||||
{Type: fantasy.StreamPartTypeTextDelta, Delta: "hello"},
|
||||
@@ -408,8 +590,13 @@ func TestDebugModel_StreamInterruptedBeforeFinish(t *testing.T) {
|
||||
{Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop},
|
||||
}
|
||||
|
||||
// The mock expectation for UpdateStep with StatusInterrupted is the
|
||||
// assertion: breaking before the finish part means interrupted.
|
||||
expectDebugLoggingEnabled(t, db, ownerID)
|
||||
stepID := expectCreateStep(t, db, runID, chatID, OperationStream)
|
||||
expectUpdateStep(t, db, stepID, chatID, StatusInterrupted, nil)
|
||||
|
||||
svc := NewService(db, testutil.Logger(t), nil)
|
||||
var capturedHandle *stepHandle
|
||||
model := &debugModel{
|
||||
inner: &chattest.FakeModel{
|
||||
StreamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
@@ -417,7 +604,7 @@ func TestDebugModel_StreamInterruptedBeforeFinish(t *testing.T) {
|
||||
},
|
||||
},
|
||||
svc: svc,
|
||||
opts: RecorderOptions{ChatID: chatID, OwnerID: uuid.New()},
|
||||
opts: RecorderOptions{ChatID: chatID, OwnerID: ownerID},
|
||||
}
|
||||
t.Cleanup(func() { CleanupStepCounter(runID) })
|
||||
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})
|
||||
@@ -434,7 +621,7 @@ func TestDebugModel_StreamInterruptedBeforeFinish(t *testing.T) {
|
||||
}
|
||||
}
|
||||
require.Equal(t, 1, count)
|
||||
_ = capturedHandle
|
||||
// gomock verifies UpdateStep was called with StatusInterrupted.
|
||||
}
|
||||
|
||||
func TestDebugModel_StreamRejectsNilSequence(t *testing.T) {
|
||||
@@ -443,7 +630,19 @@ func TestDebugModel_StreamRejectsNilSequence(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
chatID := uuid.New()
|
||||
ownerID := uuid.New()
|
||||
runID := uuid.New()
|
||||
|
||||
expectDebugLoggingEnabled(t, db, ownerID)
|
||||
stepID := expectCreateStep(t, db, runID, chatID, OperationStream)
|
||||
expectUpdateStep(t, db, stepID, chatID, StatusError, func(params database.UpdateChatDebugStepParams) {
|
||||
require.False(t, params.NormalizedResponse.Valid)
|
||||
require.False(t, params.Usage.Valid)
|
||||
require.True(t, params.Attempts.Valid)
|
||||
require.True(t, params.Error.Valid)
|
||||
require.False(t, params.Metadata.Valid)
|
||||
})
|
||||
|
||||
svc := NewService(db, testutil.Logger(t), nil)
|
||||
model := &debugModel{
|
||||
inner: &chattest.FakeModel{
|
||||
@@ -453,9 +652,10 @@ func TestDebugModel_StreamRejectsNilSequence(t *testing.T) {
|
||||
},
|
||||
},
|
||||
svc: svc,
|
||||
opts: RecorderOptions{ChatID: chatID, OwnerID: uuid.New()},
|
||||
opts: RecorderOptions{ChatID: chatID, OwnerID: ownerID},
|
||||
}
|
||||
t.Cleanup(func() { CleanupStepCounter(runID) })
|
||||
|
||||
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})
|
||||
|
||||
seq, err := model.Stream(ctx, fantasy.Call{})
|
||||
@@ -469,7 +669,19 @@ func TestDebugModel_StreamObjectRejectsNilSequence(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
chatID := uuid.New()
|
||||
ownerID := uuid.New()
|
||||
runID := uuid.New()
|
||||
|
||||
expectDebugLoggingEnabled(t, db, ownerID)
|
||||
stepID := expectCreateStep(t, db, runID, chatID, OperationStream)
|
||||
expectUpdateStep(t, db, stepID, chatID, StatusError, func(params database.UpdateChatDebugStepParams) {
|
||||
require.False(t, params.NormalizedResponse.Valid)
|
||||
require.False(t, params.Usage.Valid)
|
||||
require.True(t, params.Attempts.Valid)
|
||||
require.True(t, params.Error.Valid)
|
||||
require.True(t, params.Metadata.Valid)
|
||||
})
|
||||
|
||||
svc := NewService(db, testutil.Logger(t), nil)
|
||||
model := &debugModel{
|
||||
inner: &chattest.FakeModel{
|
||||
@@ -479,9 +691,10 @@ func TestDebugModel_StreamObjectRejectsNilSequence(t *testing.T) {
|
||||
},
|
||||
},
|
||||
svc: svc,
|
||||
opts: RecorderOptions{ChatID: chatID, OwnerID: uuid.New()},
|
||||
opts: RecorderOptions{ChatID: chatID, OwnerID: ownerID},
|
||||
}
|
||||
t.Cleanup(func() { CleanupStepCounter(runID) })
|
||||
|
||||
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})
|
||||
|
||||
seq, err := model.StreamObject(ctx, fantasy.ObjectCall{})
|
||||
@@ -502,6 +715,16 @@ func TestDebugModel_StreamEarlyStop(t *testing.T) {
|
||||
{Type: fantasy.StreamPartTypeTextDelta, Delta: "second"},
|
||||
}
|
||||
|
||||
expectDebugLoggingEnabled(t, db, ownerID)
|
||||
stepID := expectCreateStep(t, db, runID, chatID, OperationStream)
|
||||
expectUpdateStep(t, db, stepID, chatID, StatusInterrupted, func(params database.UpdateChatDebugStepParams) {
|
||||
require.True(t, params.NormalizedResponse.Valid)
|
||||
require.False(t, params.Usage.Valid)
|
||||
require.True(t, params.Attempts.Valid)
|
||||
require.False(t, params.Error.Valid)
|
||||
require.True(t, params.Metadata.Valid)
|
||||
})
|
||||
|
||||
svc := NewService(db, testutil.Logger(t), nil)
|
||||
model := &debugModel{
|
||||
inner: &chattest.FakeModel{
|
||||
|
||||
@@ -81,6 +81,31 @@ func attemptSinkFromContext(ctx context.Context) *attemptSink {
|
||||
|
||||
var stepCounters sync.Map // map[uuid.UUID]*atomic.Int32
|
||||
|
||||
// runRefCounts tracks how many live RunContext instances reference each
|
||||
// RunID. Cleanup of shared state (step counters) is deferred until the
|
||||
// last RunContext for a given RunID is garbage collected.
|
||||
var runRefCounts sync.Map // map[uuid.UUID]*atomic.Int32
|
||||
|
||||
func trackRunRef(runID uuid.UUID) {
|
||||
val, _ := runRefCounts.LoadOrStore(runID, &atomic.Int32{})
|
||||
counter := val.(*atomic.Int32)
|
||||
counter.Add(1)
|
||||
}
|
||||
|
||||
// releaseRunRef decrements the reference count for runID and cleans up
|
||||
// shared state when the last reference is released.
|
||||
func releaseRunRef(runID uuid.UUID) {
|
||||
val, ok := runRefCounts.Load(runID)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
counter := val.(*atomic.Int32)
|
||||
if counter.Add(-1) <= 0 {
|
||||
runRefCounts.Delete(runID)
|
||||
stepCounters.Delete(runID)
|
||||
}
|
||||
}
|
||||
|
||||
func nextStepNumber(runID uuid.UUID) int32 {
|
||||
val, _ := stepCounters.LoadOrStore(runID, &atomic.Int32{})
|
||||
counter, ok := val.(*atomic.Int32)
|
||||
@@ -129,13 +154,16 @@ type stepHandle struct {
|
||||
sink *attemptSink
|
||||
svc *Service
|
||||
opts RecorderOptions
|
||||
once sync.Once
|
||||
mu sync.Mutex
|
||||
status Status
|
||||
response any
|
||||
usage any
|
||||
err any
|
||||
metadata any
|
||||
// hadError tracks whether a prior finalization wrote an error
|
||||
// payload. Used to decide whether a successful retry needs to
|
||||
// explicitly clear the error field via jsonClear.
|
||||
hadError bool
|
||||
}
|
||||
|
||||
// beginStep validates preconditions, creates a debug step, and returns a
|
||||
@@ -223,11 +251,11 @@ func beginStep(
|
||||
return handle, enriched
|
||||
}
|
||||
|
||||
// finish updates the debug step with final status and data.
|
||||
// sync.Once prevents data races when concurrent callers (e.g.
|
||||
// retried stream wrappers sharing a reuse handle) both attempt
|
||||
// to finalize the same step. Only the first finish call takes
|
||||
// effect.
|
||||
// finish updates the debug step with final status and data. A mutex
|
||||
// guards the write so concurrent callers (e.g. retried stream wrappers
|
||||
// sharing a reuse handle) don't race. Unlike sync.Once, later retries
|
||||
// are allowed to overwrite earlier failure results so the step reflects
|
||||
// the final outcome.
|
||||
func (h *stepHandle) finish(
|
||||
ctx context.Context,
|
||||
status Status,
|
||||
@@ -240,38 +268,52 @@ func (h *stepHandle) finish(
|
||||
return
|
||||
}
|
||||
|
||||
h.once.Do(func() {
|
||||
h.mu.Lock()
|
||||
h.status = status
|
||||
h.response = response
|
||||
h.usage = usage
|
||||
h.err = errPayload
|
||||
h.metadata = metadata
|
||||
h.mu.Unlock()
|
||||
if h.svc == nil {
|
||||
return
|
||||
}
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
updateCtx, cancel := stepFinalizeContext(ctx)
|
||||
defer cancel()
|
||||
h.status = status
|
||||
h.response = response
|
||||
h.usage = usage
|
||||
h.err = errPayload
|
||||
h.metadata = metadata
|
||||
if errPayload != nil {
|
||||
h.hadError = true
|
||||
}
|
||||
if h.svc == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if _, updateErr := h.svc.UpdateStep(updateCtx, UpdateStepParams{
|
||||
ID: h.stepCtx.StepID,
|
||||
ChatID: h.stepCtx.ChatID,
|
||||
Status: status,
|
||||
NormalizedResponse: response,
|
||||
Usage: usage,
|
||||
Attempts: h.sink.snapshot(),
|
||||
Error: errPayload,
|
||||
Metadata: metadata,
|
||||
FinishedAt: time.Now(),
|
||||
}); updateErr != nil {
|
||||
h.svc.log.Warn(updateCtx, "failed to finalize chat debug step",
|
||||
slog.Error(updateErr),
|
||||
slog.F("step_id", h.stepCtx.StepID),
|
||||
slog.F("chat_id", h.stepCtx.ChatID),
|
||||
slog.F("status", status),
|
||||
)
|
||||
}
|
||||
})
|
||||
updateCtx, cancel := stepFinalizeContext(ctx)
|
||||
defer cancel()
|
||||
|
||||
// When the step completes successfully after a prior failed
|
||||
// attempt, the error field must be explicitly cleared. A plain
|
||||
// nil would leave the COALESCE-based SQL untouched, so we send
|
||||
// jsonClear{} which serializes as a valid JSONB null. Only do
|
||||
// this when a prior error was actually recorded — otherwise
|
||||
// clean successes would get a spurious JSONB null that downstream
|
||||
// aggregation could misread as an error.
|
||||
errValue := errPayload
|
||||
if errValue == nil && status == StatusCompleted && h.hadError {
|
||||
errValue = jsonClear{}
|
||||
}
|
||||
|
||||
if _, updateErr := h.svc.UpdateStep(updateCtx, UpdateStepParams{
|
||||
ID: h.stepCtx.StepID,
|
||||
ChatID: h.stepCtx.ChatID,
|
||||
Status: status,
|
||||
NormalizedResponse: response,
|
||||
Usage: usage,
|
||||
Attempts: h.sink.snapshot(),
|
||||
Error: errValue,
|
||||
Metadata: metadata,
|
||||
FinishedAt: time.Now(),
|
||||
}); updateErr != nil {
|
||||
h.svc.log.Warn(updateCtx, "failed to finalize chat debug step",
|
||||
slog.Error(updateErr),
|
||||
slog.F("step_id", h.stepCtx.StepID),
|
||||
slog.F("chat_id", h.stepCtx.ChatID),
|
||||
slog.F("status", status),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,8 +9,11 @@ import (
|
||||
"charm.land/fantasy"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database/dbmock"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chattest"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestAttemptSink_ThreadSafe(t *testing.T) {
|
||||
@@ -147,11 +150,18 @@ func TestBeginStep_NilService(t *testing.T) {
|
||||
func TestBeginStep_FallsBackToRunChatID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
runID := uuid.New()
|
||||
runChatID := uuid.New()
|
||||
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: runChatID})
|
||||
ownerID := uuid.New()
|
||||
expectDebugLoggingEnabled(t, db, ownerID)
|
||||
expectCreateStepNumberWithRequestValidity(t, db, runID, runChatID, 1, OperationGenerate, false)
|
||||
|
||||
handle, enriched := beginStep(ctx, &Service{}, RecorderOptions{}, OperationGenerate, nil)
|
||||
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: runChatID})
|
||||
svc := NewService(db, testutil.Logger(t), nil)
|
||||
|
||||
handle, enriched := beginStep(ctx, svc, RecorderOptions{OwnerID: ownerID}, OperationGenerate, nil)
|
||||
require.NotNil(t, handle)
|
||||
require.Equal(t, runChatID, handle.stepCtx.ChatID)
|
||||
|
||||
|
||||
@@ -6,7 +6,9 @@ import (
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database/dbmock"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
@@ -21,7 +23,21 @@ func TestBeginStepReuseStep(t *testing.T) {
|
||||
runID := uuid.New()
|
||||
t.Cleanup(func() { CleanupStepCounter(runID) })
|
||||
|
||||
svc := NewService(nil, testutil.Logger(t), nil)
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
expectDebugLoggingEnabled(t, db, ownerID)
|
||||
expectCreateStepNumberWithRequestValidity(
|
||||
t,
|
||||
db,
|
||||
runID,
|
||||
chatID,
|
||||
1,
|
||||
OperationStream,
|
||||
false,
|
||||
)
|
||||
expectDebugLoggingEnabled(t, db, ownerID)
|
||||
|
||||
svc := NewService(db, testutil.Logger(t), nil)
|
||||
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})
|
||||
ctx = ReuseStep(ctx)
|
||||
opts := RecorderOptions{ChatID: chatID, OwnerID: ownerID}
|
||||
@@ -56,7 +72,30 @@ func TestBeginStepReuseStep(t *testing.T) {
|
||||
runID := uuid.New()
|
||||
t.Cleanup(func() { CleanupStepCounter(runID) })
|
||||
|
||||
svc := NewService(nil, testutil.Logger(t), nil)
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
expectDebugLoggingEnabled(t, db, ownerID)
|
||||
expectCreateStepNumberWithRequestValidity(
|
||||
t,
|
||||
db,
|
||||
runID,
|
||||
chatID,
|
||||
1,
|
||||
OperationStream,
|
||||
false,
|
||||
)
|
||||
expectDebugLoggingEnabled(t, db, ownerID)
|
||||
expectCreateStepNumberWithRequestValidity(
|
||||
t,
|
||||
db,
|
||||
runID,
|
||||
chatID,
|
||||
2,
|
||||
OperationStream,
|
||||
false,
|
||||
)
|
||||
|
||||
svc := NewService(db, testutil.Logger(t), nil)
|
||||
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})
|
||||
opts := RecorderOptions{ChatID: chatID, OwnerID: ownerID}
|
||||
|
||||
|
||||
@@ -0,0 +1,539 @@
|
||||
package chatdebug
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/sqlc-dev/pqtype"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/database/pubsub"
|
||||
)
|
||||
|
||||
// DefaultStaleThreshold is the fallback stale timeout for debug rows
|
||||
// when no caller-provided value is supplied.
|
||||
const DefaultStaleThreshold = 5 * time.Minute
|
||||
|
||||
// Service persists chat debug rows and fans out lightweight change events.
|
||||
type Service struct {
|
||||
db database.Store
|
||||
log slog.Logger
|
||||
pubsub pubsub.Pubsub
|
||||
alwaysEnable bool
|
||||
// staleAfterNanos stores the stale threshold as nanoseconds in an
|
||||
// atomic.Int64 so SetStaleAfter and FinalizeStale can be called
|
||||
// from concurrent goroutines without a data race.
|
||||
staleAfterNanos atomic.Int64
|
||||
}
|
||||
|
||||
// ServiceOption configures optional Service behavior.
|
||||
type ServiceOption func(*Service)
|
||||
|
||||
// WithStaleThreshold overrides the default stale-row finalization
|
||||
// threshold. Callers that already have a configurable in-flight chat
|
||||
// timeout (e.g. chatd's InFlightChatStaleAfter) should pass it here
|
||||
// so the two sweeps stay in sync.
|
||||
func WithStaleThreshold(d time.Duration) ServiceOption {
|
||||
return func(s *Service) {
|
||||
if d > 0 {
|
||||
s.staleAfterNanos.Store(d.Nanoseconds())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WithAlwaysEnable forces debug logging on for every chat regardless
|
||||
// of the runtime admin and user opt-in settings. This is used for the
|
||||
// deployment-level serpent flag.
|
||||
func WithAlwaysEnable(always bool) ServiceOption {
|
||||
return func(s *Service) {
|
||||
s.alwaysEnable = always
|
||||
}
|
||||
}
|
||||
|
||||
// CreateRunParams contains friendly inputs for creating a debug run.
|
||||
type CreateRunParams struct {
|
||||
ChatID uuid.UUID
|
||||
RootChatID uuid.UUID
|
||||
ParentChatID uuid.UUID
|
||||
ModelConfigID uuid.UUID
|
||||
TriggerMessageID int64
|
||||
HistoryTipMessageID int64
|
||||
Kind RunKind
|
||||
Status Status
|
||||
Provider string
|
||||
Model string
|
||||
Summary any
|
||||
}
|
||||
|
||||
// UpdateRunParams contains inputs for updating a debug run.
|
||||
// Zero-valued fields are treated as "keep the existing value" by the
|
||||
// COALESCE-based SQL query. Once a field is set it cannot be cleared
|
||||
// back to NULL — this is intentional for the write-once-finalize
|
||||
// lifecycle of debug rows.
|
||||
type UpdateRunParams struct {
|
||||
ID uuid.UUID
|
||||
ChatID uuid.UUID
|
||||
Status Status
|
||||
Summary any
|
||||
FinishedAt time.Time
|
||||
}
|
||||
|
||||
// CreateStepParams contains friendly inputs for creating a debug step.
|
||||
type CreateStepParams struct {
|
||||
RunID uuid.UUID
|
||||
ChatID uuid.UUID
|
||||
StepNumber int32
|
||||
Operation Operation
|
||||
Status Status
|
||||
HistoryTipMessageID int64
|
||||
NormalizedRequest any
|
||||
}
|
||||
|
||||
// UpdateStepParams contains optional inputs for updating a debug step.
|
||||
// Most payload fields are typed as any and serialized through nullJSON
|
||||
// because their shape varies by provider. The Attempts field uses a
|
||||
// concrete slice for compile-time safety where the schema is stable.
|
||||
// Zero-valued fields are treated as "keep the existing value" by the
|
||||
// COALESCE-based SQL query — once set, fields cannot be cleared back
|
||||
// to NULL. This is intentional for the write-once-finalize lifecycle
|
||||
// of debug rows.
|
||||
type UpdateStepParams struct {
|
||||
ID uuid.UUID
|
||||
ChatID uuid.UUID
|
||||
Status Status
|
||||
AssistantMessageID int64
|
||||
NormalizedResponse any
|
||||
Usage any
|
||||
Attempts []Attempt
|
||||
Error any
|
||||
Metadata any
|
||||
FinishedAt time.Time
|
||||
}
|
||||
|
||||
// NewService constructs a chat debug persistence service.
|
||||
func NewService(db database.Store, log slog.Logger, ps pubsub.Pubsub, opts ...ServiceOption) *Service {
|
||||
if db == nil {
|
||||
panic("chatdebug: nil database.Store")
|
||||
}
|
||||
|
||||
s := &Service{
|
||||
db: db,
|
||||
log: log,
|
||||
pubsub: ps,
|
||||
}
|
||||
s.staleAfterNanos.Store(DefaultStaleThreshold.Nanoseconds())
|
||||
for _, opt := range opts {
|
||||
opt(s)
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// SetStaleAfter overrides the in-flight stale threshold used when
|
||||
// finalizing abandoned debug rows. Zero or negative durations keep the
|
||||
// default threshold.
|
||||
func (s *Service) SetStaleAfter(staleAfter time.Duration) {
|
||||
if s == nil || staleAfter <= 0 {
|
||||
return
|
||||
}
|
||||
s.staleAfterNanos.Store(staleAfter.Nanoseconds())
|
||||
}
|
||||
|
||||
func chatdContext(ctx context.Context) context.Context {
|
||||
//nolint:gocritic // AsChatd provides narrowly-scoped daemon access for
|
||||
// chat debug persistence reads and writes.
|
||||
return dbauthz.AsChatd(ctx)
|
||||
}
|
||||
|
||||
// IsEnabled returns whether debug logging is enabled for the given chat.
|
||||
func (s *Service) IsEnabled(
|
||||
ctx context.Context,
|
||||
chatID uuid.UUID,
|
||||
ownerID uuid.UUID,
|
||||
) bool {
|
||||
if s == nil {
|
||||
return false
|
||||
}
|
||||
if s.alwaysEnable {
|
||||
return true
|
||||
}
|
||||
if s.db == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
authCtx := chatdContext(ctx)
|
||||
|
||||
allowUsers, err := s.db.GetChatDebugLoggingEnabled(authCtx)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return false
|
||||
}
|
||||
s.log.Warn(ctx, "failed to load runtime admin chat debug logging setting",
|
||||
slog.Error(err),
|
||||
)
|
||||
return false
|
||||
}
|
||||
if !allowUsers {
|
||||
return false
|
||||
}
|
||||
|
||||
if ownerID == uuid.Nil {
|
||||
s.log.Warn(ctx, "missing chat owner for debug logging enablement check",
|
||||
slog.F("chat_id", chatID),
|
||||
)
|
||||
return false
|
||||
}
|
||||
|
||||
enabled, err := s.db.GetUserChatDebugLoggingEnabled(authCtx, ownerID)
|
||||
if err == nil {
|
||||
return enabled
|
||||
}
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return false
|
||||
}
|
||||
|
||||
s.log.Warn(ctx, "failed to load user chat debug logging setting",
|
||||
slog.Error(err),
|
||||
slog.F("chat_id", chatID),
|
||||
slog.F("owner_id", ownerID),
|
||||
)
|
||||
return false
|
||||
}
|
||||
|
||||
// CreateRun inserts a new debug run and emits a run update event.
|
||||
func (s *Service) CreateRun(
|
||||
ctx context.Context,
|
||||
params CreateRunParams,
|
||||
) (database.ChatDebugRun, error) {
|
||||
run, err := s.db.InsertChatDebugRun(chatdContext(ctx),
|
||||
database.InsertChatDebugRunParams{
|
||||
ChatID: params.ChatID,
|
||||
RootChatID: nullUUID(params.RootChatID),
|
||||
ParentChatID: nullUUID(params.ParentChatID),
|
||||
ModelConfigID: nullUUID(params.ModelConfigID),
|
||||
TriggerMessageID: nullInt64(params.TriggerMessageID),
|
||||
HistoryTipMessageID: nullInt64(params.HistoryTipMessageID),
|
||||
Kind: string(params.Kind),
|
||||
Status: string(params.Status),
|
||||
Provider: nullString(params.Provider),
|
||||
Model: nullString(params.Model),
|
||||
Summary: s.nullJSON(ctx, params.Summary),
|
||||
StartedAt: sql.NullTime{},
|
||||
UpdatedAt: sql.NullTime{},
|
||||
FinishedAt: sql.NullTime{},
|
||||
})
|
||||
if err != nil {
|
||||
return database.ChatDebugRun{}, err
|
||||
}
|
||||
|
||||
s.publishEvent(ctx, run.ChatID, EventKindRunUpdate, run.ID, uuid.Nil)
|
||||
return run, nil
|
||||
}
|
||||
|
||||
// UpdateRun updates an existing debug run and emits a run update event.
|
||||
func (s *Service) UpdateRun(
|
||||
ctx context.Context,
|
||||
params UpdateRunParams,
|
||||
) (database.ChatDebugRun, error) {
|
||||
run, err := s.db.UpdateChatDebugRun(chatdContext(ctx),
|
||||
database.UpdateChatDebugRunParams{
|
||||
RootChatID: uuid.NullUUID{},
|
||||
ParentChatID: uuid.NullUUID{},
|
||||
ModelConfigID: uuid.NullUUID{},
|
||||
TriggerMessageID: sql.NullInt64{},
|
||||
HistoryTipMessageID: sql.NullInt64{},
|
||||
Status: nullString(string(params.Status)),
|
||||
Provider: sql.NullString{},
|
||||
Model: sql.NullString{},
|
||||
Summary: s.nullJSON(ctx, params.Summary),
|
||||
FinishedAt: nullTime(params.FinishedAt),
|
||||
ID: params.ID,
|
||||
ChatID: params.ChatID,
|
||||
})
|
||||
if err != nil {
|
||||
return database.ChatDebugRun{}, err
|
||||
}
|
||||
|
||||
s.publishEvent(ctx, run.ChatID, EventKindRunUpdate, run.ID, uuid.Nil)
|
||||
return run, nil
|
||||
}
|
||||
|
||||
// CreateStep inserts a new debug step and emits a step update event.
|
||||
func (s *Service) CreateStep(
|
||||
ctx context.Context,
|
||||
params CreateStepParams,
|
||||
) (database.ChatDebugStep, error) {
|
||||
insert := database.InsertChatDebugStepParams{
|
||||
RunID: params.RunID,
|
||||
StepNumber: params.StepNumber,
|
||||
Operation: string(params.Operation),
|
||||
Status: string(params.Status),
|
||||
HistoryTipMessageID: nullInt64(params.HistoryTipMessageID),
|
||||
AssistantMessageID: sql.NullInt64{},
|
||||
NormalizedRequest: s.nullJSON(ctx, params.NormalizedRequest),
|
||||
NormalizedResponse: pqtype.NullRawMessage{},
|
||||
Usage: pqtype.NullRawMessage{},
|
||||
Attempts: pqtype.NullRawMessage{},
|
||||
Error: pqtype.NullRawMessage{},
|
||||
Metadata: pqtype.NullRawMessage{},
|
||||
StartedAt: sql.NullTime{},
|
||||
UpdatedAt: sql.NullTime{},
|
||||
FinishedAt: sql.NullTime{},
|
||||
ChatID: params.ChatID,
|
||||
}
|
||||
|
||||
// Cap retry attempts to prevent infinite loops under
|
||||
// pathological concurrency. Each iteration performs two DB
|
||||
// round-trips (insert + list), so 10 retries is generous.
|
||||
const maxCreateStepRetries = 10
|
||||
|
||||
for attempt := 0; attempt < maxCreateStepRetries; attempt++ {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return database.ChatDebugStep{}, err
|
||||
}
|
||||
|
||||
step, err := s.db.InsertChatDebugStep(chatdContext(ctx), insert)
|
||||
if err == nil {
|
||||
// Touch the parent run's updated_at so the stale-
|
||||
// finalization sweep does not prematurely interrupt
|
||||
// long-running runs that are still producing steps.
|
||||
if _, touchErr := s.db.UpdateChatDebugRun(chatdContext(ctx), database.UpdateChatDebugRunParams{
|
||||
RootChatID: uuid.NullUUID{},
|
||||
ParentChatID: uuid.NullUUID{},
|
||||
ModelConfigID: uuid.NullUUID{},
|
||||
TriggerMessageID: sql.NullInt64{},
|
||||
HistoryTipMessageID: sql.NullInt64{},
|
||||
Status: sql.NullString{},
|
||||
Provider: sql.NullString{},
|
||||
Model: sql.NullString{},
|
||||
Summary: pqtype.NullRawMessage{},
|
||||
FinishedAt: sql.NullTime{},
|
||||
ID: params.RunID,
|
||||
ChatID: params.ChatID,
|
||||
}); touchErr != nil {
|
||||
s.log.Warn(ctx, "failed to touch parent run updated_at",
|
||||
slog.F("run_id", params.RunID),
|
||||
slog.Error(touchErr),
|
||||
)
|
||||
}
|
||||
s.publishEvent(ctx, step.ChatID, EventKindStepUpdate, step.RunID, step.ID)
|
||||
return step, nil
|
||||
}
|
||||
if !database.IsUniqueViolation(err, database.UniqueIndexChatDebugStepsRunStep) {
|
||||
return database.ChatDebugStep{}, err
|
||||
}
|
||||
|
||||
steps, listErr := s.db.GetChatDebugStepsByRunID(chatdContext(ctx), params.RunID)
|
||||
if listErr != nil {
|
||||
return database.ChatDebugStep{}, listErr
|
||||
}
|
||||
nextStepNumber := insert.StepNumber + 1
|
||||
for _, existing := range steps {
|
||||
if existing.StepNumber >= nextStepNumber {
|
||||
nextStepNumber = existing.StepNumber + 1
|
||||
}
|
||||
}
|
||||
insert.StepNumber = nextStepNumber
|
||||
}
|
||||
|
||||
return database.ChatDebugStep{}, xerrors.Errorf(
|
||||
"failed to create debug step after %d attempts (run_id=%s)",
|
||||
maxCreateStepRetries, params.RunID,
|
||||
)
|
||||
}
|
||||
|
||||
// UpdateStep updates an existing debug step and emits a step update event.
|
||||
func (s *Service) UpdateStep(
|
||||
ctx context.Context,
|
||||
params UpdateStepParams,
|
||||
) (database.ChatDebugStep, error) {
|
||||
step, err := s.db.UpdateChatDebugStep(chatdContext(ctx),
|
||||
database.UpdateChatDebugStepParams{
|
||||
Status: nullString(string(params.Status)),
|
||||
HistoryTipMessageID: sql.NullInt64{},
|
||||
AssistantMessageID: nullInt64(params.AssistantMessageID),
|
||||
NormalizedRequest: pqtype.NullRawMessage{},
|
||||
NormalizedResponse: s.nullJSON(ctx, params.NormalizedResponse),
|
||||
Usage: s.nullJSON(ctx, params.Usage),
|
||||
Attempts: s.nullJSON(ctx, params.Attempts),
|
||||
Error: s.nullJSON(ctx, params.Error),
|
||||
Metadata: s.nullJSON(ctx, params.Metadata),
|
||||
FinishedAt: nullTime(params.FinishedAt),
|
||||
ID: params.ID,
|
||||
ChatID: params.ChatID,
|
||||
})
|
||||
if err != nil {
|
||||
return database.ChatDebugStep{}, err
|
||||
}
|
||||
|
||||
s.publishEvent(ctx, step.ChatID, EventKindStepUpdate, step.RunID, step.ID)
|
||||
return step, nil
|
||||
}
|
||||
|
||||
// DeleteByChatID deletes all debug data for a chat and emits a delete event.
|
||||
func (s *Service) DeleteByChatID(
|
||||
ctx context.Context,
|
||||
chatID uuid.UUID,
|
||||
) (int64, error) {
|
||||
deleted, err := s.db.DeleteChatDebugDataByChatID(chatdContext(ctx), chatID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
s.publishEvent(ctx, chatID, EventKindDelete, uuid.Nil, uuid.Nil)
|
||||
return deleted, nil
|
||||
}
|
||||
|
||||
// DeleteAfterMessageID deletes debug data newer than the given message.
|
||||
func (s *Service) DeleteAfterMessageID(
|
||||
ctx context.Context,
|
||||
chatID uuid.UUID,
|
||||
messageID int64,
|
||||
) (int64, error) {
|
||||
deleted, err := s.db.DeleteChatDebugDataAfterMessageID(
|
||||
chatdContext(ctx),
|
||||
database.DeleteChatDebugDataAfterMessageIDParams{
|
||||
ChatID: chatID,
|
||||
MessageID: messageID,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
s.publishEvent(ctx, chatID, EventKindDelete, uuid.Nil, uuid.Nil)
|
||||
return deleted, nil
|
||||
}
|
||||
|
||||
// FinalizeStale finalizes stale in-flight debug rows and emits a broadcast.
|
||||
func (s *Service) FinalizeStale(
|
||||
ctx context.Context,
|
||||
) (database.FinalizeStaleChatDebugRowsRow, error) {
|
||||
ns := s.staleAfterNanos.Load()
|
||||
staleAfter := time.Duration(ns)
|
||||
if staleAfter <= 0 {
|
||||
staleAfter = DefaultStaleThreshold
|
||||
}
|
||||
|
||||
result, err := s.db.FinalizeStaleChatDebugRows(
|
||||
chatdContext(ctx),
|
||||
time.Now().Add(-staleAfter),
|
||||
)
|
||||
if err != nil {
|
||||
return database.FinalizeStaleChatDebugRowsRow{}, err
|
||||
}
|
||||
|
||||
if result.RunsFinalized > 0 || result.StepsFinalized > 0 {
|
||||
s.publishEvent(ctx, uuid.Nil, EventKindFinalize, uuid.Nil, uuid.Nil)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func nullUUID(id uuid.UUID) uuid.NullUUID {
|
||||
return uuid.NullUUID{UUID: id, Valid: id != uuid.Nil}
|
||||
}
|
||||
|
||||
func nullInt64(v int64) sql.NullInt64 {
|
||||
return sql.NullInt64{Int64: v, Valid: v != 0}
|
||||
}
|
||||
|
||||
func nullString(value string) sql.NullString {
|
||||
return sql.NullString{String: value, Valid: value != ""}
|
||||
}
|
||||
|
||||
func nullTime(value time.Time) sql.NullTime {
|
||||
return sql.NullTime{Time: value, Valid: !value.IsZero()}
|
||||
}
|
||||
|
||||
// nullJSON marshals value to a NullRawMessage. When value is nil or
|
||||
// marshals to JSON "null", the result is {Valid: false}. Combined with
|
||||
// the COALESCE-based UPDATE queries, this means a caller cannot clear a
|
||||
// previously-set JSON column back to NULL — passing nil preserves the
|
||||
// existing value. This is acceptable for debug logs because fields
|
||||
// accumulate monotonically (request → response → usage → error) and
|
||||
// never need to be cleared during normal operation.
|
||||
// jsonClear is a sentinel value that tells nullJSON to emit a valid
|
||||
// JSON null (JSONB 'null') instead of SQL NULL. COALESCE treats SQL
|
||||
// NULL as "keep existing" but replaces with a non-NULL JSONB value,
|
||||
// so passing jsonClear explicitly overwrites a previously set field.
|
||||
type jsonClear struct{}
|
||||
|
||||
func (s *Service) nullJSON(ctx context.Context, value any) pqtype.NullRawMessage {
|
||||
if value == nil {
|
||||
return pqtype.NullRawMessage{}
|
||||
}
|
||||
// Sentinel: emit a valid JSONB null so COALESCE replaces
|
||||
// any previously stored value.
|
||||
if _, ok := value.(jsonClear); ok {
|
||||
return pqtype.NullRawMessage{
|
||||
RawMessage: json.RawMessage("null"),
|
||||
Valid: true,
|
||||
}
|
||||
}
|
||||
|
||||
data, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
s.log.Warn(ctx, "failed to marshal chat debug JSON",
|
||||
slog.Error(err),
|
||||
slog.F("value_type", fmt.Sprintf("%T", value)),
|
||||
)
|
||||
return pqtype.NullRawMessage{}
|
||||
}
|
||||
if bytes.Equal(data, []byte("null")) {
|
||||
return pqtype.NullRawMessage{}
|
||||
}
|
||||
|
||||
return pqtype.NullRawMessage{RawMessage: data, Valid: true}
|
||||
}
|
||||
|
||||
func (s *Service) publishEvent(
|
||||
ctx context.Context,
|
||||
chatID uuid.UUID,
|
||||
kind EventKind,
|
||||
runID uuid.UUID,
|
||||
stepID uuid.UUID,
|
||||
) {
|
||||
if s.pubsub == nil {
|
||||
s.log.Debug(ctx,
|
||||
"chat debug pubsub unavailable; skipping event",
|
||||
slog.F("kind", kind),
|
||||
slog.F("chat_id", chatID),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
event := DebugEvent{
|
||||
Kind: kind,
|
||||
ChatID: chatID,
|
||||
RunID: runID,
|
||||
StepID: stepID,
|
||||
}
|
||||
data, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
s.log.Warn(ctx, "failed to marshal chat debug event",
|
||||
slog.Error(err),
|
||||
slog.F("kind", kind),
|
||||
slog.F("chat_id", chatID),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
channel := PubsubChannel(chatID)
|
||||
if err := s.pubsub.Publish(channel, data); err != nil {
|
||||
s.log.Warn(ctx, "failed to publish chat debug event",
|
||||
slog.Error(err),
|
||||
slog.F("channel", channel),
|
||||
slog.F("kind", kind),
|
||||
slog.F("chat_id", chatID),
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,853 @@
|
||||
package chatdebug_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/lib/pq"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbgen"
|
||||
"github.com/coder/coder/v2/coderd/database/dbmock"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
||||
dbpubsub "github.com/coder/coder/v2/coderd/database/pubsub"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatdebug"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatprompt"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
type testFixture struct {
|
||||
ctx context.Context
|
||||
db database.Store
|
||||
svc *chatdebug.Service
|
||||
owner database.User
|
||||
chat database.Chat
|
||||
model database.ChatModelConfig
|
||||
}
|
||||
|
||||
func TestService_IsEnabled(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
db, _, _ := dbtestutil.NewDBWithSQLDB(t)
|
||||
owner, chat, model := seedChat(ctx, t, db)
|
||||
require.NotEqual(t, uuid.Nil, model.ID)
|
||||
|
||||
svc := chatdebug.NewService(db, testutil.Logger(t), nil)
|
||||
|
||||
// Default is off until an admin allows user opt-in.
|
||||
require.False(t, svc.IsEnabled(ctx, chat.ID, owner.ID))
|
||||
|
||||
err := db.UpsertChatDebugLoggingEnabled(ctx, true)
|
||||
require.NoError(t, err)
|
||||
// Allowing user opt-in is not enough on its own; the user must opt in.
|
||||
require.False(t, svc.IsEnabled(ctx, chat.ID, owner.ID))
|
||||
require.False(t, svc.IsEnabled(ctx, chat.ID, uuid.Nil))
|
||||
|
||||
err = db.UpsertUserChatDebugLoggingEnabled(ctx,
|
||||
database.UpsertUserChatDebugLoggingEnabledParams{
|
||||
UserID: owner.ID,
|
||||
DebugLoggingEnabled: true,
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.True(t, svc.IsEnabled(ctx, chat.ID, owner.ID))
|
||||
|
||||
err = db.UpsertUserChatDebugLoggingEnabled(ctx,
|
||||
database.UpsertUserChatDebugLoggingEnabledParams{
|
||||
UserID: owner.ID,
|
||||
DebugLoggingEnabled: false,
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.False(t, svc.IsEnabled(ctx, chat.ID, owner.ID))
|
||||
}
|
||||
|
||||
func TestService_IsEnabled_AlwaysEnable(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
db, _, _ := dbtestutil.NewDBWithSQLDB(t)
|
||||
owner, chat, model := seedChat(ctx, t, db)
|
||||
require.NotEqual(t, uuid.Nil, model.ID)
|
||||
|
||||
svc := chatdebug.NewService(db, testutil.Logger(t), nil, chatdebug.WithAlwaysEnable(true))
|
||||
require.True(t, svc.IsEnabled(ctx, chat.ID, owner.ID))
|
||||
require.True(t, svc.IsEnabled(ctx, chat.ID, uuid.Nil))
|
||||
}
|
||||
|
||||
func TestService_IsEnabled_ZeroValueService(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var svc *chatdebug.Service
|
||||
require.False(t, svc.IsEnabled(context.Background(), uuid.Nil, uuid.Nil))
|
||||
|
||||
require.False(t, (&chatdebug.Service{}).IsEnabled(context.Background(), uuid.Nil, uuid.Nil))
|
||||
}
|
||||
|
||||
func TestService_CreateRun(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
fixture := newFixture(t)
|
||||
rootChat := insertChat(fixture.ctx, t, fixture.db, fixture.owner.ID, fixture.model.ID)
|
||||
parentChat := insertChat(fixture.ctx, t, fixture.db, fixture.owner.ID, fixture.model.ID)
|
||||
triggerMsg := insertMessage(fixture.ctx, t, fixture.db, fixture.chat.ID,
|
||||
fixture.owner.ID, fixture.model.ID, database.ChatMessageRoleUser, "trigger")
|
||||
historyTipMsg := insertMessage(fixture.ctx, t, fixture.db, fixture.chat.ID,
|
||||
fixture.owner.ID, fixture.model.ID, database.ChatMessageRoleAssistant,
|
||||
"history-tip")
|
||||
|
||||
run, err := fixture.svc.CreateRun(fixture.ctx, chatdebug.CreateRunParams{
|
||||
ChatID: fixture.chat.ID,
|
||||
RootChatID: rootChat.ID,
|
||||
ParentChatID: parentChat.ID,
|
||||
ModelConfigID: fixture.model.ID,
|
||||
TriggerMessageID: triggerMsg.ID,
|
||||
HistoryTipMessageID: historyTipMsg.ID,
|
||||
Kind: chatdebug.KindChatTurn,
|
||||
Status: chatdebug.StatusInProgress,
|
||||
Provider: fixture.model.Provider,
|
||||
Model: fixture.model.Model,
|
||||
Summary: map[string]any{
|
||||
"phase": "create",
|
||||
"count": 1,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assertRunMatches(t, run, fixture.chat.ID, rootChat.ID, parentChat.ID,
|
||||
fixture.model.ID, triggerMsg.ID, historyTipMsg.ID,
|
||||
chatdebug.KindChatTurn, chatdebug.StatusInProgress,
|
||||
fixture.model.Provider, fixture.model.Model,
|
||||
`{"count":1,"phase":"create"}`)
|
||||
|
||||
stored, err := fixture.db.GetChatDebugRunByID(fixture.ctx, run.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, run.ID, stored.ID)
|
||||
require.JSONEq(t, string(run.Summary), string(stored.Summary))
|
||||
}
|
||||
|
||||
func TestService_CreateRun_TypedNilSummaryUsesDefaultObject(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
fixture := newFixture(t)
|
||||
var summary map[string]any
|
||||
|
||||
run, err := fixture.svc.CreateRun(fixture.ctx, chatdebug.CreateRunParams{
|
||||
ChatID: fixture.chat.ID,
|
||||
Kind: chatdebug.KindChatTurn,
|
||||
Status: chatdebug.StatusInProgress,
|
||||
Summary: summary,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.JSONEq(t, `{}`, string(run.Summary))
|
||||
}
|
||||
|
||||
func TestService_UpdateRun(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
fixture := newFixture(t)
|
||||
run, err := fixture.svc.CreateRun(fixture.ctx, chatdebug.CreateRunParams{
|
||||
ChatID: fixture.chat.ID,
|
||||
Kind: chatdebug.KindChatTurn,
|
||||
Status: chatdebug.StatusInProgress,
|
||||
Summary: map[string]any{
|
||||
"before": true,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
finishedAt := time.Now().UTC().Round(time.Microsecond)
|
||||
updated, err := fixture.svc.UpdateRun(fixture.ctx, chatdebug.UpdateRunParams{
|
||||
ID: run.ID,
|
||||
ChatID: fixture.chat.ID,
|
||||
Status: chatdebug.StatusCompleted,
|
||||
Summary: map[string]any{"after": "done"},
|
||||
FinishedAt: finishedAt,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, string(chatdebug.StatusCompleted), updated.Status)
|
||||
require.True(t, updated.FinishedAt.Valid)
|
||||
require.WithinDuration(t, finishedAt, updated.FinishedAt.Time, time.Second)
|
||||
require.JSONEq(t, `{"after":"done"}`, string(updated.Summary))
|
||||
|
||||
stored, err := fixture.db.GetChatDebugRunByID(fixture.ctx, run.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, string(chatdebug.StatusCompleted), stored.Status)
|
||||
require.JSONEq(t, `{"after":"done"}`, string(stored.Summary))
|
||||
require.True(t, stored.FinishedAt.Valid)
|
||||
}
|
||||
|
||||
func TestService_CreateStep(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
fixture := newFixture(t)
|
||||
run := createRun(t, fixture)
|
||||
historyTipMsg := insertMessage(fixture.ctx, t, fixture.db, fixture.chat.ID,
|
||||
fixture.owner.ID, fixture.model.ID, database.ChatMessageRoleAssistant,
|
||||
"history-tip")
|
||||
|
||||
step, err := fixture.svc.CreateStep(fixture.ctx, chatdebug.CreateStepParams{
|
||||
RunID: run.ID,
|
||||
ChatID: fixture.chat.ID,
|
||||
StepNumber: 1,
|
||||
Operation: chatdebug.OperationStream,
|
||||
Status: chatdebug.StatusInProgress,
|
||||
HistoryTipMessageID: historyTipMsg.ID,
|
||||
NormalizedRequest: map[string]any{
|
||||
"messages": []string{"hello"},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, fixture.chat.ID, step.ChatID)
|
||||
require.Equal(t, run.ID, step.RunID)
|
||||
require.EqualValues(t, 1, step.StepNumber)
|
||||
require.Equal(t, string(chatdebug.OperationStream), step.Operation)
|
||||
require.Equal(t, string(chatdebug.StatusInProgress), step.Status)
|
||||
require.True(t, step.HistoryTipMessageID.Valid)
|
||||
require.Equal(t, historyTipMsg.ID, step.HistoryTipMessageID.Int64)
|
||||
require.JSONEq(t, `{"messages":["hello"]}`, string(step.NormalizedRequest))
|
||||
|
||||
steps, err := fixture.db.GetChatDebugStepsByRunID(fixture.ctx, run.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, steps, 1)
|
||||
require.Equal(t, step.ID, steps[0].ID)
|
||||
}
|
||||
|
||||
func TestService_CreateStep_RetriesDuplicateStepNumbers(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
fixture := newFixture(t)
|
||||
run := createRun(t, fixture)
|
||||
|
||||
first, err := fixture.svc.CreateStep(fixture.ctx, chatdebug.CreateStepParams{
|
||||
RunID: run.ID,
|
||||
ChatID: fixture.chat.ID,
|
||||
StepNumber: 1,
|
||||
Operation: chatdebug.OperationStream,
|
||||
Status: chatdebug.StatusInProgress,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
second, err := fixture.svc.CreateStep(fixture.ctx, chatdebug.CreateStepParams{
|
||||
RunID: run.ID,
|
||||
ChatID: fixture.chat.ID,
|
||||
StepNumber: 1,
|
||||
Operation: chatdebug.OperationGenerate,
|
||||
Status: chatdebug.StatusInProgress,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 1, first.StepNumber)
|
||||
require.EqualValues(t, 2, second.StepNumber)
|
||||
}
|
||||
|
||||
func TestService_CreateStep_ListRetryErrorWins(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
svc := chatdebug.NewService(db, testutil.Logger(t), nil)
|
||||
runID := uuid.New()
|
||||
chatID := uuid.New()
|
||||
listErr := xerrors.New("list chat debug steps")
|
||||
|
||||
db.EXPECT().InsertChatDebugStep(
|
||||
gomock.Any(),
|
||||
gomock.AssignableToTypeOf(database.InsertChatDebugStepParams{}),
|
||||
).Return(database.ChatDebugStep{}, &pq.Error{
|
||||
Code: pq.ErrorCode("23505"),
|
||||
Constraint: string(database.UniqueIndexChatDebugStepsRunStep),
|
||||
})
|
||||
db.EXPECT().GetChatDebugStepsByRunID(gomock.Any(), runID).Return(nil, listErr)
|
||||
|
||||
_, err := svc.CreateStep(context.Background(), chatdebug.CreateStepParams{
|
||||
RunID: runID,
|
||||
ChatID: chatID,
|
||||
StepNumber: 1,
|
||||
Operation: chatdebug.OperationStream,
|
||||
Status: chatdebug.StatusInProgress,
|
||||
})
|
||||
require.ErrorIs(t, err, listErr)
|
||||
}
|
||||
|
||||
func TestService_UpdateStep(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
fixture := newFixture(t)
|
||||
run := createRun(t, fixture)
|
||||
step, err := fixture.svc.CreateStep(fixture.ctx, chatdebug.CreateStepParams{
|
||||
RunID: run.ID,
|
||||
ChatID: fixture.chat.ID,
|
||||
StepNumber: 1,
|
||||
Operation: chatdebug.OperationStream,
|
||||
Status: chatdebug.StatusInProgress,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
assistantMsg := insertMessage(fixture.ctx, t, fixture.db, fixture.chat.ID,
|
||||
fixture.owner.ID, fixture.model.ID, database.ChatMessageRoleAssistant,
|
||||
"assistant")
|
||||
finishedAt := time.Now().UTC().Round(time.Microsecond)
|
||||
updated, err := fixture.svc.UpdateStep(fixture.ctx, chatdebug.UpdateStepParams{
|
||||
ID: step.ID,
|
||||
ChatID: fixture.chat.ID,
|
||||
Status: chatdebug.StatusCompleted,
|
||||
AssistantMessageID: assistantMsg.ID,
|
||||
NormalizedResponse: map[string]any{"text": "done"},
|
||||
Usage: map[string]any{"input_tokens": 10, "output_tokens": 5},
|
||||
Attempts: []chatdebug.Attempt{{
|
||||
Number: 1,
|
||||
ResponseStatus: 200,
|
||||
DurationMs: 25,
|
||||
}},
|
||||
Metadata: map[string]any{"provider": fixture.model.Provider},
|
||||
FinishedAt: finishedAt,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, string(chatdebug.StatusCompleted), updated.Status)
|
||||
require.True(t, updated.AssistantMessageID.Valid)
|
||||
require.Equal(t, assistantMsg.ID, updated.AssistantMessageID.Int64)
|
||||
require.True(t, updated.NormalizedResponse.Valid)
|
||||
require.JSONEq(t, `{"text":"done"}`,
|
||||
string(updated.NormalizedResponse.RawMessage))
|
||||
require.True(t, updated.Usage.Valid)
|
||||
require.JSONEq(t, `{"input_tokens":10,"output_tokens":5}`,
|
||||
string(updated.Usage.RawMessage))
|
||||
require.JSONEq(t,
|
||||
`[{"number":1,"response_status":200,"duration_ms":25}]`,
|
||||
string(updated.Attempts),
|
||||
)
|
||||
require.JSONEq(t, `{"provider":"`+fixture.model.Provider+`"}`,
|
||||
string(updated.Metadata))
|
||||
require.True(t, updated.FinishedAt.Valid)
|
||||
storedSteps, err := fixture.db.GetChatDebugStepsByRunID(fixture.ctx, run.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, storedSteps, 1)
|
||||
require.Equal(t, updated.ID, storedSteps[0].ID)
|
||||
}
|
||||
|
||||
func TestService_UpdateStep_TypedNilAttemptsPreserveExistingValue(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
fixture := newFixture(t)
|
||||
run := createRun(t, fixture)
|
||||
step, err := fixture.svc.CreateStep(fixture.ctx, chatdebug.CreateStepParams{
|
||||
RunID: run.ID,
|
||||
ChatID: fixture.chat.ID,
|
||||
StepNumber: 1,
|
||||
Operation: chatdebug.OperationStream,
|
||||
Status: chatdebug.StatusInProgress,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = fixture.svc.UpdateStep(fixture.ctx, chatdebug.UpdateStepParams{
|
||||
ID: step.ID,
|
||||
ChatID: fixture.chat.ID,
|
||||
Status: chatdebug.StatusCompleted,
|
||||
Attempts: []chatdebug.Attempt{{
|
||||
Number: 1,
|
||||
}},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
var typedNilAttempts []chatdebug.Attempt
|
||||
updated, err := fixture.svc.UpdateStep(fixture.ctx, chatdebug.UpdateStepParams{
|
||||
ID: step.ID,
|
||||
ChatID: fixture.chat.ID,
|
||||
Attempts: typedNilAttempts,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
var attempts []map[string]any
|
||||
require.NoError(t, json.Unmarshal(updated.Attempts, &attempts))
|
||||
require.Len(t, attempts, 1)
|
||||
require.EqualValues(t, 1, attempts[0]["number"])
|
||||
}
|
||||
|
||||
func TestService_DeleteByChatID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
fixture := newFixture(t)
|
||||
run := createRun(t, fixture)
|
||||
_, err := fixture.svc.CreateStep(fixture.ctx, chatdebug.CreateStepParams{
|
||||
RunID: run.ID,
|
||||
ChatID: fixture.chat.ID,
|
||||
StepNumber: 1,
|
||||
Operation: chatdebug.OperationGenerate,
|
||||
Status: chatdebug.StatusInProgress,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
deleted, err := fixture.svc.DeleteByChatID(fixture.ctx, fixture.chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 1, deleted)
|
||||
|
||||
runs, err := fixture.db.GetChatDebugRunsByChatID(fixture.ctx, database.GetChatDebugRunsByChatIDParams{
|
||||
ChatID: fixture.chat.ID,
|
||||
LimitVal: 100,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, runs)
|
||||
}
|
||||
|
||||
func TestService_DeleteAfterMessageID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
fixture := newFixture(t)
|
||||
low := insertMessage(fixture.ctx, t, fixture.db, fixture.chat.ID, fixture.owner.ID,
|
||||
fixture.model.ID, database.ChatMessageRoleAssistant, "low")
|
||||
threshold := insertMessage(fixture.ctx, t, fixture.db, fixture.chat.ID,
|
||||
fixture.owner.ID, fixture.model.ID, database.ChatMessageRoleAssistant,
|
||||
"threshold")
|
||||
high := insertMessage(fixture.ctx, t, fixture.db, fixture.chat.ID, fixture.owner.ID,
|
||||
fixture.model.ID, database.ChatMessageRoleAssistant, "high")
|
||||
require.Less(t, low.ID, threshold.ID)
|
||||
require.Less(t, threshold.ID, high.ID)
|
||||
|
||||
runKeep := createRun(t, fixture)
|
||||
stepKeep, err := fixture.svc.CreateStep(fixture.ctx, chatdebug.CreateStepParams{
|
||||
RunID: runKeep.ID,
|
||||
ChatID: fixture.chat.ID,
|
||||
StepNumber: 1,
|
||||
Operation: chatdebug.OperationGenerate,
|
||||
Status: chatdebug.StatusInProgress,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = fixture.svc.UpdateStep(fixture.ctx, chatdebug.UpdateStepParams{
|
||||
ID: stepKeep.ID,
|
||||
ChatID: fixture.chat.ID,
|
||||
AssistantMessageID: low.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
runDelete := createRun(t, fixture)
|
||||
stepDelete, err := fixture.svc.CreateStep(fixture.ctx, chatdebug.CreateStepParams{
|
||||
RunID: runDelete.ID,
|
||||
ChatID: fixture.chat.ID,
|
||||
StepNumber: 1,
|
||||
Operation: chatdebug.OperationGenerate,
|
||||
Status: chatdebug.StatusInProgress,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = fixture.svc.UpdateStep(fixture.ctx, chatdebug.UpdateStepParams{
|
||||
ID: stepDelete.ID,
|
||||
ChatID: fixture.chat.ID,
|
||||
AssistantMessageID: high.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
deleted, err := fixture.svc.DeleteAfterMessageID(fixture.ctx, fixture.chat.ID,
|
||||
threshold.ID)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 1, deleted)
|
||||
|
||||
runs, err := fixture.db.GetChatDebugRunsByChatID(fixture.ctx, database.GetChatDebugRunsByChatIDParams{
|
||||
ChatID: fixture.chat.ID,
|
||||
LimitVal: 100,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, runs, 1)
|
||||
require.Equal(t, runKeep.ID, runs[0].ID)
|
||||
|
||||
steps, err := fixture.db.GetChatDebugStepsByRunID(fixture.ctx, runKeep.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, steps, 1)
|
||||
require.Equal(t, stepKeep.ID, steps[0].ID)
|
||||
}
|
||||
|
||||
func TestService_FinalizeStale_UsesConfiguredThreshold(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
svc := chatdebug.NewService(db, testutil.Logger(t), nil)
|
||||
svc.SetStaleAfter(42 * time.Second)
|
||||
|
||||
db.EXPECT().FinalizeStaleChatDebugRows(gomock.Any(), gomock.Any()).DoAndReturn(
|
||||
func(_ context.Context, staleBefore time.Time) (database.FinalizeStaleChatDebugRowsRow, error) {
|
||||
require.WithinDuration(t, time.Now().Add(-42*time.Second), staleBefore, 2*time.Second)
|
||||
return database.FinalizeStaleChatDebugRowsRow{}, nil
|
||||
},
|
||||
)
|
||||
|
||||
result, err := svc.FinalizeStale(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Zero(t, result.RunsFinalized)
|
||||
require.Zero(t, result.StepsFinalized)
|
||||
}
|
||||
|
||||
func TestService_FinalizeStale(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
owner, chat, model := seedChat(ctx, t, db)
|
||||
require.NotEqual(t, uuid.Nil, owner.ID)
|
||||
|
||||
staleTime := time.Now().Add(-10 * time.Minute).UTC().Round(time.Microsecond)
|
||||
run, err := db.InsertChatDebugRun(ctx, database.InsertChatDebugRunParams{
|
||||
ChatID: chat.ID,
|
||||
ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true},
|
||||
Kind: string(chatdebug.KindChatTurn),
|
||||
Status: string(chatdebug.StatusInProgress),
|
||||
StartedAt: sql.NullTime{Time: staleTime, Valid: true},
|
||||
UpdatedAt: sql.NullTime{Time: staleTime, Valid: true},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
step, err := db.InsertChatDebugStep(ctx, database.InsertChatDebugStepParams{
|
||||
RunID: run.ID,
|
||||
StepNumber: 1,
|
||||
Operation: string(chatdebug.OperationStream),
|
||||
Status: string(chatdebug.StatusInProgress),
|
||||
StartedAt: sql.NullTime{Time: staleTime, Valid: true},
|
||||
UpdatedAt: sql.NullTime{Time: staleTime, Valid: true},
|
||||
ChatID: chat.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
svc := chatdebug.NewService(db, testutil.Logger(t), nil)
|
||||
result, err := svc.FinalizeStale(ctx)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 1, result.RunsFinalized)
|
||||
require.EqualValues(t, 1, result.StepsFinalized)
|
||||
|
||||
storedRun, err := db.GetChatDebugRunByID(ctx, run.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, string(chatdebug.StatusInterrupted), storedRun.Status)
|
||||
require.True(t, storedRun.FinishedAt.Valid)
|
||||
|
||||
storedSteps, err := db.GetChatDebugStepsByRunID(ctx, run.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, storedSteps, 1)
|
||||
require.Equal(t, step.ID, storedSteps[0].ID)
|
||||
require.Equal(t, string(chatdebug.StatusInterrupted), storedSteps[0].Status)
|
||||
require.True(t, storedSteps[0].FinishedAt.Valid)
|
||||
}
|
||||
|
||||
func TestService_FinalizeStale_BroadcastsFinalizeEvent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
owner, chat, model := seedChat(ctx, t, db)
|
||||
require.NotEqual(t, uuid.Nil, owner.ID)
|
||||
|
||||
staleTime := time.Now().Add(-10 * time.Minute).UTC().Round(time.Microsecond)
|
||||
run, err := db.InsertChatDebugRun(ctx, database.InsertChatDebugRunParams{
|
||||
ChatID: chat.ID,
|
||||
ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true},
|
||||
Kind: string(chatdebug.KindChatTurn),
|
||||
Status: string(chatdebug.StatusInProgress),
|
||||
StartedAt: sql.NullTime{Time: staleTime, Valid: true},
|
||||
UpdatedAt: sql.NullTime{Time: staleTime, Valid: true},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = db.InsertChatDebugStep(ctx, database.InsertChatDebugStepParams{
|
||||
RunID: run.ID,
|
||||
StepNumber: 1,
|
||||
Operation: string(chatdebug.OperationStream),
|
||||
Status: string(chatdebug.StatusInProgress),
|
||||
StartedAt: sql.NullTime{Time: staleTime, Valid: true},
|
||||
UpdatedAt: sql.NullTime{Time: staleTime, Valid: true},
|
||||
ChatID: chat.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
memoryPubsub := dbpubsub.NewInMemory()
|
||||
svc := chatdebug.NewService(db, testutil.Logger(t), memoryPubsub)
|
||||
type eventResult struct {
|
||||
event chatdebug.DebugEvent
|
||||
err error
|
||||
}
|
||||
events := make(chan eventResult, 1)
|
||||
cancel, err := memoryPubsub.Subscribe(chatdebug.PubsubChannel(uuid.Nil),
|
||||
func(_ context.Context, message []byte) {
|
||||
var event chatdebug.DebugEvent
|
||||
unmarshalErr := json.Unmarshal(message, &event)
|
||||
events <- eventResult{event: event, err: unmarshalErr}
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
defer cancel()
|
||||
|
||||
result, err := svc.FinalizeStale(ctx)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 1, result.RunsFinalized)
|
||||
require.EqualValues(t, 1, result.StepsFinalized)
|
||||
|
||||
select {
|
||||
case received := <-events:
|
||||
require.NoError(t, received.err)
|
||||
require.Equal(t, chatdebug.EventKindFinalize, received.event.Kind)
|
||||
require.Equal(t, uuid.Nil, received.event.ChatID)
|
||||
require.Equal(t, uuid.Nil, received.event.RunID)
|
||||
require.Equal(t, uuid.Nil, received.event.StepID)
|
||||
case <-time.After(testutil.WaitShort):
|
||||
t.Fatal("timed out waiting for finalize event")
|
||||
}
|
||||
}
|
||||
|
||||
func TestService_FinalizeStale_NoChangesDoesNotBroadcast(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
owner, chat, _ := seedChat(ctx, t, db)
|
||||
require.NotEqual(t, uuid.Nil, owner.ID)
|
||||
|
||||
memoryPubsub := dbpubsub.NewInMemory()
|
||||
svc := chatdebug.NewService(db, testutil.Logger(t), memoryPubsub)
|
||||
events := make(chan chatdebug.DebugEvent, 1)
|
||||
cancel, err := memoryPubsub.Subscribe(chatdebug.PubsubChannel(uuid.Nil),
|
||||
func(_ context.Context, message []byte) {
|
||||
var event chatdebug.DebugEvent
|
||||
if err := json.Unmarshal(message, &event); err == nil {
|
||||
events <- event
|
||||
}
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
defer cancel()
|
||||
|
||||
result, err := svc.FinalizeStale(ctx)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 0, result.RunsFinalized)
|
||||
require.EqualValues(t, 0, result.StepsFinalized)
|
||||
|
||||
select {
|
||||
case event := <-events:
|
||||
t.Fatalf("unexpected finalize event: %+v", event)
|
||||
default:
|
||||
}
|
||||
|
||||
_ = chat // keep seeded chat usage explicit for test readability.
|
||||
}
|
||||
|
||||
func TestService_PublishesEvents(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
owner, chat, model := seedChat(ctx, t, db)
|
||||
require.NotEqual(t, uuid.Nil, owner.ID)
|
||||
|
||||
memoryPubsub := dbpubsub.NewInMemory()
|
||||
svc := chatdebug.NewService(db, testutil.Logger(t), memoryPubsub)
|
||||
type eventResult struct {
|
||||
event chatdebug.DebugEvent
|
||||
err error
|
||||
}
|
||||
events := make(chan eventResult, 1)
|
||||
cancel, err := memoryPubsub.Subscribe(chatdebug.PubsubChannel(chat.ID),
|
||||
func(_ context.Context, message []byte) {
|
||||
var event chatdebug.DebugEvent
|
||||
unmarshalErr := json.Unmarshal(message, &event)
|
||||
events <- eventResult{event: event, err: unmarshalErr}
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
defer cancel()
|
||||
|
||||
run, err := svc.CreateRun(ctx, chatdebug.CreateRunParams{
|
||||
ChatID: chat.ID,
|
||||
ModelConfigID: model.ID,
|
||||
Kind: chatdebug.KindChatTurn,
|
||||
Status: chatdebug.StatusInProgress,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case received := <-events:
|
||||
require.NoError(t, received.err)
|
||||
require.Equal(t, chatdebug.EventKindRunUpdate, received.event.Kind)
|
||||
require.Equal(t, chat.ID, received.event.ChatID)
|
||||
require.Equal(t, run.ID, received.event.RunID)
|
||||
require.Equal(t, uuid.Nil, received.event.StepID)
|
||||
case <-time.After(testutil.WaitShort):
|
||||
t.Fatal("timed out waiting for debug event")
|
||||
}
|
||||
|
||||
select {
|
||||
case received := <-events:
|
||||
t.Fatalf("unexpected extra event: %+v", received.event)
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func newFixture(t *testing.T) testFixture {
|
||||
t.Helper()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
owner, chat, model := seedChat(ctx, t, db)
|
||||
return testFixture{
|
||||
ctx: ctx,
|
||||
db: db,
|
||||
svc: chatdebug.NewService(db, testutil.Logger(t), nil),
|
||||
owner: owner,
|
||||
chat: chat,
|
||||
model: model,
|
||||
}
|
||||
}
|
||||
|
||||
func seedChat(
|
||||
ctx context.Context,
|
||||
t *testing.T,
|
||||
db database.Store,
|
||||
) (database.User, database.Chat, database.ChatModelConfig) {
|
||||
t.Helper()
|
||||
|
||||
owner := dbgen.User(t, db, database.User{})
|
||||
providerName := "openai"
|
||||
_, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
||||
Provider: providerName,
|
||||
DisplayName: "OpenAI",
|
||||
APIKey: "test-key",
|
||||
CreatedBy: uuid.NullUUID{UUID: owner.ID, Valid: true},
|
||||
Enabled: true,
|
||||
CentralApiKeyEnabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
model, err := db.InsertChatModelConfig(ctx,
|
||||
database.InsertChatModelConfigParams{
|
||||
Provider: providerName,
|
||||
Model: "model-" + uuid.NewString(),
|
||||
DisplayName: "Test Model",
|
||||
CreatedBy: uuid.NullUUID{UUID: owner.ID, Valid: true},
|
||||
UpdatedBy: uuid.NullUUID{UUID: owner.ID, Valid: true},
|
||||
Enabled: true,
|
||||
IsDefault: true,
|
||||
ContextLimit: 128000,
|
||||
CompressionThreshold: 70,
|
||||
Options: json.RawMessage(`{}`),
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
chat := insertChat(ctx, t, db, owner.ID, model.ID)
|
||||
return owner, chat, model
|
||||
}
|
||||
|
||||
func insertChat(
|
||||
ctx context.Context,
|
||||
t *testing.T,
|
||||
db database.Store,
|
||||
ownerID uuid.UUID,
|
||||
modelID uuid.UUID,
|
||||
) database.Chat {
|
||||
t.Helper()
|
||||
|
||||
chat, err := db.InsertChat(ctx, database.InsertChatParams{
|
||||
Status: database.ChatStatusWaiting,
|
||||
OwnerID: ownerID,
|
||||
LastModelConfigID: modelID,
|
||||
Title: "chat-" + uuid.NewString(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
return chat
|
||||
}
|
||||
|
||||
func insertMessage(
|
||||
ctx context.Context,
|
||||
t *testing.T,
|
||||
db database.Store,
|
||||
chatID uuid.UUID,
|
||||
createdBy uuid.UUID,
|
||||
modelID uuid.UUID,
|
||||
role database.ChatMessageRole,
|
||||
text string,
|
||||
) database.ChatMessage {
|
||||
t.Helper()
|
||||
|
||||
parts, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{
|
||||
codersdk.ChatMessageText(text),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
messages, err := db.InsertChatMessages(ctx, database.InsertChatMessagesParams{
|
||||
ChatID: chatID,
|
||||
CreatedBy: []uuid.UUID{createdBy},
|
||||
ModelConfigID: []uuid.UUID{modelID},
|
||||
Role: []database.ChatMessageRole{role},
|
||||
Content: []string{string(parts.RawMessage)},
|
||||
ContentVersion: []int16{chatprompt.CurrentContentVersion},
|
||||
Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth},
|
||||
InputTokens: []int64{0},
|
||||
OutputTokens: []int64{0},
|
||||
TotalTokens: []int64{0},
|
||||
ReasoningTokens: []int64{0},
|
||||
CacheCreationTokens: []int64{0},
|
||||
CacheReadTokens: []int64{0},
|
||||
ContextLimit: []int64{0},
|
||||
Compressed: []bool{false},
|
||||
TotalCostMicros: []int64{0},
|
||||
RuntimeMs: []int64{0},
|
||||
ProviderResponseID: []string{""},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, messages, 1)
|
||||
return messages[0]
|
||||
}
|
||||
|
||||
func createRun(t *testing.T, fixture testFixture) database.ChatDebugRun {
|
||||
t.Helper()
|
||||
|
||||
run, err := fixture.svc.CreateRun(fixture.ctx, chatdebug.CreateRunParams{
|
||||
ChatID: fixture.chat.ID,
|
||||
ModelConfigID: fixture.model.ID,
|
||||
Kind: chatdebug.KindChatTurn,
|
||||
Status: chatdebug.StatusInProgress,
|
||||
Provider: fixture.model.Provider,
|
||||
Model: fixture.model.Model,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
return run
|
||||
}
|
||||
|
||||
func assertRunMatches(
|
||||
t *testing.T,
|
||||
run database.ChatDebugRun,
|
||||
chatID uuid.UUID,
|
||||
rootChatID uuid.UUID,
|
||||
parentChatID uuid.UUID,
|
||||
modelID uuid.UUID,
|
||||
triggerMessageID int64,
|
||||
historyTipMessageID int64,
|
||||
kind chatdebug.RunKind,
|
||||
status chatdebug.Status,
|
||||
provider string,
|
||||
model string,
|
||||
summary string,
|
||||
) {
|
||||
t.Helper()
|
||||
|
||||
require.Equal(t, chatID, run.ChatID)
|
||||
require.True(t, run.RootChatID.Valid)
|
||||
require.Equal(t, rootChatID, run.RootChatID.UUID)
|
||||
require.True(t, run.ParentChatID.Valid)
|
||||
require.Equal(t, parentChatID, run.ParentChatID.UUID)
|
||||
require.True(t, run.ModelConfigID.Valid)
|
||||
require.Equal(t, modelID, run.ModelConfigID.UUID)
|
||||
require.True(t, run.TriggerMessageID.Valid)
|
||||
require.Equal(t, triggerMessageID, run.TriggerMessageID.Int64)
|
||||
require.True(t, run.HistoryTipMessageID.Valid)
|
||||
require.Equal(t, historyTipMessageID, run.HistoryTipMessageID.Int64)
|
||||
require.Equal(t, string(kind), run.Kind)
|
||||
require.Equal(t, string(status), run.Status)
|
||||
require.True(t, run.Provider.Valid)
|
||||
require.Equal(t, provider, run.Provider.String)
|
||||
require.True(t, run.Model.Valid)
|
||||
require.Equal(t, model, run.Model.String)
|
||||
require.JSONEq(t, summary, string(run.Summary))
|
||||
require.False(t, run.StartedAt.IsZero())
|
||||
require.False(t, run.UpdatedAt.IsZero())
|
||||
require.False(t, run.FinishedAt.Valid)
|
||||
}
|
||||
@@ -1,140 +0,0 @@
|
||||
package chatdebug
|
||||
|
||||
import (
|
||||
"context"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/pubsub"
|
||||
)
|
||||
|
||||
// This compatibility shim forward-declares service and summary symbols
|
||||
// that land in later stacked branches. Delete this file once service.go
|
||||
// and summary.go are available here.
|
||||
|
||||
// Service is a placeholder for the later chat debug persistence service.
|
||||
type Service struct {
|
||||
log slog.Logger
|
||||
}
|
||||
|
||||
// CreateStepParams identifies the data recorded when a debug step starts.
|
||||
type CreateStepParams struct {
|
||||
RunID uuid.UUID
|
||||
ChatID uuid.UUID
|
||||
StepNumber int32
|
||||
Operation Operation
|
||||
Status Status
|
||||
HistoryTipMessageID int64
|
||||
NormalizedRequest any
|
||||
}
|
||||
|
||||
// UpdateStepParams identifies the data recorded when a debug step finishes.
|
||||
type UpdateStepParams struct {
|
||||
ID uuid.UUID
|
||||
ChatID uuid.UUID
|
||||
Status Status
|
||||
NormalizedResponse any
|
||||
Usage any
|
||||
Attempts []Attempt
|
||||
Error any
|
||||
Metadata any
|
||||
FinishedAt time.Time
|
||||
}
|
||||
|
||||
// NewService constructs the placeholder chat debug service.
|
||||
func NewService(_ database.Store, log slog.Logger, _ pubsub.Pubsub) *Service {
|
||||
return &Service{log: log}
|
||||
}
|
||||
|
||||
// IsEnabled reports whether debug recording is enabled for a chat owner.
|
||||
func (*Service) IsEnabled(context.Context, uuid.UUID, uuid.UUID) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// CreateStep synthesizes a debug step so recorder tests can exercise the
|
||||
// wrapper without requiring the later persistence service implementation.
|
||||
func (*Service) CreateStep(
|
||||
_ context.Context,
|
||||
params CreateStepParams,
|
||||
) (database.ChatDebugStep, error) {
|
||||
return database.ChatDebugStep{
|
||||
ID: uuid.New(),
|
||||
RunID: params.RunID,
|
||||
ChatID: params.ChatID,
|
||||
StepNumber: params.StepNumber,
|
||||
Operation: string(params.Operation),
|
||||
Status: string(params.Status),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// UpdateStep accepts final step state once recording completes.
|
||||
func (*Service) UpdateStep(
|
||||
_ context.Context,
|
||||
params UpdateStepParams,
|
||||
) (database.ChatDebugStep, error) {
|
||||
return database.ChatDebugStep{
|
||||
ID: params.ID,
|
||||
ChatID: params.ChatID,
|
||||
Status: string(params.Status),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// runRefCounts tracks how many live RunContext instances reference each
|
||||
// RunID. Cleanup of shared state (step counters) is deferred until the
|
||||
// last RunContext for a given RunID is garbage collected.
|
||||
var runRefCounts sync.Map // map[uuid.UUID]*atomic.Int32
|
||||
|
||||
func trackRunRef(runID uuid.UUID) {
|
||||
val, _ := runRefCounts.LoadOrStore(runID, &atomic.Int32{})
|
||||
counter := val.(*atomic.Int32)
|
||||
counter.Add(1)
|
||||
}
|
||||
|
||||
// releaseRunRef decrements the reference count for runID and cleans up
|
||||
// shared state when the last reference is released.
|
||||
func releaseRunRef(runID uuid.UUID) {
|
||||
val, ok := runRefCounts.Load(runID)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
counter := val.(*atomic.Int32)
|
||||
if counter.Add(-1) <= 0 {
|
||||
runRefCounts.Delete(runID)
|
||||
stepCounters.Delete(runID)
|
||||
}
|
||||
}
|
||||
|
||||
// whitespaceRun matches one or more consecutive whitespace characters.
|
||||
var whitespaceRun = regexp.MustCompile(`\s+`)
|
||||
|
||||
// TruncateLabel whitespace-normalizes and truncates text to maxLen runes.
|
||||
// Returns "" if input is empty or whitespace-only.
|
||||
func TruncateLabel(text string, maxLen int) string {
|
||||
if maxLen < 0 {
|
||||
maxLen = 0
|
||||
}
|
||||
|
||||
normalized := strings.TrimSpace(whitespaceRun.ReplaceAllString(text, " "))
|
||||
if normalized == "" || maxLen == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
if utf8.RuneCountInString(normalized) <= maxLen {
|
||||
return normalized
|
||||
}
|
||||
if maxLen == 1 {
|
||||
return "…"
|
||||
}
|
||||
|
||||
// Truncate to leave room for the trailing ellipsis within maxLen.
|
||||
runes := []rune(normalized)
|
||||
return string(runes[:maxLen-1]) + "…"
|
||||
}
|
||||
@@ -0,0 +1,218 @@
|
||||
package chatdebug
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"regexp"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/google/uuid"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
)
|
||||
|
||||
// MaxLabelLength is the default rune limit for truncated labels.
|
||||
const MaxLabelLength = 100
|
||||
|
||||
// whitespaceRun matches one or more consecutive whitespace characters.
|
||||
var whitespaceRun = regexp.MustCompile(`\s+`)
|
||||
|
||||
// TruncateLabel whitespace-normalizes and truncates text to maxLen runes.
|
||||
// Returns "" if input is empty or whitespace-only.
|
||||
func TruncateLabel(text string, maxLen int) string {
|
||||
if maxLen < 0 {
|
||||
maxLen = 0
|
||||
}
|
||||
|
||||
normalized := strings.TrimSpace(whitespaceRun.ReplaceAllString(text, " "))
|
||||
if normalized == "" || maxLen == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
if utf8.RuneCountInString(normalized) <= maxLen {
|
||||
return normalized
|
||||
}
|
||||
if maxLen == 1 {
|
||||
return "…"
|
||||
}
|
||||
|
||||
// Truncate to leave room for the trailing ellipsis within maxLen.
|
||||
runes := []rune(normalized)
|
||||
return string(runes[:maxLen-1]) + "…"
|
||||
}
|
||||
|
||||
// SeedSummary builds a base summary map with a first_message label.
|
||||
// Returns nil if label is empty.
|
||||
func SeedSummary(label string) map[string]any {
|
||||
if label == "" {
|
||||
return nil
|
||||
}
|
||||
return map[string]any{"first_message": label}
|
||||
}
|
||||
|
||||
// ExtractFirstUserText extracts the plain text content from a
|
||||
// fantasy.Prompt for the first user message. Used to derive
|
||||
// first_message labels at run creation time.
|
||||
func ExtractFirstUserText(prompt fantasy.Prompt) string {
|
||||
for _, msg := range prompt {
|
||||
if msg.Role != fantasy.MessageRoleUser {
|
||||
continue
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
for _, part := range msg.Content {
|
||||
tp, ok := fantasy.AsMessagePart[fantasy.TextPart](part)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
_, _ = sb.WriteString(tp.Text)
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// AggregateRunSummary reads all steps for the given run, computes token
|
||||
// totals, and merges them with the run's existing summary (preserving any
|
||||
// seeded first_message label). The baseSummary parameter should be the
|
||||
// current run summary (may be nil).
|
||||
func (s *Service) AggregateRunSummary(
|
||||
ctx context.Context,
|
||||
runID uuid.UUID,
|
||||
baseSummary map[string]any,
|
||||
) (map[string]any, error) {
|
||||
if runID == uuid.Nil {
|
||||
return baseSummary, nil
|
||||
}
|
||||
|
||||
steps, err := s.db.GetChatDebugStepsByRunID(chatdContext(ctx), runID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Start from a shallow copy of baseSummary to avoid mutating the
|
||||
// caller's map.
|
||||
// Capacity hint: baseSummary entries plus 8 derived keys
|
||||
// (step_count, total_input_tokens, total_output_tokens,
|
||||
// total_reasoning_tokens, total_cache_creation_tokens,
|
||||
// total_cache_read_tokens, has_error, endpoint_label).
|
||||
result := make(map[string]any, len(baseSummary)+8)
|
||||
for k, v := range baseSummary {
|
||||
result[k] = v
|
||||
}
|
||||
|
||||
// Clear derived fields before recomputing them so stale values from a
|
||||
// previous aggregation do not survive when the new totals are zero or
|
||||
// the endpoint label is unavailable.
|
||||
for _, key := range []string{
|
||||
"step_count",
|
||||
"total_input_tokens",
|
||||
"total_output_tokens",
|
||||
"total_reasoning_tokens",
|
||||
"total_cache_creation_tokens",
|
||||
"total_cache_read_tokens",
|
||||
"endpoint_label",
|
||||
"has_error",
|
||||
} {
|
||||
delete(result, key)
|
||||
}
|
||||
var (
|
||||
totalInput int64
|
||||
totalOutput int64
|
||||
totalReasoning int64
|
||||
totalCacheCreation int64
|
||||
totalCacheRead int64
|
||||
hasError bool
|
||||
)
|
||||
|
||||
for _, step := range steps {
|
||||
// Flag runs that hit a real error. Interrupted steps represent
|
||||
// user-initiated cancellation (e.g. clicking Stop) and should
|
||||
// not trigger the error indicator in the debug panel.
|
||||
// A JSONB null (used by jsonClear to erase a prior error) is
|
||||
// Valid but carries no meaningful content, so exclude it.
|
||||
errorIsReal := step.Error.Valid &&
|
||||
len(step.Error.RawMessage) > 0 &&
|
||||
!bytes.Equal(step.Error.RawMessage, []byte("null"))
|
||||
if step.Status == string(StatusError) ||
|
||||
(errorIsReal && step.Status != string(StatusInterrupted)) {
|
||||
hasError = true
|
||||
}
|
||||
if !step.Usage.Valid || len(step.Usage.RawMessage) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
var usage fantasy.Usage
|
||||
if err := json.Unmarshal(step.Usage.RawMessage, &usage); err != nil {
|
||||
s.log.Warn(ctx, "skipping malformed step usage JSON",
|
||||
slog.Error(err),
|
||||
slog.F("run_id", runID),
|
||||
slog.F("step_id", step.ID),
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
totalInput += usage.InputTokens
|
||||
totalOutput += usage.OutputTokens
|
||||
totalReasoning += usage.ReasoningTokens
|
||||
totalCacheCreation += usage.CacheCreationTokens
|
||||
totalCacheRead += usage.CacheReadTokens
|
||||
}
|
||||
|
||||
result["step_count"] = len(steps)
|
||||
result["total_input_tokens"] = totalInput
|
||||
result["total_output_tokens"] = totalOutput
|
||||
|
||||
// Only include reasoning/cache fields when non-zero to keep the
|
||||
// summary compact for the common case.
|
||||
if totalReasoning > 0 {
|
||||
result["total_reasoning_tokens"] = totalReasoning
|
||||
}
|
||||
if totalCacheCreation > 0 {
|
||||
result["total_cache_creation_tokens"] = totalCacheCreation
|
||||
}
|
||||
if totalCacheRead > 0 {
|
||||
result["total_cache_read_tokens"] = totalCacheRead
|
||||
}
|
||||
|
||||
if hasError {
|
||||
result["has_error"] = true
|
||||
}
|
||||
|
||||
// Derive endpoint_label from the first completed attempt's path
|
||||
// across all steps. This gives the debug panel a meaningful
|
||||
// identifier like "POST /v1/messages" for the run row.
|
||||
if label := extractEndpointLabel(steps); label != "" {
|
||||
result["endpoint_label"] = label
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// extractEndpointLabel scans steps for the first completed attempt with a
|
||||
// non-empty path and returns "METHOD /path" (or just "/path").
|
||||
func extractEndpointLabel(steps []database.ChatDebugStep) string {
|
||||
for _, step := range steps {
|
||||
if len(step.Attempts) == 0 {
|
||||
continue
|
||||
}
|
||||
var attempts []Attempt
|
||||
if err := json.Unmarshal(step.Attempts, &attempts); err != nil {
|
||||
continue
|
||||
}
|
||||
for _, a := range attempts {
|
||||
if a.Status != attemptStatusCompleted || a.Path == "" {
|
||||
continue
|
||||
}
|
||||
if a.Method != "" {
|
||||
return a.Method + " " + a.Path
|
||||
}
|
||||
return a.Path
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@@ -0,0 +1,416 @@
|
||||
package chatdebug_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"unicode/utf8"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatdebug"
|
||||
)
|
||||
|
||||
func TestTruncateLabel(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
maxLen int
|
||||
want string
|
||||
}{
|
||||
{name: "Empty", input: "", maxLen: 10, want: ""},
|
||||
{name: "WhitespaceOnly", input: " \t\n ", maxLen: 10, want: ""},
|
||||
{name: "ShortText", input: "hello world", maxLen: 20, want: "hello world"},
|
||||
{name: "ExactLength", input: "abcde", maxLen: 5, want: "abcde"},
|
||||
{name: "LongTextTruncated", input: "abcdefghij", maxLen: 5, want: "abcd…"},
|
||||
{name: "NegativeMaxLen", input: "hello", maxLen: -1, want: ""},
|
||||
{name: "ZeroMaxLen", input: "hello", maxLen: 0, want: ""},
|
||||
{name: "SingleRuneLimit", input: "hello", maxLen: 1, want: "…"},
|
||||
{name: "MultipleWhitespaceRuns", input: " hello world \t again ", maxLen: 100, want: "hello world again"},
|
||||
{name: "UnicodeRunes", input: "こんにちは世界", maxLen: 3, want: "こん…"},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := chatdebug.TruncateLabel(tc.input, tc.maxLen)
|
||||
require.Equal(t, tc.want, got)
|
||||
require.LessOrEqual(t, utf8.RuneCountInString(got), maxInt(tc.maxLen, 0))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func maxInt(a, b int) int {
|
||||
if a > b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func TestSeedSummary(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("NonEmptyLabel", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := chatdebug.SeedSummary("hello world")
|
||||
require.Equal(t, map[string]any{"first_message": "hello world"}, got)
|
||||
})
|
||||
|
||||
t.Run("EmptyLabel", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := chatdebug.SeedSummary("")
|
||||
require.Nil(t, got)
|
||||
})
|
||||
}
|
||||
|
||||
func TestExtractFirstUserText(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("EmptyPrompt", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := chatdebug.ExtractFirstUserText(fantasy.Prompt{})
|
||||
require.Equal(t, "", got)
|
||||
})
|
||||
|
||||
t.Run("NoUserMessages", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
prompt := fantasy.Prompt{
|
||||
{
|
||||
Role: fantasy.MessageRoleSystem,
|
||||
Content: []fantasy.MessagePart{fantasy.TextPart{Text: "system"}},
|
||||
},
|
||||
{
|
||||
Role: fantasy.MessageRoleAssistant,
|
||||
Content: []fantasy.MessagePart{fantasy.TextPart{Text: "assistant"}},
|
||||
},
|
||||
}
|
||||
got := chatdebug.ExtractFirstUserText(prompt)
|
||||
require.Equal(t, "", got)
|
||||
})
|
||||
|
||||
t.Run("FirstUserMessageMixedParts", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
prompt := fantasy.Prompt{
|
||||
{
|
||||
Role: fantasy.MessageRoleUser,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.TextPart{Text: "hello "},
|
||||
fantasy.FilePart{Filename: "test.png"},
|
||||
fantasy.TextPart{Text: "world"},
|
||||
},
|
||||
},
|
||||
}
|
||||
got := chatdebug.ExtractFirstUserText(prompt)
|
||||
require.Equal(t, "hello world", got)
|
||||
})
|
||||
|
||||
t.Run("MultipleUserMessagesReturnsFirst", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
prompt := fantasy.Prompt{
|
||||
{
|
||||
Role: fantasy.MessageRoleSystem,
|
||||
Content: []fantasy.MessagePart{fantasy.TextPart{Text: "system"}},
|
||||
},
|
||||
{
|
||||
Role: fantasy.MessageRoleUser,
|
||||
Content: []fantasy.MessagePart{fantasy.TextPart{Text: "first"}},
|
||||
},
|
||||
{
|
||||
Role: fantasy.MessageRoleUser,
|
||||
Content: []fantasy.MessagePart{fantasy.TextPart{Text: "second"}},
|
||||
},
|
||||
}
|
||||
got := chatdebug.ExtractFirstUserText(prompt)
|
||||
require.Equal(t, "first", got)
|
||||
})
|
||||
}
|
||||
|
||||
func TestService_AggregateRunSummary(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("NilRunID", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
fixture := newFixture(t)
|
||||
got, err := fixture.svc.AggregateRunSummary(fixture.ctx, uuid.Nil, nil)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, got)
|
||||
})
|
||||
|
||||
t.Run("NilBaseSummary", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
fixture := newFixture(t)
|
||||
run := createRun(t, fixture)
|
||||
|
||||
// Create a step with usage.
|
||||
step := createTestStep(t, fixture, run.ID)
|
||||
updateTestStepWithUsage(t, fixture, step.ID, 10, 5, 0, 0)
|
||||
|
||||
got, err := fixture.svc.AggregateRunSummary(fixture.ctx, run.ID, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, got)
|
||||
require.EqualValues(t, 1, got["step_count"])
|
||||
require.EqualValues(t, int64(10), got["total_input_tokens"])
|
||||
require.EqualValues(t, int64(5), got["total_output_tokens"])
|
||||
})
|
||||
|
||||
t.Run("PreservesFirstMessage", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
fixture := newFixture(t)
|
||||
run := createRun(t, fixture)
|
||||
|
||||
step := createTestStep(t, fixture, run.ID)
|
||||
updateTestStepWithUsage(t, fixture, step.ID, 20, 10, 0, 0)
|
||||
|
||||
base := map[string]any{"first_message": "hello world"}
|
||||
got, err := fixture.svc.AggregateRunSummary(fixture.ctx, run.ID, base)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "hello world", got["first_message"])
|
||||
require.EqualValues(t, 1, got["step_count"])
|
||||
require.EqualValues(t, int64(20), got["total_input_tokens"])
|
||||
require.EqualValues(t, int64(10), got["total_output_tokens"])
|
||||
})
|
||||
|
||||
t.Run("ClearsStaleDerivedFields", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
fixture := newFixture(t)
|
||||
run := createRun(t, fixture)
|
||||
|
||||
step := createTestStep(t, fixture, run.ID)
|
||||
updateTestStepWithUsage(t, fixture, step.ID, 10, 5, 0, 0)
|
||||
|
||||
base := map[string]any{
|
||||
"first_message": "hello world",
|
||||
"step_count": 9,
|
||||
"total_input_tokens": 999,
|
||||
"total_output_tokens": 888,
|
||||
"total_reasoning_tokens": 777,
|
||||
"total_cache_creation_tokens": 100,
|
||||
"total_cache_read_tokens": 200,
|
||||
"has_error": true,
|
||||
"endpoint_label": "POST /stale",
|
||||
}
|
||||
|
||||
got, err := fixture.svc.AggregateRunSummary(fixture.ctx, run.ID, base)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "hello world", got["first_message"])
|
||||
require.EqualValues(t, 1, got["step_count"])
|
||||
require.EqualValues(t, int64(10), got["total_input_tokens"])
|
||||
require.EqualValues(t, int64(5), got["total_output_tokens"])
|
||||
// Stale reasoning tokens must be cleared because the step
|
||||
// has zero reasoning tokens.
|
||||
require.NotContains(t, got, "total_reasoning_tokens")
|
||||
require.NotContains(t, got, "total_cache_creation_tokens")
|
||||
require.NotContains(t, got, "total_cache_read_tokens")
|
||||
// has_error must be cleared because the step is not in error
|
||||
// status and has no error payload.
|
||||
require.NotContains(t, got, "has_error")
|
||||
require.NotContains(t, got, "endpoint_label")
|
||||
})
|
||||
|
||||
t.Run("RecomputesHasErrorAndCompletedEndpointLabel", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
fixture := newFixture(t)
|
||||
run := createRun(t, fixture)
|
||||
|
||||
step1 := createTestStep(t, fixture, run.ID)
|
||||
_, err := fixture.svc.UpdateStep(fixture.ctx, chatdebug.UpdateStepParams{
|
||||
ID: step1.ID,
|
||||
ChatID: fixture.chat.ID,
|
||||
Status: chatdebug.StatusError,
|
||||
Attempts: []chatdebug.Attempt{{
|
||||
Number: 1,
|
||||
Status: "failed",
|
||||
Method: "POST",
|
||||
Path: "/failed",
|
||||
}},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
step2 := createTestStepN(t, fixture, run.ID, 2)
|
||||
_, err = fixture.svc.UpdateStep(fixture.ctx, chatdebug.UpdateStepParams{
|
||||
ID: step2.ID,
|
||||
ChatID: fixture.chat.ID,
|
||||
Status: chatdebug.StatusCompleted,
|
||||
Attempts: []chatdebug.Attempt{{
|
||||
Number: 1,
|
||||
Status: "completed",
|
||||
Method: "POST",
|
||||
Path: "/v1/messages",
|
||||
}},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
got, err := fixture.svc.AggregateRunSummary(fixture.ctx, run.ID, nil)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, true, got["has_error"])
|
||||
require.Equal(t, "POST /v1/messages", got["endpoint_label"])
|
||||
})
|
||||
|
||||
t.Run("MultipleStepsSumTokens", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
fixture := newFixture(t)
|
||||
run := createRun(t, fixture)
|
||||
|
||||
step1 := createTestStep(t, fixture, run.ID)
|
||||
updateTestStepWithUsage(t, fixture, step1.ID, 10, 5, 2, 3)
|
||||
|
||||
step2 := createTestStepN(t, fixture, run.ID, 2)
|
||||
updateTestStepWithUsage(t, fixture, step2.ID, 15, 7, 1, 4)
|
||||
|
||||
got, err := fixture.svc.AggregateRunSummary(fixture.ctx, run.ID, nil)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 2, got["step_count"])
|
||||
require.EqualValues(t, int64(25), got["total_input_tokens"])
|
||||
require.EqualValues(t, int64(12), got["total_output_tokens"])
|
||||
require.EqualValues(t, int64(3), got["total_cache_creation_tokens"])
|
||||
require.EqualValues(t, int64(7), got["total_cache_read_tokens"])
|
||||
})
|
||||
|
||||
t.Run("StepWithNilUsageContributesZeroTokens", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
fixture := newFixture(t)
|
||||
run := createRun(t, fixture)
|
||||
|
||||
// Step with usage.
|
||||
step1 := createTestStep(t, fixture, run.ID)
|
||||
updateTestStepWithUsage(t, fixture, step1.ID, 10, 5, 0, 0)
|
||||
|
||||
// Step without usage (just complete it, no usage).
|
||||
step2 := createTestStepN(t, fixture, run.ID, 2)
|
||||
_, err := fixture.svc.UpdateStep(fixture.ctx, chatdebug.UpdateStepParams{
|
||||
ID: step2.ID,
|
||||
ChatID: fixture.chat.ID,
|
||||
Status: chatdebug.StatusCompleted,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
got, err := fixture.svc.AggregateRunSummary(fixture.ctx, run.ID, nil)
|
||||
require.NoError(t, err)
|
||||
// Both steps are counted even though one has no usage.
|
||||
require.EqualValues(t, 2, got["step_count"])
|
||||
require.EqualValues(t, int64(10), got["total_input_tokens"])
|
||||
require.EqualValues(t, int64(5), got["total_output_tokens"])
|
||||
})
|
||||
|
||||
t.Run("ZeroCacheTotalsOmitCacheFields", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
fixture := newFixture(t)
|
||||
run := createRun(t, fixture)
|
||||
|
||||
step := createTestStep(t, fixture, run.ID)
|
||||
updateTestStepWithUsage(t, fixture, step.ID, 10, 5, 0, 0)
|
||||
|
||||
got, err := fixture.svc.AggregateRunSummary(fixture.ctx, run.ID, nil)
|
||||
require.NoError(t, err)
|
||||
_, hasCacheCreation := got["total_cache_creation_tokens"]
|
||||
_, hasCacheRead := got["total_cache_read_tokens"]
|
||||
require.False(t, hasCacheCreation,
|
||||
"cache creation tokens should be omitted when zero")
|
||||
require.False(t, hasCacheRead,
|
||||
"cache read tokens should be omitted when zero")
|
||||
})
|
||||
|
||||
t.Run("ReasoningTokensSummedAcrossSteps", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
fixture := newFixture(t)
|
||||
run := createRun(t, fixture)
|
||||
|
||||
step1 := createTestStep(t, fixture, run.ID)
|
||||
updateTestStepWithFullUsage(t, fixture, step1.ID, 10, 5, 20, 0, 0)
|
||||
|
||||
step2 := createTestStepN(t, fixture, run.ID, 2)
|
||||
updateTestStepWithFullUsage(t, fixture, step2.ID, 15, 7, 30, 0, 0)
|
||||
|
||||
got, err := fixture.svc.AggregateRunSummary(fixture.ctx, run.ID, nil)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 2, got["step_count"])
|
||||
require.EqualValues(t, int64(25), got["total_input_tokens"])
|
||||
require.EqualValues(t, int64(12), got["total_output_tokens"])
|
||||
require.EqualValues(t, int64(50), got["total_reasoning_tokens"],
|
||||
"reasoning tokens should be summed across steps")
|
||||
})
|
||||
|
||||
t.Run("ZeroReasoningTokensOmitsField", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
fixture := newFixture(t)
|
||||
run := createRun(t, fixture)
|
||||
|
||||
step := createTestStep(t, fixture, run.ID)
|
||||
updateTestStepWithFullUsage(t, fixture, step.ID, 10, 5, 0, 0, 0)
|
||||
|
||||
got, err := fixture.svc.AggregateRunSummary(fixture.ctx, run.ID, nil)
|
||||
require.NoError(t, err)
|
||||
_, hasReasoning := got["total_reasoning_tokens"]
|
||||
require.False(t, hasReasoning,
|
||||
"reasoning tokens should be omitted when zero")
|
||||
})
|
||||
}
|
||||
|
||||
// createTestStep is a thin helper that creates a debug step with
|
||||
// step number 1 for the given run.
|
||||
func createTestStep(
|
||||
t *testing.T,
|
||||
fixture testFixture,
|
||||
runID uuid.UUID,
|
||||
) database.ChatDebugStep {
|
||||
t.Helper()
|
||||
return createTestStepN(t, fixture, runID, 1)
|
||||
}
|
||||
|
||||
// createTestStepN creates a debug step with the given step number.
|
||||
func createTestStepN(
|
||||
t *testing.T,
|
||||
fixture testFixture,
|
||||
runID uuid.UUID,
|
||||
stepNumber int32,
|
||||
) database.ChatDebugStep {
|
||||
t.Helper()
|
||||
step, err := fixture.svc.CreateStep(fixture.ctx, chatdebug.CreateStepParams{
|
||||
RunID: runID,
|
||||
ChatID: fixture.chat.ID,
|
||||
StepNumber: stepNumber,
|
||||
Operation: chatdebug.OperationGenerate,
|
||||
Status: chatdebug.StatusInProgress,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
return step
|
||||
}
|
||||
|
||||
// updateTestStepWithUsage completes a step and sets token usage fields.
|
||||
func updateTestStepWithUsage(
|
||||
t *testing.T,
|
||||
fixture testFixture,
|
||||
stepID uuid.UUID,
|
||||
input, output, cacheCreation, cacheRead int64,
|
||||
) {
|
||||
t.Helper()
|
||||
updateTestStepWithFullUsage(t, fixture, stepID, input, output, 0, cacheCreation, cacheRead)
|
||||
}
|
||||
|
||||
// updateTestStepWithFullUsage completes a step with all token usage
|
||||
// fields, including reasoning tokens.
|
||||
func updateTestStepWithFullUsage(
|
||||
t *testing.T,
|
||||
fixture testFixture,
|
||||
stepID uuid.UUID,
|
||||
input, output, reasoning, cacheCreation, cacheRead int64,
|
||||
) {
|
||||
t.Helper()
|
||||
_, err := fixture.svc.UpdateStep(fixture.ctx, chatdebug.UpdateStepParams{
|
||||
ID: stepID,
|
||||
ChatID: fixture.chat.ID,
|
||||
Status: chatdebug.StatusCompleted,
|
||||
Usage: map[string]any{
|
||||
"input_tokens": input,
|
||||
"output_tokens": output,
|
||||
"reasoning_tokens": reasoning,
|
||||
"cache_creation_tokens": cacheCreation,
|
||||
"cache_read_tokens": cacheRead,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
@@ -131,7 +131,16 @@ type DebugEvent struct {
|
||||
StepID uuid.UUID `json:"step_id"`
|
||||
}
|
||||
|
||||
// BroadcastPubsubChannel is the shared pubsub channel for chat-debug events
|
||||
// that are not scoped to a single chat, such as stale finalization sweeps.
|
||||
const BroadcastPubsubChannel = "chat_debug:broadcast"
|
||||
|
||||
// PubsubChannel returns the chat-scoped pubsub channel for debug events.
|
||||
// Nil chat IDs use the shared broadcast channel so publishers and subscribers
|
||||
// can coordinate through one discoverable helper.
|
||||
func PubsubChannel(chatID uuid.UUID) string {
|
||||
if chatID == uuid.Nil {
|
||||
return BroadcastPubsubChannel
|
||||
}
|
||||
return "chat_debug:" + chatID.String()
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user