feat(coderd/x/chatd): wire debug logging into chat lifecycle

Signed-off-by: Thomas Kosiewski <tk@coder.com>
This commit is contained in:
Thomas Kosiewski
2026-04-08 22:31:39 +00:00
parent cf9847a799
commit 76f342f2fc
11 changed files with 1237 additions and 74 deletions
+515 -55
View File
@@ -34,6 +34,7 @@ import (
"github.com/coder/coder/v2/coderd/webpush"
"github.com/coder/coder/v2/coderd/workspacestats"
"github.com/coder/coder/v2/coderd/x/chatd/chatcost"
"github.com/coder/coder/v2/coderd/x/chatd/chatdebug"
"github.com/coder/coder/v2/coderd/x/chatd/chaterror"
"github.com/coder/coder/v2/coderd/x/chatd/chatloop"
"github.com/coder/coder/v2/coderd/x/chatd/chatprompt"
@@ -129,6 +130,7 @@ type Server struct {
pubsub pubsub.Pubsub
webpushDispatcher webpush.Dispatcher
providerAPIKeys chatprovider.ProviderAPIKeys
debugSvc *chatdebug.Service
configCache *chatConfigCache
configCacheUnsubscribe func()
@@ -1210,7 +1212,10 @@ func (p *Server) EditMessage(
return EditMessageResult{}, xerrors.Errorf("marshal message content: %w", err)
}
var result EditMessageResult
var (
result EditMessageResult
editedMsg database.ChatMessage
)
txErr := p.db.InTx(func(tx database.Store) error {
lockedChat, err := tx.GetChatByIDForUpdate(ctx, opts.ChatID)
if err != nil {
@@ -1221,17 +1226,17 @@ func (p *Server) EditMessage(
return limitErr
}
existing, err := tx.GetChatMessageByID(ctx, opts.EditedMessageID)
editedMsg, err = tx.GetChatMessageByID(ctx, opts.EditedMessageID)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return ErrEditedMessageNotFound
}
return xerrors.Errorf("get edited message: %w", err)
}
if existing.ChatID != opts.ChatID {
if editedMsg.ChatID != opts.ChatID {
return ErrEditedMessageNotFound
}
if existing.Role != database.ChatMessageRoleUser {
if editedMsg.Role != database.ChatMessageRoleUser {
return ErrEditedMessageNotUser
}
@@ -1258,8 +1263,8 @@ func (p *Server) EditMessage(
appendChatMessage(&msgParams, newChatMessage(
database.ChatMessageRoleUser,
content,
existing.Visibility,
existing.ModelConfigID.UUID,
editedMsg.Visibility,
editedMsg.ModelConfigID.UUID,
chatprompt.CurrentContentVersion,
).withCreatedBy(opts.CreatedBy))
newMessages, err := insertChatMessageWithStore(ctx, tx, msgParams)
@@ -1302,6 +1307,26 @@ func (p *Server) EditMessage(
})
p.publishStatus(opts.ChatID, result.Chat.Status, result.Chat.WorkerID)
p.publishChatPubsubEvent(result.Chat, codersdk.ChatWatchEventKindStatusChange, nil)
// Best-effort debug row cleanup. We do not wait for the active
// worker to stop because activeChats is process-local and would
// not cover multi-replica deployments. Any rows that survive
// this pass are caught by the periodic stale-finalization sweep.
if p.debugSvc != nil {
cleanupCtx, cleanupCancel := context.WithTimeout(context.WithoutCancel(ctx), 5*time.Second)
defer cleanupCancel()
if _, err := p.debugSvc.DeleteAfterMessageID(
cleanupCtx,
opts.ChatID,
editedMsg.ID-1,
); err != nil {
p.logger.Warn(ctx, "failed to delete chat debug rows after edit",
slog.F("chat_id", opts.ChatID),
slog.F("edited_message_id", editedMsg.ID),
slog.Error(err),
)
}
}
p.signalWake()
return result, nil
@@ -1316,46 +1341,67 @@ func (p *Server) ArchiveChat(ctx context.Context, chat database.Chat) error {
return xerrors.New("chat_id is required")
}
statusChat := chat
interrupted := false
var archivedChats []database.Chat
var (
archivedChats []database.Chat
interruptedChats []database.Chat
)
if err := p.db.InTx(func(tx database.Store) error {
lockedChat, err := tx.GetChatByIDForUpdate(ctx, chat.ID)
if err != nil {
if _, err := tx.GetChatByIDForUpdate(ctx, chat.ID); err != nil {
return xerrors.Errorf("lock chat for archive: %w", err)
}
statusChat = lockedChat
// We do not call setChatWaiting here because it intentionally preserves
// pending chats so queued-message promotion can win. Archiving is a
// harder stop: both pending and running chats must transition to waiting.
if lockedChat.Status == database.ChatStatusPending || lockedChat.Status == database.ChatStatusRunning {
statusChat, err = tx.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
ID: chat.ID,
var err error
archivedChats, err = tx.ArchiveChatByID(ctx, chat.ID)
if err != nil {
return xerrors.Errorf("archive chat: %w", err)
}
for i, archivedChat := range archivedChats {
if archivedChat.Status != database.ChatStatusPending &&
archivedChat.Status != database.ChatStatusRunning {
continue
}
updatedChat, updateErr := tx.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
ID: archivedChat.ID,
Status: database.ChatStatusWaiting,
WorkerID: uuid.NullUUID{},
StartedAt: sql.NullTime{},
HeartbeatAt: sql.NullTime{},
LastError: sql.NullString{},
})
if err != nil {
return xerrors.Errorf("set chat waiting before archive: %w", err)
if updateErr != nil {
return xerrors.Errorf("set archived chat waiting before cleanup: %w", updateErr)
}
interrupted = true
}
archivedChats, err = tx.ArchiveChatByID(ctx, chat.ID)
if err != nil {
return xerrors.Errorf("archive chat: %w", err)
archivedChats[i] = updatedChat
interruptedChats = append(interruptedChats, updatedChat)
}
return nil
}, nil); err != nil {
return err
}
if interrupted {
p.publishStatus(chat.ID, statusChat.Status, statusChat.WorkerID)
p.publishChatPubsubEvent(statusChat, codersdk.ChatWatchEventKindStatusChange, nil)
for _, interruptedChat := range interruptedChats {
p.publishStatus(interruptedChat.ID, interruptedChat.Status, interruptedChat.WorkerID)
p.publishChatPubsubEvent(interruptedChat, codersdk.ChatWatchEventKindStatusChange, nil)
}
// Best-effort debug row cleanup — no process-local wait so this
// works correctly across replicas. If an active goroutine writes
// new debug rows after the delete, FinalizeStale will mark them
// as interrupted. Those orphaned rows are harmless because the
// chat itself is archived and no longer served through the API.
if p.debugSvc != nil {
for _, archivedChat := range archivedChats {
cleanupCtx, cleanupCancel := context.WithTimeout(context.WithoutCancel(ctx), 5*time.Second)
if _, err := p.debugSvc.DeleteByChatID(cleanupCtx, archivedChat.ID); err != nil {
p.logger.Warn(ctx, "failed to delete chat debug rows after archive",
slog.F("chat_id", archivedChat.ID),
slog.Error(err),
)
}
cleanupCancel()
}
}
p.publishChatPubsubEvents(archivedChats, codersdk.ChatWatchEventKindDeleted)
@@ -1818,6 +1864,8 @@ func (p *Server) InterruptChat(
}
}
// Debug runs are finalized in the execution path when the owning
// goroutine observes cancellation, so we do not mutate debug state here.
updatedChat, err := p.setChatWaiting(ctx, chat.ID)
if err != nil {
p.logger.Error(ctx, "failed to mark chat as waiting",
@@ -2058,7 +2106,23 @@ func (p *Server) regenerateChatTitleWithStore(
return database.Chat{}, err
}
title, usage, err := generateManualTitle(ctx, messages, model)
debugEnabled := p.debugSvc != nil && p.debugSvc.IsEnabled(ctx, chat.ID, chat.OwnerID)
titleCtx := ctx
titleModel := model
finishDebugRun := func(error) {}
if debugEnabled {
titleCtx, titleModel, finishDebugRun = p.prepareManualTitleDebugRun(
ctx,
chat,
modelConfig,
keys,
messages,
model,
)
}
title, usage, err := generateManualTitle(titleCtx, messages, titleModel)
finishDebugRun(err)
if err != nil {
wrappedErr := xerrors.Errorf("generate manual title: %w", err)
if usage == (fantasy.Usage{}) {
@@ -2096,6 +2160,177 @@ func (p *Server) regenerateChatTitleWithStore(
return updatedChat, nil
}
func (p *Server) prepareManualTitleDebugRun(
ctx context.Context,
chat database.Chat,
modelConfig database.ChatModelConfig,
keys chatprovider.ProviderAPIKeys,
messages []database.ChatMessage,
fallbackModel fantasy.LanguageModel,
) (context.Context, fantasy.LanguageModel, func(error)) {
titleCtx := ctx
titleModel := fallbackModel
finishDebugRun := func(error) {}
httpClient := &http.Client{Transport: &chatdebug.RecordingTransport{}}
debugModel, debugModelErr := chatprovider.ModelFromConfig(
modelConfig.Provider,
modelConfig.Model,
keys,
chatprovider.UserAgent(),
chatprovider.CoderHeaders(chat),
httpClient,
)
switch {
case debugModelErr != nil:
p.logger.Warn(ctx, "failed to create debug-aware manual title model",
slog.F("chat_id", chat.ID),
slog.F("provider", modelConfig.Provider),
slog.F("model", modelConfig.Model),
slog.Error(debugModelErr),
)
case debugModel == nil:
p.logger.Warn(ctx, "manual title debug model creation returned nil",
slog.F("chat_id", chat.ID),
slog.F("provider", modelConfig.Provider),
slog.F("model", modelConfig.Model),
)
default:
titleModel = chatdebug.WrapModel(debugModel, p.debugSvc, chatdebug.RecorderOptions{
ChatID: chat.ID,
OwnerID: chat.OwnerID,
Provider: modelConfig.Provider,
Model: modelConfig.Model,
})
}
var historyTipMessageID int64
if len(messages) > 0 {
historyTipMessageID = messages[len(messages)-1].ID
}
// Derive a first_message label from the first user message.
var firstUserLabel string
for _, msg := range messages {
if msg.Role == database.ChatMessageRoleUser {
if parts, parseErr := chatprompt.ParseContent(msg); parseErr == nil {
firstUserLabel = contentBlocksToText(parts)
}
break
}
}
if firstUserLabel == "" {
firstUserLabel = "Title generation"
}
seedSummary := chatdebug.SeedSummary(
chatdebug.TruncateLabel(firstUserLabel, chatdebug.MaxLabelLength),
)
createRunCtx, createRunCancel := context.WithTimeout(context.WithoutCancel(ctx), 5*time.Second)
debugRun, createRunErr := p.debugSvc.CreateRun(createRunCtx, chatdebug.CreateRunParams{
ChatID: chat.ID,
ModelConfigID: modelConfig.ID,
Provider: modelConfig.Provider,
Model: modelConfig.Model,
Kind: chatdebug.KindTitleGeneration,
Status: chatdebug.StatusInProgress,
HistoryTipMessageID: historyTipMessageID,
TriggerMessageID: 0,
Summary: seedSummary,
})
createRunCancel()
if createRunErr != nil {
p.logger.Warn(ctx, "failed to create manual title debug run",
slog.F("chat_id", chat.ID),
slog.F("provider", modelConfig.Provider),
slog.F("model", modelConfig.Model),
slog.Error(createRunErr),
)
return titleCtx, titleModel, finishDebugRun
}
runContext := chatdebugRunContext(debugRun)
titleCtx = chatdebug.ContextWithRun(titleCtx, &runContext)
finishDebugRun = func(generateErr error) {
status := chatdebug.StatusCompleted
switch {
case generateErr == nil:
// keep completed
case errors.Is(generateErr, context.Canceled):
status = chatdebug.StatusInterrupted
default:
status = chatdebug.StatusError
}
finalSummary := seedSummary
aggCtx, aggCancel := context.WithTimeout(context.WithoutCancel(ctx), 5*time.Second)
defer aggCancel()
if aggregated, aggErr := p.debugSvc.AggregateRunSummary(
aggCtx,
debugRun.ID,
seedSummary,
); aggErr != nil {
p.logger.Warn(ctx, "failed to aggregate debug run summary",
slog.F("chat_id", chat.ID),
slog.F("run_id", debugRun.ID),
slog.Error(aggErr),
)
} else {
finalSummary = aggregated
}
updateRunCtx, updateRunCancel := context.WithTimeout(context.WithoutCancel(ctx), 5*time.Second)
defer updateRunCancel()
_, updateRunErr := p.debugSvc.UpdateRun(updateRunCtx, chatdebug.UpdateRunParams{
ID: debugRun.ID,
ChatID: debugRun.ChatID,
Status: status,
Summary: finalSummary,
FinishedAt: time.Now(),
})
if updateRunErr != nil {
p.logger.Warn(ctx, "failed to finalize manual title debug run",
slog.F("chat_id", chat.ID),
slog.F("run_id", debugRun.ID),
slog.Error(updateRunErr),
)
}
chatdebug.CleanupStepCounter(debugRun.ID)
}
return titleCtx, titleModel, finishDebugRun
}
func chatdebugRunContext(run database.ChatDebugRun) chatdebug.RunContext {
runContext := chatdebug.RunContext{
RunID: run.ID,
ChatID: run.ChatID,
Kind: chatdebug.RunKind(run.Kind),
}
if run.RootChatID.Valid {
runContext.RootChatID = run.RootChatID.UUID
}
if run.ParentChatID.Valid {
runContext.ParentChatID = run.ParentChatID.UUID
}
if run.ModelConfigID.Valid {
runContext.ModelConfigID = run.ModelConfigID.UUID
}
if run.TriggerMessageID.Valid {
runContext.TriggerMessageID = run.TriggerMessageID.Int64
}
if run.HistoryTipMessageID.Valid {
runContext.HistoryTipMessageID = run.HistoryTipMessageID.Int64
}
if run.Provider.Valid {
runContext.Provider = run.Provider.String
}
if run.Model.Valid {
runContext.Model = run.Model.String
}
return runContext
}
func (p *Server) resolveManualTitleModel(
ctx context.Context,
store database.Store,
@@ -2122,6 +2357,7 @@ func (p *Server) resolveManualTitleModel(
keys,
chatprovider.UserAgent(),
chatprovider.CoderHeaders(chat),
nil,
)
if err != nil {
p.logger.Debug(ctx, "manual title preferred model unavailable",
@@ -2154,6 +2390,7 @@ func (p *Server) resolveFallbackManualTitleModel(
keys,
chatprovider.UserAgent(),
chatprovider.CoderHeaders(chat),
nil,
)
if err != nil {
return nil, database.ChatModelConfig{}, xerrors.Errorf(
@@ -2688,6 +2925,7 @@ type Config struct {
StartWorkspace chattool.StartWorkspaceFn
Pubsub pubsub.Pubsub
ProviderAPIKeys chatprovider.ProviderAPIKeys
AlwaysEnableDebugLogs bool
WebpushDispatcher webpush.Dispatcher
UsageTracker *workspacestats.UsageTracker
Clock quartz.Clock
@@ -2734,6 +2972,14 @@ func New(cfg Config) *Server {
workerID = uuid.New()
}
debugSvc := chatdebug.NewService(
cfg.Database,
cfg.Logger.Named("chatdebug"),
cfg.Pubsub,
chatdebug.WithAlwaysEnable(cfg.AlwaysEnableDebugLogs),
)
debugSvc.SetStaleAfter(inFlightChatStaleAfter)
p := &Server{
cancel: cancel,
closed: make(chan struct{}),
@@ -2749,6 +2995,7 @@ func New(cfg Config) *Server {
pubsub: cfg.Pubsub,
webpushDispatcher: cfg.WebpushDispatcher,
providerAPIKeys: cfg.ProviderAPIKeys,
debugSvc: debugSvc,
pendingChatAcquireInterval: pendingChatAcquireInterval,
maxChatsPerAcquire: maxChatsPerAcquire,
inFlightChatStaleAfter: inFlightChatStaleAfter,
@@ -2797,6 +3044,12 @@ func (p *Server) start(ctx context.Context) {
// Recover stale chats on startup and periodically thereafter
// to handle chats orphaned by crashed or redeployed workers.
p.recoverStaleChats(ctx)
if p.debugSvc != nil {
_, err := p.debugSvc.FinalizeStale(ctx)
if err != nil {
p.logger.Warn(ctx, "failed to finalize stale chat debug rows", slog.Error(err))
}
}
// Single heartbeat loop for all chats on this replica.
go p.heartbeatLoop(ctx)
@@ -2826,6 +3079,11 @@ func (p *Server) start(ctx context.Context) {
p.processOnce(ctx)
case <-staleTicker.C:
p.recoverStaleChats(ctx)
if p.debugSvc != nil {
if _, err := p.debugSvc.FinalizeStale(ctx); err != nil {
p.logger.Warn(ctx, "failed to finalize stale chat debug rows", slog.Error(err))
}
}
}
}
}
@@ -4337,6 +4595,10 @@ type runChatResult struct {
PushSummaryModel fantasy.LanguageModel
ProviderKeys chatprovider.ProviderAPIKeys
PendingDynamicToolCalls []chatloop.PendingToolCall
FallbackProvider string
FallbackModel string
TriggerMessageID int64
HistoryTipMessageID int64
}
func (p *Server) runChat(
@@ -4347,11 +4609,14 @@ func (p *Server) runChat(
) (runChatResult, error) {
result := runChatResult{}
var (
model fantasy.LanguageModel
modelConfig database.ChatModelConfig
providerKeys chatprovider.ProviderAPIKeys
callConfig codersdk.ChatModelCallConfig
messages []database.ChatMessage
model fantasy.LanguageModel
modelConfig database.ChatModelConfig
providerKeys chatprovider.ProviderAPIKeys
callConfig codersdk.ChatModelCallConfig
messages []database.ChatMessage
debugEnabled bool
debugProvider string
debugModel string
)
// Load MCP server configs and user tokens in parallel with
@@ -4364,7 +4629,7 @@ func (p *Server) runChat(
var g errgroup.Group
g.Go(func() error {
var err error
model, modelConfig, providerKeys, err = p.resolveChatModel(ctx, chat)
model, modelConfig, providerKeys, debugEnabled, debugProvider, debugModel, err = p.resolveChatModel(ctx, chat)
if err != nil {
return err
}
@@ -4422,23 +4687,31 @@ func (p *Server) runChat(
chainInfo := resolveChainMode(messages)
result.PushSummaryModel = model
result.ProviderKeys = providerKeys
result.FallbackProvider = modelConfig.Provider
result.FallbackModel = modelConfig.Model
// Fire title generation asynchronously so it doesn't block the
// chat response. It uses a detached context so it can finish
// even after the chat processing context is canceled.
// Snapshot the original chat model so the goroutine doesn't
// race with the model = cuModel reassignment below.
// Snapshot ctx before the goroutine to avoid a data race with
// the ctx = runCtx reassignment later in the main goroutine.
titleModel := result.PushSummaryModel
titleCtx := context.WithoutCancel(ctx)
p.inflight.Add(1)
go func() {
defer p.inflight.Done()
p.maybeGenerateChatTitle(
context.WithoutCancel(ctx),
titleCtx,
chat,
messages,
modelConfig.Provider,
modelConfig.Model,
titleModel,
providerKeys,
generatedTitle,
logger,
p.debugSvc,
)
}()
@@ -4677,6 +4950,13 @@ func (p *Server) runChat(
var finalAssistantText string
var pendingDynamicCalls []chatloop.PendingToolCall
compactionHistoryTipMessageID := int64(0)
if len(messages) > 0 {
compactionHistoryTipMessageID = messages[len(messages)-1].ID
}
var compactionOptions *chatloop.CompactionOptions
persistStep := func(persistCtx context.Context, step chatloop.PersistedStep) error {
// If the chat context has been canceled, bail out before
// inserting any messages. We distinguish the cause so that
@@ -4889,6 +5169,12 @@ func (p *Server) runChat(
for _, msg := range insertedMessages {
p.publishMessage(chat.ID, msg)
}
if len(insertedMessages) > 0 {
compactionHistoryTipMessageID = insertedMessages[len(insertedMessages)-1].ID
if compactionOptions != nil {
compactionOptions.HistoryTipMessageID = compactionHistoryTipMessageID
}
}
// Do NOT clear the stream buffer here. Cross-replica
// relay subscribers may still need to snapshot buffered
@@ -4918,9 +5204,10 @@ func (p *Server) runChat(
effectiveThreshold = override
thresholdSource = "user_override"
}
compactionOptions := &chatloop.CompactionOptions{
ThresholdPercent: effectiveThreshold,
ContextLimit: modelConfig.ContextLimit,
compactionOptions = &chatloop.CompactionOptions{
ThresholdPercent: effectiveThreshold,
ContextLimit: modelConfig.ContextLimit,
HistoryTipMessageID: compactionHistoryTipMessageID,
Persist: func(
persistCtx context.Context,
result chatloop.CompactionResult,
@@ -4956,7 +5243,16 @@ func (p *Server) runChat(
if isComputerUse {
// Override model for computer use subagent.
cuModel, cuErr := chatprovider.ModelFromConfig(
resolvedProvider, resolvedModel, resolveErr := chatprovider.ResolveModelWithProviderHint(
chattool.ComputerUseModelName,
chattool.ComputerUseModelProvider,
)
if resolveErr != nil {
return result, xerrors.Errorf("resolve computer use model metadata: %w", resolveErr)
}
cuModel, cuDebugEnabled, cuErr := p.newDebugAwareModelFromConfig(
ctx,
chat,
chattool.ComputerUseModelProvider,
chattool.ComputerUseModelName,
providerKeys,
@@ -4967,6 +5263,13 @@ func (p *Server) runChat(
return result, xerrors.Errorf("resolve computer use model: %w", cuErr)
}
model = cuModel
debugEnabled = cuDebugEnabled
debugProvider = resolvedProvider
debugModel = resolvedModel
}
if debugEnabled {
compactionOptions.DebugSvc = p.debugSvc
compactionOptions.ChatID = chat.ID
}
tools := []fantasy.AgentTool{
@@ -5183,7 +5486,132 @@ func (p *Server) runChat(
)
prompt = filterPromptForChainMode(prompt, chainInfo)
}
err = chatloop.Run(ctx, chatloop.RunOptions{
var loopErr error
triggerMessageID := int64(0)
var triggerLabel string
for i := len(messages) - 1; i >= 0; i-- {
if messages[i].Role == database.ChatMessageRoleUser {
triggerMessageID = messages[i].ID
if parts, parseErr := chatprompt.ParseContent(messages[i]); parseErr == nil {
triggerLabel = contentBlocksToText(parts)
}
break
}
}
historyTipMessageID := int64(0)
if len(messages) > 0 {
historyTipMessageID = messages[len(messages)-1].ID
}
result.TriggerMessageID = triggerMessageID
result.HistoryTipMessageID = historyTipMessageID
if debugEnabled {
seedSummary := chatdebug.SeedSummary(
chatdebug.TruncateLabel(triggerLabel, chatdebug.MaxLabelLength),
)
rootChatID := uuid.Nil
if chat.RootChatID.Valid {
rootChatID = chat.RootChatID.UUID
}
parentChatID := uuid.Nil
if chat.ParentChatID.Valid {
parentChatID = chat.ParentChatID.UUID
}
run, createRunErr := p.debugSvc.CreateRun(ctx, chatdebug.CreateRunParams{
ChatID: chat.ID,
RootChatID: rootChatID,
ParentChatID: parentChatID,
ModelConfigID: modelConfig.ID,
TriggerMessageID: triggerMessageID,
HistoryTipMessageID: historyTipMessageID,
Kind: chatdebug.KindChatTurn,
Status: chatdebug.StatusInProgress,
Provider: debugProvider,
Model: debugModel,
Summary: seedSummary,
})
if createRunErr != nil {
logger.Warn(ctx, "failed to create chat debug run",
slog.F("chat_id", chat.ID),
slog.Error(createRunErr),
)
} else {
runCtx := chatdebug.ContextWithRun(ctx, &chatdebug.RunContext{
RunID: run.ID,
ChatID: chat.ID,
RootChatID: rootChatID,
ParentChatID: parentChatID,
ModelConfigID: modelConfig.ID,
TriggerMessageID: triggerMessageID,
HistoryTipMessageID: historyTipMessageID,
Kind: chatdebug.KindChatTurn,
Provider: debugProvider,
Model: debugModel,
})
defer func() {
panicValue := recover()
var status chatdebug.Status
switch {
case panicValue != nil:
status = chatdebug.StatusError
case loopErr == nil:
status = chatdebug.StatusCompleted
case errors.Is(loopErr, chatloop.ErrInterrupted),
errors.Is(loopErr, context.Canceled):
status = chatdebug.StatusInterrupted
case errors.Is(loopErr, chatloop.ErrDynamicToolCall):
// Dynamic tool calls are a successful pause;
// the run completed its model round-trip.
status = chatdebug.StatusCompleted
default:
status = chatdebug.StatusError
}
finalSummary := seedSummary
aggCtx, aggCancel := context.WithTimeout(context.WithoutCancel(runCtx), 5*time.Second)
defer aggCancel()
if aggregated, aggErr := p.debugSvc.AggregateRunSummary(
aggCtx,
run.ID,
seedSummary,
); aggErr != nil {
logger.Warn(ctx, "failed to aggregate debug run summary",
slog.F("chat_id", chat.ID),
slog.F("run_id", run.ID),
slog.Error(aggErr),
)
} else {
finalSummary = aggregated
}
updateRunCtx, updateRunCancel := context.WithTimeout(context.WithoutCancel(runCtx), 5*time.Second)
defer updateRunCancel()
if _, updateRunErr := p.debugSvc.UpdateRun(
updateRunCtx,
chatdebug.UpdateRunParams{
ID: run.ID,
ChatID: chat.ID,
Status: status,
Summary: finalSummary,
FinishedAt: time.Now(),
},
); updateRunErr != nil {
logger.Warn(ctx, "failed to finalize chat debug run",
slog.F("chat_id", chat.ID),
slog.F("run_id", run.ID),
slog.Error(updateRunErr),
)
}
chatdebug.CleanupStepCounter(run.ID)
if panicValue != nil {
panic(panicValue)
}
}()
ctx = runCtx
}
}
loopErr = chatloop.Run(ctx, chatloop.RunOptions{
Model: model,
Messages: prompt,
Tools: tools, MaxSteps: maxChatSteps,
@@ -5215,6 +5643,13 @@ func (p *Server) runChat(
if err != nil {
return nil, xerrors.Errorf("reload chat messages: %w", err)
}
compactionHistoryTipMessageID = 0
if len(reloadedMsgs) > 0 {
compactionHistoryTipMessageID = reloadedMsgs[len(reloadedMsgs)-1].ID
}
if compactionOptions != nil {
compactionOptions.HistoryTipMessageID = compactionHistoryTipMessageID
}
reloadedPrompt, err := chatprompt.ConvertMessagesWithFiles(reloadCtx, reloadedMsgs, p.chatFileResolver(), logger)
if err != nil {
return nil, xerrors.Errorf("convert reloaded messages: %w", err)
@@ -5271,7 +5706,7 @@ func (p *Server) runChat(
p.logger.Warn(ctx, "failed to persist interrupted chat step", slog.Error(err))
},
})
if errors.Is(err, chatloop.ErrDynamicToolCall) {
if errors.Is(loopErr, chatloop.ErrDynamicToolCall) {
// The stream event is published in processChat's
// defer after the DB status transitions to
// requires_action, preventing a race where a fast
@@ -5280,9 +5715,9 @@ func (p *Server) runChat(
result.PendingDynamicToolCalls = pendingDynamicCalls
return result, nil
}
if err != nil {
classified := chaterror.Classify(err).WithProvider(model.Provider())
return result, chaterror.WithClassification(err, classified)
if loopErr != nil {
classified := chaterror.Classify(loopErr).WithProvider(model.Provider())
return result, chaterror.WithClassification(loopErr, classified)
}
result.FinalAssistantText = finalAssistantText
return result, nil
@@ -5446,10 +5881,15 @@ func (p *Server) persistChatContextSummary(
func (p *Server) resolveChatModel(
ctx context.Context,
chat database.Chat,
) (fantasy.LanguageModel, database.ChatModelConfig, chatprovider.ProviderAPIKeys, error) {
var dbConfig database.ChatModelConfig
var keys chatprovider.ProviderAPIKeys
) (
model fantasy.LanguageModel,
dbConfig database.ChatModelConfig,
keys chatprovider.ProviderAPIKeys,
debugEnabled bool,
resolvedProvider string,
resolvedModel string,
err error,
) {
var g errgroup.Group
g.Go(func() error {
var err error
@@ -5468,19 +5908,34 @@ func (p *Server) resolveChatModel(
return nil
})
if err := g.Wait(); err != nil {
return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, err
return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, false, "", "", err
}
model, err := chatprovider.ModelFromConfig(
dbConfig.Provider, dbConfig.Model, keys, chatprovider.UserAgent(),
resolvedProvider, resolvedModel, err = chatprovider.ResolveModelWithProviderHint(
dbConfig.Model,
dbConfig.Provider,
)
if err != nil {
return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, false, "", "", xerrors.Errorf(
"resolve model metadata: %w", err,
)
}
model, debugEnabled, err = p.newDebugAwareModelFromConfig(
ctx,
chat,
dbConfig.Provider,
dbConfig.Model,
keys,
chatprovider.UserAgent(),
chatprovider.CoderHeaders(chat),
)
if err != nil {
return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, xerrors.Errorf(
return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, false, "", "", xerrors.Errorf(
"create model: %w", err,
)
}
return model, dbConfig, keys, nil
return model, dbConfig, keys, debugEnabled, resolvedProvider, resolvedModel, nil
}
func (p *Server) resolveUserProviderAPIKeys(
@@ -6162,9 +6617,14 @@ func (p *Server) maybeSendPushNotification(
pushCtx,
chat,
assistantText,
runResult.FallbackProvider,
runResult.FallbackModel,
runResult.PushSummaryModel,
runResult.ProviderKeys,
logger,
p.debugSvc,
runResult.TriggerMessageID,
runResult.HistoryTipMessageID,
); summary != "" {
pushBody = summary
}
+64
View File
@@ -0,0 +1,64 @@
package chatd
import (
"context"
"net/http"
"charm.land/fantasy"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/x/chatd/chatdebug"
"github.com/coder/coder/v2/coderd/x/chatd/chatprovider"
)
func (p *Server) newDebugAwareModelFromConfig(
ctx context.Context,
chat database.Chat,
providerHint string,
modelName string,
providerKeys chatprovider.ProviderAPIKeys,
userAgent string,
extraHeaders map[string]string,
) (fantasy.LanguageModel, bool, error) {
provider, resolvedModel, err := chatprovider.ResolveModelWithProviderHint(modelName, providerHint)
if err != nil {
return nil, false, err
}
debugEnabled := p.debugSvc != nil && p.debugSvc.IsEnabled(ctx, chat.ID, chat.OwnerID)
var httpClient *http.Client
if debugEnabled {
httpClient = &http.Client{Transport: &chatdebug.RecordingTransport{}}
}
model, err := chatprovider.ModelFromConfig(
provider,
resolvedModel,
providerKeys,
userAgent,
extraHeaders,
httpClient,
)
if err != nil {
return nil, debugEnabled, err
}
if model == nil {
return nil, debugEnabled, xerrors.Errorf(
"create model for %s/%s returned nil",
provider,
resolvedModel,
)
}
if !debugEnabled {
return model, false, nil
}
return chatdebug.WrapModel(model, p.debugSvc, chatdebug.RecorderOptions{
ChatID: chat.ID,
OwnerID: chat.OwnerID,
Provider: provider,
Model: resolvedModel,
}), true, nil
}
+9
View File
@@ -33,6 +33,15 @@ import (
"github.com/coder/quartz"
)
// TestWaitForActiveChatStop and TestWaitForActiveChatStop_WaitsForReplacementRun
// were removed along with the process-local activeChats mechanism.
// Debug cleanup is now best-effort; stale finalization handles orphaned rows.
// TestArchiveChatWaitsForActiveChatStop and
// TestArchiveChatWaitsForEveryInterruptedChat were removed along with
// the process-local activeChats mechanism. Archive cleanup is now
// best-effort; stale finalization handles any orphaned rows.
func TestRegenerateChatTitle_PersistsAndBroadcasts(t *testing.T) {
t.Parallel()
+1 -1
View File
@@ -346,7 +346,7 @@ func (s *Service) CreateStep(
}
return database.ChatDebugStep{}, xerrors.Errorf(
"failed to create debug step after %d attempts (run_id=%s)",
"chatdebug: failed to create step after %d retries (run %s)",
maxCreateStepRetries, params.RunID,
)
}
+3 -1
View File
@@ -20,6 +20,7 @@ import (
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/coderd/x/chatd/chatdebug"
"github.com/coder/coder/v2/coderd/x/chatd/chaterror"
"github.com/coder/coder/v2/coderd/x/chatd/chatprompt"
"github.com/coder/coder/v2/coderd/x/chatd/chatretry"
@@ -368,7 +369,8 @@ func Run(ctx context.Context, opts RunOptions) error {
}
var result stepResult
err := chatretry.Retry(ctx, func(retryCtx context.Context) error {
stepCtx := chatdebug.ReuseStep(ctx)
err := chatretry.Retry(stepCtx, func(retryCtx context.Context) error {
attempt, streamErr := guardedStream(
retryCtx,
opts.Model.Provider(),
+97 -1
View File
@@ -7,8 +7,10 @@ import (
"time"
"charm.land/fantasy"
"github.com/google/uuid"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/x/chatd/chatdebug"
"github.com/coder/coder/v2/codersdk"
)
@@ -46,6 +48,9 @@ type CompactionOptions struct {
SystemSummaryPrefix string
Timeout time.Duration
Persist func(context.Context, CompactionResult) error
DebugSvc *chatdebug.Service
ChatID uuid.UUID
HistoryTipMessageID int64
// ToolCallID and ToolName identify the synthetic tool call
// used to represent compaction in the message stream.
@@ -269,6 +274,92 @@ func shouldCompact(contextTokens, contextLimit int64, thresholdPercent int32) (f
return usagePercent, usagePercent >= float64(thresholdPercent)
}
func startCompactionDebugRun(
ctx context.Context,
options CompactionOptions,
) (context.Context, func(error)) {
if options.DebugSvc == nil || options.ChatID == uuid.Nil {
return ctx, func(error) {}
}
parentRun, ok := chatdebug.RunFromContext(ctx)
if !ok {
return ctx, func(error) {}
}
historyTipMessageID := options.HistoryTipMessageID
if historyTipMessageID == 0 {
historyTipMessageID = parentRun.HistoryTipMessageID
}
run, err := options.DebugSvc.CreateRun(ctx, chatdebug.CreateRunParams{
ChatID: options.ChatID,
RootChatID: parentRun.RootChatID,
ParentChatID: parentRun.ParentChatID,
ModelConfigID: parentRun.ModelConfigID,
TriggerMessageID: parentRun.TriggerMessageID,
HistoryTipMessageID: historyTipMessageID,
Kind: chatdebug.KindCompaction,
Status: chatdebug.StatusInProgress,
Provider: parentRun.Provider,
Model: parentRun.Model,
})
if err != nil {
// Debug instrumentation must not surface as a compaction failure.
return ctx, func(error) {}
}
compactionCtx := chatdebug.ContextWithRun(ctx, &chatdebug.RunContext{
RunID: run.ID,
ChatID: options.ChatID,
RootChatID: parentRun.RootChatID,
ParentChatID: parentRun.ParentChatID,
ModelConfigID: parentRun.ModelConfigID,
TriggerMessageID: parentRun.TriggerMessageID,
HistoryTipMessageID: historyTipMessageID,
Kind: chatdebug.KindCompaction,
Provider: parentRun.Provider,
Model: parentRun.Model,
})
return compactionCtx, func(runErr error) {
status := chatdebug.StatusCompleted
if runErr != nil {
status = chatdebug.StatusError
if xerrors.Is(runErr, ErrInterrupted) || xerrors.Is(runErr, context.Canceled) {
status = chatdebug.StatusInterrupted
}
}
finalizeCtx, finalizeCancel := context.WithTimeout(
context.WithoutCancel(compactionCtx),
5*time.Second,
)
defer finalizeCancel()
finalSummary := map[string]any(nil)
if aggregated, aggErr := options.DebugSvc.AggregateRunSummary(
finalizeCtx,
run.ID,
nil,
); aggErr == nil {
finalSummary = aggregated
}
// Debug instrumentation must not surface as a compaction failure.
_, _ = options.DebugSvc.UpdateRun(
finalizeCtx,
chatdebug.UpdateRunParams{
ID: run.ID,
ChatID: options.ChatID,
Status: status,
Summary: finalSummary,
FinishedAt: time.Now(),
},
)
chatdebug.CleanupStepCounter(run.ID)
}
}
// generateCompactionSummary asks the model to summarize the
// conversation so far. The provided messages should contain the
// complete history (system prompt, user/assistant turns, tool
@@ -279,7 +370,7 @@ func generateCompactionSummary(
model fantasy.LanguageModel,
messages []fantasy.Message,
options CompactionOptions,
) (string, error) {
) (summary string, err error) {
summaryPrompt := make([]fantasy.Message, 0, len(messages)+1)
summaryPrompt = append(summaryPrompt, messages...)
summaryPrompt = append(summaryPrompt, fantasy.Message{
@@ -293,6 +384,11 @@ func generateCompactionSummary(
summaryCtx, cancel := context.WithTimeout(ctx, options.Timeout)
defer cancel()
summaryCtx, finishDebugRun := startCompactionDebugRun(summaryCtx, options)
defer func() {
finishDebugRun(err)
}()
response, err := model.Generate(summaryCtx, fantasy.Call{
Prompt: summaryPrompt,
ToolChoice: &toolChoice,
+151
View File
@@ -2,17 +2,168 @@ package chatloop //nolint:testpackage // Uses internal symbols.
import (
"context"
"encoding/json"
"sync"
"testing"
"charm.land/fantasy"
"github.com/google/uuid"
"github.com/sqlc-dev/pqtype"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbmock"
"github.com/coder/coder/v2/coderd/x/chatd/chatdebug"
"github.com/coder/coder/v2/coderd/x/chatd/chattest"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/testutil"
)
func TestStartCompactionDebugRun_DoesNotReportDebugErrors(t *testing.T) {
t.Parallel()
newParentContext := func(chatID uuid.UUID) context.Context {
return chatdebug.ContextWithRun(context.Background(), &chatdebug.RunContext{
RunID: uuid.New(),
ChatID: chatID,
RootChatID: uuid.New(),
ParentChatID: uuid.New(),
ModelConfigID: uuid.New(),
TriggerMessageID: 41,
HistoryTipMessageID: 42,
Kind: chatdebug.KindChatTurn,
Provider: "fake-provider",
Model: "fake-model",
})
}
t.Run("CreateRun", func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
svc := chatdebug.NewService(db, testutil.Logger(t), nil)
chatID := uuid.New()
reportedErr := make(chan error, 1)
db.EXPECT().InsertChatDebugRun(
gomock.Any(),
gomock.AssignableToTypeOf(database.InsertChatDebugRunParams{}),
).Return(database.ChatDebugRun{}, xerrors.New("insert compaction debug run"))
ctx := newParentContext(chatID)
compactionCtx, finish := startCompactionDebugRun(ctx, CompactionOptions{
DebugSvc: svc,
ChatID: chatID,
OnError: func(err error) {
reportedErr <- err
},
})
require.Same(t, ctx, compactionCtx)
finish(nil)
select {
case err := <-reportedErr:
t.Fatalf("unexpected OnError callback: %v", err)
default:
}
})
t.Run("FinalizeRunAggregatesSummary", func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
svc := chatdebug.NewService(db, testutil.Logger(t), nil)
chatID := uuid.New()
runID := uuid.New()
usageJSON, err := json.Marshal(fantasy.Usage{InputTokens: 7, OutputTokens: 3})
require.NoError(t, err)
attemptsJSON, err := json.Marshal([]chatdebug.Attempt{{
Status: "completed",
Method: "POST",
Path: "/v1/messages",
}})
require.NoError(t, err)
db.EXPECT().InsertChatDebugRun(
gomock.Any(),
gomock.AssignableToTypeOf(database.InsertChatDebugRunParams{}),
).Return(database.ChatDebugRun{ //nolint:exhaustruct // Test only needs IDs.
ID: runID,
ChatID: chatID,
}, nil)
db.EXPECT().GetChatDebugStepsByRunID(gomock.Any(), runID).Return([]database.ChatDebugStep{{
ID: uuid.New(),
RunID: runID,
ChatID: chatID,
Status: string(chatdebug.StatusCompleted),
Usage: pqtype.NullRawMessage{RawMessage: usageJSON, Valid: true},
Attempts: attemptsJSON,
}}, nil)
db.EXPECT().UpdateChatDebugRun(
gomock.Any(),
gomock.AssignableToTypeOf(database.UpdateChatDebugRunParams{}),
).DoAndReturn(func(_ context.Context, params database.UpdateChatDebugRunParams) (database.ChatDebugRun, error) {
require.Equal(t, chatID, params.ChatID)
require.Equal(t, runID, params.ID)
require.True(t, params.Summary.Valid)
require.JSONEq(t, `{"endpoint_label":"POST /v1/messages","step_count":1,"total_input_tokens":7,"total_output_tokens":3}`,
string(params.Summary.RawMessage))
return database.ChatDebugRun{ID: runID, ChatID: chatID}, nil
})
ctx := newParentContext(chatID)
compactionCtx, finish := startCompactionDebugRun(ctx, CompactionOptions{
DebugSvc: svc,
ChatID: chatID,
})
require.NotSame(t, ctx, compactionCtx)
finish(nil)
})
t.Run("FinalizeRun", func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
svc := chatdebug.NewService(db, testutil.Logger(t), nil)
chatID := uuid.New()
reportedErr := make(chan error, 1)
runID := uuid.New()
db.EXPECT().InsertChatDebugRun(
gomock.Any(),
gomock.AssignableToTypeOf(database.InsertChatDebugRunParams{}),
).Return(database.ChatDebugRun{ //nolint:exhaustruct // Test only needs IDs.
ID: runID,
ChatID: chatID,
}, nil)
db.EXPECT().GetChatDebugStepsByRunID(gomock.Any(), runID).Return(nil, xerrors.New("aggregate compaction debug run"))
db.EXPECT().UpdateChatDebugRun(
gomock.Any(),
gomock.AssignableToTypeOf(database.UpdateChatDebugRunParams{}),
).Return(database.ChatDebugRun{}, xerrors.New("finalize compaction debug run"))
ctx := newParentContext(chatID)
compactionCtx, finish := startCompactionDebugRun(ctx, CompactionOptions{
DebugSvc: svc,
ChatID: chatID,
OnError: func(err error) {
reportedErr <- err
},
})
require.NotSame(t, ctx, compactionCtx)
finish(nil)
select {
case err := <-reportedErr:
t.Fatalf("unexpected OnError callback: %v", err)
default:
}
})
}
func TestRun_Compaction(t *testing.T) {
t.Parallel()
+28 -1
View File
@@ -2,6 +2,7 @@ package chatprovider
import (
"context"
"net/http"
"sort"
"strings"
@@ -1114,13 +1115,15 @@ func CoderHeadersFromIDs(
// language model client using the provided provider credentials. The
// userAgent is sent as the User-Agent header on every outgoing LLM
// API request. extraHeaders, when non-nil, are sent as additional
// HTTP headers on every request.
// HTTP headers on every request. httpClient, when non-nil, is used for
// all provider HTTP requests.
func ModelFromConfig(
providerHint string,
modelName string,
providerKeys ProviderAPIKeys,
userAgent string,
extraHeaders map[string]string,
httpClient *http.Client,
) (fantasy.LanguageModel, error) {
provider, modelID, err := ResolveModelWithProviderHint(modelName, providerHint)
if err != nil {
@@ -1146,6 +1149,9 @@ func ModelFromConfig(
if baseURL != "" {
options = append(options, fantasyanthropic.WithBaseURL(baseURL))
}
if httpClient != nil {
options = append(options, fantasyanthropic.WithHTTPClient(httpClient))
}
providerClient, err = fantasyanthropic.New(options...)
case fantasyazure.Name:
if baseURL == "" {
@@ -1160,6 +1166,9 @@ func ModelFromConfig(
if len(extraHeaders) > 0 {
azureOpts = append(azureOpts, fantasyazure.WithHeaders(extraHeaders))
}
if httpClient != nil {
azureOpts = append(azureOpts, fantasyazure.WithHTTPClient(httpClient))
}
providerClient, err = fantasyazure.New(azureOpts...)
case fantasybedrock.Name:
bedrockOpts := []fantasybedrock.Option{
@@ -1169,6 +1178,9 @@ func ModelFromConfig(
if len(extraHeaders) > 0 {
bedrockOpts = append(bedrockOpts, fantasybedrock.WithHeaders(extraHeaders))
}
if httpClient != nil {
bedrockOpts = append(bedrockOpts, fantasybedrock.WithHTTPClient(httpClient))
}
providerClient, err = fantasybedrock.New(bedrockOpts...)
case fantasygoogle.Name:
options := []fantasygoogle.Option{
@@ -1181,6 +1193,9 @@ func ModelFromConfig(
if baseURL != "" {
options = append(options, fantasygoogle.WithBaseURL(baseURL))
}
if httpClient != nil {
options = append(options, fantasygoogle.WithHTTPClient(httpClient))
}
providerClient, err = fantasygoogle.New(options...)
case fantasyopenai.Name:
options := []fantasyopenai.Option{
@@ -1194,6 +1209,9 @@ func ModelFromConfig(
if baseURL != "" {
options = append(options, fantasyopenai.WithBaseURL(baseURL))
}
if httpClient != nil {
options = append(options, fantasyopenai.WithHTTPClient(httpClient))
}
providerClient, err = fantasyopenai.New(options...)
case fantasyopenaicompat.Name:
options := []fantasyopenaicompat.Option{
@@ -1206,6 +1224,9 @@ func ModelFromConfig(
if baseURL != "" {
options = append(options, fantasyopenaicompat.WithBaseURL(baseURL))
}
if httpClient != nil {
options = append(options, fantasyopenaicompat.WithHTTPClient(httpClient))
}
providerClient, err = fantasyopenaicompat.New(options...)
case fantasyopenrouter.Name:
routerOpts := []fantasyopenrouter.Option{
@@ -1215,6 +1236,9 @@ func ModelFromConfig(
if len(extraHeaders) > 0 {
routerOpts = append(routerOpts, fantasyopenrouter.WithHeaders(extraHeaders))
}
if httpClient != nil {
routerOpts = append(routerOpts, fantasyopenrouter.WithHTTPClient(httpClient))
}
providerClient, err = fantasyopenrouter.New(routerOpts...)
case fantasyvercel.Name:
options := []fantasyvercel.Option{
@@ -1227,6 +1251,9 @@ func ModelFromConfig(
if baseURL != "" {
options = append(options, fantasyvercel.WithBaseURL(baseURL))
}
if httpClient != nil {
options = append(options, fantasyvercel.WithHTTPClient(httpClient))
}
providerClient, err = fantasyvercel.New(options...)
default:
return nil, xerrors.Errorf("unsupported model provider %q", provider)
@@ -181,6 +181,12 @@ func TestResolveUserProviderKeys(t *testing.T) {
}
}
type roundTripperFunc func(*http.Request) (*http.Response, error)
func (fn roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return fn(req)
}
func TestReasoningEffortFromChat(t *testing.T) {
t.Parallel()
@@ -777,7 +783,7 @@ func TestModelFromConfig_ExtraHeaders(t *testing.T) {
BaseURLByProvider: map[string]string{"openai": serverURL},
}
model, err := chatprovider.ModelFromConfig("openai", "gpt-4", keys, chatprovider.UserAgent(), headers)
model, err := chatprovider.ModelFromConfig("openai", "gpt-4", keys, chatprovider.UserAgent(), headers, nil)
require.NoError(t, err)
_, err = model.Generate(ctx, fantasy.Call{
@@ -808,7 +814,7 @@ func TestModelFromConfig_ExtraHeaders(t *testing.T) {
BaseURLByProvider: map[string]string{"anthropic": serverURL},
}
model, err := chatprovider.ModelFromConfig("anthropic", "claude-sonnet-4-20250514", keys, chatprovider.UserAgent(), headers)
model, err := chatprovider.ModelFromConfig("anthropic", "claude-sonnet-4-20250514", keys, chatprovider.UserAgent(), headers, nil)
require.NoError(t, err)
_, err = model.Generate(ctx, fantasy.Call{
@@ -844,7 +850,7 @@ func TestModelFromConfig_NilExtraHeaders(t *testing.T) {
BaseURLByProvider: map[string]string{"openai": serverURL},
}
model, err := chatprovider.ModelFromConfig("openai", "gpt-4", keys, chatprovider.UserAgent(), nil)
model, err := chatprovider.ModelFromConfig("openai", "gpt-4", keys, chatprovider.UserAgent(), nil, nil)
require.NoError(t, err)
_, err = model.Generate(ctx, fantasy.Call{
@@ -859,6 +865,48 @@ func TestModelFromConfig_NilExtraHeaders(t *testing.T) {
_ = testutil.TryReceive(ctx, t, called)
}
func TestModelFromConfig_HTTPClient(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
called := make(chan struct{})
serverURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
assert.Equal(t, "true", req.Header.Get("X-Test-Transport"))
close(called)
return chattest.OpenAINonStreamingResponse("hello")
})
keys := chatprovider.ProviderAPIKeys{
ByProvider: map[string]string{"openai": "test-key"},
BaseURLByProvider: map[string]string{"openai": serverURL},
}
client := &http.Client{Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
cloned := req.Clone(req.Context())
cloned.Header = req.Header.Clone()
cloned.Header.Set("X-Test-Transport", "true")
return http.DefaultTransport.RoundTrip(cloned)
})}
model, err := chatprovider.ModelFromConfig(
"openai",
"gpt-4",
keys,
chatprovider.UserAgent(),
nil,
client,
)
require.NoError(t, err)
_, err = model.Generate(ctx, fantasy.Call{
Prompt: []fantasy.Message{{
Role: fantasy.MessageRoleUser,
Content: []fantasy.MessagePart{fantasy.TextPart{Text: "hello"}},
}},
})
require.NoError(t, err)
_ = testutil.TryReceive(ctx, t, called)
}
func TestMergeMissingProviderOptions_OpenRouterNested(t *testing.T) {
t.Parallel()
@@ -48,7 +48,7 @@ func TestModelFromConfig_UserAgent(t *testing.T) {
BaseURLByProvider: map[string]string{"openai": serverURL},
}
model, err := chatprovider.ModelFromConfig("openai", "gpt-4", keys, expectedUA, nil)
model, err := chatprovider.ModelFromConfig("openai", "gpt-4", keys, expectedUA, nil, nil)
require.NoError(t, err)
// Make a real call so Fantasy sends an HTTP request to the
+317 -11
View File
@@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"net/http"
"slices"
"strings"
"time"
@@ -21,6 +22,7 @@ import (
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/x/chatd/chatdebug"
"github.com/coder/coder/v2/coderd/x/chatd/chatprompt"
"github.com/coder/coder/v2/coderd/x/chatd/chatprovider"
"github.com/coder/coder/v2/coderd/x/chatd/chatretry"
@@ -105,35 +107,173 @@ func (p *Server) maybeGenerateChatTitle(
ctx context.Context,
chat database.Chat,
messages []database.ChatMessage,
fallbackProvider string,
fallbackModelName string,
fallbackModel fantasy.LanguageModel,
keys chatprovider.ProviderAPIKeys,
generatedTitle *generatedChatTitle,
logger slog.Logger,
debugSvc *chatdebug.Service,
) {
input, ok := titleInput(chat, messages)
if !ok {
return
}
debugEnabled := debugSvc != nil && debugSvc.IsEnabled(ctx, chat.ID, chat.OwnerID)
titleCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
type candidateDescriptor struct {
provider string
model string
lm fantasy.LanguageModel
}
// Build candidate list: preferred lightweight models first,
// then the user's chat model as last resort.
candidates := make([]fantasy.LanguageModel, 0, len(preferredTitleModels)+1)
candidates := make([]candidateDescriptor, 0, len(preferredTitleModels)+1)
for _, c := range preferredTitleModels {
m, err := chatprovider.ModelFromConfig(
c.provider, c.model, keys, chatprovider.UserAgent(),
chatprovider.CoderHeaders(chat),
nil,
)
if err == nil {
candidates = append(candidates, m)
candidates = append(candidates, candidateDescriptor{
provider: c.provider,
model: c.model,
lm: m,
})
}
}
candidates = append(candidates, fallbackModel)
candidates = append(candidates, candidateDescriptor{
provider: fallbackProvider,
model: fallbackModelName,
lm: fallbackModel,
})
var historyTipMessageID int64
if len(messages) > 0 {
historyTipMessageID = messages[len(messages)-1].ID
}
var triggerMessageID int64
for _, message := range messages {
if message.Role == database.ChatMessageRoleUser {
triggerMessageID = message.ID
break
}
}
seedSummary := chatdebug.SeedSummary(
chatdebug.TruncateLabel(input, chatdebug.MaxLabelLength),
)
var lastErr error
for _, model := range candidates {
title, err := generateTitle(titleCtx, model, input)
for _, candidate := range candidates {
candidateModel := candidate.lm
candidateCtx := titleCtx
var debugRun *database.ChatDebugRun
if debugEnabled {
run, err := debugSvc.CreateRun(titleCtx, chatdebug.CreateRunParams{
ChatID: chat.ID,
TriggerMessageID: triggerMessageID,
HistoryTipMessageID: historyTipMessageID,
Kind: chatdebug.KindTitleGeneration,
Status: chatdebug.StatusInProgress,
Provider: candidate.provider,
Model: candidate.model,
Summary: seedSummary,
})
if err != nil {
logger.Warn(ctx, "failed to create title debug run",
slog.F("chat_id", chat.ID),
slog.F("provider", candidate.provider),
slog.F("model", candidate.model),
slog.Error(err),
)
} else {
debugRun = &run
candidateCtx = chatdebug.ContextWithRun(
candidateCtx,
&chatdebug.RunContext{
RunID: run.ID,
ChatID: chat.ID,
TriggerMessageID: triggerMessageID,
HistoryTipMessageID: historyTipMessageID,
Kind: chatdebug.KindTitleGeneration,
Provider: candidate.provider,
Model: candidate.model,
},
)
debugModel, err := newQuickgenDebugModel(
chat,
keys,
debugSvc,
candidate.provider,
candidate.model,
)
if err != nil {
logger.Warn(ctx, "failed to build title debug model",
slog.F("chat_id", chat.ID),
slog.F("provider", candidate.provider),
slog.F("model", candidate.model),
slog.Error(err),
)
} else {
candidateModel = debugModel
}
}
}
title, err := generateTitle(candidateCtx, candidateModel, input)
if debugRun != nil {
status := chatdebug.StatusCompleted
switch {
case err == nil:
// keep completed
case errors.Is(err, context.Canceled):
status = chatdebug.StatusInterrupted
default:
status = chatdebug.StatusError
}
finalizeCtx, finalizeCancel := context.WithTimeout(
context.WithoutCancel(ctx), 10*time.Second,
)
finalSummary := seedSummary
if aggregated, aggErr := debugSvc.AggregateRunSummary(
finalizeCtx,
debugRun.ID,
seedSummary,
); aggErr != nil {
logger.Warn(ctx, "failed to aggregate debug run summary",
slog.F("chat_id", chat.ID),
slog.F("run_id", debugRun.ID),
slog.Error(aggErr),
)
} else {
finalSummary = aggregated
}
if _, updateErr := debugSvc.UpdateRun(
finalizeCtx,
chatdebug.UpdateRunParams{
ID: debugRun.ID,
ChatID: chat.ID,
Status: status,
Summary: finalSummary,
FinishedAt: time.Now(),
},
); updateErr != nil {
logger.Warn(ctx, "failed to finalize title debug run",
slog.F("chat_id", chat.ID),
slog.F("run_id", debugRun.ID),
slog.Error(updateErr),
)
}
chatdebug.CleanupStepCounter(debugRun.ID)
finalizeCancel()
}
if err != nil {
lastErr = err
logger.Debug(ctx, "title model candidate failed",
@@ -171,6 +311,41 @@ func (p *Server) maybeGenerateChatTitle(
}
}
func newQuickgenDebugModel(
chat database.Chat,
keys chatprovider.ProviderAPIKeys,
debugSvc *chatdebug.Service,
provider string,
model string,
) (fantasy.LanguageModel, error) {
httpClient := &http.Client{Transport: &chatdebug.RecordingTransport{}}
debugModel, err := chatprovider.ModelFromConfig(
provider,
model,
keys,
chatprovider.UserAgent(),
chatprovider.CoderHeaders(chat),
httpClient,
)
if err != nil {
return nil, err
}
if debugModel == nil {
return nil, xerrors.Errorf(
"create model for %s/%s returned nil",
provider,
model,
)
}
return chatdebug.WrapModel(debugModel, debugSvc, chatdebug.RecorderOptions{
ChatID: chat.ID,
OwnerID: chat.OwnerID,
Provider: provider,
Model: model,
}), nil
}
// generateTitle calls the model with a title-generation system prompt
// and returns the normalized result. It retries transient LLM errors
// (rate limits, overloaded, etc.) with exponential backoff.
@@ -571,30 +746,160 @@ func generatePushSummary(
ctx context.Context,
chat database.Chat,
assistantText string,
fallbackProvider string,
fallbackModelName string,
fallbackModel fantasy.LanguageModel,
keys chatprovider.ProviderAPIKeys,
logger slog.Logger,
debugSvc *chatdebug.Service,
triggerMessageID int64,
historyTipMessageID int64,
) string {
debugEnabled := debugSvc != nil && debugSvc.IsEnabled(ctx, chat.ID, chat.OwnerID)
summaryCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
assistantText = truncateRunes(assistantText, maxConversationContextRunes)
input := "Chat title: " + chat.Title + "\n\nAgent's last message:\n" + assistantText
candidates := make([]fantasy.LanguageModel, 0, len(preferredTitleModels)+1)
type candidateDescriptor struct {
provider string
model string
lm fantasy.LanguageModel
}
candidates := make([]candidateDescriptor, 0, len(preferredTitleModels)+1)
for _, c := range preferredTitleModels {
m, err := chatprovider.ModelFromConfig(
c.provider, c.model, keys, chatprovider.UserAgent(),
chatprovider.CoderHeaders(chat),
nil,
)
if err == nil {
candidates = append(candidates, m)
candidates = append(candidates, candidateDescriptor{
provider: c.provider,
model: c.model,
lm: m,
})
}
}
candidates = append(candidates, fallbackModel)
candidates = append(candidates, candidateDescriptor{
provider: fallbackProvider,
model: fallbackModelName,
lm: fallbackModel,
})
for _, model := range candidates {
summary, err := generateShortText(summaryCtx, model, pushSummaryPrompt, input)
pushSeedSummary := chatdebug.SeedSummary("Push summary")
for _, candidate := range candidates {
candidateModel := candidate.lm
candidateCtx := summaryCtx
var debugRun *database.ChatDebugRun
if debugEnabled {
run, err := debugSvc.CreateRun(summaryCtx, chatdebug.CreateRunParams{
ChatID: chat.ID,
TriggerMessageID: triggerMessageID,
HistoryTipMessageID: historyTipMessageID,
Kind: chatdebug.KindQuickgen,
Status: chatdebug.StatusInProgress,
Provider: candidate.provider,
Model: candidate.model,
Summary: pushSeedSummary,
})
if err != nil {
logger.Warn(ctx, "failed to create quickgen debug run",
slog.F("chat_id", chat.ID),
slog.F("provider", candidate.provider),
slog.F("model", candidate.model),
slog.Error(err),
)
} else {
debugRun = &run
candidateCtx = chatdebug.ContextWithRun(
candidateCtx,
&chatdebug.RunContext{
RunID: run.ID,
ChatID: chat.ID,
TriggerMessageID: triggerMessageID,
HistoryTipMessageID: historyTipMessageID,
Kind: chatdebug.KindQuickgen,
Provider: candidate.provider,
Model: candidate.model,
},
)
debugModel, err := newQuickgenDebugModel(
chat,
keys,
debugSvc,
candidate.provider,
candidate.model,
)
if err != nil {
logger.Warn(ctx, "failed to build quickgen debug model",
slog.F("chat_id", chat.ID),
slog.F("provider", candidate.provider),
slog.F("model", candidate.model),
slog.Error(err),
)
} else {
candidateModel = debugModel
}
}
}
summary, err := generateShortText(
candidateCtx,
candidateModel,
pushSummaryPrompt,
input,
)
if debugRun != nil {
status := chatdebug.StatusCompleted
switch {
case err == nil:
// keep completed
case errors.Is(err, context.Canceled):
status = chatdebug.StatusInterrupted
default:
status = chatdebug.StatusError
}
finalizeCtx, finalizeCancel := context.WithTimeout(
context.WithoutCancel(ctx), 10*time.Second,
)
finalSummary := pushSeedSummary
if aggregated, aggErr := debugSvc.AggregateRunSummary(
finalizeCtx,
debugRun.ID,
pushSeedSummary,
); aggErr != nil {
logger.Warn(ctx, "failed to aggregate debug run summary",
slog.F("chat_id", chat.ID),
slog.F("run_id", debugRun.ID),
slog.Error(aggErr),
)
} else {
finalSummary = aggregated
}
if _, updateErr := debugSvc.UpdateRun(
finalizeCtx,
chatdebug.UpdateRunParams{
ID: debugRun.ID,
ChatID: chat.ID,
Status: status,
Summary: finalSummary,
FinishedAt: time.Now(),
},
); updateErr != nil {
logger.Warn(ctx, "failed to finalize quickgen debug run",
slog.F("chat_id", chat.ID),
slog.F("run_id", debugRun.ID),
slog.Error(updateErr),
)
}
chatdebug.CleanupStepCounter(debugRun.ID)
finalizeCancel()
}
if err != nil {
logger.Debug(ctx, "push summary model candidate failed",
slog.Error(err),
@@ -610,7 +915,8 @@ func generatePushSummary(
// generateShortText calls a model with a system prompt and user
// input, returning a cleaned-up short text response. It reuses the
// same retry logic as title generation.
// same retry logic as title generation. Retries can therefore
// produce multiple debug steps for a single quickgen run.
func generateShortText(
ctx context.Context,
model fantasy.LanguageModel,