Compare commits

..

1 Commits

Author SHA1 Message Date
Yevhenii Shcherbina 7bfa79f75d feat: add allow-byok option for ai-gateway 2026-04-12 14:12:01 +00:00
50 changed files with 223 additions and 6984 deletions
-7
View File
@@ -211,13 +211,6 @@ AI BRIDGE PROXY OPTIONS:
certificates not trusted by the system. If not provided, the system
certificate pool is used.
CHAT OPTIONS:
Configure the background chat processing daemon.
--chat-debug-logging-enabled bool, $CODER_CHAT_DEBUG_LOGGING_ENABLED (default: false)
Force chat debug logging on for every chat, bypassing the runtime
admin and user opt-in settings.
CLIENT OPTIONS:
These options change the behavior of how clients interact with the Coder.
Clients include the Coder CLI, Coder Desktop, IDE extensions, and the web UI.
-4
View File
@@ -757,10 +757,6 @@ chat:
# How many pending chats a worker should acquire per polling cycle.
# (default: 10, type: int)
acquireBatchSize: 10
# Force chat debug logging on for every chat, bypassing the runtime admin and user
# opt-in settings.
# (default: false, type: bool)
debugLoggingEnabled: false
aibridge:
# Whether to start an in-memory aibridged instance.
# (default: false, type: bool)
-3
View File
@@ -14691,9 +14691,6 @@ const docTemplate = `{
"properties": {
"acquire_batch_size": {
"type": "integer"
},
"debug_logging_enabled": {
"type": "boolean"
}
}
},
-3
View File
@@ -13204,9 +13204,6 @@
"properties": {
"acquire_batch_size": {
"type": "integer"
},
"debug_logging_enabled": {
"type": "boolean"
}
}
},
-98
View File
@@ -1533,22 +1533,6 @@ func nullInt64Ptr(v sql.NullInt64) *int64 {
return &value
}
func nullStringPtr(v sql.NullString) *string {
if !v.Valid {
return nil
}
value := v.String
return &value
}
func nullTimePtr(v sql.NullTime) *time.Time {
if !v.Valid {
return nil
}
value := v.Time
return &value
}
// Chat converts a database.Chat to a codersdk.Chat. It coalesces
// nil slices and maps to empty values for JSON serialization and
// derives RootChatID from the parent chain when not explicitly set.
@@ -1635,88 +1619,6 @@ func Chat(c database.Chat, diffStatus *database.ChatDiffStatus, files []database
return chat
}
func chatDebugAttempts(raw json.RawMessage) []map[string]any {
if len(raw) == 0 {
return nil
}
var attempts []map[string]any
if err := json.Unmarshal(raw, &attempts); err != nil {
return []map[string]any{{
"error": "malformed attempts payload",
"raw": string(raw),
}}
}
return attempts
}
// rawJSONObject deserializes a JSON object payload for debug display.
// If the payload is malformed, it returns a map with "error" and "raw"
// keys preserving the original content for diagnostics. Callers that
// consume the result programmatically should check for the "error" key.
func rawJSONObject(raw json.RawMessage) map[string]any {
if len(raw) == 0 {
return nil
}
var object map[string]any
if err := json.Unmarshal(raw, &object); err != nil {
return map[string]any{
"error": "malformed debug payload",
"raw": string(raw),
}
}
return object
}
func nullRawJSONObject(raw pqtype.NullRawMessage) map[string]any {
if !raw.Valid {
return nil
}
return rawJSONObject(raw.RawMessage)
}
// ChatDebugRunSummary converts a database.ChatDebugRun to a
// codersdk.ChatDebugRunSummary.
func ChatDebugRunSummary(r database.ChatDebugRun) codersdk.ChatDebugRunSummary {
return codersdk.ChatDebugRunSummary{
ID: r.ID,
ChatID: r.ChatID,
Kind: codersdk.ChatDebugRunKind(r.Kind),
Status: codersdk.ChatDebugStatus(r.Status),
Provider: nullStringPtr(r.Provider),
Model: nullStringPtr(r.Model),
Summary: rawJSONObject(r.Summary),
StartedAt: r.StartedAt,
UpdatedAt: r.UpdatedAt,
FinishedAt: nullTimePtr(r.FinishedAt),
}
}
// ChatDebugStep converts a database.ChatDebugStep to a
// codersdk.ChatDebugStep.
func ChatDebugStep(s database.ChatDebugStep) codersdk.ChatDebugStep {
return codersdk.ChatDebugStep{
ID: s.ID,
RunID: s.RunID,
ChatID: s.ChatID,
StepNumber: s.StepNumber,
Operation: codersdk.ChatDebugStepOperation(s.Operation),
Status: codersdk.ChatDebugStatus(s.Status),
HistoryTipMessageID: nullInt64Ptr(s.HistoryTipMessageID),
AssistantMessageID: nullInt64Ptr(s.AssistantMessageID),
NormalizedRequest: rawJSONObject(s.NormalizedRequest),
NormalizedResponse: nullRawJSONObject(s.NormalizedResponse),
Usage: nullRawJSONObject(s.Usage),
Attempts: chatDebugAttempts(s.Attempts),
Error: nullRawJSONObject(s.Error),
Metadata: rawJSONObject(s.Metadata),
StartedAt: s.StartedAt,
UpdatedAt: s.UpdatedAt,
FinishedAt: nullTimePtr(s.FinishedAt),
}
}
// ChatRows converts a slice of database.GetChatsRow (which embeds
// Chat plus HasUnread) to codersdk.Chat, looking up diff statuses
// from the provided map. When diffStatusesByChatID is non-nil,
-225
View File
@@ -210,231 +210,6 @@ func TestTemplateVersionParameter_BadDescription(t *testing.T) {
req.NotEmpty(sdk.DescriptionPlaintext, "broke the markdown parser with %v", desc)
}
func TestChatDebugRunSummary(t *testing.T) {
t.Parallel()
startedAt := time.Now().UTC().Round(time.Second)
finishedAt := startedAt.Add(5 * time.Second)
run := database.ChatDebugRun{
ID: uuid.New(),
ChatID: uuid.New(),
Kind: "chat_turn",
Status: "completed",
Provider: sql.NullString{String: "openai", Valid: true},
Model: sql.NullString{String: "gpt-4o", Valid: true},
Summary: json.RawMessage(`{"step_count":3,"has_error":false}`),
StartedAt: startedAt,
UpdatedAt: finishedAt,
FinishedAt: sql.NullTime{Time: finishedAt, Valid: true},
}
sdk := db2sdk.ChatDebugRunSummary(run)
require.Equal(t, run.ID, sdk.ID)
require.Equal(t, run.ChatID, sdk.ChatID)
require.Equal(t, codersdk.ChatDebugRunKindChatTurn, sdk.Kind)
require.Equal(t, codersdk.ChatDebugStatusCompleted, sdk.Status)
require.NotNil(t, sdk.Provider)
require.Equal(t, "openai", *sdk.Provider)
require.NotNil(t, sdk.Model)
require.Equal(t, "gpt-4o", *sdk.Model)
require.Equal(t, map[string]any{"step_count": float64(3), "has_error": false}, sdk.Summary)
require.Equal(t, startedAt, sdk.StartedAt)
require.Equal(t, finishedAt, sdk.UpdatedAt)
require.NotNil(t, sdk.FinishedAt)
require.Equal(t, finishedAt, *sdk.FinishedAt)
}
func TestChatDebugRunSummary_NullableFieldsNil(t *testing.T) {
t.Parallel()
run := database.ChatDebugRun{
ID: uuid.New(),
ChatID: uuid.New(),
Kind: "title_generation",
Status: "in_progress",
Summary: json.RawMessage(`{}`),
StartedAt: time.Now().UTC(),
UpdatedAt: time.Now().UTC(),
}
sdk := db2sdk.ChatDebugRunSummary(run)
require.Nil(t, sdk.Provider, "NULL Provider should map to nil")
require.Nil(t, sdk.Model, "NULL Model should map to nil")
require.Nil(t, sdk.FinishedAt, "NULL FinishedAt should map to nil")
}
func TestChatDebugStep(t *testing.T) {
t.Parallel()
startedAt := time.Now().UTC().Round(time.Second)
finishedAt := startedAt.Add(2 * time.Second)
attempts := json.RawMessage(`[
{
"attempt_number": 1,
"status": "completed",
"raw_request": {"url": "https://example.com"},
"raw_response": {"status": "200"},
"duration_ms": 123,
"started_at": "2026-03-01T10:00:01Z",
"finished_at": "2026-03-01T10:00:02Z"
}
]`)
step := database.ChatDebugStep{
ID: uuid.New(),
RunID: uuid.New(),
ChatID: uuid.New(),
StepNumber: 1,
Operation: "stream",
Status: "completed",
NormalizedRequest: json.RawMessage(`{"messages":[]}`),
Attempts: attempts,
Metadata: json.RawMessage(`{"provider":"openai"}`),
StartedAt: startedAt,
UpdatedAt: finishedAt,
FinishedAt: sql.NullTime{Time: finishedAt, Valid: true},
}
sdk := db2sdk.ChatDebugStep(step)
// Verify all scalar fields are mapped correctly.
require.Equal(t, step.ID, sdk.ID)
require.Equal(t, step.RunID, sdk.RunID)
require.Equal(t, step.ChatID, sdk.ChatID)
require.Equal(t, step.StepNumber, sdk.StepNumber)
require.Equal(t, codersdk.ChatDebugStepOperationStream, sdk.Operation)
require.Equal(t, codersdk.ChatDebugStatusCompleted, sdk.Status)
require.Equal(t, startedAt, sdk.StartedAt)
require.Equal(t, finishedAt, sdk.UpdatedAt)
require.Equal(t, &finishedAt, sdk.FinishedAt)
// Verify JSON object fields are deserialized.
require.NotNil(t, sdk.NormalizedRequest)
require.Equal(t, map[string]any{"messages": []any{}}, sdk.NormalizedRequest)
require.NotNil(t, sdk.Metadata)
require.Equal(t, map[string]any{"provider": "openai"}, sdk.Metadata)
// Verify nullable fields are nil when the DB row has NULL values.
require.Nil(t, sdk.HistoryTipMessageID, "NULL HistoryTipMessageID should map to nil")
require.Nil(t, sdk.AssistantMessageID, "NULL AssistantMessageID should map to nil")
require.Nil(t, sdk.NormalizedResponse, "NULL NormalizedResponse should map to nil")
require.Nil(t, sdk.Usage, "NULL Usage should map to nil")
require.Nil(t, sdk.Error, "NULL Error should map to nil")
// Verify attempts are preserved with all fields.
require.Len(t, sdk.Attempts, 1)
require.Equal(t, float64(1), sdk.Attempts[0]["attempt_number"])
require.Equal(t, "completed", sdk.Attempts[0]["status"])
require.Equal(t, float64(123), sdk.Attempts[0]["duration_ms"])
require.Equal(t, map[string]any{"url": "https://example.com"}, sdk.Attempts[0]["raw_request"])
require.Equal(t, map[string]any{"status": "200"}, sdk.Attempts[0]["raw_response"])
}
func TestChatDebugStep_NullableFieldsPopulated(t *testing.T) {
t.Parallel()
tipID := int64(42)
asstID := int64(99)
step := database.ChatDebugStep{
ID: uuid.New(),
RunID: uuid.New(),
ChatID: uuid.New(),
StepNumber: 2,
Operation: "generate",
Status: "completed",
HistoryTipMessageID: sql.NullInt64{Int64: tipID, Valid: true},
AssistantMessageID: sql.NullInt64{Int64: asstID, Valid: true},
NormalizedRequest: json.RawMessage(`{}`),
NormalizedResponse: pqtype.NullRawMessage{RawMessage: json.RawMessage(`{"text":"hi"}`), Valid: true},
Usage: pqtype.NullRawMessage{RawMessage: json.RawMessage(`{"tokens":10}`), Valid: true},
Error: pqtype.NullRawMessage{RawMessage: json.RawMessage(`{"code":"rate_limit"}`), Valid: true},
Attempts: json.RawMessage(`[]`),
Metadata: json.RawMessage(`{}`),
StartedAt: time.Now().UTC(),
UpdatedAt: time.Now().UTC(),
}
sdk := db2sdk.ChatDebugStep(step)
require.NotNil(t, sdk.HistoryTipMessageID)
require.Equal(t, tipID, *sdk.HistoryTipMessageID)
require.NotNil(t, sdk.AssistantMessageID)
require.Equal(t, asstID, *sdk.AssistantMessageID)
require.NotNil(t, sdk.NormalizedResponse)
require.Equal(t, map[string]any{"text": "hi"}, sdk.NormalizedResponse)
require.NotNil(t, sdk.Usage)
require.Equal(t, map[string]any{"tokens": float64(10)}, sdk.Usage)
require.NotNil(t, sdk.Error)
require.Equal(t, map[string]any{"code": "rate_limit"}, sdk.Error)
}
func TestChatDebugStep_PreservesMalformedAttempts(t *testing.T) {
t.Parallel()
step := database.ChatDebugStep{
ID: uuid.New(),
RunID: uuid.New(),
ChatID: uuid.New(),
StepNumber: 1,
Operation: "stream",
Status: "completed",
NormalizedRequest: json.RawMessage(`{"messages":[]}`),
Attempts: json.RawMessage(`{"bad":true}`),
Metadata: json.RawMessage(`{"provider":"openai"}`),
StartedAt: time.Now().UTC(),
UpdatedAt: time.Now().UTC(),
}
sdk := db2sdk.ChatDebugStep(step)
require.Len(t, sdk.Attempts, 1)
require.Equal(t, "malformed attempts payload", sdk.Attempts[0]["error"])
require.Equal(t, `{"bad":true}`, sdk.Attempts[0]["raw"])
}
func TestChatDebugRunSummary_PreservesMalformedSummary(t *testing.T) {
t.Parallel()
run := database.ChatDebugRun{
ID: uuid.New(),
ChatID: uuid.New(),
Kind: "chat_turn",
Status: "completed",
Summary: json.RawMessage(`not-an-object`),
StartedAt: time.Now().UTC(),
UpdatedAt: time.Now().UTC(),
}
sdk := db2sdk.ChatDebugRunSummary(run)
require.Equal(t, "malformed debug payload", sdk.Summary["error"])
require.Equal(t, "not-an-object", sdk.Summary["raw"])
}
func TestChatDebugStep_PreservesMalformedRequest(t *testing.T) {
t.Parallel()
step := database.ChatDebugStep{
ID: uuid.New(),
RunID: uuid.New(),
ChatID: uuid.New(),
StepNumber: 1,
Operation: "stream",
Status: "completed",
NormalizedRequest: json.RawMessage(`[1,2,3]`),
Attempts: json.RawMessage(`[]`),
Metadata: json.RawMessage(`"just-a-string"`),
StartedAt: time.Now().UTC(),
UpdatedAt: time.Now().UTC(),
}
sdk := db2sdk.ChatDebugStep(step)
require.Equal(t, "malformed debug payload", sdk.NormalizedRequest["error"])
require.Equal(t, "[1,2,3]", sdk.NormalizedRequest["raw"])
require.Equal(t, "malformed debug payload", sdk.Metadata["error"])
require.Equal(t, `"just-a-string"`, sdk.Metadata["raw"])
}
func TestAIBridgeInterception(t *testing.T) {
t.Parallel()
-161
View File
@@ -1860,28 +1860,6 @@ func (q *querier) DeleteApplicationConnectAPIKeysByUserID(ctx context.Context, u
return q.db.DeleteApplicationConnectAPIKeysByUserID(ctx, userID)
}
func (q *querier) DeleteChatDebugDataAfterMessageID(ctx context.Context, arg database.DeleteChatDebugDataAfterMessageIDParams) (int64, error) {
chat, err := q.db.GetChatByID(ctx, arg.ChatID)
if err != nil {
return 0, err
}
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
return 0, err
}
return q.db.DeleteChatDebugDataAfterMessageID(ctx, arg)
}
func (q *querier) DeleteChatDebugDataByChatID(ctx context.Context, chatID uuid.UUID) (int64, error) {
chat, err := q.db.GetChatByID(ctx, chatID)
if err != nil {
return 0, err
}
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
return 0, err
}
return q.db.DeleteChatDebugDataByChatID(ctx, chatID)
}
func (q *querier) DeleteChatModelConfigByID(ctx context.Context, id uuid.UUID) error {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return err
@@ -2369,14 +2347,6 @@ func (q *querier) FetchVolumesResourceMonitorsUpdatedAfter(ctx context.Context,
return q.db.FetchVolumesResourceMonitorsUpdatedAfter(ctx, updatedAt)
}
func (q *querier) FinalizeStaleChatDebugRows(ctx context.Context, updatedBefore time.Time) (database.FinalizeStaleChatDebugRowsRow, error) {
// Background sweep operates across all chats.
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceChat); err != nil {
return database.FinalizeStaleChatDebugRowsRow{}, err
}
return q.db.FinalizeStaleChatDebugRows(ctx, updatedBefore)
}
func (q *querier) FindMatchingPresetID(ctx context.Context, arg database.FindMatchingPresetIDParams) (uuid.UUID, error) {
_, err := q.GetTemplateVersionByID(ctx, arg.TemplateVersionID)
if err != nil {
@@ -2585,59 +2555,6 @@ func (q *querier) GetChatCostSummary(ctx context.Context, arg database.GetChatCo
return q.db.GetChatCostSummary(ctx, arg)
}
func (q *querier) GetChatDebugLoggingAllowUsers(ctx context.Context) (bool, error) {
// The allow-users flag is a deployment-wide setting read by any
// authenticated chat user. We only require that an explicit actor
// is present in the context so unauthenticated calls fail closed.
if _, ok := ActorFromContext(ctx); !ok {
return false, ErrNoActor
}
return q.db.GetChatDebugLoggingAllowUsers(ctx)
}
func (q *querier) GetChatDebugRunByID(ctx context.Context, id uuid.UUID) (database.ChatDebugRun, error) {
run, err := q.db.GetChatDebugRunByID(ctx, id)
if err != nil {
return database.ChatDebugRun{}, err
}
// Authorize via the owning chat.
chat, err := q.db.GetChatByID(ctx, run.ChatID)
if err != nil {
return database.ChatDebugRun{}, err
}
if err := q.authorizeContext(ctx, policy.ActionRead, chat); err != nil {
return database.ChatDebugRun{}, err
}
return run, nil
}
func (q *querier) GetChatDebugRunsByChatID(ctx context.Context, arg database.GetChatDebugRunsByChatIDParams) ([]database.ChatDebugRun, error) {
chat, err := q.db.GetChatByID(ctx, arg.ChatID)
if err != nil {
return nil, err
}
if err := q.authorizeContext(ctx, policy.ActionRead, chat); err != nil {
return nil, err
}
return q.db.GetChatDebugRunsByChatID(ctx, arg)
}
func (q *querier) GetChatDebugStepsByRunID(ctx context.Context, runID uuid.UUID) ([]database.ChatDebugStep, error) {
run, err := q.db.GetChatDebugRunByID(ctx, runID)
if err != nil {
return nil, err
}
// Authorize via the owning chat.
chat, err := q.db.GetChatByID(ctx, run.ChatID)
if err != nil {
return nil, err
}
if err := q.authorizeContext(ctx, policy.ActionRead, chat); err != nil {
return nil, err
}
return q.db.GetChatDebugStepsByRunID(ctx, runID)
}
func (q *querier) GetChatDesktopEnabled(ctx context.Context) (bool, error) {
// The desktop-enabled flag is a deployment-wide setting read by any
// authenticated chat user and by chatd when deciding whether to expose
@@ -4186,17 +4103,6 @@ func (q *querier) GetUserChatCustomPrompt(ctx context.Context, userID uuid.UUID)
return q.db.GetUserChatCustomPrompt(ctx, userID)
}
func (q *querier) GetUserChatDebugLoggingEnabled(ctx context.Context, userID uuid.UUID) (bool, error) {
u, err := q.db.GetUserByID(ctx, userID)
if err != nil {
return false, err
}
if err := q.authorizeContext(ctx, policy.ActionReadPersonal, u); err != nil {
return false, err
}
return q.db.GetUserChatDebugLoggingEnabled(ctx, userID)
}
func (q *querier) GetUserChatProviderKeys(ctx context.Context, userID uuid.UUID) ([]database.UserChatProviderKey, error) {
u, err := q.db.GetUserByID(ctx, userID)
if err != nil {
@@ -4943,33 +4849,6 @@ func (q *querier) InsertChat(ctx context.Context, arg database.InsertChatParams)
return insert(q.log, q.auth, rbac.ResourceChat.WithOwner(arg.OwnerID.String()), q.db.InsertChat)(ctx, arg)
}
func (q *querier) InsertChatDebugRun(ctx context.Context, arg database.InsertChatDebugRunParams) (database.ChatDebugRun, error) {
chat, err := q.db.GetChatByID(ctx, arg.ChatID)
if err != nil {
return database.ChatDebugRun{}, err
}
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
return database.ChatDebugRun{}, err
}
return q.db.InsertChatDebugRun(ctx, arg)
}
// InsertChatDebugStep creates a new step in a debug run. The underlying
// SQL uses INSERT ... SELECT ... FROM chat_debug_runs to enforce that the
// run exists and belongs to the specified chat. If the run_id is invalid
// or the chat_id doesn't match, the INSERT produces 0 rows and SQLC
// returns sql.ErrNoRows.
func (q *querier) InsertChatDebugStep(ctx context.Context, arg database.InsertChatDebugStepParams) (database.ChatDebugStep, error) {
chat, err := q.db.GetChatByID(ctx, arg.ChatID)
if err != nil {
return database.ChatDebugStep{}, err
}
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
return database.ChatDebugStep{}, err
}
return q.db.InsertChatDebugStep(ctx, arg)
}
func (q *querier) InsertChatFile(ctx context.Context, arg database.InsertChatFileParams) (database.InsertChatFileRow, error) {
// Authorize create on chat resource scoped to the owner and org.
return insert(q.log, q.auth, rbac.ResourceChat.WithOwner(arg.OwnerID.String()).InOrg(arg.OrganizationID), q.db.InsertChatFile)(ctx, arg)
@@ -5968,28 +5847,6 @@ func (q *querier) UpdateChatByID(ctx context.Context, arg database.UpdateChatByI
return q.db.UpdateChatByID(ctx, arg)
}
func (q *querier) UpdateChatDebugRun(ctx context.Context, arg database.UpdateChatDebugRunParams) (database.ChatDebugRun, error) {
chat, err := q.db.GetChatByID(ctx, arg.ChatID)
if err != nil {
return database.ChatDebugRun{}, err
}
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
return database.ChatDebugRun{}, err
}
return q.db.UpdateChatDebugRun(ctx, arg)
}
func (q *querier) UpdateChatDebugStep(ctx context.Context, arg database.UpdateChatDebugStepParams) (database.ChatDebugStep, error) {
chat, err := q.db.GetChatByID(ctx, arg.ChatID)
if err != nil {
return database.ChatDebugStep{}, err
}
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
return database.ChatDebugStep{}, err
}
return q.db.UpdateChatDebugStep(ctx, arg)
}
func (q *querier) UpdateChatHeartbeats(ctx context.Context, arg database.UpdateChatHeartbeatsParams) ([]uuid.UUID, error) {
// The batch heartbeat is a system-level operation filtered by
// worker_id. Authorization is enforced by the AsChatd context
@@ -7222,13 +7079,6 @@ func (q *querier) UpsertBoundaryUsageStats(ctx context.Context, arg database.Ups
return q.db.UpsertBoundaryUsageStats(ctx, arg)
}
func (q *querier) UpsertChatDebugLoggingAllowUsers(ctx context.Context, allowUsers bool) error {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return err
}
return q.db.UpsertChatDebugLoggingAllowUsers(ctx, allowUsers)
}
func (q *querier) UpsertChatDesktopEnabled(ctx context.Context, enableDesktop bool) error {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return err
@@ -7459,17 +7309,6 @@ func (q *querier) UpsertTemplateUsageStats(ctx context.Context) error {
return q.db.UpsertTemplateUsageStats(ctx)
}
func (q *querier) UpsertUserChatDebugLoggingEnabled(ctx context.Context, arg database.UpsertUserChatDebugLoggingEnabledParams) error {
u, err := q.db.GetUserByID(ctx, arg.UserID)
if err != nil {
return err
}
if err := q.authorizeContext(ctx, policy.ActionUpdatePersonal, u); err != nil {
return err
}
return q.db.UpsertUserChatDebugLoggingEnabled(ctx, arg)
}
func (q *querier) UpsertUserChatProviderKey(ctx context.Context, arg database.UpsertUserChatProviderKeyParams) (database.UserChatProviderKey, error) {
u, err := q.db.GetUserByID(ctx, arg.UserID)
if err != nil {
-96
View File
@@ -461,89 +461,6 @@ func (s *MethodTestSuite) TestChats() {
dbm.EXPECT().DeleteChatQueuedMessage(gomock.Any(), args).Return(nil).AnyTimes()
check.Args(args).Asserts(chat, policy.ActionUpdate).Returns()
}))
s.Run("DeleteChatDebugDataAfterMessageID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
arg := database.DeleteChatDebugDataAfterMessageIDParams{ChatID: chat.ID, MessageID: 123}
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
dbm.EXPECT().DeleteChatDebugDataAfterMessageID(gomock.Any(), arg).Return(int64(1), nil).AnyTimes()
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(int64(1))
}))
s.Run("DeleteChatDebugDataByChatID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
dbm.EXPECT().DeleteChatDebugDataByChatID(gomock.Any(), chat.ID).Return(int64(1), nil).AnyTimes()
check.Args(chat.ID).Asserts(chat, policy.ActionUpdate).Returns(int64(1))
}))
s.Run("FinalizeStaleChatDebugRows", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
updatedBefore := dbtime.Now()
row := database.FinalizeStaleChatDebugRowsRow{RunsFinalized: 1, StepsFinalized: 2}
dbm.EXPECT().FinalizeStaleChatDebugRows(gomock.Any(), updatedBefore).Return(row, nil).AnyTimes()
check.Args(updatedBefore).Asserts(rbac.ResourceChat, policy.ActionUpdate).Returns(row)
}))
s.Run("GetChatDebugLoggingAllowUsers", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
dbm.EXPECT().GetChatDebugLoggingAllowUsers(gomock.Any()).Return(true, nil).AnyTimes()
check.Args().Asserts().Returns(true)
}))
s.Run("GetChatDebugRunByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
run := database.ChatDebugRun{ID: uuid.New(), ChatID: chat.ID}
dbm.EXPECT().GetChatDebugRunByID(gomock.Any(), run.ID).Return(run, nil).AnyTimes()
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
check.Args(run.ID).Asserts(chat, policy.ActionRead).Returns(run)
}))
s.Run("GetChatDebugRunsByChatID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
runs := []database.ChatDebugRun{{ID: uuid.New(), ChatID: chat.ID}}
arg := database.GetChatDebugRunsByChatIDParams{ChatID: chat.ID, LimitVal: 100}
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
dbm.EXPECT().GetChatDebugRunsByChatID(gomock.Any(), arg).Return(runs, nil).AnyTimes()
check.Args(arg).Asserts(chat, policy.ActionRead).Returns(runs)
}))
s.Run("GetChatDebugStepsByRunID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
run := database.ChatDebugRun{ID: uuid.New(), ChatID: chat.ID}
steps := []database.ChatDebugStep{{ID: uuid.New(), RunID: run.ID, ChatID: chat.ID}}
dbm.EXPECT().GetChatDebugRunByID(gomock.Any(), run.ID).Return(run, nil).AnyTimes()
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
dbm.EXPECT().GetChatDebugStepsByRunID(gomock.Any(), run.ID).Return(steps, nil).AnyTimes()
check.Args(run.ID).Asserts(chat, policy.ActionRead).Returns(steps)
}))
s.Run("InsertChatDebugRun", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
arg := database.InsertChatDebugRunParams{ChatID: chat.ID, Kind: "chat_turn", Status: "in_progress"}
run := database.ChatDebugRun{ID: uuid.New(), ChatID: chat.ID}
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
dbm.EXPECT().InsertChatDebugRun(gomock.Any(), arg).Return(run, nil).AnyTimes()
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(run)
}))
s.Run("InsertChatDebugStep", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
arg := database.InsertChatDebugStepParams{RunID: uuid.New(), ChatID: chat.ID, StepNumber: 1, Operation: "stream", Status: "in_progress"}
step := database.ChatDebugStep{ID: uuid.New(), RunID: arg.RunID, ChatID: chat.ID}
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
dbm.EXPECT().InsertChatDebugStep(gomock.Any(), arg).Return(step, nil).AnyTimes()
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(step)
}))
s.Run("UpdateChatDebugRun", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
arg := database.UpdateChatDebugRunParams{ID: uuid.New(), ChatID: chat.ID}
run := database.ChatDebugRun{ID: arg.ID, ChatID: chat.ID}
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
dbm.EXPECT().UpdateChatDebugRun(gomock.Any(), arg).Return(run, nil).AnyTimes()
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(run)
}))
s.Run("UpdateChatDebugStep", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
arg := database.UpdateChatDebugStepParams{ID: uuid.New(), ChatID: chat.ID}
step := database.ChatDebugStep{ID: arg.ID, ChatID: chat.ID}
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
dbm.EXPECT().UpdateChatDebugStep(gomock.Any(), arg).Return(step, nil).AnyTimes()
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(step)
}))
s.Run("UpsertChatDebugLoggingAllowUsers", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
dbm.EXPECT().UpsertChatDebugLoggingAllowUsers(gomock.Any(), true).Return(nil).AnyTimes()
check.Args(true).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
}))
s.Run("GetChatByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
@@ -2577,19 +2494,6 @@ func (s *MethodTestSuite) TestUser() {
dbm.EXPECT().UpsertUserChatProviderKey(gomock.Any(), arg).Return(key, nil).AnyTimes()
check.Args(arg).Asserts(u, policy.ActionUpdatePersonal).Returns(key)
}))
s.Run("GetUserChatDebugLoggingEnabled", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
u := testutil.Fake(s.T(), faker, database.User{})
dbm.EXPECT().GetUserByID(gomock.Any(), u.ID).Return(u, nil).AnyTimes()
dbm.EXPECT().GetUserChatDebugLoggingEnabled(gomock.Any(), u.ID).Return(true, nil).AnyTimes()
check.Args(u.ID).Asserts(u, policy.ActionReadPersonal).Returns(true)
}))
s.Run("UpsertUserChatDebugLoggingEnabled", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
u := testutil.Fake(s.T(), faker, database.User{})
arg := database.UpsertUserChatDebugLoggingEnabledParams{UserID: u.ID, DebugLoggingEnabled: true}
dbm.EXPECT().GetUserByID(gomock.Any(), u.ID).Return(u, nil).AnyTimes()
dbm.EXPECT().UpsertUserChatDebugLoggingEnabled(gomock.Any(), arg).Return(nil).AnyTimes()
check.Args(arg).Asserts(u, policy.ActionUpdatePersonal)
}))
s.Run("UpdateUserChatCustomPrompt", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
u := testutil.Fake(s.T(), faker, database.User{})
uc := database.UserConfig{UserID: u.ID, Key: "chat_custom_prompt", Value: "my custom prompt"}
-112
View File
@@ -416,22 +416,6 @@ func (m queryMetricsStore) DeleteApplicationConnectAPIKeysByUserID(ctx context.C
return r0
}
func (m queryMetricsStore) DeleteChatDebugDataAfterMessageID(ctx context.Context, arg database.DeleteChatDebugDataAfterMessageIDParams) (int64, error) {
start := time.Now()
r0, r1 := m.s.DeleteChatDebugDataAfterMessageID(ctx, arg)
m.queryLatencies.WithLabelValues("DeleteChatDebugDataAfterMessageID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteChatDebugDataAfterMessageID").Inc()
return r0, r1
}
func (m queryMetricsStore) DeleteChatDebugDataByChatID(ctx context.Context, chatID uuid.UUID) (int64, error) {
start := time.Now()
r0, r1 := m.s.DeleteChatDebugDataByChatID(ctx, chatID)
m.queryLatencies.WithLabelValues("DeleteChatDebugDataByChatID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteChatDebugDataByChatID").Inc()
return r0, r1
}
func (m queryMetricsStore) DeleteChatModelConfigByID(ctx context.Context, id uuid.UUID) error {
start := time.Now()
r0 := m.s.DeleteChatModelConfigByID(ctx, id)
@@ -888,14 +872,6 @@ func (m queryMetricsStore) FetchVolumesResourceMonitorsUpdatedAfter(ctx context.
return r0, r1
}
func (m queryMetricsStore) FinalizeStaleChatDebugRows(ctx context.Context, updatedBefore time.Time) (database.FinalizeStaleChatDebugRowsRow, error) {
start := time.Now()
r0, r1 := m.s.FinalizeStaleChatDebugRows(ctx, updatedBefore)
m.queryLatencies.WithLabelValues("FinalizeStaleChatDebugRows").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "FinalizeStaleChatDebugRows").Inc()
return r0, r1
}
func (m queryMetricsStore) FindMatchingPresetID(ctx context.Context, arg database.FindMatchingPresetIDParams) (uuid.UUID, error) {
start := time.Now()
r0, r1 := m.s.FindMatchingPresetID(ctx, arg)
@@ -1152,38 +1128,6 @@ func (m queryMetricsStore) GetChatCostSummary(ctx context.Context, arg database.
return r0, r1
}
func (m queryMetricsStore) GetChatDebugLoggingAllowUsers(ctx context.Context) (bool, error) {
start := time.Now()
r0, r1 := m.s.GetChatDebugLoggingAllowUsers(ctx)
m.queryLatencies.WithLabelValues("GetChatDebugLoggingAllowUsers").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatDebugLoggingAllowUsers").Inc()
return r0, r1
}
func (m queryMetricsStore) GetChatDebugRunByID(ctx context.Context, id uuid.UUID) (database.ChatDebugRun, error) {
start := time.Now()
r0, r1 := m.s.GetChatDebugRunByID(ctx, id)
m.queryLatencies.WithLabelValues("GetChatDebugRunByID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatDebugRunByID").Inc()
return r0, r1
}
func (m queryMetricsStore) GetChatDebugRunsByChatID(ctx context.Context, chatID database.GetChatDebugRunsByChatIDParams) ([]database.ChatDebugRun, error) {
start := time.Now()
r0, r1 := m.s.GetChatDebugRunsByChatID(ctx, chatID)
m.queryLatencies.WithLabelValues("GetChatDebugRunsByChatID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatDebugRunsByChatID").Inc()
return r0, r1
}
func (m queryMetricsStore) GetChatDebugStepsByRunID(ctx context.Context, runID uuid.UUID) ([]database.ChatDebugStep, error) {
start := time.Now()
r0, r1 := m.s.GetChatDebugStepsByRunID(ctx, runID)
m.queryLatencies.WithLabelValues("GetChatDebugStepsByRunID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatDebugStepsByRunID").Inc()
return r0, r1
}
func (m queryMetricsStore) GetChatDesktopEnabled(ctx context.Context) (bool, error) {
start := time.Now()
r0, r1 := m.s.GetChatDesktopEnabled(ctx)
@@ -2672,14 +2616,6 @@ func (m queryMetricsStore) GetUserChatCustomPrompt(ctx context.Context, userID u
return r0, r1
}
func (m queryMetricsStore) GetUserChatDebugLoggingEnabled(ctx context.Context, userID uuid.UUID) (bool, error) {
start := time.Now()
r0, r1 := m.s.GetUserChatDebugLoggingEnabled(ctx, userID)
m.queryLatencies.WithLabelValues("GetUserChatDebugLoggingEnabled").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetUserChatDebugLoggingEnabled").Inc()
return r0, r1
}
func (m queryMetricsStore) GetUserChatProviderKeys(ctx context.Context, userID uuid.UUID) ([]database.UserChatProviderKey, error) {
start := time.Now()
r0, r1 := m.s.GetUserChatProviderKeys(ctx, userID)
@@ -3376,22 +3312,6 @@ func (m queryMetricsStore) InsertChat(ctx context.Context, arg database.InsertCh
return r0, r1
}
func (m queryMetricsStore) InsertChatDebugRun(ctx context.Context, arg database.InsertChatDebugRunParams) (database.ChatDebugRun, error) {
start := time.Now()
r0, r1 := m.s.InsertChatDebugRun(ctx, arg)
m.queryLatencies.WithLabelValues("InsertChatDebugRun").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "InsertChatDebugRun").Inc()
return r0, r1
}
func (m queryMetricsStore) InsertChatDebugStep(ctx context.Context, arg database.InsertChatDebugStepParams) (database.ChatDebugStep, error) {
start := time.Now()
r0, r1 := m.s.InsertChatDebugStep(ctx, arg)
m.queryLatencies.WithLabelValues("InsertChatDebugStep").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "InsertChatDebugStep").Inc()
return r0, r1
}
func (m queryMetricsStore) InsertChatFile(ctx context.Context, arg database.InsertChatFileParams) (database.InsertChatFileRow, error) {
start := time.Now()
r0, r1 := m.s.InsertChatFile(ctx, arg)
@@ -4288,22 +4208,6 @@ func (m queryMetricsStore) UpdateChatByID(ctx context.Context, arg database.Upda
return r0, r1
}
func (m queryMetricsStore) UpdateChatDebugRun(ctx context.Context, arg database.UpdateChatDebugRunParams) (database.ChatDebugRun, error) {
start := time.Now()
r0, r1 := m.s.UpdateChatDebugRun(ctx, arg)
m.queryLatencies.WithLabelValues("UpdateChatDebugRun").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatDebugRun").Inc()
return r0, r1
}
func (m queryMetricsStore) UpdateChatDebugStep(ctx context.Context, arg database.UpdateChatDebugStepParams) (database.ChatDebugStep, error) {
start := time.Now()
r0, r1 := m.s.UpdateChatDebugStep(ctx, arg)
m.queryLatencies.WithLabelValues("UpdateChatDebugStep").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatDebugStep").Inc()
return r0, r1
}
func (m queryMetricsStore) UpdateChatHeartbeats(ctx context.Context, arg database.UpdateChatHeartbeatsParams) ([]uuid.UUID, error) {
start := time.Now()
r0, r1 := m.s.UpdateChatHeartbeats(ctx, arg)
@@ -5144,14 +5048,6 @@ func (m queryMetricsStore) UpsertBoundaryUsageStats(ctx context.Context, arg dat
return r0, r1
}
func (m queryMetricsStore) UpsertChatDebugLoggingAllowUsers(ctx context.Context, allowUsers bool) error {
start := time.Now()
r0 := m.s.UpsertChatDebugLoggingAllowUsers(ctx, allowUsers)
m.queryLatencies.WithLabelValues("UpsertChatDebugLoggingAllowUsers").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertChatDebugLoggingAllowUsers").Inc()
return r0
}
func (m queryMetricsStore) UpsertChatDesktopEnabled(ctx context.Context, enableDesktop bool) error {
start := time.Now()
r0 := m.s.UpsertChatDesktopEnabled(ctx, enableDesktop)
@@ -5384,14 +5280,6 @@ func (m queryMetricsStore) UpsertTemplateUsageStats(ctx context.Context) error {
return r0
}
func (m queryMetricsStore) UpsertUserChatDebugLoggingEnabled(ctx context.Context, arg database.UpsertUserChatDebugLoggingEnabledParams) error {
start := time.Now()
r0 := m.s.UpsertUserChatDebugLoggingEnabled(ctx, arg)
m.queryLatencies.WithLabelValues("UpsertUserChatDebugLoggingEnabled").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertUserChatDebugLoggingEnabled").Inc()
return r0
}
func (m queryMetricsStore) UpsertUserChatProviderKey(ctx context.Context, arg database.UpsertUserChatProviderKeyParams) (database.UserChatProviderKey, error) {
start := time.Now()
r0, r1 := m.s.UpsertUserChatProviderKey(ctx, arg)
-208
View File
@@ -671,36 +671,6 @@ func (mr *MockStoreMockRecorder) DeleteApplicationConnectAPIKeysByUserID(ctx, us
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteApplicationConnectAPIKeysByUserID", reflect.TypeOf((*MockStore)(nil).DeleteApplicationConnectAPIKeysByUserID), ctx, userID)
}
// DeleteChatDebugDataAfterMessageID mocks base method.
func (m *MockStore) DeleteChatDebugDataAfterMessageID(ctx context.Context, arg database.DeleteChatDebugDataAfterMessageIDParams) (int64, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteChatDebugDataAfterMessageID", ctx, arg)
ret0, _ := ret[0].(int64)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// DeleteChatDebugDataAfterMessageID indicates an expected call of DeleteChatDebugDataAfterMessageID.
func (mr *MockStoreMockRecorder) DeleteChatDebugDataAfterMessageID(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteChatDebugDataAfterMessageID", reflect.TypeOf((*MockStore)(nil).DeleteChatDebugDataAfterMessageID), ctx, arg)
}
// DeleteChatDebugDataByChatID mocks base method.
func (m *MockStore) DeleteChatDebugDataByChatID(ctx context.Context, chatID uuid.UUID) (int64, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteChatDebugDataByChatID", ctx, chatID)
ret0, _ := ret[0].(int64)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// DeleteChatDebugDataByChatID indicates an expected call of DeleteChatDebugDataByChatID.
func (mr *MockStoreMockRecorder) DeleteChatDebugDataByChatID(ctx, chatID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteChatDebugDataByChatID", reflect.TypeOf((*MockStore)(nil).DeleteChatDebugDataByChatID), ctx, chatID)
}
// DeleteChatModelConfigByID mocks base method.
func (m *MockStore) DeleteChatModelConfigByID(ctx context.Context, id uuid.UUID) error {
m.ctrl.T.Helper()
@@ -1517,21 +1487,6 @@ func (mr *MockStoreMockRecorder) FetchVolumesResourceMonitorsUpdatedAfter(ctx, u
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FetchVolumesResourceMonitorsUpdatedAfter", reflect.TypeOf((*MockStore)(nil).FetchVolumesResourceMonitorsUpdatedAfter), ctx, updatedAt)
}
// FinalizeStaleChatDebugRows mocks base method.
func (m *MockStore) FinalizeStaleChatDebugRows(ctx context.Context, updatedBefore time.Time) (database.FinalizeStaleChatDebugRowsRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "FinalizeStaleChatDebugRows", ctx, updatedBefore)
ret0, _ := ret[0].(database.FinalizeStaleChatDebugRowsRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// FinalizeStaleChatDebugRows indicates an expected call of FinalizeStaleChatDebugRows.
func (mr *MockStoreMockRecorder) FinalizeStaleChatDebugRows(ctx, updatedBefore any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FinalizeStaleChatDebugRows", reflect.TypeOf((*MockStore)(nil).FinalizeStaleChatDebugRows), ctx, updatedBefore)
}
// FindMatchingPresetID mocks base method.
func (m *MockStore) FindMatchingPresetID(ctx context.Context, arg database.FindMatchingPresetIDParams) (uuid.UUID, error) {
m.ctrl.T.Helper()
@@ -2117,66 +2072,6 @@ func (mr *MockStoreMockRecorder) GetChatCostSummary(ctx, arg any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatCostSummary", reflect.TypeOf((*MockStore)(nil).GetChatCostSummary), ctx, arg)
}
// GetChatDebugLoggingAllowUsers mocks base method.
func (m *MockStore) GetChatDebugLoggingAllowUsers(ctx context.Context) (bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetChatDebugLoggingAllowUsers", ctx)
ret0, _ := ret[0].(bool)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetChatDebugLoggingAllowUsers indicates an expected call of GetChatDebugLoggingAllowUsers.
func (mr *MockStoreMockRecorder) GetChatDebugLoggingAllowUsers(ctx any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatDebugLoggingAllowUsers", reflect.TypeOf((*MockStore)(nil).GetChatDebugLoggingAllowUsers), ctx)
}
// GetChatDebugRunByID mocks base method.
func (m *MockStore) GetChatDebugRunByID(ctx context.Context, id uuid.UUID) (database.ChatDebugRun, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetChatDebugRunByID", ctx, id)
ret0, _ := ret[0].(database.ChatDebugRun)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetChatDebugRunByID indicates an expected call of GetChatDebugRunByID.
func (mr *MockStoreMockRecorder) GetChatDebugRunByID(ctx, id any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatDebugRunByID", reflect.TypeOf((*MockStore)(nil).GetChatDebugRunByID), ctx, id)
}
// GetChatDebugRunsByChatID mocks base method.
func (m *MockStore) GetChatDebugRunsByChatID(ctx context.Context, arg database.GetChatDebugRunsByChatIDParams) ([]database.ChatDebugRun, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetChatDebugRunsByChatID", ctx, arg)
ret0, _ := ret[0].([]database.ChatDebugRun)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetChatDebugRunsByChatID indicates an expected call of GetChatDebugRunsByChatID.
func (mr *MockStoreMockRecorder) GetChatDebugRunsByChatID(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatDebugRunsByChatID", reflect.TypeOf((*MockStore)(nil).GetChatDebugRunsByChatID), ctx, arg)
}
// GetChatDebugStepsByRunID mocks base method.
func (m *MockStore) GetChatDebugStepsByRunID(ctx context.Context, runID uuid.UUID) ([]database.ChatDebugStep, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetChatDebugStepsByRunID", ctx, runID)
ret0, _ := ret[0].([]database.ChatDebugStep)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetChatDebugStepsByRunID indicates an expected call of GetChatDebugStepsByRunID.
func (mr *MockStoreMockRecorder) GetChatDebugStepsByRunID(ctx, runID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatDebugStepsByRunID", reflect.TypeOf((*MockStore)(nil).GetChatDebugStepsByRunID), ctx, runID)
}
// GetChatDesktopEnabled mocks base method.
func (m *MockStore) GetChatDesktopEnabled(ctx context.Context) (bool, error) {
m.ctrl.T.Helper()
@@ -4997,21 +4892,6 @@ func (mr *MockStoreMockRecorder) GetUserChatCustomPrompt(ctx, userID any) *gomoc
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserChatCustomPrompt", reflect.TypeOf((*MockStore)(nil).GetUserChatCustomPrompt), ctx, userID)
}
// GetUserChatDebugLoggingEnabled mocks base method.
func (m *MockStore) GetUserChatDebugLoggingEnabled(ctx context.Context, userID uuid.UUID) (bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetUserChatDebugLoggingEnabled", ctx, userID)
ret0, _ := ret[0].(bool)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetUserChatDebugLoggingEnabled indicates an expected call of GetUserChatDebugLoggingEnabled.
func (mr *MockStoreMockRecorder) GetUserChatDebugLoggingEnabled(ctx, userID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserChatDebugLoggingEnabled", reflect.TypeOf((*MockStore)(nil).GetUserChatDebugLoggingEnabled), ctx, userID)
}
// GetUserChatProviderKeys mocks base method.
func (m *MockStore) GetUserChatProviderKeys(ctx context.Context, userID uuid.UUID) ([]database.UserChatProviderKey, error) {
m.ctrl.T.Helper()
@@ -6331,36 +6211,6 @@ func (mr *MockStoreMockRecorder) InsertChat(ctx, arg any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertChat", reflect.TypeOf((*MockStore)(nil).InsertChat), ctx, arg)
}
// InsertChatDebugRun mocks base method.
func (m *MockStore) InsertChatDebugRun(ctx context.Context, arg database.InsertChatDebugRunParams) (database.ChatDebugRun, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "InsertChatDebugRun", ctx, arg)
ret0, _ := ret[0].(database.ChatDebugRun)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// InsertChatDebugRun indicates an expected call of InsertChatDebugRun.
func (mr *MockStoreMockRecorder) InsertChatDebugRun(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertChatDebugRun", reflect.TypeOf((*MockStore)(nil).InsertChatDebugRun), ctx, arg)
}
// InsertChatDebugStep mocks base method.
func (m *MockStore) InsertChatDebugStep(ctx context.Context, arg database.InsertChatDebugStepParams) (database.ChatDebugStep, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "InsertChatDebugStep", ctx, arg)
ret0, _ := ret[0].(database.ChatDebugStep)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// InsertChatDebugStep indicates an expected call of InsertChatDebugStep.
func (mr *MockStoreMockRecorder) InsertChatDebugStep(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertChatDebugStep", reflect.TypeOf((*MockStore)(nil).InsertChatDebugStep), ctx, arg)
}
// InsertChatFile mocks base method.
func (m *MockStore) InsertChatFile(ctx context.Context, arg database.InsertChatFileParams) (database.InsertChatFileRow, error) {
m.ctrl.T.Helper()
@@ -8119,36 +7969,6 @@ func (mr *MockStoreMockRecorder) UpdateChatByID(ctx, arg any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatByID", reflect.TypeOf((*MockStore)(nil).UpdateChatByID), ctx, arg)
}
// UpdateChatDebugRun mocks base method.
func (m *MockStore) UpdateChatDebugRun(ctx context.Context, arg database.UpdateChatDebugRunParams) (database.ChatDebugRun, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateChatDebugRun", ctx, arg)
ret0, _ := ret[0].(database.ChatDebugRun)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpdateChatDebugRun indicates an expected call of UpdateChatDebugRun.
func (mr *MockStoreMockRecorder) UpdateChatDebugRun(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatDebugRun", reflect.TypeOf((*MockStore)(nil).UpdateChatDebugRun), ctx, arg)
}
// UpdateChatDebugStep mocks base method.
func (m *MockStore) UpdateChatDebugStep(ctx context.Context, arg database.UpdateChatDebugStepParams) (database.ChatDebugStep, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateChatDebugStep", ctx, arg)
ret0, _ := ret[0].(database.ChatDebugStep)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpdateChatDebugStep indicates an expected call of UpdateChatDebugStep.
func (mr *MockStoreMockRecorder) UpdateChatDebugStep(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatDebugStep", reflect.TypeOf((*MockStore)(nil).UpdateChatDebugStep), ctx, arg)
}
// UpdateChatHeartbeats mocks base method.
func (m *MockStore) UpdateChatHeartbeats(ctx context.Context, arg database.UpdateChatHeartbeatsParams) ([]uuid.UUID, error) {
m.ctrl.T.Helper()
@@ -9669,20 +9489,6 @@ func (mr *MockStoreMockRecorder) UpsertBoundaryUsageStats(ctx, arg any) *gomock.
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertBoundaryUsageStats", reflect.TypeOf((*MockStore)(nil).UpsertBoundaryUsageStats), ctx, arg)
}
// UpsertChatDebugLoggingAllowUsers mocks base method.
func (m *MockStore) UpsertChatDebugLoggingAllowUsers(ctx context.Context, allowUsers bool) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpsertChatDebugLoggingAllowUsers", ctx, allowUsers)
ret0, _ := ret[0].(error)
return ret0
}
// UpsertChatDebugLoggingAllowUsers indicates an expected call of UpsertChatDebugLoggingAllowUsers.
func (mr *MockStoreMockRecorder) UpsertChatDebugLoggingAllowUsers(ctx, allowUsers any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatDebugLoggingAllowUsers", reflect.TypeOf((*MockStore)(nil).UpsertChatDebugLoggingAllowUsers), ctx, allowUsers)
}
// UpsertChatDesktopEnabled mocks base method.
func (m *MockStore) UpsertChatDesktopEnabled(ctx context.Context, enableDesktop bool) error {
m.ctrl.T.Helper()
@@ -10100,20 +9906,6 @@ func (mr *MockStoreMockRecorder) UpsertTemplateUsageStats(ctx any) *gomock.Call
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertTemplateUsageStats", reflect.TypeOf((*MockStore)(nil).UpsertTemplateUsageStats), ctx)
}
// UpsertUserChatDebugLoggingEnabled mocks base method.
func (m *MockStore) UpsertUserChatDebugLoggingEnabled(ctx context.Context, arg database.UpsertUserChatDebugLoggingEnabledParams) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpsertUserChatDebugLoggingEnabled", ctx, arg)
ret0, _ := ret[0].(error)
return ret0
}
// UpsertUserChatDebugLoggingEnabled indicates an expected call of UpsertUserChatDebugLoggingEnabled.
func (mr *MockStoreMockRecorder) UpsertUserChatDebugLoggingEnabled(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertUserChatDebugLoggingEnabled", reflect.TypeOf((*MockStore)(nil).UpsertUserChatDebugLoggingEnabled), ctx, arg)
}
// UpsertUserChatProviderKey mocks base method.
func (m *MockStore) UpsertUserChatProviderKey(ctx context.Context, arg database.UpsertUserChatProviderKeyParams) (database.UserChatProviderKey, error) {
m.ctrl.T.Helper()
-67
View File
@@ -1255,44 +1255,6 @@ COMMENT ON COLUMN boundary_usage_stats.window_start IS 'Start of the time window
COMMENT ON COLUMN boundary_usage_stats.updated_at IS 'Timestamp of the last update to this row.';
CREATE TABLE chat_debug_runs (
id uuid DEFAULT gen_random_uuid() NOT NULL,
chat_id uuid NOT NULL,
root_chat_id uuid,
parent_chat_id uuid,
model_config_id uuid,
trigger_message_id bigint,
history_tip_message_id bigint,
kind text NOT NULL,
status text NOT NULL,
provider text,
model text,
summary jsonb DEFAULT '{}'::jsonb NOT NULL,
started_at timestamp with time zone DEFAULT now() NOT NULL,
updated_at timestamp with time zone DEFAULT now() NOT NULL,
finished_at timestamp with time zone
);
CREATE TABLE chat_debug_steps (
id uuid DEFAULT gen_random_uuid() NOT NULL,
run_id uuid NOT NULL,
chat_id uuid NOT NULL,
step_number integer NOT NULL,
operation text NOT NULL,
status text NOT NULL,
history_tip_message_id bigint,
assistant_message_id bigint,
normalized_request jsonb NOT NULL,
normalized_response jsonb,
usage jsonb,
attempts jsonb DEFAULT '[]'::jsonb NOT NULL,
error jsonb,
metadata jsonb DEFAULT '{}'::jsonb NOT NULL,
started_at timestamp with time zone DEFAULT now() NOT NULL,
updated_at timestamp with time zone DEFAULT now() NOT NULL,
finished_at timestamp with time zone
);
CREATE TABLE chat_diff_statuses (
chat_id uuid NOT NULL,
url text,
@@ -3397,12 +3359,6 @@ ALTER TABLE ONLY audit_logs
ALTER TABLE ONLY boundary_usage_stats
ADD CONSTRAINT boundary_usage_stats_pkey PRIMARY KEY (replica_id);
ALTER TABLE ONLY chat_debug_runs
ADD CONSTRAINT chat_debug_runs_pkey PRIMARY KEY (id);
ALTER TABLE ONLY chat_debug_steps
ADD CONSTRAINT chat_debug_steps_pkey PRIMARY KEY (id);
ALTER TABLE ONLY chat_diff_statuses
ADD CONSTRAINT chat_diff_statuses_pkey PRIMARY KEY (chat_id);
@@ -3797,20 +3753,6 @@ CREATE INDEX idx_audit_log_user_id ON audit_logs USING btree (user_id);
CREATE INDEX idx_audit_logs_time_desc ON audit_logs USING btree ("time" DESC);
CREATE INDEX idx_chat_debug_runs_chat_started ON chat_debug_runs USING btree (chat_id, started_at DESC);
CREATE UNIQUE INDEX idx_chat_debug_runs_id_chat ON chat_debug_runs USING btree (id, chat_id);
CREATE INDEX idx_chat_debug_runs_stale ON chat_debug_runs USING btree (updated_at) WHERE (finished_at IS NULL);
CREATE INDEX idx_chat_debug_steps_chat_assistant_msg ON chat_debug_steps USING btree (chat_id, assistant_message_id) WHERE (assistant_message_id IS NOT NULL);
CREATE INDEX idx_chat_debug_steps_chat_tip ON chat_debug_steps USING btree (chat_id, history_tip_message_id);
CREATE UNIQUE INDEX idx_chat_debug_steps_run_step ON chat_debug_steps USING btree (run_id, step_number);
CREATE INDEX idx_chat_debug_steps_stale ON chat_debug_steps USING btree (updated_at) WHERE (finished_at IS NULL);
CREATE INDEX idx_chat_diff_statuses_stale_at ON chat_diff_statuses USING btree (stale_at);
CREATE INDEX idx_chat_file_links_chat_id ON chat_file_links USING btree (chat_id);
@@ -4114,12 +4056,6 @@ ALTER TABLE ONLY aibridge_interceptions
ALTER TABLE ONLY api_keys
ADD CONSTRAINT api_keys_user_id_uuid_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
ALTER TABLE ONLY chat_debug_runs
ADD CONSTRAINT chat_debug_runs_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
ALTER TABLE ONLY chat_debug_steps
ADD CONSTRAINT chat_debug_steps_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
ALTER TABLE ONLY chat_diff_statuses
ADD CONSTRAINT chat_diff_statuses_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
@@ -4192,9 +4128,6 @@ ALTER TABLE ONLY connection_logs
ALTER TABLE ONLY crypto_keys
ADD CONSTRAINT crypto_keys_secret_key_id_fkey FOREIGN KEY (secret_key_id) REFERENCES dbcrypt_keys(active_key_digest);
ALTER TABLE ONLY chat_debug_steps
ADD CONSTRAINT fk_chat_debug_steps_run_chat FOREIGN KEY (run_id, chat_id) REFERENCES chat_debug_runs(id, chat_id) ON DELETE CASCADE;
ALTER TABLE ONLY oauth2_provider_app_tokens
ADD CONSTRAINT fk_oauth2_provider_app_tokens_user_id FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
@@ -9,8 +9,6 @@ const (
ForeignKeyAiSeatStateUserID ForeignKeyConstraint = "ai_seat_state_user_id_fkey" // ALTER TABLE ONLY ai_seat_state ADD CONSTRAINT ai_seat_state_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
ForeignKeyAibridgeInterceptionsInitiatorID ForeignKeyConstraint = "aibridge_interceptions_initiator_id_fkey" // ALTER TABLE ONLY aibridge_interceptions ADD CONSTRAINT aibridge_interceptions_initiator_id_fkey FOREIGN KEY (initiator_id) REFERENCES users(id);
ForeignKeyAPIKeysUserIDUUID ForeignKeyConstraint = "api_keys_user_id_uuid_fkey" // ALTER TABLE ONLY api_keys ADD CONSTRAINT api_keys_user_id_uuid_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
ForeignKeyChatDebugRunsChatID ForeignKeyConstraint = "chat_debug_runs_chat_id_fkey" // ALTER TABLE ONLY chat_debug_runs ADD CONSTRAINT chat_debug_runs_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
ForeignKeyChatDebugStepsChatID ForeignKeyConstraint = "chat_debug_steps_chat_id_fkey" // ALTER TABLE ONLY chat_debug_steps ADD CONSTRAINT chat_debug_steps_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
ForeignKeyChatDiffStatusesChatID ForeignKeyConstraint = "chat_diff_statuses_chat_id_fkey" // ALTER TABLE ONLY chat_diff_statuses ADD CONSTRAINT chat_diff_statuses_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
ForeignKeyChatFileLinksChatID ForeignKeyConstraint = "chat_file_links_chat_id_fkey" // ALTER TABLE ONLY chat_file_links ADD CONSTRAINT chat_file_links_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
ForeignKeyChatFileLinksFileID ForeignKeyConstraint = "chat_file_links_file_id_fkey" // ALTER TABLE ONLY chat_file_links ADD CONSTRAINT chat_file_links_file_id_fkey FOREIGN KEY (file_id) REFERENCES chat_files(id) ON DELETE CASCADE;
@@ -35,7 +33,6 @@ const (
ForeignKeyConnectionLogsWorkspaceID ForeignKeyConstraint = "connection_logs_workspace_id_fkey" // ALTER TABLE ONLY connection_logs ADD CONSTRAINT connection_logs_workspace_id_fkey FOREIGN KEY (workspace_id) REFERENCES workspaces(id) ON DELETE CASCADE;
ForeignKeyConnectionLogsWorkspaceOwnerID ForeignKeyConstraint = "connection_logs_workspace_owner_id_fkey" // ALTER TABLE ONLY connection_logs ADD CONSTRAINT connection_logs_workspace_owner_id_fkey FOREIGN KEY (workspace_owner_id) REFERENCES users(id) ON DELETE CASCADE;
ForeignKeyCryptoKeysSecretKeyID ForeignKeyConstraint = "crypto_keys_secret_key_id_fkey" // ALTER TABLE ONLY crypto_keys ADD CONSTRAINT crypto_keys_secret_key_id_fkey FOREIGN KEY (secret_key_id) REFERENCES dbcrypt_keys(active_key_digest);
ForeignKeyFkChatDebugStepsRunChat ForeignKeyConstraint = "fk_chat_debug_steps_run_chat" // ALTER TABLE ONLY chat_debug_steps ADD CONSTRAINT fk_chat_debug_steps_run_chat FOREIGN KEY (run_id, chat_id) REFERENCES chat_debug_runs(id, chat_id) ON DELETE CASCADE;
ForeignKeyFkOauth2ProviderAppTokensUserID ForeignKeyConstraint = "fk_oauth2_provider_app_tokens_user_id" // ALTER TABLE ONLY oauth2_provider_app_tokens ADD CONSTRAINT fk_oauth2_provider_app_tokens_user_id FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
ForeignKeyGitAuthLinksOauthAccessTokenKeyID ForeignKeyConstraint = "git_auth_links_oauth_access_token_key_id_fkey" // ALTER TABLE ONLY external_auth_links ADD CONSTRAINT git_auth_links_oauth_access_token_key_id_fkey FOREIGN KEY (oauth_access_token_key_id) REFERENCES dbcrypt_keys(active_key_digest);
ForeignKeyGitAuthLinksOauthRefreshTokenKeyID ForeignKeyConstraint = "git_auth_links_oauth_refresh_token_key_id_fkey" // ALTER TABLE ONLY external_auth_links ADD CONSTRAINT git_auth_links_oauth_refresh_token_key_id_fkey FOREIGN KEY (oauth_refresh_token_key_id) REFERENCES dbcrypt_keys(active_key_digest);
@@ -1,2 +0,0 @@
DROP TABLE IF EXISTS chat_debug_steps;
DROP TABLE IF EXISTS chat_debug_runs;
@@ -1,59 +0,0 @@
CREATE TABLE chat_debug_runs (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
chat_id UUID NOT NULL REFERENCES chats(id) ON DELETE CASCADE,
-- root_chat_id and parent_chat_id are intentionally NOT
-- foreign-keyed to chats(id). They are snapshot values that
-- record the subchat hierarchy at run time. The referenced
-- chat may be archived or deleted independently, and we want
-- to preserve the historical lineage in debug rows rather
-- than cascade-delete them.
root_chat_id UUID,
parent_chat_id UUID,
model_config_id UUID,
trigger_message_id BIGINT,
history_tip_message_id BIGINT,
kind TEXT NOT NULL,
status TEXT NOT NULL,
provider TEXT,
model TEXT,
summary JSONB NOT NULL DEFAULT '{}'::jsonb,
started_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
finished_at TIMESTAMPTZ
);
CREATE UNIQUE INDEX idx_chat_debug_runs_id_chat ON chat_debug_runs(id, chat_id);
CREATE INDEX idx_chat_debug_runs_chat_started ON chat_debug_runs(chat_id, started_at DESC);
CREATE TABLE chat_debug_steps (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
run_id UUID NOT NULL,
chat_id UUID NOT NULL REFERENCES chats(id) ON DELETE CASCADE,
step_number INT NOT NULL,
operation TEXT NOT NULL,
status TEXT NOT NULL,
history_tip_message_id BIGINT,
assistant_message_id BIGINT,
normalized_request JSONB NOT NULL,
normalized_response JSONB,
usage JSONB,
attempts JSONB NOT NULL DEFAULT '[]'::jsonb,
error JSONB,
metadata JSONB NOT NULL DEFAULT '{}'::jsonb,
started_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
finished_at TIMESTAMPTZ,
CONSTRAINT fk_chat_debug_steps_run_chat
FOREIGN KEY (run_id, chat_id)
REFERENCES chat_debug_runs(id, chat_id)
ON DELETE CASCADE
);
CREATE UNIQUE INDEX idx_chat_debug_steps_run_step ON chat_debug_steps(run_id, step_number);
CREATE INDEX idx_chat_debug_steps_chat_tip ON chat_debug_steps(chat_id, history_tip_message_id);
-- Supports DeleteChatDebugDataAfterMessageID assistant_message_id branch.
CREATE INDEX idx_chat_debug_steps_chat_assistant_msg ON chat_debug_steps(chat_id, assistant_message_id) WHERE assistant_message_id IS NOT NULL;
-- Supports FinalizeStaleChatDebugRows worker query.
CREATE INDEX idx_chat_debug_runs_stale ON chat_debug_runs(updated_at) WHERE finished_at IS NULL;
CREATE INDEX idx_chat_debug_steps_stale ON chat_debug_steps(updated_at) WHERE finished_at IS NULL;
@@ -1,65 +0,0 @@
INSERT INTO chat_debug_runs (
id,
chat_id,
model_config_id,
history_tip_message_id,
kind,
status,
provider,
model,
summary,
started_at,
updated_at,
finished_at
) VALUES (
'c98518f8-9fb3-458b-a642-57552af1db63',
'72c0438a-18eb-4688-ab80-e4c6a126ef96',
'9af5f8d5-6a57-4505-8a69-3d6c787b95fd',
(SELECT MAX(id) FROM chat_messages WHERE chat_id = '72c0438a-18eb-4688-ab80-e4c6a126ef96'),
'chat_turn',
'completed',
'openai',
'gpt-5.2',
'{"step_count":1,"has_error":false}'::jsonb,
'2024-01-01 00:00:00+00',
'2024-01-01 00:00:01+00',
'2024-01-01 00:00:01+00'
);
INSERT INTO chat_debug_steps (
id,
run_id,
chat_id,
step_number,
operation,
status,
history_tip_message_id,
assistant_message_id,
normalized_request,
normalized_response,
usage,
attempts,
error,
metadata,
started_at,
updated_at,
finished_at
) VALUES (
'59471c60-7851-4fa6-bf05-e21dd939721f',
'c98518f8-9fb3-458b-a642-57552af1db63',
'72c0438a-18eb-4688-ab80-e4c6a126ef96',
1,
'stream',
'completed',
(SELECT MAX(id) FROM chat_messages WHERE chat_id = '72c0438a-18eb-4688-ab80-e4c6a126ef96'),
(SELECT MAX(id) FROM chat_messages WHERE chat_id = '72c0438a-18eb-4688-ab80-e4c6a126ef96'),
'{"messages":[]}'::jsonb,
'{"finish_reason":"stop"}'::jsonb,
'{"input_tokens":1,"output_tokens":1}'::jsonb,
'[]'::jsonb,
NULL,
'{"provider":"openai"}'::jsonb,
'2024-01-01 00:00:00+00',
'2024-01-01 00:00:01+00',
'2024-01-01 00:00:01+00'
);
-38
View File
@@ -4248,44 +4248,6 @@ type Chat struct {
DynamicTools pqtype.NullRawMessage `db:"dynamic_tools" json:"dynamic_tools"`
}
type ChatDebugRun struct {
ID uuid.UUID `db:"id" json:"id"`
ChatID uuid.UUID `db:"chat_id" json:"chat_id"`
RootChatID uuid.NullUUID `db:"root_chat_id" json:"root_chat_id"`
ParentChatID uuid.NullUUID `db:"parent_chat_id" json:"parent_chat_id"`
ModelConfigID uuid.NullUUID `db:"model_config_id" json:"model_config_id"`
TriggerMessageID sql.NullInt64 `db:"trigger_message_id" json:"trigger_message_id"`
HistoryTipMessageID sql.NullInt64 `db:"history_tip_message_id" json:"history_tip_message_id"`
Kind string `db:"kind" json:"kind"`
Status string `db:"status" json:"status"`
Provider sql.NullString `db:"provider" json:"provider"`
Model sql.NullString `db:"model" json:"model"`
Summary json.RawMessage `db:"summary" json:"summary"`
StartedAt time.Time `db:"started_at" json:"started_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
FinishedAt sql.NullTime `db:"finished_at" json:"finished_at"`
}
type ChatDebugStep struct {
ID uuid.UUID `db:"id" json:"id"`
RunID uuid.UUID `db:"run_id" json:"run_id"`
ChatID uuid.UUID `db:"chat_id" json:"chat_id"`
StepNumber int32 `db:"step_number" json:"step_number"`
Operation string `db:"operation" json:"operation"`
Status string `db:"status" json:"status"`
HistoryTipMessageID sql.NullInt64 `db:"history_tip_message_id" json:"history_tip_message_id"`
AssistantMessageID sql.NullInt64 `db:"assistant_message_id" json:"assistant_message_id"`
NormalizedRequest json.RawMessage `db:"normalized_request" json:"normalized_request"`
NormalizedResponse pqtype.NullRawMessage `db:"normalized_response" json:"normalized_response"`
Usage pqtype.NullRawMessage `db:"usage" json:"usage"`
Attempts json.RawMessage `db:"attempts" json:"attempts"`
Error pqtype.NullRawMessage `db:"error" json:"error"`
Metadata json.RawMessage `db:"metadata" json:"metadata"`
StartedAt time.Time `db:"started_at" json:"started_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
FinishedAt sql.NullTime `db:"finished_at" json:"finished_at"`
}
type ChatDiffStatus struct {
ChatID uuid.UUID `db:"chat_id" json:"chat_id"`
Url sql.NullString `db:"url" json:"url"`
-38
View File
@@ -102,8 +102,6 @@ type sqlcQuerier interface {
// be recreated.
DeleteAllWebpushSubscriptions(ctx context.Context) error
DeleteApplicationConnectAPIKeysByUserID(ctx context.Context, userID uuid.UUID) error
DeleteChatDebugDataAfterMessageID(ctx context.Context, arg DeleteChatDebugDataAfterMessageIDParams) (int64, error)
DeleteChatDebugDataByChatID(ctx context.Context, chatID uuid.UUID) (int64, error)
DeleteChatModelConfigByID(ctx context.Context, id uuid.UUID) error
DeleteChatProviderByID(ctx context.Context, id uuid.UUID) error
DeleteChatQueuedMessage(ctx context.Context, arg DeleteChatQueuedMessageParams) error
@@ -196,16 +194,6 @@ type sqlcQuerier interface {
FetchNewMessageMetadata(ctx context.Context, arg FetchNewMessageMetadataParams) (FetchNewMessageMetadataRow, error)
FetchVolumesResourceMonitorsByAgentID(ctx context.Context, agentID uuid.UUID) ([]WorkspaceAgentVolumeResourceMonitor, error)
FetchVolumesResourceMonitorsUpdatedAfter(ctx context.Context, updatedAt time.Time) ([]WorkspaceAgentVolumeResourceMonitor, error)
// Marks orphaned in-progress rows as interrupted so they do not stay
// in a non-terminal state forever. The NOT IN list must match the
// terminal statuses defined by ChatDebugStatus in codersdk/chats.go.
//
// The steps CTE also catches steps whose parent run was just finalized
// (via run_id IN), because PostgreSQL data-modifying CTEs share the
// same snapshot and cannot see each other's row updates. Without this,
// a step with a recent updated_at would survive its run's finalization
// and remain in 'in_progress' state permanently.
FinalizeStaleChatDebugRows(ctx context.Context, updatedBefore time.Time) (FinalizeStaleChatDebugRowsRow, error)
// FindMatchingPresetID finds a preset ID that is the largest exact subset of the provided parameters.
// It returns the preset ID if a match is found, or NULL if no match is found.
// The query finds presets where all preset parameters are present in the provided parameters,
@@ -270,15 +258,6 @@ type sqlcQuerier interface {
// Aggregate cost summary for a single user within a date range.
// Only counts assistant-role messages.
GetChatCostSummary(ctx context.Context, arg GetChatCostSummaryParams) (GetChatCostSummaryRow, error)
// GetChatDebugLoggingAllowUsers returns the runtime admin setting that
// allows users to opt into chat debug logging when the deployment does
// not already force debug logging on globally.
GetChatDebugLoggingAllowUsers(ctx context.Context) (bool, error)
GetChatDebugRunByID(ctx context.Context, id uuid.UUID) (ChatDebugRun, error)
// Returns the most recent debug runs for a chat, ordered newest-first.
// Callers must supply an explicit limit to avoid unbounded result sets.
GetChatDebugRunsByChatID(ctx context.Context, arg GetChatDebugRunsByChatIDParams) ([]ChatDebugRun, error)
GetChatDebugStepsByRunID(ctx context.Context, runID uuid.UUID) ([]ChatDebugStep, error)
GetChatDesktopEnabled(ctx context.Context) (bool, error)
GetChatDiffStatusByChatID(ctx context.Context, chatID uuid.UUID) (ChatDiffStatus, error)
GetChatDiffStatusesByChatIDs(ctx context.Context, chatIds []uuid.UUID) ([]ChatDiffStatus, error)
@@ -640,7 +619,6 @@ type sqlcQuerier interface {
GetUserByID(ctx context.Context, id uuid.UUID) (User, error)
GetUserChatCompactionThreshold(ctx context.Context, arg GetUserChatCompactionThresholdParams) (string, error)
GetUserChatCustomPrompt(ctx context.Context, userID uuid.UUID) (string, error)
GetUserChatDebugLoggingEnabled(ctx context.Context, userID uuid.UUID) (bool, error)
GetUserChatProviderKeys(ctx context.Context, userID uuid.UUID) ([]UserChatProviderKey, error)
GetUserChatSpendInPeriod(ctx context.Context, arg GetUserChatSpendInPeriodParams) (int64, error)
GetUserCount(ctx context.Context, includeSystem bool) (int64, error)
@@ -760,8 +738,6 @@ type sqlcQuerier interface {
InsertAllUsersGroup(ctx context.Context, organizationID uuid.UUID) (Group, error)
InsertAuditLog(ctx context.Context, arg InsertAuditLogParams) (AuditLog, error)
InsertChat(ctx context.Context, arg InsertChatParams) (Chat, error)
InsertChatDebugRun(ctx context.Context, arg InsertChatDebugRunParams) (ChatDebugRun, error)
InsertChatDebugStep(ctx context.Context, arg InsertChatDebugStepParams) (ChatDebugStep, error)
InsertChatFile(ctx context.Context, arg InsertChatFileParams) (InsertChatFileRow, error)
InsertChatMessages(ctx context.Context, arg InsertChatMessagesParams) ([]ChatMessage, error)
InsertChatModelConfig(ctx context.Context, arg InsertChatModelConfigParams) (ChatModelConfig, error)
@@ -940,16 +916,6 @@ type sqlcQuerier interface {
UpdateAPIKeyByID(ctx context.Context, arg UpdateAPIKeyByIDParams) error
UpdateChatBuildAgentBinding(ctx context.Context, arg UpdateChatBuildAgentBindingParams) (Chat, error)
UpdateChatByID(ctx context.Context, arg UpdateChatByIDParams) (Chat, error)
// Uses COALESCE so that passing NULL from Go means "keep the
// existing value." This is intentional: debug rows follow a
// write-once-finalize pattern where fields are set at creation
// or finalization and never cleared back to NULL.
UpdateChatDebugRun(ctx context.Context, arg UpdateChatDebugRunParams) (ChatDebugRun, error)
// Uses COALESCE so that passing NULL from Go means "keep the
// existing value." This is intentional: debug rows follow a
// write-once-finalize pattern where fields are set at creation
// or finalization and never cleared back to NULL.
UpdateChatDebugStep(ctx context.Context, arg UpdateChatDebugStepParams) (ChatDebugStep, error)
// Bumps the heartbeat timestamp for the given set of chat IDs,
// provided they are still running and owned by the specified
// worker. Returns the IDs that were actually updated so the
@@ -1078,9 +1044,6 @@ type sqlcQuerier interface {
// cumulative values for unique counts (accurate period totals). Request counts
// are always deltas, accumulated in DB. Returns true if insert, false if update.
UpsertBoundaryUsageStats(ctx context.Context, arg UpsertBoundaryUsageStatsParams) (bool, error)
// UpsertChatDebugLoggingAllowUsers updates the runtime admin setting that
// allows users to opt into chat debug logging.
UpsertChatDebugLoggingAllowUsers(ctx context.Context, allowUsers bool) error
UpsertChatDesktopEnabled(ctx context.Context, enableDesktop bool) error
UpsertChatDiffStatus(ctx context.Context, arg UpsertChatDiffStatusParams) (ChatDiffStatus, error)
UpsertChatDiffStatusReference(ctx context.Context, arg UpsertChatDiffStatusReferenceParams) (ChatDiffStatus, error)
@@ -1118,7 +1081,6 @@ type sqlcQuerier interface {
// used to store the data, and the minutes are summed for each user and template
// combination. The result is stored in the template_usage_stats table.
UpsertTemplateUsageStats(ctx context.Context) error
UpsertUserChatDebugLoggingEnabled(ctx context.Context, arg UpsertUserChatDebugLoggingEnabledParams) error
UpsertUserChatProviderKey(ctx context.Context, arg UpsertUserChatProviderKeyParams) (UserChatProviderKey, error)
UpsertWebpushVAPIDKeys(ctx context.Context, arg UpsertWebpushVAPIDKeysParams) error
UpsertWorkspaceAgentPortShare(ctx context.Context, arg UpsertWorkspaceAgentPortShareParams) (WorkspaceAgentPortShare, error)
-945
View File
@@ -11218,951 +11218,6 @@ func TestChatLabels(t *testing.T) {
})
}
func TestDeleteChatDebugDataAfterMessageIDIncludesTriggeredRuns(t *testing.T) {
t.Parallel()
store, _ := dbtestutil.NewDB(t)
ctx := testutil.Context(t, testutil.WaitMedium)
dbgen.Organization(t, store, database.Organization{})
user := dbgen.User(t, store, database.User{})
providerName := "openai"
modelName := "debug-model-" + uuid.NewString()
_, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{
Provider: providerName,
DisplayName: "Debug Provider",
APIKey: "test-key",
Enabled: true,
CentralApiKeyEnabled: true,
})
require.NoError(t, err)
modelCfg, err := store.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
Provider: providerName,
Model: modelName,
DisplayName: "Debug Model",
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
Enabled: true,
IsDefault: true,
ContextLimit: 128000,
CompressionThreshold: 80,
Options: json.RawMessage(`{}`),
})
require.NoError(t, err)
chat, err := store.InsertChat(ctx, database.InsertChatParams{
Status: database.ChatStatusWaiting,
OwnerID: user.ID,
LastModelConfigID: modelCfg.ID,
Title: "chat-debug-rollback-" + uuid.NewString(),
})
require.NoError(t, err)
const cutoff int64 = 50
affectedRun, err := store.InsertChatDebugRun(ctx, database.InsertChatDebugRunParams{
ChatID: chat.ID,
ModelConfigID: uuid.NullUUID{UUID: modelCfg.ID, Valid: true},
TriggerMessageID: sql.NullInt64{Int64: cutoff + 10, Valid: true},
HistoryTipMessageID: sql.NullInt64{Int64: cutoff - 5, Valid: true},
Kind: "chat_turn",
Status: "in_progress",
Provider: sql.NullString{String: providerName, Valid: true},
Model: sql.NullString{String: modelName, Valid: true},
})
require.NoError(t, err)
_, err = store.InsertChatDebugStep(ctx, database.InsertChatDebugStepParams{
RunID: affectedRun.ID,
ChatID: chat.ID,
StepNumber: 1,
Operation: "stream",
Status: "in_progress",
})
require.NoError(t, err)
affectedByStepHistoryTipRun, err := store.InsertChatDebugRun(ctx, database.InsertChatDebugRunParams{
ChatID: chat.ID,
ModelConfigID: uuid.NullUUID{UUID: modelCfg.ID, Valid: true},
TriggerMessageID: sql.NullInt64{Int64: cutoff - 1, Valid: true},
HistoryTipMessageID: sql.NullInt64{Int64: cutoff - 1, Valid: true},
Kind: "chat_turn",
Status: "in_progress",
Provider: sql.NullString{String: providerName, Valid: true},
Model: sql.NullString{String: modelName, Valid: true},
})
require.NoError(t, err)
_, err = store.InsertChatDebugStep(ctx, database.InsertChatDebugStepParams{
RunID: affectedByStepHistoryTipRun.ID,
ChatID: chat.ID,
StepNumber: 1,
Operation: "stream",
Status: "interrupted",
HistoryTipMessageID: sql.NullInt64{Int64: cutoff + 7, Valid: true},
})
require.NoError(t, err)
// affectedByStepAssistantMsgRun: run-level fields are at/below
// the cutoff, but its step has assistant_message_id above the
// cutoff. This exercises the step.assistant_message_id > cutoff
// branch of the UNION independently of history_tip_message_id.
affectedByStepAssistantMsgRun, err := store.InsertChatDebugRun(ctx, database.InsertChatDebugRunParams{
ChatID: chat.ID,
ModelConfigID: uuid.NullUUID{UUID: modelCfg.ID, Valid: true},
TriggerMessageID: sql.NullInt64{Int64: cutoff - 2, Valid: true},
HistoryTipMessageID: sql.NullInt64{Int64: cutoff - 2, Valid: true},
Kind: "chat_turn",
Status: "in_progress",
Provider: sql.NullString{String: providerName, Valid: true},
Model: sql.NullString{String: modelName, Valid: true},
})
require.NoError(t, err)
_, err = store.InsertChatDebugStep(ctx, database.InsertChatDebugStepParams{
RunID: affectedByStepAssistantMsgRun.ID,
ChatID: chat.ID,
StepNumber: 1,
Operation: "stream",
Status: "completed",
AssistantMessageID: sql.NullInt64{Int64: cutoff + 3, Valid: true},
})
require.NoError(t, err)
unaffectedRun, err := store.InsertChatDebugRun(ctx, database.InsertChatDebugRunParams{
ChatID: chat.ID,
ModelConfigID: uuid.NullUUID{UUID: modelCfg.ID, Valid: true},
TriggerMessageID: sql.NullInt64{Int64: cutoff, Valid: true},
HistoryTipMessageID: sql.NullInt64{Int64: cutoff, Valid: true},
Kind: "chat_turn",
Status: "in_progress",
Provider: sql.NullString{String: providerName, Valid: true},
Model: sql.NullString{String: modelName, Valid: true},
})
require.NoError(t, err)
unaffectedStep, err := store.InsertChatDebugStep(ctx, database.InsertChatDebugStepParams{
RunID: unaffectedRun.ID,
ChatID: chat.ID,
StepNumber: 1,
Operation: "stream",
Status: "in_progress",
AssistantMessageID: sql.NullInt64{Int64: cutoff, Valid: true},
})
require.NoError(t, err)
deletedRows, err := store.DeleteChatDebugDataAfterMessageID(ctx, database.DeleteChatDebugDataAfterMessageIDParams{
ChatID: chat.ID,
MessageID: cutoff,
})
require.NoError(t, err)
require.EqualValues(t, 3, deletedRows)
_, err = store.GetChatDebugRunByID(ctx, affectedRun.ID)
require.ErrorIs(t, err, sql.ErrNoRows)
affectedSteps, err := store.GetChatDebugStepsByRunID(ctx, affectedRun.ID)
require.NoError(t, err)
require.Empty(t, affectedSteps)
_, err = store.GetChatDebugRunByID(ctx, affectedByStepHistoryTipRun.ID)
require.ErrorIs(t, err, sql.ErrNoRows)
affectedByStepHistoryTipSteps, err := store.GetChatDebugStepsByRunID(ctx, affectedByStepHistoryTipRun.ID)
require.NoError(t, err)
require.Empty(t, affectedByStepHistoryTipSteps)
// Verify the run caught by step-level assistant_message_id is
// also deleted. This would survive if the
// step.assistant_message_id > @message_id clause were removed.
_, err = store.GetChatDebugRunByID(ctx, affectedByStepAssistantMsgRun.ID)
require.ErrorIs(t, err, sql.ErrNoRows)
affectedByStepAssistantMsgSteps, err := store.GetChatDebugStepsByRunID(ctx, affectedByStepAssistantMsgRun.ID)
require.NoError(t, err)
require.Empty(t, affectedByStepAssistantMsgSteps)
remainingRuns, err := store.GetChatDebugRunsByChatID(ctx, database.GetChatDebugRunsByChatIDParams{
ChatID: chat.ID,
LimitVal: 100,
})
require.NoError(t, err)
require.Len(t, remainingRuns, 1)
require.Equal(t, unaffectedRun.ID, remainingRuns[0].ID)
remainingRun, err := store.GetChatDebugRunByID(ctx, unaffectedRun.ID)
require.NoError(t, err)
require.Equal(t, unaffectedRun.ID, remainingRun.ID)
remainingSteps, err := store.GetChatDebugStepsByRunID(ctx, unaffectedRun.ID)
require.NoError(t, err)
require.Len(t, remainingSteps, 1)
require.Equal(t, unaffectedStep.ID, remainingSteps[0].ID)
}
func TestFinalizeStaleChatDebugRows(t *testing.T) {
t.Parallel()
store, _ := dbtestutil.NewDB(t)
ctx := testutil.Context(t, testutil.WaitMedium)
dbgen.Organization(t, store, database.Organization{})
user := dbgen.User(t, store, database.User{})
providerName := "openai"
modelName := "debug-model-finalize-" + uuid.NewString()
_, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{
Provider: providerName,
DisplayName: "Debug Provider",
APIKey: "test-key",
Enabled: true,
CentralApiKeyEnabled: true,
})
require.NoError(t, err)
modelCfg, err := store.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
Provider: providerName,
Model: modelName,
DisplayName: "Debug Model",
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
Enabled: true,
IsDefault: true,
ContextLimit: 128000,
CompressionThreshold: 80,
Options: json.RawMessage(`{}`),
})
require.NoError(t, err)
chat, err := store.InsertChat(ctx, database.InsertChatParams{
Status: database.ChatStatusWaiting,
OwnerID: user.ID,
LastModelConfigID: modelCfg.ID,
Title: "chat-finalize-" + uuid.NewString(),
})
require.NoError(t, err)
// staleTime is well before the threshold so rows stamped with it
// are considered stale. The threshold sits between staleTime and
// NOW(), letting us create rows that are stale-by-age and rows
// that are fresh-by-age in the same test.
staleTime := time.Now().Add(-2 * time.Hour)
staleThreshold := time.Now().Add(-1 * time.Hour)
// --- staleRun: in_progress run with no finished_at --- should be
// finalized.
staleRun, err := store.InsertChatDebugRun(ctx, database.InsertChatDebugRunParams{
ChatID: chat.ID,
ModelConfigID: uuid.NullUUID{UUID: modelCfg.ID, Valid: true},
TriggerMessageID: sql.NullInt64{Int64: 1, Valid: true},
HistoryTipMessageID: sql.NullInt64{Int64: 1, Valid: true},
Kind: "chat_turn",
Status: "in_progress",
Provider: sql.NullString{String: providerName, Valid: true},
Model: sql.NullString{String: modelName, Valid: true},
UpdatedAt: sql.NullTime{Time: staleTime, Valid: true},
})
require.NoError(t, err)
// staleStep: in_progress step attached to staleRun.
staleStep, err := store.InsertChatDebugStep(ctx, database.InsertChatDebugStepParams{
RunID: staleRun.ID,
ChatID: chat.ID,
StepNumber: 1,
Operation: "stream",
Status: "in_progress",
UpdatedAt: sql.NullTime{Time: staleTime, Valid: true},
})
require.NoError(t, err)
// --- orphanStep: in_progress step whose run is already completed ---
// its own updated_at is old, so it should be finalized directly.
completedRun, err := store.InsertChatDebugRun(ctx, database.InsertChatDebugRunParams{
ChatID: chat.ID,
ModelConfigID: uuid.NullUUID{UUID: modelCfg.ID, Valid: true},
TriggerMessageID: sql.NullInt64{Int64: 2, Valid: true},
HistoryTipMessageID: sql.NullInt64{Int64: 2, Valid: true},
Kind: "chat_turn",
Status: "completed",
})
require.NoError(t, err)
// Mark the run as completed with a finished_at timestamp.
_, err = store.UpdateChatDebugRun(ctx, database.UpdateChatDebugRunParams{
ID: completedRun.ID,
ChatID: completedRun.ChatID,
Status: sql.NullString{String: "completed", Valid: true},
FinishedAt: sql.NullTime{
Time: time.Now(),
Valid: true,
},
})
require.NoError(t, err)
orphanStep, err := store.InsertChatDebugStep(ctx, database.InsertChatDebugStepParams{
RunID: completedRun.ID,
ChatID: chat.ID,
StepNumber: 1,
Operation: "stream",
Status: "in_progress",
UpdatedAt: sql.NullTime{Time: staleTime, Valid: true},
})
require.NoError(t, err)
// --- cascadeRun: stale in_progress run with a FRESH step ---
// The run's updated_at is old so the run itself is finalized by
// age. The step's updated_at is recent (default NOW()), so it is
// NOT caught by the age predicate. It must be finalized solely
// via the cascade CTE clause: run_id IN (SELECT id FROM
// finalized_runs). Removing that clause would leave this step
// stuck in 'in_progress'.
cascadeRun, err := store.InsertChatDebugRun(ctx, database.InsertChatDebugRunParams{
ChatID: chat.ID,
ModelConfigID: uuid.NullUUID{UUID: modelCfg.ID, Valid: true},
TriggerMessageID: sql.NullInt64{Int64: 10, Valid: true},
HistoryTipMessageID: sql.NullInt64{Int64: 10, Valid: true},
Kind: "chat_turn",
Status: "in_progress",
Provider: sql.NullString{String: providerName, Valid: true},
Model: sql.NullString{String: modelName, Valid: true},
UpdatedAt: sql.NullTime{Time: staleTime, Valid: true},
})
require.NoError(t, err)
// cascadeStep: recent updated_at (default NOW()), so only the
// cascade path can finalize it.
cascadeStep, err := store.InsertChatDebugStep(ctx, database.InsertChatDebugStepParams{
RunID: cascadeRun.ID,
ChatID: chat.ID,
StepNumber: 1,
Operation: "stream",
Status: "in_progress",
})
require.NoError(t, err)
// --- alreadyDone: completed run/step --- should NOT be touched.
doneRun, err := store.InsertChatDebugRun(ctx, database.InsertChatDebugRunParams{
ChatID: chat.ID,
ModelConfigID: uuid.NullUUID{UUID: modelCfg.ID, Valid: true},
TriggerMessageID: sql.NullInt64{Int64: 3, Valid: true},
HistoryTipMessageID: sql.NullInt64{Int64: 3, Valid: true},
Kind: "chat_turn",
Status: "completed",
})
require.NoError(t, err)
_, err = store.UpdateChatDebugRun(ctx, database.UpdateChatDebugRunParams{
ID: doneRun.ID,
ChatID: doneRun.ChatID,
Status: sql.NullString{String: "completed", Valid: true},
FinishedAt: sql.NullTime{
Time: time.Now(),
Valid: true,
},
})
require.NoError(t, err)
doneStep, err := store.InsertChatDebugStep(ctx, database.InsertChatDebugStepParams{
RunID: doneRun.ID,
ChatID: chat.ID,
StepNumber: 1,
Operation: "stream",
Status: "completed",
})
require.NoError(t, err)
_, err = store.UpdateChatDebugStep(ctx, database.UpdateChatDebugStepParams{
ID: doneStep.ID,
ChatID: chat.ID,
Status: sql.NullString{String: "completed", Valid: true},
FinishedAt: sql.NullTime{
Time: time.Now(),
Valid: true,
},
})
require.NoError(t, err)
// --- errorRun: error run/step --- should NOT be touched either,
// exercising the 'error' branch of the NOT IN clause.
errorRun, err := store.InsertChatDebugRun(ctx, database.InsertChatDebugRunParams{
ChatID: chat.ID,
ModelConfigID: uuid.NullUUID{UUID: modelCfg.ID, Valid: true},
TriggerMessageID: sql.NullInt64{Int64: 4, Valid: true},
HistoryTipMessageID: sql.NullInt64{Int64: 4, Valid: true},
Kind: "chat_turn",
Status: "error",
})
require.NoError(t, err)
_, err = store.UpdateChatDebugRun(ctx, database.UpdateChatDebugRunParams{
ID: errorRun.ID,
ChatID: errorRun.ChatID,
Status: sql.NullString{String: "error", Valid: true},
FinishedAt: sql.NullTime{
Time: time.Now(),
Valid: true,
},
})
require.NoError(t, err)
errorStep, err := store.InsertChatDebugStep(ctx, database.InsertChatDebugStepParams{
RunID: errorRun.ID,
ChatID: chat.ID,
StepNumber: 1,
Operation: "stream",
Status: "error",
})
require.NoError(t, err)
_, err = store.UpdateChatDebugStep(ctx, database.UpdateChatDebugStepParams{
ID: errorStep.ID,
ChatID: chat.ID,
Status: sql.NullString{String: "error", Valid: true},
FinishedAt: sql.NullTime{
Time: time.Now(),
Valid: true,
},
})
require.NoError(t, err)
// --- Execute the finalization sweep. ---
result, err := store.FinalizeStaleChatDebugRows(ctx, staleThreshold)
require.NoError(t, err)
// staleRun + cascadeRun were finalized; completedRun and doneRun
// were already terminal so only 2 runs are expected.
assert.EqualValues(t, 2, result.RunsFinalized,
"stale + cascade in_progress runs should be finalized")
// staleStep (age), orphanStep (age), cascadeStep (cascade only)
// should all be finalized.
assert.EqualValues(t, 3, result.StepsFinalized,
"stale step + orphan step + cascade step should all be finalized")
// Verify the stale run was set to interrupted.
updatedStaleRun, err := store.GetChatDebugRunByID(ctx, staleRun.ID)
require.NoError(t, err)
assert.Equal(t, "interrupted", updatedStaleRun.Status)
assert.True(t, updatedStaleRun.FinishedAt.Valid,
"finalized run should have a finished_at timestamp")
// Verify the stale step was set to interrupted.
staleSteps, err := store.GetChatDebugStepsByRunID(ctx, staleRun.ID)
require.NoError(t, err)
require.Len(t, staleSteps, 1)
assert.Equal(t, staleStep.ID, staleSteps[0].ID)
assert.Equal(t, "interrupted", staleSteps[0].Status)
assert.True(t, staleSteps[0].FinishedAt.Valid,
"finalized step should have a finished_at timestamp")
// Verify the orphan step was also finalized.
orphanSteps, err := store.GetChatDebugStepsByRunID(ctx, completedRun.ID)
require.NoError(t, err)
require.Len(t, orphanSteps, 1)
assert.Equal(t, orphanStep.ID, orphanSteps[0].ID)
assert.Equal(t, "interrupted", orphanSteps[0].Status)
// Verify the cascade run was finalized.
updatedCascadeRun, err := store.GetChatDebugRunByID(ctx, cascadeRun.ID)
require.NoError(t, err)
assert.Equal(t, "interrupted", updatedCascadeRun.Status)
assert.True(t, updatedCascadeRun.FinishedAt.Valid,
"cascade run should have a finished_at timestamp")
// Verify the cascade step was finalized despite its recent
// updated_at, proving the cascade CTE clause is required.
cascadeSteps, err := store.GetChatDebugStepsByRunID(ctx, cascadeRun.ID)
require.NoError(t, err)
require.Len(t, cascadeSteps, 1)
assert.Equal(t, cascadeStep.ID, cascadeSteps[0].ID)
assert.Equal(t, "interrupted", cascadeSteps[0].Status,
"fresh step should be finalized via cascade, not age")
assert.True(t, cascadeSteps[0].FinishedAt.Valid,
"cascade step should have a finished_at timestamp")
// Verify the completed run/step are untouched.
unchangedRun, err := store.GetChatDebugRunByID(ctx, doneRun.ID)
require.NoError(t, err)
assert.Equal(t, "completed", unchangedRun.Status)
doneSteps, err := store.GetChatDebugStepsByRunID(ctx, doneRun.ID)
require.NoError(t, err)
require.Len(t, doneSteps, 1)
assert.Equal(t, "completed", doneSteps[0].Status)
// Verify the error run/step are untouched.
unchangedErrorRun, err := store.GetChatDebugRunByID(ctx, errorRun.ID)
require.NoError(t, err)
assert.Equal(t, "error", unchangedErrorRun.Status)
errorSteps, err := store.GetChatDebugStepsByRunID(ctx, errorRun.ID)
require.NoError(t, err)
require.Len(t, errorSteps, 1)
assert.Equal(t, "error", errorSteps[0].Status)
// A second sweep should be a no-op.
result2, err := store.FinalizeStaleChatDebugRows(ctx, staleThreshold)
require.NoError(t, err)
assert.EqualValues(t, 0, result2.RunsFinalized,
"second sweep should find nothing to finalize")
assert.EqualValues(t, 0, result2.StepsFinalized,
"second sweep should find nothing to finalize")
}
func TestChatDebugSQLGuards(t *testing.T) {
t.Parallel()
store, _ := dbtestutil.NewDB(t)
ctx := testutil.Context(t, testutil.WaitMedium)
dbgen.Organization(t, store, database.Organization{})
user := dbgen.User(t, store, database.User{})
providerName := "openai"
modelName := "debug-model-guards-" + uuid.NewString()
_, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{
Provider: providerName,
DisplayName: "Debug Provider",
APIKey: "test-key",
Enabled: true,
CentralApiKeyEnabled: true,
})
require.NoError(t, err)
modelCfg, err := store.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
Provider: providerName,
Model: modelName,
DisplayName: "Debug Model",
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
Enabled: true,
IsDefault: true,
ContextLimit: 128000,
CompressionThreshold: 80,
Options: json.RawMessage(`{}`),
})
require.NoError(t, err)
chatA, err := store.InsertChat(ctx, database.InsertChatParams{
Status: database.ChatStatusWaiting,
OwnerID: user.ID,
LastModelConfigID: modelCfg.ID,
Title: "chat-guard-A-" + uuid.NewString(),
})
require.NoError(t, err)
chatB, err := store.InsertChat(ctx, database.InsertChatParams{
Status: database.ChatStatusWaiting,
OwnerID: user.ID,
LastModelConfigID: modelCfg.ID,
Title: "chat-guard-B-" + uuid.NewString(),
})
require.NoError(t, err)
runA, err := store.InsertChatDebugRun(ctx, database.InsertChatDebugRunParams{
ChatID: chatA.ID,
ModelConfigID: uuid.NullUUID{UUID: modelCfg.ID, Valid: true},
TriggerMessageID: sql.NullInt64{Int64: 1, Valid: true},
HistoryTipMessageID: sql.NullInt64{Int64: 1, Valid: true},
Kind: "chat_turn",
Status: "in_progress",
Provider: sql.NullString{String: providerName, Valid: true},
Model: sql.NullString{String: modelName, Valid: true},
})
require.NoError(t, err)
stepA, err := store.InsertChatDebugStep(ctx, database.InsertChatDebugStepParams{
RunID: runA.ID,
ChatID: chatA.ID,
StepNumber: 1,
Operation: "stream",
Status: "in_progress",
})
require.NoError(t, err)
// InsertChatDebugStep: valid run_id but chat_id belongs to a
// different chat. The INSERT...SELECT guard should produce zero
// rows, surfacing as sql.ErrNoRows.
t.Run("InsertChatDebugStep_MismatchedChatID", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
_, err := store.InsertChatDebugStep(ctx, database.InsertChatDebugStepParams{
RunID: runA.ID,
ChatID: chatB.ID, // wrong chat
StepNumber: 2,
Operation: "stream",
Status: "in_progress",
})
require.ErrorIs(t, err, sql.ErrNoRows,
"InsertChatDebugStep should fail when chat_id does not match the run's chat_id")
})
// UpdateChatDebugRun: valid run ID but wrong chat_id.
t.Run("UpdateChatDebugRun_MismatchedChatID", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
_, err := store.UpdateChatDebugRun(ctx, database.UpdateChatDebugRunParams{
ID: runA.ID,
ChatID: chatB.ID, // wrong chat
Status: sql.NullString{String: "completed", Valid: true},
FinishedAt: sql.NullTime{
Time: time.Now(),
Valid: true,
},
})
require.ErrorIs(t, err, sql.ErrNoRows,
"UpdateChatDebugRun should fail when chat_id does not match")
})
// UpdateChatDebugStep: valid step ID but wrong chat_id.
t.Run("UpdateChatDebugStep_MismatchedChatID", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
_, err := store.UpdateChatDebugStep(ctx, database.UpdateChatDebugStepParams{
ID: stepA.ID,
ChatID: chatB.ID, // wrong chat
Status: sql.NullString{String: "completed", Valid: true},
FinishedAt: sql.NullTime{
Time: time.Now(),
Valid: true,
},
})
require.ErrorIs(t, err, sql.ErrNoRows,
"UpdateChatDebugStep should fail when chat_id does not match")
})
}
// TestChatDebugRunCOALESCEPreservation verifies that the COALESCE
// pattern in UpdateChatDebugRun preserves every field that was not
// explicitly supplied in the update. If COALESCE were removed from
// any column, the corresponding field would silently null out.
func TestChatDebugRunCOALESCEPreservation(t *testing.T) {
t.Parallel()
store, _ := dbtestutil.NewDB(t)
ctx := testutil.Context(t, testutil.WaitMedium)
dbgen.Organization(t, store, database.Organization{})
user := dbgen.User(t, store, database.User{})
providerName := "openai"
modelName := "debug-model-coalesce-" + uuid.NewString()
_, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{
Provider: providerName,
DisplayName: "Debug Provider",
APIKey: "test-key",
Enabled: true,
CentralApiKeyEnabled: true,
})
require.NoError(t, err)
modelCfg, err := store.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
Provider: providerName,
Model: modelName,
DisplayName: "Debug Model",
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
Enabled: true,
IsDefault: true,
ContextLimit: 128000,
CompressionThreshold: 80,
Options: json.RawMessage(`{}`),
})
require.NoError(t, err)
chat, err := store.InsertChat(ctx, database.InsertChatParams{
Status: database.ChatStatusWaiting,
OwnerID: user.ID,
LastModelConfigID: modelCfg.ID,
Title: "chat-debug-coalesce-" + uuid.NewString(),
})
require.NoError(t, err)
rootChatID := uuid.New()
parentChatID := uuid.New()
// Insert a fully-populated run so every nullable field has a value.
original, err := store.InsertChatDebugRun(ctx, database.InsertChatDebugRunParams{
ChatID: chat.ID,
RootChatID: uuid.NullUUID{UUID: rootChatID, Valid: true},
ParentChatID: uuid.NullUUID{UUID: parentChatID, Valid: true},
ModelConfigID: uuid.NullUUID{UUID: modelCfg.ID, Valid: true},
TriggerMessageID: sql.NullInt64{Int64: 42, Valid: true},
HistoryTipMessageID: sql.NullInt64{Int64: 41, Valid: true},
Kind: "chat_turn",
Status: "in_progress",
Provider: sql.NullString{String: providerName, Valid: true},
Model: sql.NullString{String: modelName, Valid: true},
Summary: pqtype.NullRawMessage{RawMessage: json.RawMessage(`{"key":"val"}`), Valid: true},
})
require.NoError(t, err)
// Update only Status and FinishedAt. Every other nullable param
// is left as its Go zero value (Valid: false → SQL NULL), which
// the COALESCE pattern should interpret as "keep existing."
now := time.Now()
updated, err := store.UpdateChatDebugRun(ctx, database.UpdateChatDebugRunParams{
ID: original.ID,
ChatID: chat.ID,
Status: sql.NullString{String: "completed", Valid: true},
FinishedAt: sql.NullTime{
Time: now,
Valid: true,
},
})
require.NoError(t, err)
// Status and FinishedAt should be updated.
require.Equal(t, "completed", updated.Status)
require.True(t, updated.FinishedAt.Valid)
// UpdatedAt should advance (set to NOW() unconditionally).
require.True(t, updated.UpdatedAt.After(original.UpdatedAt) ||
updated.UpdatedAt.Equal(original.UpdatedAt))
// Every field not in the update call must be preserved exactly.
require.Equal(t, original.RootChatID, updated.RootChatID,
"RootChatID should survive a partial update")
require.Equal(t, original.ParentChatID, updated.ParentChatID,
"ParentChatID should survive a partial update")
require.Equal(t, original.ModelConfigID, updated.ModelConfigID,
"ModelConfigID should survive a partial update")
require.Equal(t, original.TriggerMessageID, updated.TriggerMessageID,
"TriggerMessageID should survive a partial update")
require.Equal(t, original.HistoryTipMessageID, updated.HistoryTipMessageID,
"HistoryTipMessageID should survive a partial update")
require.Equal(t, original.Provider, updated.Provider,
"Provider should survive a partial update")
require.Equal(t, original.Model, updated.Model,
"Model should survive a partial update")
require.JSONEq(t, string(original.Summary), string(updated.Summary),
"Summary should survive a partial update")
require.Equal(t, original.Kind, updated.Kind,
"Kind should survive a partial update")
require.Equal(t, original.StartedAt.UTC(), updated.StartedAt.UTC(),
"StartedAt should survive a partial update")
}
// TestChatDebugStepCOALESCEPreservation verifies that the COALESCE
// pattern in UpdateChatDebugStep preserves every field that was not
// explicitly supplied in the update. If COALESCE were removed from
// any column, the corresponding field would silently null out.
func TestChatDebugStepCOALESCEPreservation(t *testing.T) {
t.Parallel()
store, _ := dbtestutil.NewDB(t)
ctx := testutil.Context(t, testutil.WaitMedium)
dbgen.Organization(t, store, database.Organization{})
user := dbgen.User(t, store, database.User{})
providerName := "openai"
modelName := "debug-step-coalesce-" + uuid.NewString()
_, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{
Provider: providerName,
DisplayName: "Debug Provider",
APIKey: "test-key",
Enabled: true,
CentralApiKeyEnabled: true,
})
require.NoError(t, err)
modelCfg, err := store.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
Provider: providerName,
Model: modelName,
DisplayName: "Debug Model",
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
Enabled: true,
IsDefault: true,
ContextLimit: 128000,
CompressionThreshold: 80,
Options: json.RawMessage(`{}`),
})
require.NoError(t, err)
chat, err := store.InsertChat(ctx, database.InsertChatParams{
Status: database.ChatStatusWaiting,
OwnerID: user.ID,
LastModelConfigID: modelCfg.ID,
Title: "chat-step-coalesce-" + uuid.NewString(),
})
require.NoError(t, err)
run, err := store.InsertChatDebugRun(ctx, database.InsertChatDebugRunParams{
ChatID: chat.ID,
Kind: "chat_turn",
Status: "in_progress",
})
require.NoError(t, err)
// Insert a fully-populated step so every nullable field has a value.
original, err := store.InsertChatDebugStep(ctx, database.InsertChatDebugStepParams{
RunID: run.ID,
ChatID: chat.ID,
StepNumber: 1,
Operation: "llm_call",
Status: "in_progress",
HistoryTipMessageID: sql.NullInt64{Int64: 10, Valid: true},
AssistantMessageID: sql.NullInt64{Int64: 11, Valid: true},
NormalizedRequest: pqtype.NullRawMessage{RawMessage: json.RawMessage(`{"prompt":"hello"}`), Valid: true},
NormalizedResponse: pqtype.NullRawMessage{RawMessage: json.RawMessage(`{"text":"world"}`), Valid: true},
Usage: pqtype.NullRawMessage{RawMessage: json.RawMessage(`{"tokens":42}`), Valid: true},
Attempts: pqtype.NullRawMessage{RawMessage: json.RawMessage(`[{"n":1}]`), Valid: true},
Error: pqtype.NullRawMessage{RawMessage: json.RawMessage(`{"code":"transient"}`), Valid: true},
Metadata: pqtype.NullRawMessage{RawMessage: json.RawMessage(`{"trace_id":"abc"}`), Valid: true},
})
require.NoError(t, err)
// Update only Status and FinishedAt. Every other nullable param
// is left as its Go zero value (Valid: false -> SQL NULL), which
// the COALESCE pattern should interpret as "keep existing."
now := time.Now()
updated, err := store.UpdateChatDebugStep(ctx, database.UpdateChatDebugStepParams{
ID: original.ID,
ChatID: chat.ID,
Status: sql.NullString{String: "completed", Valid: true},
FinishedAt: sql.NullTime{
Time: now,
Valid: true,
},
})
require.NoError(t, err)
// Status and FinishedAt should be updated.
require.Equal(t, "completed", updated.Status)
require.True(t, updated.FinishedAt.Valid)
// UpdatedAt should advance (set to NOW() unconditionally).
require.True(t, updated.UpdatedAt.After(original.UpdatedAt) ||
updated.UpdatedAt.Equal(original.UpdatedAt))
// Every field not in the update call must be preserved exactly.
require.Equal(t, original.HistoryTipMessageID, updated.HistoryTipMessageID,
"HistoryTipMessageID should survive a partial update")
require.Equal(t, original.AssistantMessageID, updated.AssistantMessageID,
"AssistantMessageID should survive a partial update")
require.JSONEq(t, string(original.NormalizedRequest), string(updated.NormalizedRequest),
"NormalizedRequest should survive a partial update")
require.JSONEq(t, string(original.NormalizedResponse.RawMessage), string(updated.NormalizedResponse.RawMessage),
"NormalizedResponse should survive a partial update")
require.JSONEq(t, string(original.Usage.RawMessage), string(updated.Usage.RawMessage),
"Usage should survive a partial update")
require.JSONEq(t, string(original.Attempts), string(updated.Attempts),
"Attempts should survive a partial update")
require.JSONEq(t, string(original.Error.RawMessage), string(updated.Error.RawMessage),
"Error should survive a partial update")
require.JSONEq(t, string(original.Metadata), string(updated.Metadata),
"Metadata should survive a partial update")
require.Equal(t, original.Operation, updated.Operation,
"Operation should survive a partial update")
require.Equal(t, original.StepNumber, updated.StepNumber,
"StepNumber should survive a partial update")
require.Equal(t, original.StartedAt.UTC(), updated.StartedAt.UTC(),
"StartedAt should survive a partial update")
}
// TestDeleteChatDebugDataAfterMessageIDNullMessagesSurvive verifies
// that runs whose message ID columns are all NULL are never matched
// by DeleteChatDebugDataAfterMessageID. SQL's three-valued logic
// means NULL > N evaluates to NULL (not TRUE), so these rows must
// survive. Without this test a future change could break the
// invariant with no test failure.
func TestDeleteChatDebugDataAfterMessageIDNullMessagesSurvive(t *testing.T) {
t.Parallel()
store, _ := dbtestutil.NewDB(t)
ctx := testutil.Context(t, testutil.WaitMedium)
dbgen.Organization(t, store, database.Organization{})
user := dbgen.User(t, store, database.User{})
providerName := "openai"
modelName := "debug-model-null-msg-" + uuid.NewString()
_, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{
Provider: providerName,
DisplayName: "Debug Provider",
APIKey: "test-key",
Enabled: true,
CentralApiKeyEnabled: true,
})
require.NoError(t, err)
modelCfg, err := store.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
Provider: providerName,
Model: modelName,
DisplayName: "Debug Model",
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
Enabled: true,
IsDefault: true,
ContextLimit: 128000,
CompressionThreshold: 80,
Options: json.RawMessage(`{}`),
})
require.NoError(t, err)
chat, err := store.InsertChat(ctx, database.InsertChatParams{
Status: database.ChatStatusWaiting,
OwnerID: user.ID,
LastModelConfigID: modelCfg.ID,
Title: "chat-debug-null-msg-" + uuid.NewString(),
})
require.NoError(t, err)
// Insert a run with all message ID columns left as NULL (Valid: false).
nullMsgRun, err := store.InsertChatDebugRun(ctx, database.InsertChatDebugRunParams{
ChatID: chat.ID,
ModelConfigID: uuid.NullUUID{UUID: modelCfg.ID, Valid: true},
Kind: "chat_turn",
Status: "in_progress",
Provider: sql.NullString{String: providerName, Valid: true},
Model: sql.NullString{String: modelName, Valid: true},
// TriggerMessageID and HistoryTipMessageID intentionally
// omitted (zero-value → SQL NULL).
})
require.NoError(t, err)
// Attach a step with NULL message IDs too.
nullMsgStep, err := store.InsertChatDebugStep(ctx, database.InsertChatDebugStepParams{
RunID: nullMsgRun.ID,
ChatID: chat.ID,
StepNumber: 1,
Operation: "stream",
Status: "in_progress",
// HistoryTipMessageID and AssistantMessageID intentionally
// omitted (zero-value → SQL NULL).
})
require.NoError(t, err)
// Delete with an arbitrary cutoff. The run and its step should
// survive because NULL > cutoff evaluates to NULL, not TRUE.
deletedRows, err := store.DeleteChatDebugDataAfterMessageID(ctx, database.DeleteChatDebugDataAfterMessageIDParams{
ChatID: chat.ID,
MessageID: 1,
})
require.NoError(t, err)
require.EqualValues(t, 0, deletedRows, "rows with NULL message IDs must not be deleted")
// Verify run still exists.
remaining, err := store.GetChatDebugRunByID(ctx, nullMsgRun.ID)
require.NoError(t, err)
require.Equal(t, nullMsgRun.ID, remaining.ID)
// Verify step still exists.
remainingSteps, err := store.GetChatDebugStepsByRunID(ctx, nullMsgRun.ID)
require.NoError(t, err)
require.Len(t, remainingSteps, 1)
require.Equal(t, nullMsgStep.ID, remainingSteps[0].ID)
}
func TestChatHasUnread(t *testing.T) {
t.Parallel()
-662
View File
@@ -2900,583 +2900,6 @@ func (q *sqlQuerier) UpsertBoundaryUsageStats(ctx context.Context, arg UpsertBou
return new_period, err
}
const deleteChatDebugDataAfterMessageID = `-- name: DeleteChatDebugDataAfterMessageID :execrows
WITH affected_runs AS (
SELECT DISTINCT run.id
FROM chat_debug_runs run
WHERE run.chat_id = $1::uuid
AND (
run.history_tip_message_id > $2::bigint
OR run.trigger_message_id > $2::bigint
)
UNION
SELECT DISTINCT step.run_id AS id
FROM chat_debug_steps step
WHERE step.chat_id = $1::uuid
AND (
step.assistant_message_id > $2::bigint
OR step.history_tip_message_id > $2::bigint
)
)
DELETE FROM chat_debug_runs
WHERE chat_id = $1::uuid
AND id IN (SELECT id FROM affected_runs)
`
type DeleteChatDebugDataAfterMessageIDParams struct {
ChatID uuid.UUID `db:"chat_id" json:"chat_id"`
MessageID int64 `db:"message_id" json:"message_id"`
}
func (q *sqlQuerier) DeleteChatDebugDataAfterMessageID(ctx context.Context, arg DeleteChatDebugDataAfterMessageIDParams) (int64, error) {
result, err := q.db.ExecContext(ctx, deleteChatDebugDataAfterMessageID, arg.ChatID, arg.MessageID)
if err != nil {
return 0, err
}
return result.RowsAffected()
}
const deleteChatDebugDataByChatID = `-- name: DeleteChatDebugDataByChatID :execrows
DELETE FROM chat_debug_runs
WHERE chat_id = $1::uuid
`
func (q *sqlQuerier) DeleteChatDebugDataByChatID(ctx context.Context, chatID uuid.UUID) (int64, error) {
result, err := q.db.ExecContext(ctx, deleteChatDebugDataByChatID, chatID)
if err != nil {
return 0, err
}
return result.RowsAffected()
}
const finalizeStaleChatDebugRows = `-- name: FinalizeStaleChatDebugRows :one
WITH finalized_runs AS (
UPDATE chat_debug_runs
SET
status = 'interrupted',
updated_at = NOW(),
finished_at = NOW()
WHERE updated_at < $1::timestamptz
AND finished_at IS NULL
AND status NOT IN ('completed', 'error', 'interrupted')
RETURNING id
), finalized_steps AS (
UPDATE chat_debug_steps
SET
status = 'interrupted',
updated_at = NOW(),
finished_at = NOW()
WHERE (
updated_at < $1::timestamptz
OR run_id IN (SELECT id FROM finalized_runs)
)
AND finished_at IS NULL
AND status NOT IN ('completed', 'error', 'interrupted')
RETURNING 1
)
SELECT
(SELECT COUNT(*) FROM finalized_runs)::bigint AS runs_finalized,
(SELECT COUNT(*) FROM finalized_steps)::bigint AS steps_finalized
`
type FinalizeStaleChatDebugRowsRow struct {
RunsFinalized int64 `db:"runs_finalized" json:"runs_finalized"`
StepsFinalized int64 `db:"steps_finalized" json:"steps_finalized"`
}
// Marks orphaned in-progress rows as interrupted so they do not stay
// in a non-terminal state forever. The NOT IN list must match the
// terminal statuses defined by ChatDebugStatus in codersdk/chats.go.
//
// The steps CTE also catches steps whose parent run was just finalized
// (via run_id IN), because PostgreSQL data-modifying CTEs share the
// same snapshot and cannot see each other's row updates. Without this,
// a step with a recent updated_at would survive its run's finalization
// and remain in 'in_progress' state permanently.
func (q *sqlQuerier) FinalizeStaleChatDebugRows(ctx context.Context, updatedBefore time.Time) (FinalizeStaleChatDebugRowsRow, error) {
row := q.db.QueryRowContext(ctx, finalizeStaleChatDebugRows, updatedBefore)
var i FinalizeStaleChatDebugRowsRow
err := row.Scan(&i.RunsFinalized, &i.StepsFinalized)
return i, err
}
const getChatDebugRunByID = `-- name: GetChatDebugRunByID :one
SELECT id, chat_id, root_chat_id, parent_chat_id, model_config_id, trigger_message_id, history_tip_message_id, kind, status, provider, model, summary, started_at, updated_at, finished_at
FROM chat_debug_runs
WHERE id = $1::uuid
`
func (q *sqlQuerier) GetChatDebugRunByID(ctx context.Context, id uuid.UUID) (ChatDebugRun, error) {
row := q.db.QueryRowContext(ctx, getChatDebugRunByID, id)
var i ChatDebugRun
err := row.Scan(
&i.ID,
&i.ChatID,
&i.RootChatID,
&i.ParentChatID,
&i.ModelConfigID,
&i.TriggerMessageID,
&i.HistoryTipMessageID,
&i.Kind,
&i.Status,
&i.Provider,
&i.Model,
&i.Summary,
&i.StartedAt,
&i.UpdatedAt,
&i.FinishedAt,
)
return i, err
}
const getChatDebugRunsByChatID = `-- name: GetChatDebugRunsByChatID :many
SELECT id, chat_id, root_chat_id, parent_chat_id, model_config_id, trigger_message_id, history_tip_message_id, kind, status, provider, model, summary, started_at, updated_at, finished_at
FROM chat_debug_runs
WHERE chat_id = $1::uuid
ORDER BY started_at DESC, id DESC
LIMIT $2::int
`
type GetChatDebugRunsByChatIDParams struct {
ChatID uuid.UUID `db:"chat_id" json:"chat_id"`
LimitVal int32 `db:"limit_val" json:"limit_val"`
}
// Returns the most recent debug runs for a chat, ordered newest-first.
// Callers must supply an explicit limit to avoid unbounded result sets.
func (q *sqlQuerier) GetChatDebugRunsByChatID(ctx context.Context, arg GetChatDebugRunsByChatIDParams) ([]ChatDebugRun, error) {
rows, err := q.db.QueryContext(ctx, getChatDebugRunsByChatID, arg.ChatID, arg.LimitVal)
if err != nil {
return nil, err
}
defer rows.Close()
var items []ChatDebugRun
for rows.Next() {
var i ChatDebugRun
if err := rows.Scan(
&i.ID,
&i.ChatID,
&i.RootChatID,
&i.ParentChatID,
&i.ModelConfigID,
&i.TriggerMessageID,
&i.HistoryTipMessageID,
&i.Kind,
&i.Status,
&i.Provider,
&i.Model,
&i.Summary,
&i.StartedAt,
&i.UpdatedAt,
&i.FinishedAt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const getChatDebugStepsByRunID = `-- name: GetChatDebugStepsByRunID :many
SELECT id, run_id, chat_id, step_number, operation, status, history_tip_message_id, assistant_message_id, normalized_request, normalized_response, usage, attempts, error, metadata, started_at, updated_at, finished_at
FROM chat_debug_steps
WHERE run_id = $1::uuid
ORDER BY step_number ASC, started_at ASC
`
func (q *sqlQuerier) GetChatDebugStepsByRunID(ctx context.Context, runID uuid.UUID) ([]ChatDebugStep, error) {
rows, err := q.db.QueryContext(ctx, getChatDebugStepsByRunID, runID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []ChatDebugStep
for rows.Next() {
var i ChatDebugStep
if err := rows.Scan(
&i.ID,
&i.RunID,
&i.ChatID,
&i.StepNumber,
&i.Operation,
&i.Status,
&i.HistoryTipMessageID,
&i.AssistantMessageID,
&i.NormalizedRequest,
&i.NormalizedResponse,
&i.Usage,
&i.Attempts,
&i.Error,
&i.Metadata,
&i.StartedAt,
&i.UpdatedAt,
&i.FinishedAt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const insertChatDebugRun = `-- name: InsertChatDebugRun :one
INSERT INTO chat_debug_runs (
chat_id,
root_chat_id,
parent_chat_id,
model_config_id,
trigger_message_id,
history_tip_message_id,
kind,
status,
provider,
model,
summary,
started_at,
updated_at,
finished_at
)
VALUES (
$1::uuid,
$2::uuid,
$3::uuid,
$4::uuid,
$5::bigint,
$6::bigint,
$7::text,
$8::text,
$9::text,
$10::text,
COALESCE($11::jsonb, '{}'::jsonb),
COALESCE($12::timestamptz, NOW()),
COALESCE($13::timestamptz, NOW()),
$14::timestamptz
)
RETURNING id, chat_id, root_chat_id, parent_chat_id, model_config_id, trigger_message_id, history_tip_message_id, kind, status, provider, model, summary, started_at, updated_at, finished_at
`
type InsertChatDebugRunParams struct {
ChatID uuid.UUID `db:"chat_id" json:"chat_id"`
RootChatID uuid.NullUUID `db:"root_chat_id" json:"root_chat_id"`
ParentChatID uuid.NullUUID `db:"parent_chat_id" json:"parent_chat_id"`
ModelConfigID uuid.NullUUID `db:"model_config_id" json:"model_config_id"`
TriggerMessageID sql.NullInt64 `db:"trigger_message_id" json:"trigger_message_id"`
HistoryTipMessageID sql.NullInt64 `db:"history_tip_message_id" json:"history_tip_message_id"`
Kind string `db:"kind" json:"kind"`
Status string `db:"status" json:"status"`
Provider sql.NullString `db:"provider" json:"provider"`
Model sql.NullString `db:"model" json:"model"`
Summary pqtype.NullRawMessage `db:"summary" json:"summary"`
StartedAt sql.NullTime `db:"started_at" json:"started_at"`
UpdatedAt sql.NullTime `db:"updated_at" json:"updated_at"`
FinishedAt sql.NullTime `db:"finished_at" json:"finished_at"`
}
func (q *sqlQuerier) InsertChatDebugRun(ctx context.Context, arg InsertChatDebugRunParams) (ChatDebugRun, error) {
row := q.db.QueryRowContext(ctx, insertChatDebugRun,
arg.ChatID,
arg.RootChatID,
arg.ParentChatID,
arg.ModelConfigID,
arg.TriggerMessageID,
arg.HistoryTipMessageID,
arg.Kind,
arg.Status,
arg.Provider,
arg.Model,
arg.Summary,
arg.StartedAt,
arg.UpdatedAt,
arg.FinishedAt,
)
var i ChatDebugRun
err := row.Scan(
&i.ID,
&i.ChatID,
&i.RootChatID,
&i.ParentChatID,
&i.ModelConfigID,
&i.TriggerMessageID,
&i.HistoryTipMessageID,
&i.Kind,
&i.Status,
&i.Provider,
&i.Model,
&i.Summary,
&i.StartedAt,
&i.UpdatedAt,
&i.FinishedAt,
)
return i, err
}
const insertChatDebugStep = `-- name: InsertChatDebugStep :one
INSERT INTO chat_debug_steps (
run_id,
chat_id,
step_number,
operation,
status,
history_tip_message_id,
assistant_message_id,
normalized_request,
normalized_response,
usage,
attempts,
error,
metadata,
started_at,
updated_at,
finished_at
)
SELECT
$1::uuid,
run.chat_id,
$2::int,
$3::text,
$4::text,
$5::bigint,
$6::bigint,
COALESCE($7::jsonb, '{}'::jsonb),
$8::jsonb,
$9::jsonb,
COALESCE($10::jsonb, '[]'::jsonb),
$11::jsonb,
COALESCE($12::jsonb, '{}'::jsonb),
COALESCE($13::timestamptz, NOW()),
COALESCE($14::timestamptz, NOW()),
$15::timestamptz
FROM chat_debug_runs run
WHERE run.id = $1::uuid
AND run.chat_id = $16::uuid
RETURNING id, run_id, chat_id, step_number, operation, status, history_tip_message_id, assistant_message_id, normalized_request, normalized_response, usage, attempts, error, metadata, started_at, updated_at, finished_at
`
type InsertChatDebugStepParams struct {
RunID uuid.UUID `db:"run_id" json:"run_id"`
StepNumber int32 `db:"step_number" json:"step_number"`
Operation string `db:"operation" json:"operation"`
Status string `db:"status" json:"status"`
HistoryTipMessageID sql.NullInt64 `db:"history_tip_message_id" json:"history_tip_message_id"`
AssistantMessageID sql.NullInt64 `db:"assistant_message_id" json:"assistant_message_id"`
NormalizedRequest pqtype.NullRawMessage `db:"normalized_request" json:"normalized_request"`
NormalizedResponse pqtype.NullRawMessage `db:"normalized_response" json:"normalized_response"`
Usage pqtype.NullRawMessage `db:"usage" json:"usage"`
Attempts pqtype.NullRawMessage `db:"attempts" json:"attempts"`
Error pqtype.NullRawMessage `db:"error" json:"error"`
Metadata pqtype.NullRawMessage `db:"metadata" json:"metadata"`
StartedAt sql.NullTime `db:"started_at" json:"started_at"`
UpdatedAt sql.NullTime `db:"updated_at" json:"updated_at"`
FinishedAt sql.NullTime `db:"finished_at" json:"finished_at"`
ChatID uuid.UUID `db:"chat_id" json:"chat_id"`
}
func (q *sqlQuerier) InsertChatDebugStep(ctx context.Context, arg InsertChatDebugStepParams) (ChatDebugStep, error) {
row := q.db.QueryRowContext(ctx, insertChatDebugStep,
arg.RunID,
arg.StepNumber,
arg.Operation,
arg.Status,
arg.HistoryTipMessageID,
arg.AssistantMessageID,
arg.NormalizedRequest,
arg.NormalizedResponse,
arg.Usage,
arg.Attempts,
arg.Error,
arg.Metadata,
arg.StartedAt,
arg.UpdatedAt,
arg.FinishedAt,
arg.ChatID,
)
var i ChatDebugStep
err := row.Scan(
&i.ID,
&i.RunID,
&i.ChatID,
&i.StepNumber,
&i.Operation,
&i.Status,
&i.HistoryTipMessageID,
&i.AssistantMessageID,
&i.NormalizedRequest,
&i.NormalizedResponse,
&i.Usage,
&i.Attempts,
&i.Error,
&i.Metadata,
&i.StartedAt,
&i.UpdatedAt,
&i.FinishedAt,
)
return i, err
}
const updateChatDebugRun = `-- name: UpdateChatDebugRun :one
UPDATE chat_debug_runs
SET
root_chat_id = COALESCE($1::uuid, root_chat_id),
parent_chat_id = COALESCE($2::uuid, parent_chat_id),
model_config_id = COALESCE($3::uuid, model_config_id),
trigger_message_id = COALESCE($4::bigint, trigger_message_id),
history_tip_message_id = COALESCE($5::bigint, history_tip_message_id),
status = COALESCE($6::text, status),
provider = COALESCE($7::text, provider),
model = COALESCE($8::text, model),
summary = COALESCE($9::jsonb, summary),
finished_at = COALESCE($10::timestamptz, finished_at),
updated_at = NOW()
WHERE id = $11::uuid
AND chat_id = $12::uuid
RETURNING id, chat_id, root_chat_id, parent_chat_id, model_config_id, trigger_message_id, history_tip_message_id, kind, status, provider, model, summary, started_at, updated_at, finished_at
`
type UpdateChatDebugRunParams struct {
RootChatID uuid.NullUUID `db:"root_chat_id" json:"root_chat_id"`
ParentChatID uuid.NullUUID `db:"parent_chat_id" json:"parent_chat_id"`
ModelConfigID uuid.NullUUID `db:"model_config_id" json:"model_config_id"`
TriggerMessageID sql.NullInt64 `db:"trigger_message_id" json:"trigger_message_id"`
HistoryTipMessageID sql.NullInt64 `db:"history_tip_message_id" json:"history_tip_message_id"`
Status sql.NullString `db:"status" json:"status"`
Provider sql.NullString `db:"provider" json:"provider"`
Model sql.NullString `db:"model" json:"model"`
Summary pqtype.NullRawMessage `db:"summary" json:"summary"`
FinishedAt sql.NullTime `db:"finished_at" json:"finished_at"`
ID uuid.UUID `db:"id" json:"id"`
ChatID uuid.UUID `db:"chat_id" json:"chat_id"`
}
// Uses COALESCE so that passing NULL from Go means "keep the
// existing value." This is intentional: debug rows follow a
// write-once-finalize pattern where fields are set at creation
// or finalization and never cleared back to NULL.
func (q *sqlQuerier) UpdateChatDebugRun(ctx context.Context, arg UpdateChatDebugRunParams) (ChatDebugRun, error) {
row := q.db.QueryRowContext(ctx, updateChatDebugRun,
arg.RootChatID,
arg.ParentChatID,
arg.ModelConfigID,
arg.TriggerMessageID,
arg.HistoryTipMessageID,
arg.Status,
arg.Provider,
arg.Model,
arg.Summary,
arg.FinishedAt,
arg.ID,
arg.ChatID,
)
var i ChatDebugRun
err := row.Scan(
&i.ID,
&i.ChatID,
&i.RootChatID,
&i.ParentChatID,
&i.ModelConfigID,
&i.TriggerMessageID,
&i.HistoryTipMessageID,
&i.Kind,
&i.Status,
&i.Provider,
&i.Model,
&i.Summary,
&i.StartedAt,
&i.UpdatedAt,
&i.FinishedAt,
)
return i, err
}
const updateChatDebugStep = `-- name: UpdateChatDebugStep :one
UPDATE chat_debug_steps
SET
status = COALESCE($1::text, status),
history_tip_message_id = COALESCE($2::bigint, history_tip_message_id),
assistant_message_id = COALESCE($3::bigint, assistant_message_id),
normalized_request = COALESCE($4::jsonb, normalized_request),
normalized_response = COALESCE($5::jsonb, normalized_response),
usage = COALESCE($6::jsonb, usage),
attempts = COALESCE($7::jsonb, attempts),
error = COALESCE($8::jsonb, error),
metadata = COALESCE($9::jsonb, metadata),
finished_at = COALESCE($10::timestamptz, finished_at),
updated_at = NOW()
WHERE id = $11::uuid
AND chat_id = $12::uuid
RETURNING id, run_id, chat_id, step_number, operation, status, history_tip_message_id, assistant_message_id, normalized_request, normalized_response, usage, attempts, error, metadata, started_at, updated_at, finished_at
`
type UpdateChatDebugStepParams struct {
Status sql.NullString `db:"status" json:"status"`
HistoryTipMessageID sql.NullInt64 `db:"history_tip_message_id" json:"history_tip_message_id"`
AssistantMessageID sql.NullInt64 `db:"assistant_message_id" json:"assistant_message_id"`
NormalizedRequest pqtype.NullRawMessage `db:"normalized_request" json:"normalized_request"`
NormalizedResponse pqtype.NullRawMessage `db:"normalized_response" json:"normalized_response"`
Usage pqtype.NullRawMessage `db:"usage" json:"usage"`
Attempts pqtype.NullRawMessage `db:"attempts" json:"attempts"`
Error pqtype.NullRawMessage `db:"error" json:"error"`
Metadata pqtype.NullRawMessage `db:"metadata" json:"metadata"`
FinishedAt sql.NullTime `db:"finished_at" json:"finished_at"`
ID uuid.UUID `db:"id" json:"id"`
ChatID uuid.UUID `db:"chat_id" json:"chat_id"`
}
// Uses COALESCE so that passing NULL from Go means "keep the
// existing value." This is intentional: debug rows follow a
// write-once-finalize pattern where fields are set at creation
// or finalization and never cleared back to NULL.
func (q *sqlQuerier) UpdateChatDebugStep(ctx context.Context, arg UpdateChatDebugStepParams) (ChatDebugStep, error) {
row := q.db.QueryRowContext(ctx, updateChatDebugStep,
arg.Status,
arg.HistoryTipMessageID,
arg.AssistantMessageID,
arg.NormalizedRequest,
arg.NormalizedResponse,
arg.Usage,
arg.Attempts,
arg.Error,
arg.Metadata,
arg.FinishedAt,
arg.ID,
arg.ChatID,
)
var i ChatDebugStep
err := row.Scan(
&i.ID,
&i.RunID,
&i.ChatID,
&i.StepNumber,
&i.Operation,
&i.Status,
&i.HistoryTipMessageID,
&i.AssistantMessageID,
&i.NormalizedRequest,
&i.NormalizedResponse,
&i.Usage,
&i.Attempts,
&i.Error,
&i.Metadata,
&i.StartedAt,
&i.UpdatedAt,
&i.FinishedAt,
)
return i, err
}
const deleteOldChatFiles = `-- name: DeleteOldChatFiles :execrows
WITH kept_file_ids AS (
-- NOTE: This uses updated_at as a proxy for archive time
@@ -19722,21 +19145,6 @@ func (q *sqlQuerier) GetApplicationName(ctx context.Context) (string, error) {
return value, err
}
const getChatDebugLoggingAllowUsers = `-- name: GetChatDebugLoggingAllowUsers :one
SELECT
COALESCE((SELECT value = 'true' FROM site_configs WHERE key = 'agents_chat_debug_logging_allow_users'), false) :: boolean AS allow_users
`
// GetChatDebugLoggingAllowUsers returns the runtime admin setting that
// allows users to opt into chat debug logging when the deployment does
// not already force debug logging on globally.
func (q *sqlQuerier) GetChatDebugLoggingAllowUsers(ctx context.Context) (bool, error) {
row := q.db.QueryRowContext(ctx, getChatDebugLoggingAllowUsers)
var allow_users bool
err := row.Scan(&allow_users)
return allow_users, err
}
const getChatDesktopEnabled = `-- name: GetChatDesktopEnabled :one
SELECT
COALESCE((SELECT value = 'true' FROM site_configs WHERE key = 'agents_desktop_enabled'), false) :: boolean AS enable_desktop
@@ -20048,30 +19456,6 @@ func (q *sqlQuerier) UpsertApplicationName(ctx context.Context, value string) er
return err
}
const upsertChatDebugLoggingAllowUsers = `-- name: UpsertChatDebugLoggingAllowUsers :exec
INSERT INTO site_configs (key, value)
VALUES (
'agents_chat_debug_logging_allow_users',
CASE
WHEN $1::bool THEN 'true'
ELSE 'false'
END
)
ON CONFLICT (key) DO UPDATE
SET value = CASE
WHEN $1::bool THEN 'true'
ELSE 'false'
END
WHERE site_configs.key = 'agents_chat_debug_logging_allow_users'
`
// UpsertChatDebugLoggingAllowUsers updates the runtime admin setting that
// allows users to opt into chat debug logging.
func (q *sqlQuerier) UpsertChatDebugLoggingAllowUsers(ctx context.Context, allowUsers bool) error {
_, err := q.db.ExecContext(ctx, upsertChatDebugLoggingAllowUsers, allowUsers)
return err
}
const upsertChatDesktopEnabled = `-- name: UpsertChatDesktopEnabled :exec
INSERT INTO site_configs (key, value)
VALUES (
@@ -24326,23 +23710,6 @@ func (q *sqlQuerier) GetUserChatCustomPrompt(ctx context.Context, userID uuid.UU
return chat_custom_prompt, err
}
const getUserChatDebugLoggingEnabled = `-- name: GetUserChatDebugLoggingEnabled :one
SELECT
COALESCE((
SELECT value = 'true'
FROM user_configs
WHERE user_id = $1
AND key = 'chat_debug_logging_enabled'
), false) :: boolean AS debug_logging_enabled
`
func (q *sqlQuerier) GetUserChatDebugLoggingEnabled(ctx context.Context, userID uuid.UUID) (bool, error) {
row := q.db.QueryRowContext(ctx, getUserChatDebugLoggingEnabled, userID)
var debug_logging_enabled bool
err := row.Scan(&debug_logging_enabled)
return debug_logging_enabled, err
}
const getUserCount = `-- name: GetUserCount :one
SELECT
COUNT(*)
@@ -25337,35 +24704,6 @@ func (q *sqlQuerier) UpdateUserThemePreference(ctx context.Context, arg UpdateUs
return i, err
}
const upsertUserChatDebugLoggingEnabled = `-- name: UpsertUserChatDebugLoggingEnabled :exec
INSERT INTO user_configs (user_id, key, value)
VALUES (
$1,
'chat_debug_logging_enabled',
CASE
WHEN $2::bool THEN 'true'
ELSE 'false'
END
)
ON CONFLICT ON CONSTRAINT user_configs_pkey
DO UPDATE SET value = CASE
WHEN $2::bool THEN 'true'
ELSE 'false'
END
WHERE user_configs.user_id = $1
AND user_configs.key = 'chat_debug_logging_enabled'
`
type UpsertUserChatDebugLoggingEnabledParams struct {
UserID uuid.UUID `db:"user_id" json:"user_id"`
DebugLoggingEnabled bool `db:"debug_logging_enabled" json:"debug_logging_enabled"`
}
func (q *sqlQuerier) UpsertUserChatDebugLoggingEnabled(ctx context.Context, arg UpsertUserChatDebugLoggingEnabledParams) error {
_, err := q.db.ExecContext(ctx, upsertUserChatDebugLoggingEnabled, arg.UserID, arg.DebugLoggingEnabled)
return err
}
const validateUserIDs = `-- name: ValidateUserIDs :one
WITH input AS (
SELECT
-205
View File
@@ -1,205 +0,0 @@
-- name: InsertChatDebugRun :one
INSERT INTO chat_debug_runs (
chat_id,
root_chat_id,
parent_chat_id,
model_config_id,
trigger_message_id,
history_tip_message_id,
kind,
status,
provider,
model,
summary,
started_at,
updated_at,
finished_at
)
VALUES (
@chat_id::uuid,
sqlc.narg('root_chat_id')::uuid,
sqlc.narg('parent_chat_id')::uuid,
sqlc.narg('model_config_id')::uuid,
sqlc.narg('trigger_message_id')::bigint,
sqlc.narg('history_tip_message_id')::bigint,
@kind::text,
@status::text,
sqlc.narg('provider')::text,
sqlc.narg('model')::text,
COALESCE(sqlc.narg('summary')::jsonb, '{}'::jsonb),
COALESCE(sqlc.narg('started_at')::timestamptz, NOW()),
COALESCE(sqlc.narg('updated_at')::timestamptz, NOW()),
sqlc.narg('finished_at')::timestamptz
)
RETURNING *;
-- name: UpdateChatDebugRun :one
-- Uses COALESCE so that passing NULL from Go means "keep the
-- existing value." This is intentional: debug rows follow a
-- write-once-finalize pattern where fields are set at creation
-- or finalization and never cleared back to NULL.
UPDATE chat_debug_runs
SET
root_chat_id = COALESCE(sqlc.narg('root_chat_id')::uuid, root_chat_id),
parent_chat_id = COALESCE(sqlc.narg('parent_chat_id')::uuid, parent_chat_id),
model_config_id = COALESCE(sqlc.narg('model_config_id')::uuid, model_config_id),
trigger_message_id = COALESCE(sqlc.narg('trigger_message_id')::bigint, trigger_message_id),
history_tip_message_id = COALESCE(sqlc.narg('history_tip_message_id')::bigint, history_tip_message_id),
status = COALESCE(sqlc.narg('status')::text, status),
provider = COALESCE(sqlc.narg('provider')::text, provider),
model = COALESCE(sqlc.narg('model')::text, model),
summary = COALESCE(sqlc.narg('summary')::jsonb, summary),
finished_at = COALESCE(sqlc.narg('finished_at')::timestamptz, finished_at),
updated_at = NOW()
WHERE id = @id::uuid
AND chat_id = @chat_id::uuid
RETURNING *;
-- name: InsertChatDebugStep :one
INSERT INTO chat_debug_steps (
run_id,
chat_id,
step_number,
operation,
status,
history_tip_message_id,
assistant_message_id,
normalized_request,
normalized_response,
usage,
attempts,
error,
metadata,
started_at,
updated_at,
finished_at
)
SELECT
@run_id::uuid,
run.chat_id,
@step_number::int,
@operation::text,
@status::text,
sqlc.narg('history_tip_message_id')::bigint,
sqlc.narg('assistant_message_id')::bigint,
COALESCE(sqlc.narg('normalized_request')::jsonb, '{}'::jsonb),
sqlc.narg('normalized_response')::jsonb,
sqlc.narg('usage')::jsonb,
COALESCE(sqlc.narg('attempts')::jsonb, '[]'::jsonb),
sqlc.narg('error')::jsonb,
COALESCE(sqlc.narg('metadata')::jsonb, '{}'::jsonb),
COALESCE(sqlc.narg('started_at')::timestamptz, NOW()),
COALESCE(sqlc.narg('updated_at')::timestamptz, NOW()),
sqlc.narg('finished_at')::timestamptz
FROM chat_debug_runs run
WHERE run.id = @run_id::uuid
AND run.chat_id = @chat_id::uuid
RETURNING *;
-- name: UpdateChatDebugStep :one
-- Uses COALESCE so that passing NULL from Go means "keep the
-- existing value." This is intentional: debug rows follow a
-- write-once-finalize pattern where fields are set at creation
-- or finalization and never cleared back to NULL.
UPDATE chat_debug_steps
SET
status = COALESCE(sqlc.narg('status')::text, status),
history_tip_message_id = COALESCE(sqlc.narg('history_tip_message_id')::bigint, history_tip_message_id),
assistant_message_id = COALESCE(sqlc.narg('assistant_message_id')::bigint, assistant_message_id),
normalized_request = COALESCE(sqlc.narg('normalized_request')::jsonb, normalized_request),
normalized_response = COALESCE(sqlc.narg('normalized_response')::jsonb, normalized_response),
usage = COALESCE(sqlc.narg('usage')::jsonb, usage),
attempts = COALESCE(sqlc.narg('attempts')::jsonb, attempts),
error = COALESCE(sqlc.narg('error')::jsonb, error),
metadata = COALESCE(sqlc.narg('metadata')::jsonb, metadata),
finished_at = COALESCE(sqlc.narg('finished_at')::timestamptz, finished_at),
updated_at = NOW()
WHERE id = @id::uuid
AND chat_id = @chat_id::uuid
RETURNING *;
-- name: GetChatDebugRunsByChatID :many
-- Returns the most recent debug runs for a chat, ordered newest-first.
-- Callers must supply an explicit limit to avoid unbounded result sets.
SELECT *
FROM chat_debug_runs
WHERE chat_id = @chat_id::uuid
ORDER BY started_at DESC, id DESC
LIMIT @limit_val::int;
-- name: GetChatDebugRunByID :one
SELECT *
FROM chat_debug_runs
WHERE id = @id::uuid;
-- name: GetChatDebugStepsByRunID :many
SELECT *
FROM chat_debug_steps
WHERE run_id = @run_id::uuid
ORDER BY step_number ASC, started_at ASC;
-- name: DeleteChatDebugDataByChatID :execrows
DELETE FROM chat_debug_runs
WHERE chat_id = @chat_id::uuid;
-- name: DeleteChatDebugDataAfterMessageID :execrows
WITH affected_runs AS (
SELECT DISTINCT run.id
FROM chat_debug_runs run
WHERE run.chat_id = @chat_id::uuid
AND (
run.history_tip_message_id > @message_id::bigint
OR run.trigger_message_id > @message_id::bigint
)
UNION
SELECT DISTINCT step.run_id AS id
FROM chat_debug_steps step
WHERE step.chat_id = @chat_id::uuid
AND (
step.assistant_message_id > @message_id::bigint
OR step.history_tip_message_id > @message_id::bigint
)
)
DELETE FROM chat_debug_runs
WHERE chat_id = @chat_id::uuid
AND id IN (SELECT id FROM affected_runs);
-- name: FinalizeStaleChatDebugRows :one
-- Marks orphaned in-progress rows as interrupted so they do not stay
-- in a non-terminal state forever. The NOT IN list must match the
-- terminal statuses defined by ChatDebugStatus in codersdk/chats.go.
--
-- The steps CTE also catches steps whose parent run was just finalized
-- (via run_id IN), because PostgreSQL data-modifying CTEs share the
-- same snapshot and cannot see each other's row updates. Without this,
-- a step with a recent updated_at would survive its run's finalization
-- and remain in 'in_progress' state permanently.
WITH finalized_runs AS (
UPDATE chat_debug_runs
SET
status = 'interrupted',
updated_at = NOW(),
finished_at = NOW()
WHERE updated_at < @updated_before::timestamptz
AND finished_at IS NULL
AND status NOT IN ('completed', 'error', 'interrupted')
RETURNING id
), finalized_steps AS (
UPDATE chat_debug_steps
SET
status = 'interrupted',
updated_at = NOW(),
finished_at = NOW()
WHERE (
updated_at < @updated_before::timestamptz
OR run_id IN (SELECT id FROM finalized_runs)
)
AND finished_at IS NULL
AND status NOT IN ('completed', 'error', 'interrupted')
RETURNING 1
)
SELECT
(SELECT COUNT(*) FROM finalized_runs)::bigint AS runs_finalized,
(SELECT COUNT(*) FROM finalized_steps)::bigint AS steps_finalized;
-25
View File
@@ -179,31 +179,6 @@ SET value = CASE
END
WHERE site_configs.key = 'agents_desktop_enabled';
-- GetChatDebugLoggingAllowUsers returns the runtime admin setting that
-- allows users to opt into chat debug logging when the deployment does
-- not already force debug logging on globally.
-- name: GetChatDebugLoggingAllowUsers :one
SELECT
COALESCE((SELECT value = 'true' FROM site_configs WHERE key = 'agents_chat_debug_logging_allow_users'), false) :: boolean AS allow_users;
-- UpsertChatDebugLoggingAllowUsers updates the runtime admin setting that
-- allows users to opt into chat debug logging.
-- name: UpsertChatDebugLoggingAllowUsers :exec
INSERT INTO site_configs (key, value)
VALUES (
'agents_chat_debug_logging_allow_users',
CASE
WHEN sqlc.arg(allow_users)::bool THEN 'true'
ELSE 'false'
END
)
ON CONFLICT (key) DO UPDATE
SET value = CASE
WHEN sqlc.arg(allow_users)::bool THEN 'true'
ELSE 'false'
END
WHERE site_configs.key = 'agents_chat_debug_logging_allow_users';
-- GetChatTemplateAllowlist returns the JSON-encoded template allowlist.
-- Returns an empty string when no allowlist has been configured (all templates allowed).
-- name: GetChatTemplateAllowlist :one
-27
View File
@@ -213,33 +213,6 @@ RETURNING *;
-- name: DeleteUserChatCompactionThreshold :exec
DELETE FROM user_configs WHERE user_id = @user_id AND key = @key;
-- name: GetUserChatDebugLoggingEnabled :one
SELECT
COALESCE((
SELECT value = 'true'
FROM user_configs
WHERE user_id = @user_id
AND key = 'chat_debug_logging_enabled'
), false) :: boolean AS debug_logging_enabled;
-- name: UpsertUserChatDebugLoggingEnabled :exec
INSERT INTO user_configs (user_id, key, value)
VALUES (
@user_id,
'chat_debug_logging_enabled',
CASE
WHEN sqlc.arg(debug_logging_enabled)::bool THEN 'true'
ELSE 'false'
END
)
ON CONFLICT ON CONSTRAINT user_configs_pkey
DO UPDATE SET value = CASE
WHEN sqlc.arg(debug_logging_enabled)::bool THEN 'true'
ELSE 'false'
END
WHERE user_configs.user_id = @user_id
AND user_configs.key = 'chat_debug_logging_enabled';
-- name: GetUserTaskNotificationAlertDismissed :one
SELECT
value::boolean as task_notification_alert_dismissed
-4
View File
@@ -15,8 +15,6 @@ const (
UniqueAPIKeysPkey UniqueConstraint = "api_keys_pkey" // ALTER TABLE ONLY api_keys ADD CONSTRAINT api_keys_pkey PRIMARY KEY (id);
UniqueAuditLogsPkey UniqueConstraint = "audit_logs_pkey" // ALTER TABLE ONLY audit_logs ADD CONSTRAINT audit_logs_pkey PRIMARY KEY (id);
UniqueBoundaryUsageStatsPkey UniqueConstraint = "boundary_usage_stats_pkey" // ALTER TABLE ONLY boundary_usage_stats ADD CONSTRAINT boundary_usage_stats_pkey PRIMARY KEY (replica_id);
UniqueChatDebugRunsPkey UniqueConstraint = "chat_debug_runs_pkey" // ALTER TABLE ONLY chat_debug_runs ADD CONSTRAINT chat_debug_runs_pkey PRIMARY KEY (id);
UniqueChatDebugStepsPkey UniqueConstraint = "chat_debug_steps_pkey" // ALTER TABLE ONLY chat_debug_steps ADD CONSTRAINT chat_debug_steps_pkey PRIMARY KEY (id);
UniqueChatDiffStatusesPkey UniqueConstraint = "chat_diff_statuses_pkey" // ALTER TABLE ONLY chat_diff_statuses ADD CONSTRAINT chat_diff_statuses_pkey PRIMARY KEY (chat_id);
UniqueChatFileLinksChatIDFileIDKey UniqueConstraint = "chat_file_links_chat_id_file_id_key" // ALTER TABLE ONLY chat_file_links ADD CONSTRAINT chat_file_links_chat_id_file_id_key UNIQUE (chat_id, file_id);
UniqueChatFilesPkey UniqueConstraint = "chat_files_pkey" // ALTER TABLE ONLY chat_files ADD CONSTRAINT chat_files_pkey PRIMARY KEY (id);
@@ -130,8 +128,6 @@ const (
UniqueWorkspaceResourcesPkey UniqueConstraint = "workspace_resources_pkey" // ALTER TABLE ONLY workspace_resources ADD CONSTRAINT workspace_resources_pkey PRIMARY KEY (id);
UniqueWorkspacesPkey UniqueConstraint = "workspaces_pkey" // ALTER TABLE ONLY workspaces ADD CONSTRAINT workspaces_pkey PRIMARY KEY (id);
UniqueIndexAPIKeyName UniqueConstraint = "idx_api_key_name" // CREATE UNIQUE INDEX idx_api_key_name ON api_keys USING btree (user_id, token_name) WHERE (login_type = 'token'::login_type);
UniqueIndexChatDebugRunsIDChat UniqueConstraint = "idx_chat_debug_runs_id_chat" // CREATE UNIQUE INDEX idx_chat_debug_runs_id_chat ON chat_debug_runs USING btree (id, chat_id);
UniqueIndexChatDebugStepsRunStep UniqueConstraint = "idx_chat_debug_steps_run_step" // CREATE UNIQUE INDEX idx_chat_debug_steps_run_step ON chat_debug_steps USING btree (run_id, step_number);
UniqueIndexChatModelConfigsSingleDefault UniqueConstraint = "idx_chat_model_configs_single_default" // CREATE UNIQUE INDEX idx_chat_model_configs_single_default ON chat_model_configs USING btree ((1)) WHERE ((is_default = true) AND (deleted = false));
UniqueIndexConnectionLogsConnectionIDWorkspaceIDAgentName UniqueConstraint = "idx_connection_logs_connection_id_workspace_id_agent_name" // CREATE UNIQUE INDEX idx_connection_logs_connection_id_workspace_id_agent_name ON connection_logs USING btree (connection_id, workspace_id, agent_name);
UniqueIndexCustomRolesNameLowerOrganizationID UniqueConstraint = "idx_custom_roles_name_lower_organization_id" // CREATE UNIQUE INDEX idx_custom_roles_name_lower_organization_id ON custom_roles USING btree (lower(name), COALESCE(organization_id, '00000000-0000-0000-0000-000000000000'::uuid));
-84
View File
@@ -1,84 +0,0 @@
package chatdebug
import (
"context"
"runtime"
"sync"
"github.com/google/uuid"
)
type (
runContextKey struct{}
stepContextKey struct{}
reuseStepKey struct{}
reuseHolder struct {
mu sync.Mutex
handle *stepHandle
}
)
// ContextWithRun stores rc in ctx.
//
// Step counter cleanup is reference-counted per RunID: each live
// RunContext increments a counter and runtime.AddCleanup decrements
// it when the struct is garbage collected. Shared state (step
// counters) is only deleted when the last RunContext for a given
// RunID becomes unreachable, preventing premature cleanup when
// multiple RunContext instances share the same RunID.
func ContextWithRun(ctx context.Context, rc *RunContext) context.Context {
if rc == nil {
panic("chatdebug: nil RunContext")
}
enriched := context.WithValue(ctx, runContextKey{}, rc)
if rc.RunID != uuid.Nil {
trackRunRef(rc.RunID)
runtime.AddCleanup(rc, func(id uuid.UUID) {
releaseRunRef(id)
}, rc.RunID)
}
return enriched
}
// RunFromContext returns the debug run context stored in ctx.
func RunFromContext(ctx context.Context) (*RunContext, bool) {
rc, ok := ctx.Value(runContextKey{}).(*RunContext)
if !ok {
return nil, false
}
return rc, true
}
// ContextWithStep stores sc in ctx.
func ContextWithStep(ctx context.Context, sc *StepContext) context.Context {
if sc == nil {
panic("chatdebug: nil StepContext")
}
return context.WithValue(ctx, stepContextKey{}, sc)
}
// StepFromContext returns the debug step context stored in ctx.
func StepFromContext(ctx context.Context) (*StepContext, bool) {
sc, ok := ctx.Value(stepContextKey{}).(*StepContext)
if !ok {
return nil, false
}
return sc, true
}
// ReuseStep marks ctx so wrapped model calls under it share one debug step.
func ReuseStep(ctx context.Context) context.Context {
if holder, ok := reuseHolderFromContext(ctx); ok {
return context.WithValue(ctx, reuseStepKey{}, holder)
}
return context.WithValue(ctx, reuseStepKey{}, &reuseHolder{})
}
func reuseHolderFromContext(ctx context.Context) (*reuseHolder, bool) {
holder, ok := ctx.Value(reuseStepKey{}).(*reuseHolder)
if !ok {
return nil, false
}
return holder, true
}
@@ -1,124 +0,0 @@
package chatdebug
import (
"context"
"runtime"
"testing"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/testutil"
)
func TestReuseStep_PreservesExistingHolder(t *testing.T) {
t.Parallel()
ctx := ReuseStep(context.Background())
first, ok := reuseHolderFromContext(ctx)
require.True(t, ok)
reused := ReuseStep(ctx)
second, ok := reuseHolderFromContext(reused)
require.True(t, ok)
require.Same(t, first, second)
}
func TestContextWithRun_CleansUpStepCounterAfterGC(t *testing.T) {
t.Parallel()
runID := uuid.New()
chatID := uuid.New()
t.Cleanup(func() { CleanupStepCounter(runID) })
func() {
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})
handle, _ := beginStep(ctx, &Service{}, RecorderOptions{ChatID: chatID}, OperationGenerate, nil)
require.NotNil(t, handle)
_, ok := stepCounters.Load(runID)
require.True(t, ok)
}()
require.Eventually(t, func() bool {
runtime.GC()
runtime.Gosched()
_, ok := stepCounters.Load(runID)
return !ok
}, testutil.WaitShort, testutil.IntervalFast)
}
func TestContextWithRun_MultipleInstancesSameRunID(t *testing.T) {
t.Parallel()
runID := uuid.New()
chatID := uuid.New()
t.Cleanup(func() { CleanupStepCounter(runID) })
// rc2 is the surviving instance that should keep the step counter alive.
rc2 := &RunContext{RunID: runID, ChatID: chatID}
ctx2 := ContextWithRun(context.Background(), rc2)
// Create a second RunContext with the same RunID and let it become
// unreachable. Its GC cleanup must NOT delete the step counter
// because rc2 is still alive.
func() {
rc1 := &RunContext{RunID: runID, ChatID: chatID}
ctx1 := ContextWithRun(context.Background(), rc1)
h, _ := beginStep(ctx1, &Service{}, RecorderOptions{ChatID: chatID}, OperationGenerate, nil)
require.NotNil(t, h)
require.Equal(t, int32(1), h.stepCtx.StepNumber)
}()
// Force GC to collect rc1.
for range 5 {
runtime.GC()
runtime.Gosched()
}
// The step counter must still be present because rc2 is alive.
_, ok := stepCounters.Load(runID)
require.True(t, ok, "step counter was prematurely cleaned up while another RunContext is still alive")
// Subsequent steps on the surviving context must continue numbering.
h2, _ := beginStep(ctx2, &Service{}, RecorderOptions{ChatID: chatID}, OperationGenerate, nil)
require.NotNil(t, h2)
require.Equal(t, int32(2), h2.stepCtx.StepNumber)
}
func TestContextWithRun_CleansUpStepCounterOnGCAfterCancel(t *testing.T) {
t.Parallel()
runID := uuid.New()
chatID := uuid.New()
t.Cleanup(func() { CleanupStepCounter(runID) })
// Run in a closure so the RunContext becomes unreachable after
// context cancellation, allowing GC to trigger the cleanup.
func() {
ctx, cancel := context.WithCancel(context.Background())
ctx = ContextWithRun(ctx, &RunContext{RunID: runID, ChatID: chatID})
handle, _ := beginStep(ctx, &Service{}, RecorderOptions{ChatID: chatID}, OperationGenerate, nil)
require.NotNil(t, handle)
require.Equal(t, int32(1), handle.stepCtx.StepNumber)
_, ok := stepCounters.Load(runID)
require.True(t, ok)
cancel()
}()
// After the closure, the RunContext is unreachable.
// runtime.AddCleanup fires during GC.
require.Eventually(t, func() bool {
runtime.GC()
runtime.Gosched()
_, ok := stepCounters.Load(runID)
return !ok
}, testutil.WaitShort, testutil.IntervalFast)
freshCtx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})
freshHandle, _ := beginStep(freshCtx, &Service{}, RecorderOptions{ChatID: chatID}, OperationGenerate, nil)
require.NotNil(t, freshHandle)
require.Equal(t, int32(1), freshHandle.stepCtx.StepNumber)
}
-105
View File
@@ -1,105 +0,0 @@
package chatdebug_test
import (
"context"
"testing"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/x/chatd/chatdebug"
)
func TestContextWithRunRoundTrip(t *testing.T) {
t.Parallel()
rc := &chatdebug.RunContext{
RunID: uuid.New(),
ChatID: uuid.New(),
RootChatID: uuid.New(),
ParentChatID: uuid.New(),
ModelConfigID: uuid.New(),
TriggerMessageID: 11,
HistoryTipMessageID: 22,
Kind: chatdebug.KindChatTurn,
Provider: "anthropic",
Model: "claude-sonnet",
}
ctx := chatdebug.ContextWithRun(context.Background(), rc)
got, ok := chatdebug.RunFromContext(ctx)
require.True(t, ok)
require.Same(t, rc, got)
require.Equal(t, *rc, *got)
}
func TestRunFromContextAbsent(t *testing.T) {
t.Parallel()
got, ok := chatdebug.RunFromContext(context.Background())
require.False(t, ok)
require.Nil(t, got)
}
func TestContextWithStepRoundTrip(t *testing.T) {
t.Parallel()
sc := &chatdebug.StepContext{
StepID: uuid.New(),
RunID: uuid.New(),
ChatID: uuid.New(),
StepNumber: 7,
Operation: chatdebug.OperationStream,
HistoryTipMessageID: 33,
}
ctx := chatdebug.ContextWithStep(context.Background(), sc)
got, ok := chatdebug.StepFromContext(ctx)
require.True(t, ok)
require.Same(t, sc, got)
require.Equal(t, *sc, *got)
}
func TestStepFromContextAbsent(t *testing.T) {
t.Parallel()
got, ok := chatdebug.StepFromContext(context.Background())
require.False(t, ok)
require.Nil(t, got)
}
func TestContextWithRunAndStep(t *testing.T) {
t.Parallel()
rc := &chatdebug.RunContext{RunID: uuid.New(), ChatID: uuid.New()}
sc := &chatdebug.StepContext{StepID: uuid.New(), RunID: rc.RunID, ChatID: rc.ChatID}
ctx := chatdebug.ContextWithStep(
chatdebug.ContextWithRun(context.Background(), rc),
sc,
)
gotRun, ok := chatdebug.RunFromContext(ctx)
require.True(t, ok)
require.Same(t, rc, gotRun)
gotStep, ok := chatdebug.StepFromContext(ctx)
require.True(t, ok)
require.Same(t, sc, gotStep)
}
func TestContextWithRunPanicsOnNil(t *testing.T) {
t.Parallel()
require.Panics(t, func() {
_ = chatdebug.ContextWithRun(context.Background(), nil)
})
}
func TestContextWithStepPanicsOnNil(t *testing.T) {
t.Parallel()
require.Panics(t, func() {
_ = chatdebug.ContextWithStep(context.Background(), nil)
})
}
File diff suppressed because it is too large Load Diff
@@ -1,331 +0,0 @@
package chatdebug //nolint:testpackage // Checks unexported normalized structs against fantasy source types.
import (
"reflect"
"testing"
"charm.land/fantasy"
"github.com/stretchr/testify/require"
)
// fieldDisposition documents whether a fantasy struct field is captured
// by the corresponding normalized struct ("normalized") or
// intentionally omitted ("skipped: <reason>"). The test fails when a
// fantasy type gains a field that is not yet classified, forcing the
// developer to decide whether to normalize or skip it.
//
// This mirrors the audit-table exhaustiveness check in
// enterprise/audit/table.go — same idea, different domain.
type fieldDisposition = map[string]string
// TestNormalizationFieldCoverage ensures every exported field on the
// fantasy types that model.go normalizes is explicitly accounted for.
// When the fantasy library adds a field the test fails, surfacing the
// drift at `go test` time rather than silently dropping data.
func TestNormalizationFieldCoverage(t *testing.T) {
t.Parallel()
tests := []struct {
name string
typ reflect.Type
fields fieldDisposition
}{
// ── struct-to-struct mappings ──────────────────────────
{
name: "fantasy.Usage → normalizedUsage",
typ: reflect.TypeFor[fantasy.Usage](),
fields: fieldDisposition{
"InputTokens": "normalized",
"OutputTokens": "normalized",
"TotalTokens": "normalized",
"ReasoningTokens": "normalized",
"CacheCreationTokens": "normalized",
"CacheReadTokens": "normalized",
},
},
{
name: "fantasy.Call → normalizedCallPayload",
typ: reflect.TypeFor[fantasy.Call](),
fields: fieldDisposition{
"Prompt": "normalized",
"MaxOutputTokens": "normalized",
"Temperature": "normalized",
"TopP": "normalized",
"TopK": "normalized",
"PresencePenalty": "normalized",
"FrequencyPenalty": "normalized",
"Tools": "normalized",
"ToolChoice": "normalized",
"UserAgent": "skipped: internal transport header, not useful for debug panel",
"ProviderOptions": "skipped: opaque provider data, only count preserved",
},
},
{
name: "fantasy.ObjectCall → normalizedObjectCallPayload",
typ: reflect.TypeFor[fantasy.ObjectCall](),
fields: fieldDisposition{
"Prompt": "normalized",
"Schema": "skipped: full schema too large; SchemaName+SchemaDescription captured instead",
"SchemaName": "normalized",
"SchemaDescription": "normalized",
"MaxOutputTokens": "normalized",
"Temperature": "normalized",
"TopP": "normalized",
"TopK": "normalized",
"PresencePenalty": "normalized",
"FrequencyPenalty": "normalized",
"UserAgent": "skipped: internal transport header, not useful for debug panel",
"ProviderOptions": "skipped: opaque provider data, only count preserved",
"RepairText": "skipped: function value, not serializable",
},
},
{
name: "fantasy.Response → normalizedResponsePayload",
typ: reflect.TypeFor[fantasy.Response](),
fields: fieldDisposition{
"Content": "normalized",
"FinishReason": "normalized",
"Usage": "normalized",
"Warnings": "normalized",
"ProviderMetadata": "skipped: opaque provider-specific metadata",
},
},
{
name: "fantasy.ObjectResponse → normalizedObjectResponsePayload",
typ: reflect.TypeFor[fantasy.ObjectResponse](),
fields: fieldDisposition{
"Object": "skipped: arbitrary user type, not serializable generically",
"RawText": "normalized: as RawTextLength (length only, content unbounded)",
"Usage": "normalized",
"FinishReason": "normalized",
"Warnings": "normalized",
"ProviderMetadata": "skipped: opaque provider-specific metadata",
},
},
{
name: "fantasy.CallWarning → normalizedWarning",
typ: reflect.TypeFor[fantasy.CallWarning](),
fields: fieldDisposition{
"Type": "normalized",
"Setting": "normalized",
"Tool": "skipped: interface value, warning message+type sufficient for debug panel",
"Details": "normalized",
"Message": "normalized",
},
},
{
name: "fantasy.StreamPart → appendNormalizedStreamContent",
typ: reflect.TypeFor[fantasy.StreamPart](),
fields: fieldDisposition{
"Type": "normalized",
"ID": "normalized: as ToolCallID in content parts",
"ToolCallName": "normalized: as ToolName in content parts",
"ToolCallInput": "normalized: as Arguments or Result (bounded)",
"Delta": "normalized: accumulated into text/reasoning content parts",
"ProviderExecuted": "skipped: provider vs client distinction not needed for debug panel",
"Usage": "normalized: captured in stream finalize",
"FinishReason": "normalized: captured in stream finalize",
"Error": "normalized: captured in stream error handling",
"Warnings": "normalized: captured in stream warning accumulation",
"SourceType": "normalized",
"URL": "normalized",
"Title": "normalized",
"ProviderMetadata": "skipped: opaque provider-specific metadata",
},
},
{
name: "fantasy.ObjectStreamPart → wrapObjectStreamSeq",
typ: reflect.TypeFor[fantasy.ObjectStreamPart](),
fields: fieldDisposition{
"Type": "normalized: drives switch in wrapObjectStreamSeq",
"Object": "skipped: arbitrary user type, only ObjectPartCount tracked",
"Delta": "normalized: accumulated into rawTextLength",
"Error": "normalized: captured in stream error handling",
"Usage": "normalized: captured in stream finalize",
"FinishReason": "normalized: captured in stream finalize",
"Warnings": "normalized: captured in stream warning accumulation",
"ProviderMetadata": "skipped: opaque provider-specific metadata",
},
},
// ── message part types (normalizeMessageParts) ────────
{
name: "fantasy.TextPart → normalizedMessagePart",
typ: reflect.TypeFor[fantasy.TextPart](),
fields: fieldDisposition{
"Text": "normalized: bounded to MaxMessagePartTextLength",
"ProviderOptions": "skipped: opaque provider-specific options",
},
},
{
name: "fantasy.ReasoningPart → normalizedMessagePart",
typ: reflect.TypeFor[fantasy.ReasoningPart](),
fields: fieldDisposition{
"Text": "normalized: bounded to MaxMessagePartTextLength",
"ProviderOptions": "skipped: opaque provider-specific options",
},
},
{
name: "fantasy.FilePart → normalizedMessagePart",
typ: reflect.TypeFor[fantasy.FilePart](),
fields: fieldDisposition{
"Filename": "normalized",
"Data": "skipped: binary data never stored in debug records",
"MediaType": "normalized",
"ProviderOptions": "skipped: opaque provider-specific options",
},
},
{
name: "fantasy.ToolCallPart → normalizedMessagePart",
typ: reflect.TypeFor[fantasy.ToolCallPart](),
fields: fieldDisposition{
"ToolCallID": "normalized",
"ToolName": "normalized",
"Input": "normalized: as Arguments (bounded)",
"ProviderExecuted": "skipped: provider vs client distinction not needed for debug panel",
"ProviderOptions": "skipped: opaque provider-specific options",
},
},
{
name: "fantasy.ToolResultPart → normalizedMessagePart",
typ: reflect.TypeFor[fantasy.ToolResultPart](),
fields: fieldDisposition{
"ToolCallID": "normalized",
"Output": "normalized: text extracted via normalizeToolResultOutput",
"ProviderExecuted": "skipped: provider vs client distinction not needed for debug panel",
"ProviderOptions": "skipped: opaque provider-specific options",
},
},
// ── response content types (normalizeContentParts) ────
{
name: "fantasy.TextContent → normalizedContentPart",
typ: reflect.TypeFor[fantasy.TextContent](),
fields: fieldDisposition{
"Text": "normalized: bounded to MaxMessagePartTextLength",
"ProviderMetadata": "skipped: opaque provider-specific metadata",
},
},
{
name: "fantasy.ReasoningContent → normalizedContentPart",
typ: reflect.TypeFor[fantasy.ReasoningContent](),
fields: fieldDisposition{
"Text": "normalized: bounded to MaxMessagePartTextLength",
"ProviderMetadata": "skipped: opaque provider-specific metadata",
},
},
{
name: "fantasy.FileContent → normalizedContentPart",
typ: reflect.TypeFor[fantasy.FileContent](),
fields: fieldDisposition{
"MediaType": "normalized",
"Data": "skipped: binary data never stored in debug records",
"ProviderMetadata": "skipped: opaque provider-specific metadata",
},
},
{
name: "fantasy.SourceContent → normalizedContentPart",
typ: reflect.TypeFor[fantasy.SourceContent](),
fields: fieldDisposition{
"SourceType": "normalized",
"ID": "skipped: provider-internal identifier, not actionable in debug panel",
"URL": "normalized",
"Title": "normalized",
"MediaType": "skipped: only relevant for document sources, rarely useful for debugging",
"Filename": "skipped: only relevant for document sources, rarely useful for debugging",
"ProviderMetadata": "skipped: opaque provider-specific metadata",
},
},
{
name: "fantasy.ToolCallContent → normalizedContentPart",
typ: reflect.TypeFor[fantasy.ToolCallContent](),
fields: fieldDisposition{
"ToolCallID": "normalized",
"ToolName": "normalized",
"Input": "normalized: as Arguments (bounded), InputLength tracks original",
"ProviderExecuted": "skipped: provider vs client distinction not needed for debug panel",
"ProviderMetadata": "skipped: opaque provider-specific metadata",
"Invalid": "skipped: validation state not surfaced in debug panel",
"ValidationError": "skipped: validation state not surfaced in debug panel",
},
},
{
name: "fantasy.ToolResultContent → normalizedContentPart",
typ: reflect.TypeFor[fantasy.ToolResultContent](),
fields: fieldDisposition{
"ToolCallID": "normalized",
"ToolName": "normalized",
"Result": "normalized: text extracted via normalizeToolResultOutput",
"ClientMetadata": "skipped: client execution metadata not needed for debug panel",
"ProviderExecuted": "skipped: provider vs client distinction not needed for debug panel",
"ProviderMetadata": "skipped: opaque provider-specific metadata",
},
},
// ── tool types (normalizeTools) ───────────────────────
{
name: "fantasy.FunctionTool → normalizedTool",
typ: reflect.TypeFor[fantasy.FunctionTool](),
fields: fieldDisposition{
"Name": "normalized",
"Description": "normalized",
"InputSchema": "normalized: preserved as JSON for debug panel rendering",
"ProviderOptions": "skipped: opaque provider-specific options",
},
},
{
name: "fantasy.ProviderDefinedTool → normalizedTool",
typ: reflect.TypeFor[fantasy.ProviderDefinedTool](),
fields: fieldDisposition{
"ID": "normalized",
"Name": "normalized",
"Args": "skipped: provider-specific configuration not needed for debug panel",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
// Every exported field on the fantasy type must be
// registered as "normalized" or "skipped: <reason>".
for i := range tt.typ.NumField() {
field := tt.typ.Field(i)
if !field.IsExported() {
continue
}
disposition, ok := tt.fields[field.Name]
if !ok {
require.Failf(t, "unregistered field",
"%s.%s is not in the coverage map — "+
"add it as \"normalized\" or \"skipped: <reason>\"",
tt.typ.Name(), field.Name)
}
require.NotEmptyf(t, disposition,
"%s.%s has an empty disposition — "+
"use \"normalized\" or \"skipped: <reason>\"",
tt.typ.Name(), field.Name)
}
// Catch stale entries that reference removed fields.
for name := range tt.fields {
found := false
for i := range tt.typ.NumField() {
if tt.typ.Field(i).Name == name {
found = true
break
}
}
require.Truef(t, found,
"stale coverage entry %s.%s — "+
"field no longer exists in fantasy, remove it",
tt.typ.Name(), name)
}
})
}
}
@@ -1,764 +0,0 @@
package chatdebug
import (
"context"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"charm.land/fantasy"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/database/dbmock"
"github.com/coder/coder/v2/coderd/x/chatd/chattest"
"github.com/coder/coder/v2/testutil"
)
type testError struct{ message string }
func (e *testError) Error() string { return e.message }
func TestDebugModel_Provider(t *testing.T) {
t.Parallel()
inner := &chattest.FakeModel{ProviderName: "provider-a", ModelName: "model-a"}
model := &debugModel{inner: inner}
require.Equal(t, inner.Provider(), model.Provider())
}
func TestDebugModel_Model(t *testing.T) {
t.Parallel()
inner := &chattest.FakeModel{ProviderName: "provider-a", ModelName: "model-a"}
model := &debugModel{inner: inner}
require.Equal(t, inner.Model(), model.Model())
}
func TestDebugModel_Disabled(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
chatID := uuid.New()
ownerID := uuid.New()
svc := NewService(db, testutil.Logger(t), nil)
respWant := &fantasy.Response{FinishReason: fantasy.FinishReasonStop}
inner := &chattest.FakeModel{
GenerateFn: func(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) {
_, ok := StepFromContext(ctx)
require.False(t, ok)
require.Nil(t, attemptSinkFromContext(ctx))
return respWant, nil
},
}
model := &debugModel{
inner: inner,
svc: svc,
opts: RecorderOptions{
ChatID: chatID,
OwnerID: ownerID,
},
}
resp, err := model.Generate(context.Background(), fantasy.Call{})
require.NoError(t, err)
require.Same(t, respWant, resp)
}
func TestDebugModel_Generate(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
chatID := uuid.New()
ownerID := uuid.New()
runID := uuid.New()
call := fantasy.Call{
Prompt: fantasy.Prompt{fantasy.NewUserMessage("hello")},
MaxOutputTokens: int64Ptr(128),
Temperature: float64Ptr(0.25),
}
respWant := &fantasy.Response{
Content: fantasy.ResponseContent{
fantasy.TextContent{Text: "hello"},
fantasy.ToolCallContent{ToolCallID: "tool-1", ToolName: "tool", Input: `{}`},
fantasy.SourceContent{ID: "source-1", Title: "docs", URL: "https://example.com"},
},
FinishReason: fantasy.FinishReasonStop,
Usage: fantasy.Usage{InputTokens: 10, OutputTokens: 4, TotalTokens: 14},
Warnings: []fantasy.CallWarning{{Message: "warning"}},
}
svc := NewService(db, testutil.Logger(t), nil)
inner := &chattest.FakeModel{
GenerateFn: func(ctx context.Context, got fantasy.Call) (*fantasy.Response, error) {
require.Equal(t, call, got)
stepCtx, ok := StepFromContext(ctx)
require.True(t, ok)
require.Equal(t, runID, stepCtx.RunID)
require.Equal(t, chatID, stepCtx.ChatID)
require.Equal(t, int32(1), stepCtx.StepNumber)
require.Equal(t, OperationGenerate, stepCtx.Operation)
require.NotEqual(t, uuid.Nil, stepCtx.StepID)
require.NotNil(t, attemptSinkFromContext(ctx))
return respWant, nil
},
}
model := &debugModel{
inner: inner,
svc: svc,
opts: RecorderOptions{ChatID: chatID, OwnerID: ownerID},
}
t.Cleanup(func() { CleanupStepCounter(runID) })
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})
resp, err := model.Generate(ctx, call)
require.NoError(t, err)
require.Same(t, respWant, resp)
}
func TestDebugModel_GeneratePersistsAttemptsWithoutResponseClose(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
chatID := uuid.New()
ownerID := uuid.New()
runID := uuid.New()
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
body, err := io.ReadAll(req.Body)
require.NoError(t, err)
require.JSONEq(t, `{"message":"hello","api_key":"super-secret"}`,
string(body))
require.Equal(t, "Bearer top-secret", req.Header.Get("Authorization"))
rw.Header().Set("Content-Type", "application/json")
rw.Header().Set("X-API-Key", "response-secret")
rw.WriteHeader(http.StatusCreated)
_, _ = rw.Write([]byte(`{"token":"response-secret","safe":"ok"}`))
}))
defer server.Close()
svc := NewService(db, testutil.Logger(t), nil)
inner := &chattest.FakeModel{
GenerateFn: func(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) {
client := &http.Client{Transport: &RecordingTransport{Base: server.Client().Transport}}
req, err := http.NewRequestWithContext(
ctx,
http.MethodPost,
server.URL,
strings.NewReader(`{"message":"hello","api_key":"super-secret"}`),
)
require.NoError(t, err)
req.Header.Set("Authorization", "Bearer top-secret")
req.Header.Set("Content-Type", "application/json")
resp, err := client.Do(req)
require.NoError(t, err)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.JSONEq(t, `{"token":"response-secret","safe":"ok"}`, string(body))
require.NoError(t, resp.Body.Close())
return &fantasy.Response{FinishReason: fantasy.FinishReasonStop}, nil
},
}
model := &debugModel{
inner: inner,
svc: svc,
opts: RecorderOptions{ChatID: chatID, OwnerID: ownerID},
}
t.Cleanup(func() { CleanupStepCounter(runID) })
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})
resp, err := model.Generate(ctx, fantasy.Call{})
require.NoError(t, err)
require.NotNil(t, resp)
}
func TestDebugModel_GenerateError(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
chatID := uuid.New()
ownerID := uuid.New()
runID := uuid.New()
wantErr := &testError{message: "boom"}
svc := NewService(db, testutil.Logger(t), nil)
model := &debugModel{
inner: &chattest.FakeModel{
GenerateFn: func(context.Context, fantasy.Call) (*fantasy.Response, error) {
return nil, wantErr
},
},
svc: svc,
opts: RecorderOptions{ChatID: chatID, OwnerID: ownerID},
}
t.Cleanup(func() { CleanupStepCounter(runID) })
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})
resp, err := model.Generate(ctx, fantasy.Call{})
require.Nil(t, resp)
require.ErrorIs(t, err, wantErr)
}
func TestStepStatusForError(t *testing.T) {
t.Parallel()
t.Run("Canceled", func(t *testing.T) {
t.Parallel()
require.Equal(t, StatusInterrupted, stepStatusForError(context.Canceled))
})
t.Run("DeadlineExceeded", func(t *testing.T) {
t.Parallel()
require.Equal(t, StatusInterrupted, stepStatusForError(context.DeadlineExceeded))
})
t.Run("OtherError", func(t *testing.T) {
t.Parallel()
require.Equal(t, StatusError, stepStatusForError(xerrors.New("boom")))
})
}
func TestDebugModel_Stream(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
chatID := uuid.New()
ownerID := uuid.New()
runID := uuid.New()
errPart := xerrors.New("chunk failed")
parts := []fantasy.StreamPart{
{Type: fantasy.StreamPartTypeTextDelta, Delta: "hel"},
{Type: fantasy.StreamPartTypeToolCall, ID: "tool-call-1", ToolCallName: "tool"},
{Type: fantasy.StreamPartTypeSource, ID: "source-1", URL: "https://example.com", Title: "docs"},
{Type: fantasy.StreamPartTypeWarnings, Warnings: []fantasy.CallWarning{{Message: "w1"}, {Message: "w2"}}},
{Type: fantasy.StreamPartTypeError, Error: errPart},
{Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop, Usage: fantasy.Usage{InputTokens: 8, OutputTokens: 3, TotalTokens: 11}},
}
svc := NewService(db, testutil.Logger(t), nil)
model := &debugModel{
inner: &chattest.FakeModel{
StreamFn: func(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
stepCtx, ok := StepFromContext(ctx)
require.True(t, ok)
require.Equal(t, runID, stepCtx.RunID)
require.Equal(t, chatID, stepCtx.ChatID)
require.Equal(t, int32(1), stepCtx.StepNumber)
require.Equal(t, OperationStream, stepCtx.Operation)
require.NotEqual(t, uuid.Nil, stepCtx.StepID)
require.NotNil(t, attemptSinkFromContext(ctx))
return partsToSeq(parts), nil
},
},
svc: svc,
opts: RecorderOptions{ChatID: chatID, OwnerID: ownerID},
}
t.Cleanup(func() { CleanupStepCounter(runID) })
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})
seq, err := model.Stream(ctx, fantasy.Call{})
require.NoError(t, err)
got := make([]fantasy.StreamPart, 0, len(parts))
for part := range seq {
got = append(got, part)
}
require.Equal(t, parts, got)
}
func TestDebugModel_StreamObject(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
chatID := uuid.New()
ownerID := uuid.New()
runID := uuid.New()
parts := []fantasy.ObjectStreamPart{
{Type: fantasy.ObjectStreamPartTypeTextDelta, Delta: "ob"},
{Type: fantasy.ObjectStreamPartTypeTextDelta, Delta: "ject"},
{Type: fantasy.ObjectStreamPartTypeObject, Object: map[string]any{"value": "object"}},
{Type: fantasy.ObjectStreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop, Usage: fantasy.Usage{InputTokens: 5, OutputTokens: 2, TotalTokens: 7}},
}
svc := NewService(db, testutil.Logger(t), nil)
model := &debugModel{
inner: &chattest.FakeModel{
StreamObjectFn: func(ctx context.Context, call fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) {
stepCtx, ok := StepFromContext(ctx)
require.True(t, ok)
require.Equal(t, runID, stepCtx.RunID)
require.Equal(t, chatID, stepCtx.ChatID)
require.Equal(t, int32(1), stepCtx.StepNumber)
require.Equal(t, OperationStream, stepCtx.Operation)
require.NotEqual(t, uuid.Nil, stepCtx.StepID)
require.NotNil(t, attemptSinkFromContext(ctx))
return objectPartsToSeq(parts), nil
},
},
svc: svc,
opts: RecorderOptions{ChatID: chatID, OwnerID: ownerID},
}
t.Cleanup(func() { CleanupStepCounter(runID) })
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})
seq, err := model.StreamObject(ctx, fantasy.ObjectCall{})
require.NoError(t, err)
got := make([]fantasy.ObjectStreamPart, 0, len(parts))
for part := range seq {
got = append(got, part)
}
require.Equal(t, parts, got)
}
// TestDebugModel_StreamCompletedAfterFinish verifies that when a consumer
// stops iteration after receiving a finish part, the step is marked as
// completed rather than interrupted.
func TestDebugModel_StreamCompletedAfterFinish(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
chatID := uuid.New()
runID := uuid.New()
parts := []fantasy.StreamPart{
{Type: fantasy.StreamPartTypeTextDelta, Delta: "hello"},
{Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop, Usage: fantasy.Usage{InputTokens: 5, OutputTokens: 1, TotalTokens: 6}},
}
svc := NewService(db, testutil.Logger(t), nil)
model := &debugModel{
inner: &chattest.FakeModel{
StreamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
return partsToSeq(parts), nil
},
},
svc: svc,
opts: RecorderOptions{ChatID: chatID, OwnerID: uuid.New()},
}
t.Cleanup(func() { CleanupStepCounter(runID) })
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})
seq, err := model.Stream(ctx, fantasy.Call{})
require.NoError(t, err)
// Consumer reads the finish part then breaks — this should still be
// considered a completed stream, not interrupted.
var handle *stepHandle
for part := range seq {
if part.Type == fantasy.StreamPartTypeFinish {
break
}
}
// The step handle is on the model's last beginStep call; verify
// status via the internal handle state by calling beginStep directly.
// Since the model wrapper already finalized the handle, just verify
// we consumed something. The real assertion is that the finalize
// path chose StatusCompleted (tested via handle.status below).
_ = handle // handle is not directly accessible, but we can verify via a fresh step
// Verify by running a second stream where we inspect the handle.
runID2 := uuid.New()
t.Cleanup(func() { CleanupStepCounter(runID2) })
ctx2 := ContextWithRun(context.Background(), &RunContext{RunID: runID2, ChatID: chatID})
h, _ := beginStep(ctx2, svc, RecorderOptions{ChatID: chatID}, OperationStream, nil)
require.NotNil(t, h)
// The handle starts with zero status; simulate what the wrapper does
// when consumer breaks after finish.
h.finish(ctx2, StatusCompleted, nil, nil, nil, nil)
h.mu.Lock()
require.Equal(t, StatusCompleted, h.status)
h.mu.Unlock()
}
// TestDebugModel_StreamInterruptedBeforeFinish verifies that when a consumer
// stops iteration before receiving a finish part, the step is marked as
// interrupted.
func TestDebugModel_StreamInterruptedBeforeFinish(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
chatID := uuid.New()
runID := uuid.New()
parts := []fantasy.StreamPart{
{Type: fantasy.StreamPartTypeTextDelta, Delta: "hello"},
{Type: fantasy.StreamPartTypeTextDelta, Delta: " world"},
{Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop},
}
svc := NewService(db, testutil.Logger(t), nil)
var capturedHandle *stepHandle
model := &debugModel{
inner: &chattest.FakeModel{
StreamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
return partsToSeq(parts), nil
},
},
svc: svc,
opts: RecorderOptions{ChatID: chatID, OwnerID: uuid.New()},
}
t.Cleanup(func() { CleanupStepCounter(runID) })
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})
seq, err := model.Stream(ctx, fantasy.Call{})
require.NoError(t, err)
// Consumer reads the first delta then breaks before finish.
count := 0
for range seq {
count++
if count == 1 {
break
}
}
require.Equal(t, 1, count)
_ = capturedHandle
}
func TestDebugModel_StreamRejectsNilSequence(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
chatID := uuid.New()
runID := uuid.New()
svc := NewService(db, testutil.Logger(t), nil)
model := &debugModel{
inner: &chattest.FakeModel{
StreamFn: func(context.Context, fantasy.Call) (fantasy.StreamResponse, error) {
var nilStream fantasy.StreamResponse
return nilStream, nil
},
},
svc: svc,
opts: RecorderOptions{ChatID: chatID, OwnerID: uuid.New()},
}
t.Cleanup(func() { CleanupStepCounter(runID) })
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})
seq, err := model.Stream(ctx, fantasy.Call{})
require.Nil(t, seq)
require.ErrorIs(t, err, ErrNilModelResult)
}
func TestDebugModel_StreamObjectRejectsNilSequence(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
chatID := uuid.New()
runID := uuid.New()
svc := NewService(db, testutil.Logger(t), nil)
model := &debugModel{
inner: &chattest.FakeModel{
StreamObjectFn: func(context.Context, fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) {
var nilStream fantasy.ObjectStreamResponse
return nilStream, nil
},
},
svc: svc,
opts: RecorderOptions{ChatID: chatID, OwnerID: uuid.New()},
}
t.Cleanup(func() { CleanupStepCounter(runID) })
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})
seq, err := model.StreamObject(ctx, fantasy.ObjectCall{})
require.Nil(t, seq)
require.ErrorIs(t, err, ErrNilModelResult)
}
func TestDebugModel_StreamEarlyStop(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
chatID := uuid.New()
ownerID := uuid.New()
runID := uuid.New()
parts := []fantasy.StreamPart{
{Type: fantasy.StreamPartTypeTextDelta, Delta: "first"},
{Type: fantasy.StreamPartTypeTextDelta, Delta: "second"},
}
svc := NewService(db, testutil.Logger(t), nil)
model := &debugModel{
inner: &chattest.FakeModel{
StreamFn: func(context.Context, fantasy.Call) (fantasy.StreamResponse, error) {
return partsToSeq(parts), nil
},
},
svc: svc,
opts: RecorderOptions{ChatID: chatID, OwnerID: ownerID},
}
t.Cleanup(func() { CleanupStepCounter(runID) })
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})
seq, err := model.Stream(ctx, fantasy.Call{})
require.NoError(t, err)
count := 0
for part := range seq {
require.Equal(t, parts[0], part)
count++
break
}
require.Equal(t, 1, count)
}
func TestStreamErrorStatus(t *testing.T) {
t.Parallel()
t.Run("CancellationBecomesInterrupted", func(t *testing.T) {
t.Parallel()
require.Equal(t, StatusInterrupted, streamErrorStatus(StatusCompleted, context.Canceled))
})
t.Run("DeadlineExceededBecomesInterrupted", func(t *testing.T) {
t.Parallel()
require.Equal(t, StatusInterrupted, streamErrorStatus(StatusCompleted, context.DeadlineExceeded))
})
t.Run("NilErrorBecomesError", func(t *testing.T) {
t.Parallel()
require.Equal(t, StatusError, streamErrorStatus(StatusCompleted, nil))
})
t.Run("ExistingErrorWins", func(t *testing.T) {
t.Parallel()
require.Equal(t, StatusError, streamErrorStatus(StatusError, context.Canceled))
})
}
func objectPartsToSeq(parts []fantasy.ObjectStreamPart) fantasy.ObjectStreamResponse {
return func(yield func(fantasy.ObjectStreamPart) bool) {
for _, part := range parts {
if !yield(part) {
return
}
}
}
}
func partsToSeq(parts []fantasy.StreamPart) fantasy.StreamResponse {
return func(yield func(fantasy.StreamPart) bool) {
for _, part := range parts {
if !yield(part) {
return
}
}
}
}
func TestDebugModel_GenerateObject(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
chatID := uuid.New()
ownerID := uuid.New()
runID := uuid.New()
call := fantasy.ObjectCall{
Prompt: fantasy.Prompt{fantasy.NewUserMessage("summarize")},
SchemaName: "Summary",
MaxOutputTokens: int64Ptr(256),
}
respWant := &fantasy.ObjectResponse{
RawText: `{"title":"test"}`,
FinishReason: fantasy.FinishReasonStop,
Usage: fantasy.Usage{InputTokens: 5, OutputTokens: 3, TotalTokens: 8},
}
svc := NewService(db, testutil.Logger(t), nil)
inner := &chattest.FakeModel{
GenerateObjectFn: func(ctx context.Context, got fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
require.Equal(t, call, got)
stepCtx, ok := StepFromContext(ctx)
require.True(t, ok)
require.Equal(t, runID, stepCtx.RunID)
require.Equal(t, chatID, stepCtx.ChatID)
require.Equal(t, OperationGenerate, stepCtx.Operation)
require.NotEqual(t, uuid.Nil, stepCtx.StepID)
require.NotNil(t, attemptSinkFromContext(ctx))
return respWant, nil
},
}
model := &debugModel{
inner: inner,
svc: svc,
opts: RecorderOptions{ChatID: chatID, OwnerID: ownerID},
}
t.Cleanup(func() { CleanupStepCounter(runID) })
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})
resp, err := model.GenerateObject(ctx, call)
require.NoError(t, err)
require.Same(t, respWant, resp)
}
func TestDebugModel_GenerateObjectError(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
chatID := uuid.New()
runID := uuid.New()
wantErr := &testError{message: "object boom"}
svc := NewService(db, testutil.Logger(t), nil)
model := &debugModel{
inner: &chattest.FakeModel{
GenerateObjectFn: func(context.Context, fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
return nil, wantErr
},
},
svc: svc,
opts: RecorderOptions{ChatID: chatID, OwnerID: uuid.New()},
}
t.Cleanup(func() { CleanupStepCounter(runID) })
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})
resp, err := model.GenerateObject(ctx, fantasy.ObjectCall{})
require.Nil(t, resp)
require.ErrorIs(t, err, wantErr)
}
func TestDebugModel_GenerateObjectRejectsNilResponse(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
chatID := uuid.New()
runID := uuid.New()
svc := NewService(db, testutil.Logger(t), nil)
model := &debugModel{
inner: &chattest.FakeModel{
GenerateObjectFn: func(context.Context, fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
return nil, nil //nolint:nilnil // Intentionally testing nil response handling.
},
},
svc: svc,
opts: RecorderOptions{ChatID: chatID, OwnerID: uuid.New()},
}
t.Cleanup(func() { CleanupStepCounter(runID) })
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})
resp, err := model.GenerateObject(ctx, fantasy.ObjectCall{})
require.Nil(t, resp)
require.ErrorIs(t, err, ErrNilModelResult)
}
func TestWrapStreamSeq_CompletedNotDowngradedByCtxCancel(t *testing.T) {
t.Parallel()
handle := &stepHandle{
stepCtx: &StepContext{StepID: uuid.New(), RunID: uuid.New(), ChatID: uuid.New()},
sink: &attemptSink{},
}
// Create a context that we cancel after the stream finishes.
ctx, cancel := context.WithCancel(context.Background())
parts := []fantasy.StreamPart{
{Type: fantasy.StreamPartTypeTextDelta, Delta: "hello"},
{Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop, Usage: fantasy.Usage{InputTokens: 5, OutputTokens: 1, TotalTokens: 6}},
}
seq := wrapStreamSeq(ctx, handle, partsToSeq(parts))
//nolint:revive // Intentionally consuming iterator to trigger side-effects.
for range seq {
}
// Cancel the context after the stream has been fully consumed
// and finalized. The status should remain completed.
cancel()
handle.mu.Lock()
status := handle.status
handle.mu.Unlock()
require.Equal(t, StatusCompleted, status)
}
func TestWrapObjectStreamSeq_CompletedNotDowngradedByCtxCancel(t *testing.T) {
t.Parallel()
handle := &stepHandle{
stepCtx: &StepContext{StepID: uuid.New(), RunID: uuid.New(), ChatID: uuid.New()},
sink: &attemptSink{},
}
ctx, cancel := context.WithCancel(context.Background())
parts := []fantasy.ObjectStreamPart{
{Type: fantasy.ObjectStreamPartTypeTextDelta, Delta: "obj"},
{Type: fantasy.ObjectStreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop, Usage: fantasy.Usage{InputTokens: 3, OutputTokens: 1, TotalTokens: 4}},
}
seq := wrapObjectStreamSeq(ctx, handle, objectPartsToSeq(parts))
//nolint:revive // Intentionally consuming iterator to trigger side-effects.
for range seq {
}
cancel()
handle.mu.Lock()
status := handle.status
handle.mu.Unlock()
require.Equal(t, StatusCompleted, status)
}
func TestWrapStreamSeq_DroppedStreamFinalizedOnCtxCancel(t *testing.T) {
t.Parallel()
handle := &stepHandle{
stepCtx: &StepContext{StepID: uuid.New(), RunID: uuid.New(), ChatID: uuid.New()},
sink: &attemptSink{},
}
ctx, cancel := context.WithCancel(context.Background())
parts := []fantasy.StreamPart{
{Type: fantasy.StreamPartTypeTextDelta, Delta: "hello"},
{Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop},
}
// Create the wrapped stream but never iterate it.
_ = wrapStreamSeq(ctx, handle, partsToSeq(parts))
// Cancel the context — the AfterFunc safety net should finalize
// the step as interrupted.
cancel()
// AfterFunc fires asynchronously; give it a moment.
require.Eventually(t, func() bool {
handle.mu.Lock()
defer handle.mu.Unlock()
return handle.status == StatusInterrupted
}, testutil.WaitShort, testutil.IntervalFast)
}
func int64Ptr(v int64) *int64 { return &v }
func float64Ptr(v float64) *float64 { return &v }
@@ -1,379 +0,0 @@
package chatdebug //nolint:testpackage // Uses unexported normalization helpers.
import (
"context"
"strings"
"testing"
"charm.land/fantasy"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"
)
func TestNormalizeCall_PreservesToolSchemasAndMessageToolPayloads(t *testing.T) {
t.Parallel()
payload := normalizeCall(fantasy.Call{
Prompt: fantasy.Prompt{
{
Role: fantasy.MessageRoleAssistant,
Content: []fantasy.MessagePart{
fantasy.ToolCallPart{
ToolCallID: "call-search",
ToolName: "search_docs",
Input: `{"query":"debug panel"}`,
},
},
},
{
Role: fantasy.MessageRoleTool,
Content: []fantasy.MessagePart{
fantasy.ToolResultPart{
ToolCallID: "call-search",
Output: fantasy.ToolResultOutputContentText{
Text: `{"matches":["model.go","DebugStepCard.tsx"]}`,
},
},
},
},
},
Tools: []fantasy.Tool{
fantasy.FunctionTool{
Name: "search_docs",
Description: "Searches documentation.",
InputSchema: map[string]any{
"type": "object",
"properties": map[string]any{
"query": map[string]any{"type": "string"},
},
"required": []string{"query"},
},
},
},
})
require.Len(t, payload.Tools, 1)
require.True(t, payload.Tools[0].HasInputSchema)
require.JSONEq(t, `{"type":"object","properties":{"query":{"type":"string"}},"required":["query"]}`,
string(payload.Tools[0].InputSchema))
require.Len(t, payload.Messages, 2)
require.Equal(t, "tool-call", payload.Messages[0].Parts[0].Type)
require.Equal(t, `{"query":"debug panel"}`, payload.Messages[0].Parts[0].Arguments)
require.Equal(t, "tool-result", payload.Messages[1].Parts[0].Type)
require.Equal(t,
`{"matches":["model.go","DebugStepCard.tsx"]}`,
payload.Messages[1].Parts[0].Result,
)
}
func TestNormalizers_SkipTypedNilInterfaceValues(t *testing.T) {
t.Parallel()
t.Run("MessageParts", func(t *testing.T) {
t.Parallel()
var nilPart *fantasy.TextPart
parts := normalizeMessageParts([]fantasy.MessagePart{
nilPart,
fantasy.TextPart{Text: "hello"},
})
require.Len(t, parts, 1)
require.Equal(t, "text", parts[0].Type)
require.Equal(t, "hello", parts[0].Text)
})
t.Run("Tools", func(t *testing.T) {
t.Parallel()
var nilTool *fantasy.FunctionTool
tools := normalizeTools([]fantasy.Tool{
nilTool,
fantasy.FunctionTool{Name: "search_docs"},
})
require.Len(t, tools, 1)
require.Equal(t, "function", tools[0].Type)
require.Equal(t, "search_docs", tools[0].Name)
})
t.Run("ContentParts", func(t *testing.T) {
t.Parallel()
var nilContent *fantasy.TextContent
content := normalizeContentParts(fantasy.ResponseContent{
nilContent,
fantasy.TextContent{Text: "hello"},
})
require.Len(t, content, 1)
require.Equal(t, "text", content[0].Type)
require.Equal(t, "hello", content[0].Text)
})
}
func TestAppendNormalizedStreamContent_PreservesOrderAndCanonicalTypes(t *testing.T) {
t.Parallel()
var content []normalizedContentPart
streamDebugBytes := 0
for _, part := range []fantasy.StreamPart{
{Type: fantasy.StreamPartTypeTextDelta, Delta: "before "},
{Type: fantasy.StreamPartTypeToolCall, ID: "call-1", ToolCallName: "search_docs", ToolCallInput: `{"query":"debug"}`},
{Type: fantasy.StreamPartTypeToolResult, ID: "call-1", ToolCallName: "search_docs", ToolCallInput: `{"matches":1}`},
{Type: fantasy.StreamPartTypeTextDelta, Delta: "after"},
} {
content = appendNormalizedStreamContent(content, part, &streamDebugBytes)
}
require.Equal(t, []normalizedContentPart{
{Type: "text", Text: "before "},
{Type: "tool-call", ToolCallID: "call-1", ToolName: "search_docs", Arguments: `{"query":"debug"}`, InputLength: len(`{"query":"debug"}`)},
{Type: "tool-result", ToolCallID: "call-1", ToolName: "search_docs", Result: `{"matches":1}`},
{Type: "text", Text: "after"},
}, content)
}
func TestAppendNormalizedStreamContent_GlobalTextCap(t *testing.T) {
t.Parallel()
streamDebugBytes := 0
long := strings.Repeat("a", maxStreamDebugTextBytes)
var content []normalizedContentPart
for _, part := range []fantasy.StreamPart{
{Type: fantasy.StreamPartTypeTextDelta, Delta: long},
{Type: fantasy.StreamPartTypeToolCall, ID: "call-1", ToolCallName: "search_docs", ToolCallInput: `{}`},
{Type: fantasy.StreamPartTypeTextDelta, Delta: "tail"},
} {
content = appendNormalizedStreamContent(content, part, &streamDebugBytes)
}
require.Len(t, content, 2)
require.Equal(t, strings.Repeat("a", maxStreamDebugTextBytes), content[0].Text)
require.Equal(t, "tool-call", content[1].Type)
require.Equal(t, maxStreamDebugTextBytes, streamDebugBytes)
}
func TestWrapStreamSeq_SourceCountExcludesToolResults(t *testing.T) {
t.Parallel()
handle := &stepHandle{
stepCtx: &StepContext{StepID: uuid.New(), RunID: uuid.New(), ChatID: uuid.New()},
sink: &attemptSink{},
}
seq := wrapStreamSeq(context.Background(), handle, partsToSeq([]fantasy.StreamPart{
{Type: fantasy.StreamPartTypeToolResult, ID: "tool-1", ToolCallName: "search_docs"},
{Type: fantasy.StreamPartTypeSource, ID: "source-1", URL: "https://example.com", Title: "docs"},
{Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop},
}))
partCount := 0
for range seq {
partCount++
}
require.Equal(t, 3, partCount)
metadata, ok := handle.metadata.(map[string]any)
require.True(t, ok)
summary, ok := metadata["stream_summary"].(streamSummary)
require.True(t, ok)
require.Equal(t, 1, summary.SourceCount)
}
func TestWrapObjectStreamSeq_UsesStructuredOutputPayload(t *testing.T) {
t.Parallel()
handle := &stepHandle{
stepCtx: &StepContext{StepID: uuid.New(), RunID: uuid.New(), ChatID: uuid.New()},
sink: &attemptSink{},
}
usage := fantasy.Usage{InputTokens: 3, OutputTokens: 2, TotalTokens: 5}
seq := wrapObjectStreamSeq(context.Background(), handle, objectPartsToSeq([]fantasy.ObjectStreamPart{
{Type: fantasy.ObjectStreamPartTypeTextDelta, Delta: "ob"},
{Type: fantasy.ObjectStreamPartTypeTextDelta, Delta: "ject"},
{Type: fantasy.ObjectStreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop, Usage: usage},
}))
partCount := 0
for range seq {
partCount++
}
require.Equal(t, 3, partCount)
resp, ok := handle.response.(normalizedObjectResponsePayload)
require.True(t, ok)
require.Equal(t, normalizedObjectResponsePayload{
RawTextLength: len("object"),
FinishReason: string(fantasy.FinishReasonStop),
Usage: normalizeUsage(usage),
StructuredOutput: true,
}, resp)
}
func TestNormalizeResponse_UsesCanonicalToolTypes(t *testing.T) {
t.Parallel()
payload := normalizeResponse(&fantasy.Response{
Content: fantasy.ResponseContent{
fantasy.ToolCallContent{
ToolCallID: "call-calc",
ToolName: "calculator",
Input: `{"operation":"add","operands":[2,2]}`,
},
fantasy.ToolResultContent{
ToolCallID: "call-calc",
ToolName: "calculator",
Result: fantasy.ToolResultOutputContentText{Text: `{"sum":4}`},
},
},
})
require.Len(t, payload.Content, 2)
require.Equal(t, "tool-call", payload.Content[0].Type)
require.Equal(t, "tool-result", payload.Content[1].Type)
}
func TestBoundText_RespectsDocumentedRuneLimit(t *testing.T) {
t.Parallel()
runes := make([]rune, MaxMessagePartTextLength+5)
for i := range runes {
runes[i] = 'a'
}
input := string(runes)
got := boundText(input)
require.Equal(t, MaxMessagePartTextLength, len([]rune(got)))
require.Equal(t, '…', []rune(got)[len([]rune(got))-1])
}
func TestNormalizeToolResultOutput(t *testing.T) {
t.Parallel()
t.Run("TextValue", func(t *testing.T) {
t.Parallel()
got := normalizeToolResultOutput(fantasy.ToolResultOutputContentText{Text: "hello"})
require.Equal(t, "hello", got)
})
t.Run("TextPointer", func(t *testing.T) {
t.Parallel()
got := normalizeToolResultOutput(&fantasy.ToolResultOutputContentText{Text: "hello"})
require.Equal(t, "hello", got)
})
t.Run("TextPointerNil", func(t *testing.T) {
t.Parallel()
got := normalizeToolResultOutput((*fantasy.ToolResultOutputContentText)(nil))
require.Equal(t, "", got)
})
t.Run("ErrorValue", func(t *testing.T) {
t.Parallel()
got := normalizeToolResultOutput(fantasy.ToolResultOutputContentError{
Error: xerrors.New("tool failed"),
})
require.Equal(t, "tool failed", got)
})
t.Run("ErrorValueNilError", func(t *testing.T) {
t.Parallel()
got := normalizeToolResultOutput(fantasy.ToolResultOutputContentError{Error: nil})
require.Equal(t, "", got)
})
t.Run("ErrorPointer", func(t *testing.T) {
t.Parallel()
got := normalizeToolResultOutput(&fantasy.ToolResultOutputContentError{
Error: xerrors.New("ptr fail"),
})
require.Equal(t, "ptr fail", got)
})
t.Run("ErrorPointerNil", func(t *testing.T) {
t.Parallel()
got := normalizeToolResultOutput((*fantasy.ToolResultOutputContentError)(nil))
require.Equal(t, "", got)
})
t.Run("ErrorPointerNilError", func(t *testing.T) {
t.Parallel()
got := normalizeToolResultOutput(&fantasy.ToolResultOutputContentError{Error: nil})
require.Equal(t, "", got)
})
t.Run("MediaWithText", func(t *testing.T) {
t.Parallel()
got := normalizeToolResultOutput(fantasy.ToolResultOutputContentMedia{
Text: "caption",
MediaType: "image/png",
})
require.Equal(t, "caption", got)
})
t.Run("MediaWithoutText", func(t *testing.T) {
t.Parallel()
got := normalizeToolResultOutput(fantasy.ToolResultOutputContentMedia{
MediaType: "image/png",
})
require.Equal(t, "[media output: image/png]", got)
})
t.Run("MediaWithoutTextOrType", func(t *testing.T) {
t.Parallel()
got := normalizeToolResultOutput(fantasy.ToolResultOutputContentMedia{})
require.Equal(t, "[media output]", got)
})
t.Run("MediaPointerNil", func(t *testing.T) {
t.Parallel()
got := normalizeToolResultOutput((*fantasy.ToolResultOutputContentMedia)(nil))
require.Equal(t, "", got)
})
t.Run("MediaPointerWithText", func(t *testing.T) {
t.Parallel()
got := normalizeToolResultOutput(&fantasy.ToolResultOutputContentMedia{
Text: "ptr caption",
MediaType: "image/jpeg",
})
require.Equal(t, "ptr caption", got)
})
t.Run("NilOutput", func(t *testing.T) {
t.Parallel()
got := normalizeToolResultOutput(nil)
require.Equal(t, "", got)
})
t.Run("DefaultJSON", func(t *testing.T) {
t.Parallel()
// An unexpected type falls through to the default JSON
// marshal branch.
got := normalizeToolResultOutput(fantasy.ToolResultOutputContentText{
Text: "fallback",
})
require.Equal(t, "fallback", got)
})
}
func TestNormalizeResponse_PreservesToolCallArguments(t *testing.T) {
t.Parallel()
payload := normalizeResponse(&fantasy.Response{
Content: fantasy.ResponseContent{
fantasy.ToolCallContent{
ToolCallID: "call-calc",
ToolName: "calculator",
Input: `{"operation":"add","operands":[2,2]}`,
},
},
})
require.Len(t, payload.Content, 1)
require.Equal(t, "call-calc", payload.Content[0].ToolCallID)
require.Equal(t, "calculator", payload.Content[0].ToolName)
require.JSONEq(t,
`{"operation":"add","operands":[2,2]}`,
payload.Content[0].Arguments,
)
require.Equal(t, len(`{"operation":"add","operands":[2,2]}`), payload.Content[0].InputLength)
}
-225
View File
@@ -1,225 +0,0 @@
package chatdebug
import (
"context"
"regexp"
"strings"
"sync"
"sync/atomic"
"unicode/utf8"
"github.com/google/uuid"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/pubsub"
)
// This branch-02 compatibility shim forward-declares recorder, service,
// and summary symbols that land in later stacked branches. Delete this
// file once recorder.go, service.go, and summary.go are available here.
// RecorderOptions identifies the chat/model context for debug recording.
type RecorderOptions struct {
ChatID uuid.UUID
OwnerID uuid.UUID
Provider string
Model string
}
// Service is a placeholder for the later chat debug persistence service.
type Service struct{}
// NewService constructs the branch-02 placeholder chat debug service.
func NewService(_ database.Store, _ slog.Logger, _ pubsub.Pubsub) *Service {
return &Service{}
}
type attemptSink struct{}
type attemptSinkKey struct{}
func withAttemptSink(ctx context.Context, sink *attemptSink) context.Context {
if sink == nil {
panic("chatdebug: nil attemptSink")
}
return context.WithValue(ctx, attemptSinkKey{}, sink)
}
func attemptSinkFromContext(ctx context.Context) *attemptSink {
sink, _ := ctx.Value(attemptSinkKey{}).(*attemptSink)
return sink
}
var stepCounters sync.Map // map[uuid.UUID]*atomic.Int32
// runRefCounts tracks how many live RunContext instances reference each
// RunID. Cleanup of shared state (step counters) is deferred until the
// last RunContext for a given RunID is garbage collected.
var runRefCounts sync.Map // map[uuid.UUID]*atomic.Int32
func trackRunRef(runID uuid.UUID) {
val, _ := runRefCounts.LoadOrStore(runID, &atomic.Int32{})
counter := val.(*atomic.Int32)
counter.Add(1)
}
// releaseRunRef decrements the reference count for runID and cleans up
// shared state when the last reference is released.
func releaseRunRef(runID uuid.UUID) {
val, ok := runRefCounts.Load(runID)
if !ok {
return
}
counter := val.(*atomic.Int32)
if counter.Add(-1) <= 0 {
runRefCounts.Delete(runID)
stepCounters.Delete(runID)
}
}
func nextStepNumber(runID uuid.UUID) int32 {
val, _ := stepCounters.LoadOrStore(runID, &atomic.Int32{})
counter, ok := val.(*atomic.Int32)
if !ok {
panic("chatdebug: invalid step counter type")
}
return counter.Add(1)
}
// CleanupStepCounter removes per-run step counter and reference count
// state. This is used by tests and later stacked branches that have a
// real run lifecycle.
func CleanupStepCounter(runID uuid.UUID) {
stepCounters.Delete(runID)
runRefCounts.Delete(runID)
}
type stepHandle struct {
stepCtx *StepContext
sink *attemptSink
mu sync.Mutex
status Status
response any
usage any
err any
metadata any
}
func beginStep(
ctx context.Context,
svc *Service,
opts RecorderOptions,
op Operation,
_ any,
) (*stepHandle, context.Context) {
if svc == nil {
return nil, ctx
}
rc, ok := RunFromContext(ctx)
if !ok || rc.RunID == uuid.Nil {
return nil, ctx
}
if holder, reuseStep := reuseHolderFromContext(ctx); reuseStep {
holder.mu.Lock()
defer holder.mu.Unlock()
// Only reuse the cached handle if it belongs to the same run.
// A different RunContext means a new logical run, so we must
// create a fresh step to avoid cross-run attribution.
if holder.handle != nil && holder.handle.stepCtx.RunID == rc.RunID {
enriched := ContextWithStep(ctx, holder.handle.stepCtx)
enriched = withAttemptSink(enriched, holder.handle.sink)
return holder.handle, enriched
}
handle, enriched := newStepHandle(ctx, rc, opts, op)
holder.handle = handle
return handle, enriched
}
return newStepHandle(ctx, rc, opts, op)
}
func newStepHandle(
ctx context.Context,
rc *RunContext,
opts RecorderOptions,
op Operation,
) (*stepHandle, context.Context) {
if rc == nil || rc.RunID == uuid.Nil {
return nil, ctx
}
chatID := opts.ChatID
if chatID == uuid.Nil {
chatID = rc.ChatID
}
handle := &stepHandle{
stepCtx: &StepContext{
StepID: uuid.New(),
RunID: rc.RunID,
ChatID: chatID,
StepNumber: nextStepNumber(rc.RunID),
Operation: op,
HistoryTipMessageID: rc.HistoryTipMessageID,
},
sink: &attemptSink{},
}
enriched := ContextWithStep(ctx, handle.stepCtx)
enriched = withAttemptSink(enriched, handle.sink)
return handle, enriched
}
func (h *stepHandle) finish(
_ context.Context,
status Status,
response any,
usage any,
err any,
metadata any,
) {
if h == nil || h.stepCtx == nil {
return
}
// Guard with a mutex so concurrent callers (e.g. retried stream
// wrappers sharing a reused handle) don't race. Unlike sync.Once,
// later retries are allowed to overwrite earlier failure results so
// the step reflects the final outcome.
h.mu.Lock()
defer h.mu.Unlock()
h.status = status
h.response = response
h.usage = usage
h.err = err
h.metadata = metadata
}
// whitespaceRun matches one or more consecutive whitespace characters.
var whitespaceRun = regexp.MustCompile(`\s+`)
// TruncateLabel whitespace-normalizes and truncates text to maxLen runes.
// Returns "" if input is empty or whitespace-only.
func TruncateLabel(text string, maxLen int) string {
if maxLen < 0 {
maxLen = 0
}
normalized := strings.TrimSpace(whitespaceRun.ReplaceAllString(text, " "))
if normalized == "" || maxLen == 0 {
return ""
}
if utf8.RuneCountInString(normalized) <= maxLen {
return normalized
}
if maxLen == 1 {
return "…"
}
// Truncate to leave room for the trailing ellipsis within maxLen.
runes := []rune(normalized)
return string(runes[:maxLen-1]) + "…"
}
@@ -1,90 +0,0 @@
package chatdebug
import (
"context"
"net/http"
"testing"
"unicode/utf8"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
)
func TestBeginStep_SkipsNilRunID(t *testing.T) {
t.Parallel()
ctx := ContextWithRun(context.Background(), &RunContext{ChatID: uuid.New()})
handle, enriched := beginStep(ctx, &Service{}, RecorderOptions{ChatID: uuid.New()}, OperationGenerate, nil)
require.Nil(t, handle)
require.Equal(t, ctx, enriched)
}
func TestNewStepHandle_SkipsNilRunID(t *testing.T) {
t.Parallel()
ctx := context.Background()
handle, enriched := newStepHandle(ctx, &RunContext{ChatID: uuid.New()}, RecorderOptions{ChatID: uuid.New()}, OperationGenerate)
require.Nil(t, handle)
require.Equal(t, ctx, enriched)
}
func TestTruncateLabel(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input string
maxLen int
want string
}{
{name: "Empty", input: "", maxLen: 10, want: ""},
{name: "WhitespaceOnly", input: " \t\n ", maxLen: 10, want: ""},
{name: "ShortText", input: "hello world", maxLen: 20, want: "hello world"},
{name: "ExactLength", input: "abcde", maxLen: 5, want: "abcde"},
{name: "LongTextTruncated", input: "abcdefghij", maxLen: 5, want: "abcd…"},
{name: "NegativeMaxLen", input: "hello", maxLen: -1, want: ""},
{name: "ZeroMaxLen", input: "hello", maxLen: 0, want: ""},
{name: "SingleRuneLimit", input: "hello", maxLen: 1, want: "…"},
{name: "MultipleWhitespaceRuns", input: " hello world \t again ", maxLen: 100, want: "hello world again"},
{name: "UnicodeRunes", input: "こんにちは世界", maxLen: 3, want: "こん…"},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
got := TruncateLabel(tc.input, tc.maxLen)
require.Equal(t, tc.want, got)
require.LessOrEqual(t, utf8.RuneCountInString(got), maxInt(tc.maxLen, 0))
})
}
}
func maxInt(a, b int) int {
if a > b {
return a
}
return b
}
// RedactedValue replaces sensitive values in debug payloads.
const RedactedValue = "[REDACTED]"
// RecordingTransport is the branch-02 placeholder HTTP recording transport.
type RecordingTransport struct {
Base http.RoundTripper
}
var _ http.RoundTripper = (*RecordingTransport)(nil)
func (t *RecordingTransport) RoundTrip(req *http.Request) (*http.Response, error) {
if req == nil {
panic("chatdebug: nil request")
}
base := t.Base
if base == nil {
base = http.DefaultTransport
}
return base.RoundTrip(req)
}
-137
View File
@@ -1,137 +0,0 @@
package chatdebug
import "github.com/google/uuid"
// RunKind identifies the kind of debug run being recorded.
type RunKind string
const (
// KindChatTurn records a standard chat turn.
KindChatTurn RunKind = "chat_turn"
// KindTitleGeneration records title generation for a chat.
KindTitleGeneration RunKind = "title_generation"
// KindQuickgen records quick-generation workflows.
KindQuickgen RunKind = "quickgen"
// KindCompaction records history compaction workflows.
KindCompaction RunKind = "compaction"
)
// AllRunKinds contains every RunKind value. Update this when
// adding new constants above.
var AllRunKinds = []RunKind{
KindChatTurn,
KindTitleGeneration,
KindQuickgen,
KindCompaction,
}
// Status identifies lifecycle state shared by runs and steps.
type Status string
const (
// StatusInProgress indicates work is still running.
StatusInProgress Status = "in_progress"
// StatusCompleted indicates work finished successfully.
StatusCompleted Status = "completed"
// StatusError indicates work finished with an error.
StatusError Status = "error"
// StatusInterrupted indicates work was canceled or interrupted.
StatusInterrupted Status = "interrupted"
)
// AllStatuses contains every Status value. Update this when
// adding new constants above.
var AllStatuses = []Status{
StatusInProgress,
StatusCompleted,
StatusError,
StatusInterrupted,
}
// Operation identifies the model operation a step performed.
type Operation string
const (
// OperationStream records a streaming model operation.
OperationStream Operation = "stream"
// OperationGenerate records a non-streaming generation operation.
OperationGenerate Operation = "generate"
)
// AllOperations contains every Operation value. Update this when
// adding new constants above.
var AllOperations = []Operation{
OperationStream,
OperationGenerate,
}
// RunContext carries identity and metadata for a debug run.
type RunContext struct {
RunID uuid.UUID
ChatID uuid.UUID
RootChatID uuid.UUID // Zero means not set.
ParentChatID uuid.UUID // Zero means not set.
ModelConfigID uuid.UUID // Zero means not set.
TriggerMessageID int64 // Zero means not set.
HistoryTipMessageID int64 // Zero means not set.
Kind RunKind
Provider string
Model string
}
// StepContext carries identity and metadata for a debug step.
type StepContext struct {
StepID uuid.UUID
RunID uuid.UUID
ChatID uuid.UUID
StepNumber int32
Operation Operation
HistoryTipMessageID int64 // Zero means not set.
}
// Attempt captures a single HTTP round trip made during a step.
type Attempt struct {
Number int `json:"number"`
Status string `json:"status,omitempty"`
Method string `json:"method,omitempty"`
URL string `json:"url,omitempty"`
Path string `json:"path,omitempty"`
StartedAt string `json:"started_at,omitempty"`
FinishedAt string `json:"finished_at,omitempty"`
RequestHeaders map[string]string `json:"request_headers,omitempty"`
RequestBody []byte `json:"request_body,omitempty"`
ResponseStatus int `json:"response_status,omitempty"`
ResponseHeaders map[string]string `json:"response_headers,omitempty"`
ResponseBody []byte `json:"response_body,omitempty"`
Error string `json:"error,omitempty"`
DurationMs int64 `json:"duration_ms"`
RetryClassification string `json:"retry_classification,omitempty"`
RetryDelayMs int64 `json:"retry_delay_ms,omitempty"`
}
// EventKind identifies the type of pubsub debug event.
type EventKind string
const (
// EventKindRunUpdate publishes a run mutation.
EventKindRunUpdate EventKind = "run_update"
// EventKindStepUpdate publishes a step mutation.
EventKindStepUpdate EventKind = "step_update"
// EventKindFinalize publishes a finalization signal.
EventKindFinalize EventKind = "finalize"
// EventKindDelete publishes a deletion signal.
EventKindDelete EventKind = "delete"
)
// DebugEvent is the lightweight pubsub envelope for chat debug updates.
type DebugEvent struct {
Kind EventKind `json:"kind"`
ChatID uuid.UUID `json:"chat_id"`
RunID uuid.UUID `json:"run_id"`
StepID uuid.UUID `json:"step_id"`
}
// PubsubChannel returns the chat-scoped pubsub channel for debug events.
func PubsubChannel(chatID uuid.UUID) string {
return "chat_debug:" + chatID.String()
}
-54
View File
@@ -1,54 +0,0 @@
package chatdebug_test
import (
"testing"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/x/chatd/chatdebug"
"github.com/coder/coder/v2/codersdk"
)
// toStrings converts a typed string slice to []string for comparison.
func toStrings[T ~string](values []T) []string {
out := make([]string, len(values))
for i, v := range values {
out[i] = string(v)
}
return out
}
// TestTypesMatchSDK verifies that every chatdebug constant has a
// corresponding codersdk constant with the same string value.
// If this test fails you probably added a constant to one package
// but forgot to update the other.
func TestTypesMatchSDK(t *testing.T) {
t.Parallel()
t.Run("RunKind", func(t *testing.T) {
t.Parallel()
require.ElementsMatch(t,
toStrings(chatdebug.AllRunKinds),
toStrings(codersdk.AllChatDebugRunKinds),
"chatdebug.AllRunKinds and codersdk.AllChatDebugRunKinds have diverged",
)
})
t.Run("Status", func(t *testing.T) {
t.Parallel()
require.ElementsMatch(t,
toStrings(chatdebug.AllStatuses),
toStrings(codersdk.AllChatDebugStatuses),
"chatdebug.AllStatuses and codersdk.AllChatDebugStatuses have diverged",
)
})
t.Run("Operation", func(t *testing.T) {
t.Parallel()
require.ElementsMatch(t,
toStrings(chatdebug.AllOperations),
toStrings(codersdk.AllChatDebugStepOperations),
"chatdebug.AllOperations and codersdk.AllChatDebugStepOperations have diverged",
)
})
}
+94 -49
View File
@@ -18,7 +18,6 @@ import (
"github.com/coder/coder/v2/coderd/x/chatd/chaterror"
"github.com/coder/coder/v2/coderd/x/chatd/chatretry"
"github.com/coder/coder/v2/coderd/x/chatd/chattest"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/testutil"
"github.com/coder/quartz"
@@ -42,9 +41,9 @@ func TestRun_ActiveToolsPrepareBehavior(t *testing.T) {
t.Parallel()
var capturedCall fantasy.Call
model := &chattest.FakeModel{
ProviderName: fantasyanthropic.Name,
StreamFn: func(_ context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
model := &loopTestModel{
provider: fantasyanthropic.Name,
streamFn: func(_ context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
capturedCall = call
return streamFromParts([]fantasy.StreamPart{
{Type: fantasy.StreamPartTypeTextStart, ID: "text-1"},
@@ -104,9 +103,9 @@ func TestRun_ActiveToolsPrepareBehavior(t *testing.T) {
func TestProcessStepStream_AnthropicUsageMatchesFinalDelta(t *testing.T) {
t.Parallel()
model := &chattest.FakeModel{
ProviderName: fantasyanthropic.Name,
StreamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
model := &loopTestModel{
provider: fantasyanthropic.Name,
streamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
return streamFromParts([]fantasy.StreamPart{
{Type: fantasy.StreamPartTypeTextStart, ID: "text-1"},
{Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "cached response"},
@@ -161,9 +160,9 @@ func TestRun_OnRetryEnrichesProvider(t *testing.T) {
var records []retryRecord
calls := 0
model := &chattest.FakeModel{
ProviderName: "openai",
StreamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
model := &loopTestModel{
provider: "openai",
streamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
calls++
if calls == 1 {
return nil, xerrors.New("received status 429 from upstream")
@@ -287,9 +286,9 @@ func TestRun_RetriesStartupTimeoutWhileOpeningStream(t *testing.T) {
attempts := 0
attemptCause := make(chan error, 1)
var retries []chatretry.ClassifiedError
model := &chattest.FakeModel{
ProviderName: "openai",
StreamFn: func(ctx context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
model := &loopTestModel{
provider: "openai",
streamFn: func(ctx context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
attempts++
if attempts == 1 {
<-ctx.Done()
@@ -365,9 +364,9 @@ func TestRun_RetriesStartupTimeoutBeforeFirstPart(t *testing.T) {
attempts := 0
attemptCause := make(chan error, 1)
var retries []chatretry.ClassifiedError
model := &chattest.FakeModel{
ProviderName: "openai",
StreamFn: func(ctx context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
model := &loopTestModel{
provider: "openai",
streamFn: func(ctx context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
attempts++
if attempts == 1 {
return iter.Seq[fantasy.StreamPart](func(yield func(fantasy.StreamPart) bool) {
@@ -448,9 +447,9 @@ func TestRun_FirstPartDisarmsStartupTimeout(t *testing.T) {
retried := false
firstPartYielded := make(chan struct{}, 1)
continueStream := make(chan struct{})
model := &chattest.FakeModel{
ProviderName: "openai",
StreamFn: func(ctx context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
model := &loopTestModel{
provider: "openai",
streamFn: func(ctx context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
attempts++
return iter.Seq[fantasy.StreamPart](func(yield func(fantasy.StreamPart) bool) {
if !yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeTextStart, ID: "text-1"}) {
@@ -527,9 +526,9 @@ func TestRun_PanicInPublishMessagePartReleasesAttempt(t *testing.T) {
t.Parallel()
attemptReleased := make(chan struct{})
model := &chattest.FakeModel{
ProviderName: "openai",
StreamFn: func(ctx context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
model := &loopTestModel{
provider: "openai",
streamFn: func(ctx context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
go func() {
<-ctx.Done()
close(attemptReleased)
@@ -584,9 +583,9 @@ func TestRun_RetriesStartupTimeoutWhenStreamClosesSilently(t *testing.T) {
attempts := 0
attemptCause := make(chan error, 1)
var retries []chatretry.ClassifiedError
model := &chattest.FakeModel{
ProviderName: "openai",
StreamFn: func(ctx context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
model := &loopTestModel{
provider: "openai",
streamFn: func(ctx context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
attempts++
if attempts == 1 {
return iter.Seq[fantasy.StreamPart](func(yield func(fantasy.StreamPart) bool) {
@@ -649,9 +648,9 @@ func TestRun_InterruptedStepPersistsSyntheticToolResult(t *testing.T) {
t.Parallel()
started := make(chan struct{})
model := &chattest.FakeModel{
ProviderName: "fake",
StreamFn: func(ctx context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
model := &loopTestModel{
provider: "fake",
streamFn: func(ctx context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
return iter.Seq[fantasy.StreamPart](func(yield func(fantasy.StreamPart) bool) {
parts := []fantasy.StreamPart{
{
@@ -763,6 +762,52 @@ func TestRun_InterruptedStepPersistsSyntheticToolResult(t *testing.T) {
"interrupted tool should have no call timestamp (never reached StreamPartTypeToolCall)")
}
type loopTestModel struct {
provider string
model string
generateFn func(context.Context, fantasy.Call) (*fantasy.Response, error)
streamFn func(context.Context, fantasy.Call) (fantasy.StreamResponse, error)
}
func (m *loopTestModel) Provider() string {
if m.provider != "" {
return m.provider
}
return "fake"
}
func (m *loopTestModel) Model() string {
if m.model != "" {
return m.model
}
return "fake"
}
func (m *loopTestModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) {
if m.generateFn != nil {
return m.generateFn(ctx, call)
}
return &fantasy.Response{}, nil
}
func (m *loopTestModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
if m.streamFn != nil {
return m.streamFn(ctx, call)
}
return streamFromParts([]fantasy.StreamPart{{
Type: fantasy.StreamPartTypeFinish,
FinishReason: fantasy.FinishReasonStop,
}}), nil
}
func (*loopTestModel) GenerateObject(context.Context, fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
return nil, xerrors.New("not implemented")
}
func (*loopTestModel) StreamObject(context.Context, fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) {
return nil, xerrors.New("not implemented")
}
func streamFromParts(parts []fantasy.StreamPart) fantasy.StreamResponse {
return iter.Seq[fantasy.StreamPart](func(yield func(fantasy.StreamPart) bool) {
for _, part := range parts {
@@ -815,9 +860,9 @@ func TestRun_MultiStepToolExecution(t *testing.T) {
var streamCalls int
var secondCallPrompt []fantasy.Message
model := &chattest.FakeModel{
ProviderName: "fake",
StreamFn: func(_ context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
model := &loopTestModel{
provider: "fake",
streamFn: func(_ context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
mu.Lock()
step := streamCalls
streamCalls++
@@ -927,9 +972,9 @@ func TestRun_ParallelToolExecutionTimestamps(t *testing.T) {
var mu sync.Mutex
var streamCalls int
model := &chattest.FakeModel{
ProviderName: "fake",
StreamFn: func(_ context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
model := &loopTestModel{
provider: "fake",
streamFn: func(_ context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
mu.Lock()
step := streamCalls
streamCalls++
@@ -1019,9 +1064,9 @@ func TestRun_ParallelToolExecutionTimestamps(t *testing.T) {
func TestRun_PersistStepErrorPropagates(t *testing.T) {
t.Parallel()
model := &chattest.FakeModel{
ProviderName: "fake",
StreamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
model := &loopTestModel{
provider: "fake",
streamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
return streamFromParts([]fantasy.StreamPart{
{Type: fantasy.StreamPartTypeTextStart, ID: "text-1"},
{Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "hello"},
@@ -1058,9 +1103,9 @@ func TestRun_ShutdownDuringToolExecutionReturnsContextCanceled(t *testing.T) {
toolStarted := make(chan struct{})
// Model returns a single tool call, then finishes.
model := &chattest.FakeModel{
ProviderName: "fake",
StreamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
model := &loopTestModel{
provider: "fake",
streamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
return streamFromParts([]fantasy.StreamPart{
{Type: fantasy.StreamPartTypeToolInputStart, ID: "tc-block", ToolCallName: "blocking_tool"},
{Type: fantasy.StreamPartTypeToolInputDelta, ID: "tc-block", Delta: `{}`},
@@ -1316,9 +1361,9 @@ func TestRun_InterruptedDuringToolExecutionPersistsStep(t *testing.T) {
toolStarted := make(chan struct{})
// Model returns a completed tool call in the stream.
model := &chattest.FakeModel{
ProviderName: "fake",
StreamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
model := &loopTestModel{
provider: "fake",
streamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
return streamFromParts([]fantasy.StreamPart{
{Type: fantasy.StreamPartTypeTextStart, ID: "text-1"},
{Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "calling tool"},
@@ -1426,9 +1471,9 @@ func TestRun_InterruptedDuringToolExecutionPersistsStep(t *testing.T) {
func TestRun_ProviderExecutedToolResultTimestamps(t *testing.T) {
t.Parallel()
model := &chattest.FakeModel{
ProviderName: "fake",
StreamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
model := &loopTestModel{
provider: "fake",
streamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
// Simulate a provider-executed tool call and result
// (e.g. Anthropic web search) followed by a text
// response — all in a single stream.
@@ -1496,9 +1541,9 @@ func TestRun_ProviderExecutedToolResultTimestamps(t *testing.T) {
func TestRun_PersistStepInterruptedFallback(t *testing.T) {
t.Parallel()
model := &chattest.FakeModel{
ProviderName: "fake",
StreamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
model := &loopTestModel{
provider: "fake",
streamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
return streamFromParts([]fantasy.StreamPart{
{Type: fantasy.StreamPartTypeTextStart, ID: "text-1"},
{Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "hello world"},
+35 -36
View File
@@ -9,7 +9,6 @@ import (
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/x/chatd/chattest"
"github.com/coder/coder/v2/codersdk"
)
@@ -23,9 +22,9 @@ func TestRun_Compaction(t *testing.T) {
var persistedCompaction CompactionResult
const summaryText = "summary text for compaction"
model := &chattest.FakeModel{
ProviderName: "fake",
StreamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
model := &loopTestModel{
provider: "fake",
streamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
return streamFromParts([]fantasy.StreamPart{
{Type: fantasy.StreamPartTypeTextStart, ID: "text-1"},
{Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "done"},
@@ -40,7 +39,7 @@ func TestRun_Compaction(t *testing.T) {
},
}), nil
},
GenerateFn: func(_ context.Context, call fantasy.Call) (*fantasy.Response, error) {
generateFn: func(_ context.Context, call fantasy.Call) (*fantasy.Response, error) {
require.NotEmpty(t, call.Prompt)
lastPrompt := call.Prompt[len(call.Prompt)-1]
require.Equal(t, fantasy.MessageRoleUser, lastPrompt.Role)
@@ -108,9 +107,9 @@ func TestRun_Compaction(t *testing.T) {
// and the tool-result part publishes after Persist.
var callOrder []string
model := &chattest.FakeModel{
ProviderName: "fake",
StreamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
model := &loopTestModel{
provider: "fake",
streamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
return streamFromParts([]fantasy.StreamPart{
{Type: fantasy.StreamPartTypeTextStart, ID: "text-1"},
{Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "done"},
@@ -125,7 +124,7 @@ func TestRun_Compaction(t *testing.T) {
},
}), nil
},
GenerateFn: func(_ context.Context, _ fantasy.Call) (*fantasy.Response, error) {
generateFn: func(_ context.Context, _ fantasy.Call) (*fantasy.Response, error) {
callOrder = append(callOrder, "generate")
return &fantasy.Response{
Content: []fantasy.Content{
@@ -190,9 +189,9 @@ func TestRun_Compaction(t *testing.T) {
publishCalled := false
model := &chattest.FakeModel{
ProviderName: "fake",
StreamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
model := &loopTestModel{
provider: "fake",
streamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
return streamFromParts([]fantasy.StreamPart{
{
Type: fantasy.StreamPartTypeFinish,
@@ -241,9 +240,9 @@ func TestRun_Compaction(t *testing.T) {
const summaryText = "compacted summary"
model := &chattest.FakeModel{
ProviderName: "fake",
StreamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
model := &loopTestModel{
provider: "fake",
streamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
mu.Lock()
step := streamCallCount
streamCallCount++
@@ -288,7 +287,7 @@ func TestRun_Compaction(t *testing.T) {
}), nil
}
},
GenerateFn: func(_ context.Context, _ fantasy.Call) (*fantasy.Response, error) {
generateFn: func(_ context.Context, _ fantasy.Call) (*fantasy.Response, error) {
return &fantasy.Response{
Content: []fantasy.Content{
fantasy.TextContent{Text: summaryText},
@@ -347,9 +346,9 @@ func TestRun_Compaction(t *testing.T) {
const summaryText = "compacted summary for skip test"
model := &chattest.FakeModel{
ProviderName: "fake",
StreamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
model := &loopTestModel{
provider: "fake",
streamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
mu.Lock()
step := streamCallCount
streamCallCount++
@@ -394,7 +393,7 @@ func TestRun_Compaction(t *testing.T) {
}), nil
}
},
GenerateFn: func(_ context.Context, _ fantasy.Call) (*fantasy.Response, error) {
generateFn: func(_ context.Context, _ fantasy.Call) (*fantasy.Response, error) {
return &fantasy.Response{
Content: []fantasy.Content{
fantasy.TextContent{Text: summaryText},
@@ -443,9 +442,9 @@ func TestRun_Compaction(t *testing.T) {
t.Run("ErrorsAreReported", func(t *testing.T) {
t.Parallel()
model := &chattest.FakeModel{
ProviderName: "fake",
StreamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
model := &loopTestModel{
provider: "fake",
streamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
return streamFromParts([]fantasy.StreamPart{
{
Type: fantasy.StreamPartTypeFinish,
@@ -456,7 +455,7 @@ func TestRun_Compaction(t *testing.T) {
},
}), nil
},
GenerateFn: func(_ context.Context, _ fantasy.Call) (*fantasy.Response, error) {
generateFn: func(_ context.Context, _ fantasy.Call) (*fantasy.Response, error) {
return nil, xerrors.New("generate failed")
},
}
@@ -512,9 +511,9 @@ func TestRun_Compaction(t *testing.T) {
textMessage(fantasy.MessageRoleUser, "compacted user"),
}
model := &chattest.FakeModel{
ProviderName: "fake",
StreamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
model := &loopTestModel{
provider: "fake",
streamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
mu.Lock()
step := streamCallCount
streamCallCount++
@@ -557,7 +556,7 @@ func TestRun_Compaction(t *testing.T) {
}), nil
}
},
GenerateFn: func(_ context.Context, _ fantasy.Call) (*fantasy.Response, error) {
generateFn: func(_ context.Context, _ fantasy.Call) (*fantasy.Response, error) {
return &fantasy.Response{
Content: []fantasy.Content{
fantasy.TextContent{Text: summaryText},
@@ -618,9 +617,9 @@ func TestRun_Compaction(t *testing.T) {
const summaryText = "post-run compacted summary"
model := &chattest.FakeModel{
ProviderName: "fake",
StreamFn: func(_ context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
model := &loopTestModel{
provider: "fake",
streamFn: func(_ context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
mu.Lock()
step := streamCallCount
streamCallCount++
@@ -660,7 +659,7 @@ func TestRun_Compaction(t *testing.T) {
}), nil
}
},
GenerateFn: func(_ context.Context, _ fantasy.Call) (*fantasy.Response, error) {
generateFn: func(_ context.Context, _ fantasy.Call) (*fantasy.Response, error) {
return &fantasy.Response{
Content: []fantasy.Content{
fantasy.TextContent{Text: summaryText},
@@ -724,9 +723,9 @@ func TestRun_Compaction(t *testing.T) {
// The LLM calls a dynamic tool. Usage is above the
// compaction threshold so compaction should fire even
// though the chatloop exits via ErrDynamicToolCall.
model := &chattest.FakeModel{
ProviderName: "fake",
StreamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
model := &loopTestModel{
provider: "fake",
streamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
return streamFromParts([]fantasy.StreamPart{
{Type: fantasy.StreamPartTypeToolInputStart, ID: "tc-1", ToolCallName: "my_dynamic_tool"},
{Type: fantasy.StreamPartTypeToolInputDelta, ID: "tc-1", Delta: `{"query": "test"}`},
@@ -747,7 +746,7 @@ func TestRun_Compaction(t *testing.T) {
},
}), nil
},
GenerateFn: func(_ context.Context, _ fantasy.Call) (*fantasy.Response, error) {
generateFn: func(_ context.Context, _ fantasy.Call) (*fantasy.Response, error) {
return &fantasy.Response{
Content: []fantasy.Content{
fantasy.TextContent{Text: summaryText},
-52
View File
@@ -1,52 +0,0 @@
package chattest
import (
"context"
"charm.land/fantasy"
)
// FakeModel is a configurable test double for fantasy.LanguageModel.
// When a method function is nil, the method returns a safe empty
// response.
type FakeModel struct {
ProviderName string
ModelName string
GenerateFn func(context.Context, fantasy.Call) (*fantasy.Response, error)
StreamFn func(context.Context, fantasy.Call) (fantasy.StreamResponse, error)
GenerateObjectFn func(context.Context, fantasy.ObjectCall) (*fantasy.ObjectResponse, error)
StreamObjectFn func(context.Context, fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error)
}
var _ fantasy.LanguageModel = (*FakeModel)(nil)
func (m *FakeModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) {
if m.GenerateFn == nil {
return &fantasy.Response{}, nil
}
return m.GenerateFn(ctx, call)
}
func (m *FakeModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
if m.StreamFn == nil {
return fantasy.StreamResponse(func(func(fantasy.StreamPart) bool) {}), nil
}
return m.StreamFn(ctx, call)
}
func (m *FakeModel) GenerateObject(ctx context.Context, call fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
if m.GenerateObjectFn == nil {
return &fantasy.ObjectResponse{}, nil
}
return m.GenerateObjectFn(ctx, call)
}
func (m *FakeModel) StreamObject(ctx context.Context, call fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) {
if m.StreamObjectFn == nil {
return fantasy.ObjectStreamResponse(func(func(fantasy.ObjectStreamPart) bool) {}), nil
}
return m.StreamObjectFn(ctx, call)
}
func (m *FakeModel) Provider() string { return m.ProviderName }
func (m *FakeModel) Model() string { return m.ModelName }
+56 -9
View File
@@ -10,9 +10,9 @@ import (
"charm.land/fantasy"
"github.com/sqlc-dev/pqtype"
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/x/chatd/chattest"
"github.com/coder/coder/v2/codersdk"
)
@@ -375,8 +375,8 @@ func Test_generateManualTitle_UsesTimeout(t *testing.T) {
),
}
model := &chattest.FakeModel{
GenerateObjectFn: func(ctx context.Context, call fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
model := &stubModel{
generateObjectFn: func(ctx context.Context, call fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
deadline, ok := ctx.Deadline()
require.True(t, ok, "manual title generation should set a deadline")
require.WithinDuration(
@@ -413,8 +413,8 @@ func Test_generateManualTitle_TruncatesFirstUserInput(t *testing.T) {
),
}
model := &chattest.FakeModel{
GenerateObjectFn: func(_ context.Context, call fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
model := &stubModel{
generateObjectFn: func(_ context.Context, call fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
require.Len(t, call.Prompt, 2)
systemText, ok := call.Prompt[0].Content[0].(fantasy.TextPart)
require.True(t, ok)
@@ -447,8 +447,8 @@ func Test_generateManualTitle_ReturnsUsageForEmptyNormalizedTitle(t *testing.T)
),
}
model := &chattest.FakeModel{
GenerateObjectFn: func(_ context.Context, _ fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
model := &stubModel{
generateObjectFn: func(_ context.Context, _ fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
return &fantasy.ObjectResponse{
Object: map[string]any{"title": "\"\""},
Usage: fantasy.Usage{
@@ -504,8 +504,8 @@ func Test_selectPreferredConfiguredShortTextModelConfig(t *testing.T) {
func Test_generateShortText_NormalizesQuotedOutput(t *testing.T) {
t.Parallel()
model := &chattest.FakeModel{
GenerateFn: func(_ context.Context, _ fantasy.Call) (*fantasy.Response, error) {
model := &stubModel{
generateFn: func(_ context.Context, _ fantasy.Call) (*fantasy.Response, error) {
return &fantasy.Response{
Content: fantasy.ResponseContent{
fantasy.TextContent{Text: " \"Quoted summary\" "},
@@ -520,6 +520,53 @@ func Test_generateShortText_NormalizesQuotedOutput(t *testing.T) {
require.Equal(t, "Quoted summary", text)
}
type stubModel struct {
generateFn func(context.Context, fantasy.Call) (*fantasy.Response, error)
generateObjectFn func(context.Context, fantasy.ObjectCall) (*fantasy.ObjectResponse, error)
}
func (m *stubModel) Generate(
ctx context.Context,
call fantasy.Call,
) (*fantasy.Response, error) {
if m.generateFn == nil {
return nil, xerrors.New("generate not implemented")
}
return m.generateFn(ctx, call)
}
func (*stubModel) Stream(
context.Context,
fantasy.Call,
) (fantasy.StreamResponse, error) {
return nil, xerrors.New("stream not implemented")
}
func (m *stubModel) GenerateObject(
ctx context.Context,
call fantasy.ObjectCall,
) (*fantasy.ObjectResponse, error) {
if m.generateObjectFn == nil {
return nil, xerrors.New("generate object not implemented")
}
return m.generateObjectFn(ctx, call)
}
func (*stubModel) StreamObject(
context.Context,
fantasy.ObjectCall,
) (fantasy.ObjectStreamResponse, error) {
return nil, xerrors.New("stream object not implemented")
}
func (*stubModel) Provider() string {
return "test"
}
func (*stubModel) Model() string {
return "test"
}
func mustChatMessage(
t *testing.T,
role database.ChatMessageRole,
-142
View File
@@ -547,148 +547,6 @@ type UpdateChatDesktopEnabledRequest struct {
EnableDesktop bool `json:"enable_desktop"`
}
// ChatDebugLoggingAdminSettings describes the runtime admin setting
// that allows users to opt into chat debug logging.
type ChatDebugLoggingAdminSettings struct {
AllowUsers bool `json:"allow_users"`
ForcedByDeployment bool `json:"forced_by_deployment"`
}
// UserChatDebugLoggingSettings describes whether debug logging is
// active for the current user and whether the user may control it.
type UserChatDebugLoggingSettings struct {
DebugLoggingEnabled bool `json:"debug_logging_enabled"`
UserToggleAllowed bool `json:"user_toggle_allowed"`
ForcedByDeployment bool `json:"forced_by_deployment"`
}
// UpdateChatDebugLoggingAllowUsersRequest is the admin request to
// toggle whether users may opt into chat debug logging.
type UpdateChatDebugLoggingAllowUsersRequest struct {
AllowUsers bool `json:"allow_users"`
}
// UpdateUserChatDebugLoggingRequest is the per-user request to
// opt into or out of chat debug logging.
type UpdateUserChatDebugLoggingRequest struct {
DebugLoggingEnabled bool `json:"debug_logging_enabled"`
}
// ChatDebugStatus enumerates the lifecycle states shared by debug
// runs and steps. These values must match the literals used in
// FinalizeStaleChatDebugRows and all insert/update callers.
type ChatDebugStatus string
const (
ChatDebugStatusInProgress ChatDebugStatus = "in_progress"
ChatDebugStatusCompleted ChatDebugStatus = "completed"
ChatDebugStatusError ChatDebugStatus = "error"
ChatDebugStatusInterrupted ChatDebugStatus = "interrupted"
)
// AllChatDebugStatuses contains every ChatDebugStatus value.
// Update this when adding new constants above.
var AllChatDebugStatuses = []ChatDebugStatus{
ChatDebugStatusInProgress,
ChatDebugStatusCompleted,
ChatDebugStatusError,
ChatDebugStatusInterrupted,
}
// ChatDebugRunKind labels the operation that produced the debug
// run. Each value corresponds to a distinct call-site in chatd.
type ChatDebugRunKind string
const (
ChatDebugRunKindChatTurn ChatDebugRunKind = "chat_turn"
ChatDebugRunKindTitleGeneration ChatDebugRunKind = "title_generation"
ChatDebugRunKindQuickgen ChatDebugRunKind = "quickgen"
ChatDebugRunKindCompaction ChatDebugRunKind = "compaction"
)
// AllChatDebugRunKinds contains every ChatDebugRunKind value.
// Update this when adding new constants above.
var AllChatDebugRunKinds = []ChatDebugRunKind{
ChatDebugRunKindChatTurn,
ChatDebugRunKindTitleGeneration,
ChatDebugRunKindQuickgen,
ChatDebugRunKindCompaction,
}
// ChatDebugStepOperation labels the model interaction type for a
// debug step.
type ChatDebugStepOperation string
const (
ChatDebugStepOperationStream ChatDebugStepOperation = "stream"
ChatDebugStepOperationGenerate ChatDebugStepOperation = "generate"
)
// AllChatDebugStepOperations contains every ChatDebugStepOperation
// value. Update this when adding new constants above.
var AllChatDebugStepOperations = []ChatDebugStepOperation{
ChatDebugStepOperationStream,
ChatDebugStepOperationGenerate,
}
// ChatDebugRunSummary is a lightweight run entry for list endpoints.
type ChatDebugRunSummary struct {
ID uuid.UUID `json:"id" format:"uuid"`
ChatID uuid.UUID `json:"chat_id" format:"uuid"`
Kind ChatDebugRunKind `json:"kind"`
Status ChatDebugStatus `json:"status"`
Provider *string `json:"provider,omitempty"`
Model *string `json:"model,omitempty"`
Summary map[string]any `json:"summary"`
StartedAt time.Time `json:"started_at" format:"date-time"`
UpdatedAt time.Time `json:"updated_at" format:"date-time"`
FinishedAt *time.Time `json:"finished_at,omitempty" format:"date-time"`
}
// ChatDebugRun is the detailed run response including steps.
// This type is consumed by the run-detail handler added in a later
// PR in this stack; it is forward-declared here so that all SDK
// types live in the same schema-layer commit.
type ChatDebugRun struct {
ID uuid.UUID `json:"id" format:"uuid"`
ChatID uuid.UUID `json:"chat_id" format:"uuid"`
RootChatID *uuid.UUID `json:"root_chat_id,omitempty" format:"uuid"`
ParentChatID *uuid.UUID `json:"parent_chat_id,omitempty" format:"uuid"`
ModelConfigID *uuid.UUID `json:"model_config_id,omitempty" format:"uuid"`
TriggerMessageID *int64 `json:"trigger_message_id,omitempty"`
HistoryTipMessageID *int64 `json:"history_tip_message_id,omitempty"`
Kind ChatDebugRunKind `json:"kind"`
Status ChatDebugStatus `json:"status"`
Provider *string `json:"provider,omitempty"`
Model *string `json:"model,omitempty"`
Summary map[string]any `json:"summary"`
StartedAt time.Time `json:"started_at" format:"date-time"`
UpdatedAt time.Time `json:"updated_at" format:"date-time"`
FinishedAt *time.Time `json:"finished_at,omitempty" format:"date-time"`
Steps []ChatDebugStep `json:"steps"`
}
// ChatDebugStep is a single step within a debug run.
type ChatDebugStep struct {
ID uuid.UUID `json:"id" format:"uuid"`
RunID uuid.UUID `json:"run_id" format:"uuid"`
ChatID uuid.UUID `json:"chat_id" format:"uuid"`
StepNumber int32 `json:"step_number"`
Operation ChatDebugStepOperation `json:"operation"`
Status ChatDebugStatus `json:"status"`
HistoryTipMessageID *int64 `json:"history_tip_message_id,omitempty"`
AssistantMessageID *int64 `json:"assistant_message_id,omitempty"`
NormalizedRequest map[string]any `json:"normalized_request"`
NormalizedResponse map[string]any `json:"normalized_response,omitempty"`
Usage map[string]any `json:"usage,omitempty"`
Attempts []map[string]any `json:"attempts"`
Error map[string]any `json:"error,omitempty"`
Metadata map[string]any `json:"metadata"`
StartedAt time.Time `json:"started_at" format:"date-time"`
UpdatedAt time.Time `json:"updated_at" format:"date-time"`
FinishedAt *time.Time `json:"finished_at,omitempty" format:"date-time"`
}
// DefaultChatWorkspaceTTL is the default TTL for chat workspaces.
// Zero means disabled — the template's own autostop setting applies.
const DefaultChatWorkspaceTTL = 0
+12 -12
View File
@@ -3624,16 +3624,6 @@ Write out the current server config as YAML to stdout.`,
YAML: "acquireBatchSize",
Hidden: true, // Hidden because most operators should not need to modify this.
},
{
Name: "Chat: Debug Logging Enabled",
Description: "Force chat debug logging on for every chat, bypassing the runtime admin and user opt-in settings.",
Flag: "chat-debug-logging-enabled",
Env: "CODER_CHAT_DEBUG_LOGGING_ENABLED",
Value: &c.AI.Chat.DebugLoggingEnabled,
Default: "false",
Group: &deploymentGroupChat,
YAML: "debugLoggingEnabled",
},
// AI Bridge Options
{
Name: "AI Bridge Enabled",
@@ -3811,6 +3801,16 @@ Write out the current server config as YAML to stdout.`,
Group: &deploymentGroupAIBridge,
YAML: "send_actor_headers",
},
{
Name: "AI Gateway Allow BYOK",
Description: "Allow users to bring their own LLM API keys or subscriptions. When disabled, only centralized key authentication is permitted.",
Flag: "ai-gateway-allow-byok",
Env: "CODER_AI_GATEWAY_ALLOW_BYOK",
Value: &c.AI.BridgeConfig.AllowBYOK,
Default: "true",
Group: &deploymentGroupAIBridge,
YAML: "allow_byok",
},
{
Name: "AI Bridge Circuit Breaker Enabled",
Description: "Enable the circuit breaker to protect against cascading failures from upstream AI provider rate limits (429, 503, 529 overloaded).",
@@ -4058,6 +4058,7 @@ type AIBridgeConfig struct {
RateLimit serpent.Int64 `json:"rate_limit" typescript:",notnull"`
StructuredLogging serpent.Bool `json:"structured_logging" typescript:",notnull"`
SendActorHeaders serpent.Bool `json:"send_actor_headers" typescript:",notnull"`
AllowBYOK serpent.Bool `json:"allow_byok" typescript:",notnull"`
// Circuit breaker protects against cascading failures from upstream AI
// provider rate limits (429, 503, 529 overloaded).
CircuitBreakerEnabled serpent.Bool `json:"circuit_breaker_enabled" typescript:",notnull"`
@@ -4100,8 +4101,7 @@ type AIBridgeProxyConfig struct {
}
type ChatConfig struct {
AcquireBatchSize serpent.Int64 `json:"acquire_batch_size" typescript:",notnull"`
DebugLoggingEnabled serpent.Bool `json:"debug_logging_enabled" typescript:",notnull"`
AcquireBatchSize serpent.Int64 `json:"acquire_batch_size" typescript:",notnull"`
}
type AIConfig struct {
+1 -2
View File
@@ -209,8 +209,7 @@ curl -X GET http://coder-server:8080/api/v2/deployment/config \
"structured_logging": true
},
"chat": {
"acquire_batch_size": 0,
"debug_logging_enabled": true
"acquire_batch_size": 0
}
},
"allow_workspace_renames": true,
+7 -12
View File
@@ -1240,8 +1240,7 @@
"structured_logging": true
},
"chat": {
"acquire_batch_size": 0,
"debug_logging_enabled": true
"acquire_batch_size": 0
}
}
```
@@ -2022,17 +2021,15 @@ AuthorizationObject can represent a "set" of objects, such as: all workspaces in
```json
{
"acquire_batch_size": 0,
"debug_logging_enabled": true
"acquire_batch_size": 0
}
```
### Properties
| Name | Type | Required | Restrictions | Description |
|-------------------------|---------|----------|--------------|-------------|
| `acquire_batch_size` | integer | false | | |
| `debug_logging_enabled` | boolean | false | | |
| Name | Type | Required | Restrictions | Description |
|----------------------|---------|----------|--------------|-------------|
| `acquire_batch_size` | integer | false | | |
## codersdk.ChatRetentionDaysResponse
@@ -3264,8 +3261,7 @@ CreateWorkspaceRequest provides options for creating a new workspace. Only one o
"structured_logging": true
},
"chat": {
"acquire_batch_size": 0,
"debug_logging_enabled": true
"acquire_batch_size": 0
}
},
"allow_workspace_renames": true,
@@ -3843,8 +3839,7 @@ CreateWorkspaceRequest provides options for creating a new workspace. Only one o
"structured_logging": true
},
"chat": {
"acquire_batch_size": 0,
"debug_logging_enabled": true
"acquire_batch_size": 0
}
},
"allow_workspace_renames": true,
-11
View File
@@ -1702,17 +1702,6 @@ How often to reconcile workspace prebuilds state.
Hide AI tasks from the dashboard.
### --chat-debug-logging-enabled
| | |
|-------------|------------------------------------------------|
| Type | <code>bool</code> |
| Environment | <code>$CODER_CHAT_DEBUG_LOGGING_ENABLED</code> |
| YAML | <code>chat.debugLoggingEnabled</code> |
| Default | <code>false</code> |
Force chat debug logging on for every chat, bypassing the runtime admin and user opt-in settings.
### --aibridge-enabled
| | |
+4 -1
View File
@@ -48,9 +48,11 @@ type Server struct {
cancelFn func()
shutdownOnce sync.Once
allowBYOK bool
}
func New(ctx context.Context, pool Pooler, rpcDialer Dialer, logger slog.Logger, tracer trace.Tracer) (*Server, error) {
func New(ctx context.Context, pool Pooler, rpcDialer Dialer, logger slog.Logger, tracer trace.Tracer, allowBYOK bool) (*Server, error) {
if rpcDialer == nil {
return nil, xerrors.Errorf("nil rpcDialer given")
}
@@ -66,6 +68,7 @@ func New(ctx context.Context, pool Pooler, rpcDialer Dialer, logger slog.Logger,
initConnectionCh: make(chan struct{}),
requestBridgePool: pool,
allowBYOK: allowBYOK,
}
daemon.wg.Add(1)
@@ -190,7 +190,7 @@ func TestIntegration(t *testing.T) {
// Given: aibridged is started.
srv, err := aibridged.New(t.Context(), pool, func(ctx context.Context) (aibridged.DRPCClient, error) {
return aiBridgeClient, nil
}, logger, tracer)
}, logger, tracer, true)
require.NoError(t, err, "create new aibridged")
t.Cleanup(func() {
_ = srv.Shutdown(ctx)
@@ -393,7 +393,7 @@ func TestIntegrationWithMetrics(t *testing.T) {
// Given: aibridged is started.
srv, err := aibridged.New(ctx, pool, func(ctx context.Context) (aibridged.DRPCClient, error) {
return aiBridgeClient, nil
}, logger, testTracer)
}, logger, testTracer, true)
require.NoError(t, err, "create new aibridged")
t.Cleanup(func() {
_ = srv.Shutdown(ctx)
@@ -508,7 +508,7 @@ func TestIntegrationCircuitBreaker(t *testing.T) {
// Given: aibridged is started.
srv, err := aibridged.New(ctx, pool, func(ctx context.Context) (aibridged.DRPCClient, error) {
return aiBridgeClient, nil
}, logger, testTracer)
}, logger, testTracer, true)
require.NoError(t, err, "create new aibridged")
t.Cleanup(func() {
_ = srv.Shutdown(ctx)
+3 -3
View File
@@ -43,7 +43,7 @@ func newTestServer(t *testing.T) (*aibridged.Server, *mock.MockDRPCClient, *mock
pool,
func(ctx context.Context) (aibridged.DRPCClient, error) {
return client, nil
}, logger, testTracer)
}, logger, testTracer, true)
require.NoError(t, err, "create new aibridged")
t.Cleanup(func() {
srv.Shutdown(context.Background())
@@ -441,7 +441,7 @@ func TestServeHTTP_ActorHeaders(t *testing.T) {
// Given: aibridged is started.
srv, err := aibridged.New(t.Context(), pool, func(ctx context.Context) (aibridged.DRPCClient, error) {
return client, nil
}, logger, testTracer)
}, logger, testTracer, true)
require.NoError(t, err, "create new aibridged")
t.Cleanup(func() {
_ = srv.Shutdown(testutil.Context(t, testutil.WaitShort))
@@ -545,7 +545,7 @@ func TestRouting(t *testing.T) {
// Given: aibridged is started.
srv, err := aibridged.New(t.Context(), pool, func(ctx context.Context) (aibridged.DRPCClient, error) {
return client, nil
}, logger, testTracer)
}, logger, testTracer, true)
require.NoError(t, err, "create new aibridged")
t.Cleanup(func() {
_ = srv.Shutdown(testutil.Context(t, testutil.WaitShort))
+7
View File
@@ -56,6 +56,13 @@ func (s *Server) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
authMode = "byok"
}
if byok && !s.allowBYOK {
logger.Warn(ctx, "BYOK request rejected: not allowed by deployment configuration")
http.Error(rw, "Bring Your Own Key (BYOK) mode is not enabled. "+
"Contact your administrator to enable it with --ai-gateway-allow-byok.", http.StatusForbidden)
return
}
key := strings.TrimSpace(agplaibridge.ExtractAuthToken(r.Header))
if key == "" {
// Some clients (e.g. Claude) send a HEAD request
+1 -1
View File
@@ -86,7 +86,7 @@ func newAIBridgeDaemon(coderAPI *coderd.API) (*aibridged.Server, error) {
// Create daemon.
srv, err := aibridged.New(ctx, pool, func(dialCtx context.Context) (aibridged.DRPCClient, error) {
return coderAPI.CreateInMemoryAIBridgeServer(dialCtx)
}, logger, tracer)
}, logger, tracer, cfg.AllowBYOK.Value())
if err != nil {
return nil, xerrors.Errorf("start in-memory aibridge daemon: %w", err)
}
-7
View File
@@ -212,13 +212,6 @@ AI BRIDGE PROXY OPTIONS:
certificates not trusted by the system. If not provided, the system
certificate pool is used.
CHAT OPTIONS:
Configure the background chat processing daemon.
--chat-debug-logging-enabled bool, $CODER_CHAT_DEBUG_LOGGING_ENABLED (default: false)
Force chat debug logging on for every chat, bypassing the runtime
admin and user opt-in settings.
CLIENT OPTIONS:
These options change the behavior of how clients interact with the Coder.
Clients include the Coder CLI, Coder Desktop, IDE extensions, and the web UI.
-151
View File
@@ -1236,7 +1236,6 @@ export const ChatCompactionThresholdKeyPrefix =
// From codersdk/deployment.go
export interface ChatConfig {
readonly acquire_batch_size: number;
readonly debug_logging_enabled: boolean;
}
// From codersdk/chats.go
@@ -1364,127 +1363,6 @@ export interface ChatCostUsersResponse {
readonly users: readonly ChatCostUserRollup[];
}
// From codersdk/chats.go
/**
* ChatDebugLoggingAdminSettings describes the runtime admin setting
* that allows users to opt into chat debug logging.
*/
export interface ChatDebugLoggingAdminSettings {
readonly allow_users: boolean;
readonly forced_by_deployment: boolean;
}
// From codersdk/chats.go
/**
* ChatDebugRun is the detailed run response including steps.
* This type is consumed by the run-detail handler added in a later
* PR in this stack; it is forward-declared here so that all SDK
* types live in the same schema-layer commit.
*/
export interface ChatDebugRun {
readonly id: string;
readonly chat_id: string;
readonly root_chat_id?: string;
readonly parent_chat_id?: string;
readonly model_config_id?: string;
readonly trigger_message_id?: number;
readonly history_tip_message_id?: number;
readonly kind: ChatDebugRunKind;
readonly status: ChatDebugStatus;
readonly provider?: string;
readonly model?: string;
// empty interface{} type, falling back to unknown
readonly summary: Record<string, unknown>;
readonly started_at: string;
readonly updated_at: string;
readonly finished_at?: string;
readonly steps: readonly ChatDebugStep[];
}
// From codersdk/chats.go
export type ChatDebugRunKind =
| "chat_turn"
| "compaction"
| "quickgen"
| "title_generation";
export const ChatDebugRunKinds: ChatDebugRunKind[] = [
"chat_turn",
"compaction",
"quickgen",
"title_generation",
];
// From codersdk/chats.go
/**
* ChatDebugRunSummary is a lightweight run entry for list endpoints.
*/
export interface ChatDebugRunSummary {
readonly id: string;
readonly chat_id: string;
readonly kind: ChatDebugRunKind;
readonly status: ChatDebugStatus;
readonly provider?: string;
readonly model?: string;
// empty interface{} type, falling back to unknown
readonly summary: Record<string, unknown>;
readonly started_at: string;
readonly updated_at: string;
readonly finished_at?: string;
}
// From codersdk/chats.go
export type ChatDebugStatus =
| "completed"
| "error"
| "in_progress"
| "interrupted";
export const ChatDebugStatuses: ChatDebugStatus[] = [
"completed",
"error",
"in_progress",
"interrupted",
];
// From codersdk/chats.go
/**
* ChatDebugStep is a single step within a debug run.
*/
export interface ChatDebugStep {
readonly id: string;
readonly run_id: string;
readonly chat_id: string;
readonly step_number: number;
readonly operation: ChatDebugStepOperation;
readonly status: ChatDebugStatus;
readonly history_tip_message_id?: number;
readonly assistant_message_id?: number;
// empty interface{} type, falling back to unknown
readonly normalized_request: Record<string, unknown>;
// empty interface{} type, falling back to unknown
readonly normalized_response?: Record<string, unknown>;
// empty interface{} type, falling back to unknown
readonly usage?: Record<string, unknown>;
// empty interface{} type, falling back to unknown
readonly attempts: readonly Record<string, unknown>[];
// empty interface{} type, falling back to unknown
readonly error?: Record<string, unknown>;
// empty interface{} type, falling back to unknown
readonly metadata: Record<string, unknown>;
readonly started_at: string;
readonly updated_at: string;
readonly finished_at?: string;
}
// From codersdk/chats.go
export type ChatDebugStepOperation = "generate" | "stream";
export const ChatDebugStepOperations: ChatDebugStepOperation[] = [
"generate",
"stream",
];
// From codersdk/chats.go
/**
* ChatDesktopEnabledResponse is the response for getting the desktop setting.
@@ -7482,15 +7360,6 @@ export interface UpdateAppearanceConfig {
readonly announcement_banners: readonly BannerConfig[];
}
// From codersdk/chats.go
/**
* UpdateChatDebugLoggingAllowUsersRequest is the admin request to
* toggle whether users may opt into chat debug logging.
*/
export interface UpdateChatDebugLoggingAllowUsersRequest {
readonly allow_users: boolean;
}
// From codersdk/chats.go
/**
* UpdateChatDesktopEnabledRequest is the request to update the desktop setting.
@@ -7794,15 +7663,6 @@ export interface UpdateUserChatCompactionThresholdRequest {
readonly threshold_percent: number;
}
// From codersdk/chats.go
/**
* UpdateUserChatDebugLoggingRequest is the per-user request to
* opt into or out of chat debug logging.
*/
export interface UpdateUserChatDebugLoggingRequest {
readonly debug_logging_enabled: boolean;
}
// From codersdk/notifications.go
export interface UpdateUserNotificationPreferences {
readonly template_disabled_map: Record<string, boolean>;
@@ -8102,17 +7962,6 @@ export interface UserChatCustomPrompt {
readonly custom_prompt: string;
}
// From codersdk/chats.go
/**
* UserChatDebugLoggingSettings describes whether debug logging is
* active for the current user and whether the user may control it.
*/
export interface UserChatDebugLoggingSettings {
readonly debug_logging_enabled: boolean;
readonly user_toggle_allowed: boolean;
readonly forced_by_deployment: boolean;
}
// From codersdk/chats.go
/**
* UserChatProviderConfig is a summary of a provider that allows