feat(coderd/x/chatd): wire debug logging into chat lifecycle
Signed-off-by: Thomas Kosiewski <tk@coder.com>
This commit is contained in:
+515
-55
@@ -34,6 +34,7 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/webpush"
|
||||
"github.com/coder/coder/v2/coderd/workspacestats"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatcost"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatdebug"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chaterror"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatloop"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatprompt"
|
||||
@@ -129,6 +130,7 @@ type Server struct {
|
||||
pubsub pubsub.Pubsub
|
||||
webpushDispatcher webpush.Dispatcher
|
||||
providerAPIKeys chatprovider.ProviderAPIKeys
|
||||
debugSvc *chatdebug.Service
|
||||
configCache *chatConfigCache
|
||||
configCacheUnsubscribe func()
|
||||
|
||||
@@ -1210,7 +1212,10 @@ func (p *Server) EditMessage(
|
||||
return EditMessageResult{}, xerrors.Errorf("marshal message content: %w", err)
|
||||
}
|
||||
|
||||
var result EditMessageResult
|
||||
var (
|
||||
result EditMessageResult
|
||||
editedMsg database.ChatMessage
|
||||
)
|
||||
txErr := p.db.InTx(func(tx database.Store) error {
|
||||
lockedChat, err := tx.GetChatByIDForUpdate(ctx, opts.ChatID)
|
||||
if err != nil {
|
||||
@@ -1221,17 +1226,17 @@ func (p *Server) EditMessage(
|
||||
return limitErr
|
||||
}
|
||||
|
||||
existing, err := tx.GetChatMessageByID(ctx, opts.EditedMessageID)
|
||||
editedMsg, err = tx.GetChatMessageByID(ctx, opts.EditedMessageID)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return ErrEditedMessageNotFound
|
||||
}
|
||||
return xerrors.Errorf("get edited message: %w", err)
|
||||
}
|
||||
if existing.ChatID != opts.ChatID {
|
||||
if editedMsg.ChatID != opts.ChatID {
|
||||
return ErrEditedMessageNotFound
|
||||
}
|
||||
if existing.Role != database.ChatMessageRoleUser {
|
||||
if editedMsg.Role != database.ChatMessageRoleUser {
|
||||
return ErrEditedMessageNotUser
|
||||
}
|
||||
|
||||
@@ -1258,8 +1263,8 @@ func (p *Server) EditMessage(
|
||||
appendChatMessage(&msgParams, newChatMessage(
|
||||
database.ChatMessageRoleUser,
|
||||
content,
|
||||
existing.Visibility,
|
||||
existing.ModelConfigID.UUID,
|
||||
editedMsg.Visibility,
|
||||
editedMsg.ModelConfigID.UUID,
|
||||
chatprompt.CurrentContentVersion,
|
||||
).withCreatedBy(opts.CreatedBy))
|
||||
newMessages, err := insertChatMessageWithStore(ctx, tx, msgParams)
|
||||
@@ -1302,6 +1307,26 @@ func (p *Server) EditMessage(
|
||||
})
|
||||
p.publishStatus(opts.ChatID, result.Chat.Status, result.Chat.WorkerID)
|
||||
p.publishChatPubsubEvent(result.Chat, codersdk.ChatWatchEventKindStatusChange, nil)
|
||||
|
||||
// Best-effort debug row cleanup. We do not wait for the active
|
||||
// worker to stop because activeChats is process-local and would
|
||||
// not cover multi-replica deployments. Any rows that survive
|
||||
// this pass are caught by the periodic stale-finalization sweep.
|
||||
if p.debugSvc != nil {
|
||||
cleanupCtx, cleanupCancel := context.WithTimeout(context.WithoutCancel(ctx), 5*time.Second)
|
||||
defer cleanupCancel()
|
||||
if _, err := p.debugSvc.DeleteAfterMessageID(
|
||||
cleanupCtx,
|
||||
opts.ChatID,
|
||||
editedMsg.ID-1,
|
||||
); err != nil {
|
||||
p.logger.Warn(ctx, "failed to delete chat debug rows after edit",
|
||||
slog.F("chat_id", opts.ChatID),
|
||||
slog.F("edited_message_id", editedMsg.ID),
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
}
|
||||
p.signalWake()
|
||||
|
||||
return result, nil
|
||||
@@ -1316,46 +1341,67 @@ func (p *Server) ArchiveChat(ctx context.Context, chat database.Chat) error {
|
||||
return xerrors.New("chat_id is required")
|
||||
}
|
||||
|
||||
statusChat := chat
|
||||
interrupted := false
|
||||
var archivedChats []database.Chat
|
||||
var (
|
||||
archivedChats []database.Chat
|
||||
interruptedChats []database.Chat
|
||||
)
|
||||
if err := p.db.InTx(func(tx database.Store) error {
|
||||
lockedChat, err := tx.GetChatByIDForUpdate(ctx, chat.ID)
|
||||
if err != nil {
|
||||
if _, err := tx.GetChatByIDForUpdate(ctx, chat.ID); err != nil {
|
||||
return xerrors.Errorf("lock chat for archive: %w", err)
|
||||
}
|
||||
statusChat = lockedChat
|
||||
|
||||
// We do not call setChatWaiting here because it intentionally preserves
|
||||
// pending chats so queued-message promotion can win. Archiving is a
|
||||
// harder stop: both pending and running chats must transition to waiting.
|
||||
if lockedChat.Status == database.ChatStatusPending || lockedChat.Status == database.ChatStatusRunning {
|
||||
statusChat, err = tx.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
||||
ID: chat.ID,
|
||||
var err error
|
||||
archivedChats, err = tx.ArchiveChatByID(ctx, chat.ID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("archive chat: %w", err)
|
||||
}
|
||||
|
||||
for i, archivedChat := range archivedChats {
|
||||
if archivedChat.Status != database.ChatStatusPending &&
|
||||
archivedChat.Status != database.ChatStatusRunning {
|
||||
continue
|
||||
}
|
||||
|
||||
updatedChat, updateErr := tx.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
||||
ID: archivedChat.ID,
|
||||
Status: database.ChatStatusWaiting,
|
||||
WorkerID: uuid.NullUUID{},
|
||||
StartedAt: sql.NullTime{},
|
||||
HeartbeatAt: sql.NullTime{},
|
||||
LastError: sql.NullString{},
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("set chat waiting before archive: %w", err)
|
||||
if updateErr != nil {
|
||||
return xerrors.Errorf("set archived chat waiting before cleanup: %w", updateErr)
|
||||
}
|
||||
interrupted = true
|
||||
}
|
||||
|
||||
archivedChats, err = tx.ArchiveChatByID(ctx, chat.ID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("archive chat: %w", err)
|
||||
archivedChats[i] = updatedChat
|
||||
interruptedChats = append(interruptedChats, updatedChat)
|
||||
}
|
||||
return nil
|
||||
}, nil); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if interrupted {
|
||||
p.publishStatus(chat.ID, statusChat.Status, statusChat.WorkerID)
|
||||
p.publishChatPubsubEvent(statusChat, codersdk.ChatWatchEventKindStatusChange, nil)
|
||||
for _, interruptedChat := range interruptedChats {
|
||||
p.publishStatus(interruptedChat.ID, interruptedChat.Status, interruptedChat.WorkerID)
|
||||
p.publishChatPubsubEvent(interruptedChat, codersdk.ChatWatchEventKindStatusChange, nil)
|
||||
}
|
||||
|
||||
// Best-effort debug row cleanup — no process-local wait so this
|
||||
// works correctly across replicas. If an active goroutine writes
|
||||
// new debug rows after the delete, FinalizeStale will mark them
|
||||
// as interrupted. Those orphaned rows are harmless because the
|
||||
// chat itself is archived and no longer served through the API.
|
||||
if p.debugSvc != nil {
|
||||
for _, archivedChat := range archivedChats {
|
||||
cleanupCtx, cleanupCancel := context.WithTimeout(context.WithoutCancel(ctx), 5*time.Second)
|
||||
if _, err := p.debugSvc.DeleteByChatID(cleanupCtx, archivedChat.ID); err != nil {
|
||||
p.logger.Warn(ctx, "failed to delete chat debug rows after archive",
|
||||
slog.F("chat_id", archivedChat.ID),
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
cleanupCancel()
|
||||
}
|
||||
}
|
||||
|
||||
p.publishChatPubsubEvents(archivedChats, codersdk.ChatWatchEventKindDeleted)
|
||||
@@ -1818,6 +1864,8 @@ func (p *Server) InterruptChat(
|
||||
}
|
||||
}
|
||||
|
||||
// Debug runs are finalized in the execution path when the owning
|
||||
// goroutine observes cancellation, so we do not mutate debug state here.
|
||||
updatedChat, err := p.setChatWaiting(ctx, chat.ID)
|
||||
if err != nil {
|
||||
p.logger.Error(ctx, "failed to mark chat as waiting",
|
||||
@@ -2058,7 +2106,23 @@ func (p *Server) regenerateChatTitleWithStore(
|
||||
return database.Chat{}, err
|
||||
}
|
||||
|
||||
title, usage, err := generateManualTitle(ctx, messages, model)
|
||||
debugEnabled := p.debugSvc != nil && p.debugSvc.IsEnabled(ctx, chat.ID, chat.OwnerID)
|
||||
titleCtx := ctx
|
||||
titleModel := model
|
||||
finishDebugRun := func(error) {}
|
||||
if debugEnabled {
|
||||
titleCtx, titleModel, finishDebugRun = p.prepareManualTitleDebugRun(
|
||||
ctx,
|
||||
chat,
|
||||
modelConfig,
|
||||
keys,
|
||||
messages,
|
||||
model,
|
||||
)
|
||||
}
|
||||
|
||||
title, usage, err := generateManualTitle(titleCtx, messages, titleModel)
|
||||
finishDebugRun(err)
|
||||
if err != nil {
|
||||
wrappedErr := xerrors.Errorf("generate manual title: %w", err)
|
||||
if usage == (fantasy.Usage{}) {
|
||||
@@ -2096,6 +2160,177 @@ func (p *Server) regenerateChatTitleWithStore(
|
||||
return updatedChat, nil
|
||||
}
|
||||
|
||||
func (p *Server) prepareManualTitleDebugRun(
|
||||
ctx context.Context,
|
||||
chat database.Chat,
|
||||
modelConfig database.ChatModelConfig,
|
||||
keys chatprovider.ProviderAPIKeys,
|
||||
messages []database.ChatMessage,
|
||||
fallbackModel fantasy.LanguageModel,
|
||||
) (context.Context, fantasy.LanguageModel, func(error)) {
|
||||
titleCtx := ctx
|
||||
titleModel := fallbackModel
|
||||
finishDebugRun := func(error) {}
|
||||
|
||||
httpClient := &http.Client{Transport: &chatdebug.RecordingTransport{}}
|
||||
debugModel, debugModelErr := chatprovider.ModelFromConfig(
|
||||
modelConfig.Provider,
|
||||
modelConfig.Model,
|
||||
keys,
|
||||
chatprovider.UserAgent(),
|
||||
chatprovider.CoderHeaders(chat),
|
||||
httpClient,
|
||||
)
|
||||
switch {
|
||||
case debugModelErr != nil:
|
||||
p.logger.Warn(ctx, "failed to create debug-aware manual title model",
|
||||
slog.F("chat_id", chat.ID),
|
||||
slog.F("provider", modelConfig.Provider),
|
||||
slog.F("model", modelConfig.Model),
|
||||
slog.Error(debugModelErr),
|
||||
)
|
||||
case debugModel == nil:
|
||||
p.logger.Warn(ctx, "manual title debug model creation returned nil",
|
||||
slog.F("chat_id", chat.ID),
|
||||
slog.F("provider", modelConfig.Provider),
|
||||
slog.F("model", modelConfig.Model),
|
||||
)
|
||||
default:
|
||||
titleModel = chatdebug.WrapModel(debugModel, p.debugSvc, chatdebug.RecorderOptions{
|
||||
ChatID: chat.ID,
|
||||
OwnerID: chat.OwnerID,
|
||||
Provider: modelConfig.Provider,
|
||||
Model: modelConfig.Model,
|
||||
})
|
||||
}
|
||||
|
||||
var historyTipMessageID int64
|
||||
if len(messages) > 0 {
|
||||
historyTipMessageID = messages[len(messages)-1].ID
|
||||
}
|
||||
|
||||
// Derive a first_message label from the first user message.
|
||||
var firstUserLabel string
|
||||
for _, msg := range messages {
|
||||
if msg.Role == database.ChatMessageRoleUser {
|
||||
if parts, parseErr := chatprompt.ParseContent(msg); parseErr == nil {
|
||||
firstUserLabel = contentBlocksToText(parts)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
if firstUserLabel == "" {
|
||||
firstUserLabel = "Title generation"
|
||||
}
|
||||
seedSummary := chatdebug.SeedSummary(
|
||||
chatdebug.TruncateLabel(firstUserLabel, chatdebug.MaxLabelLength),
|
||||
)
|
||||
|
||||
createRunCtx, createRunCancel := context.WithTimeout(context.WithoutCancel(ctx), 5*time.Second)
|
||||
debugRun, createRunErr := p.debugSvc.CreateRun(createRunCtx, chatdebug.CreateRunParams{
|
||||
ChatID: chat.ID,
|
||||
ModelConfigID: modelConfig.ID,
|
||||
Provider: modelConfig.Provider,
|
||||
Model: modelConfig.Model,
|
||||
Kind: chatdebug.KindTitleGeneration,
|
||||
Status: chatdebug.StatusInProgress,
|
||||
HistoryTipMessageID: historyTipMessageID,
|
||||
TriggerMessageID: 0,
|
||||
Summary: seedSummary,
|
||||
})
|
||||
createRunCancel()
|
||||
if createRunErr != nil {
|
||||
p.logger.Warn(ctx, "failed to create manual title debug run",
|
||||
slog.F("chat_id", chat.ID),
|
||||
slog.F("provider", modelConfig.Provider),
|
||||
slog.F("model", modelConfig.Model),
|
||||
slog.Error(createRunErr),
|
||||
)
|
||||
return titleCtx, titleModel, finishDebugRun
|
||||
}
|
||||
|
||||
runContext := chatdebugRunContext(debugRun)
|
||||
titleCtx = chatdebug.ContextWithRun(titleCtx, &runContext)
|
||||
finishDebugRun = func(generateErr error) {
|
||||
status := chatdebug.StatusCompleted
|
||||
switch {
|
||||
case generateErr == nil:
|
||||
// keep completed
|
||||
case errors.Is(generateErr, context.Canceled):
|
||||
status = chatdebug.StatusInterrupted
|
||||
default:
|
||||
status = chatdebug.StatusError
|
||||
}
|
||||
|
||||
finalSummary := seedSummary
|
||||
aggCtx, aggCancel := context.WithTimeout(context.WithoutCancel(ctx), 5*time.Second)
|
||||
defer aggCancel()
|
||||
if aggregated, aggErr := p.debugSvc.AggregateRunSummary(
|
||||
aggCtx,
|
||||
debugRun.ID,
|
||||
seedSummary,
|
||||
); aggErr != nil {
|
||||
p.logger.Warn(ctx, "failed to aggregate debug run summary",
|
||||
slog.F("chat_id", chat.ID),
|
||||
slog.F("run_id", debugRun.ID),
|
||||
slog.Error(aggErr),
|
||||
)
|
||||
} else {
|
||||
finalSummary = aggregated
|
||||
}
|
||||
|
||||
updateRunCtx, updateRunCancel := context.WithTimeout(context.WithoutCancel(ctx), 5*time.Second)
|
||||
defer updateRunCancel()
|
||||
_, updateRunErr := p.debugSvc.UpdateRun(updateRunCtx, chatdebug.UpdateRunParams{
|
||||
ID: debugRun.ID,
|
||||
ChatID: debugRun.ChatID,
|
||||
Status: status,
|
||||
Summary: finalSummary,
|
||||
FinishedAt: time.Now(),
|
||||
})
|
||||
if updateRunErr != nil {
|
||||
p.logger.Warn(ctx, "failed to finalize manual title debug run",
|
||||
slog.F("chat_id", chat.ID),
|
||||
slog.F("run_id", debugRun.ID),
|
||||
slog.Error(updateRunErr),
|
||||
)
|
||||
}
|
||||
chatdebug.CleanupStepCounter(debugRun.ID)
|
||||
}
|
||||
|
||||
return titleCtx, titleModel, finishDebugRun
|
||||
}
|
||||
|
||||
func chatdebugRunContext(run database.ChatDebugRun) chatdebug.RunContext {
|
||||
runContext := chatdebug.RunContext{
|
||||
RunID: run.ID,
|
||||
ChatID: run.ChatID,
|
||||
Kind: chatdebug.RunKind(run.Kind),
|
||||
}
|
||||
if run.RootChatID.Valid {
|
||||
runContext.RootChatID = run.RootChatID.UUID
|
||||
}
|
||||
if run.ParentChatID.Valid {
|
||||
runContext.ParentChatID = run.ParentChatID.UUID
|
||||
}
|
||||
if run.ModelConfigID.Valid {
|
||||
runContext.ModelConfigID = run.ModelConfigID.UUID
|
||||
}
|
||||
if run.TriggerMessageID.Valid {
|
||||
runContext.TriggerMessageID = run.TriggerMessageID.Int64
|
||||
}
|
||||
if run.HistoryTipMessageID.Valid {
|
||||
runContext.HistoryTipMessageID = run.HistoryTipMessageID.Int64
|
||||
}
|
||||
if run.Provider.Valid {
|
||||
runContext.Provider = run.Provider.String
|
||||
}
|
||||
if run.Model.Valid {
|
||||
runContext.Model = run.Model.String
|
||||
}
|
||||
return runContext
|
||||
}
|
||||
|
||||
func (p *Server) resolveManualTitleModel(
|
||||
ctx context.Context,
|
||||
store database.Store,
|
||||
@@ -2122,6 +2357,7 @@ func (p *Server) resolveManualTitleModel(
|
||||
keys,
|
||||
chatprovider.UserAgent(),
|
||||
chatprovider.CoderHeaders(chat),
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
p.logger.Debug(ctx, "manual title preferred model unavailable",
|
||||
@@ -2154,6 +2390,7 @@ func (p *Server) resolveFallbackManualTitleModel(
|
||||
keys,
|
||||
chatprovider.UserAgent(),
|
||||
chatprovider.CoderHeaders(chat),
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, database.ChatModelConfig{}, xerrors.Errorf(
|
||||
@@ -2688,6 +2925,7 @@ type Config struct {
|
||||
StartWorkspace chattool.StartWorkspaceFn
|
||||
Pubsub pubsub.Pubsub
|
||||
ProviderAPIKeys chatprovider.ProviderAPIKeys
|
||||
AlwaysEnableDebugLogs bool
|
||||
WebpushDispatcher webpush.Dispatcher
|
||||
UsageTracker *workspacestats.UsageTracker
|
||||
Clock quartz.Clock
|
||||
@@ -2734,6 +2972,14 @@ func New(cfg Config) *Server {
|
||||
workerID = uuid.New()
|
||||
}
|
||||
|
||||
debugSvc := chatdebug.NewService(
|
||||
cfg.Database,
|
||||
cfg.Logger.Named("chatdebug"),
|
||||
cfg.Pubsub,
|
||||
chatdebug.WithAlwaysEnable(cfg.AlwaysEnableDebugLogs),
|
||||
)
|
||||
debugSvc.SetStaleAfter(inFlightChatStaleAfter)
|
||||
|
||||
p := &Server{
|
||||
cancel: cancel,
|
||||
closed: make(chan struct{}),
|
||||
@@ -2749,6 +2995,7 @@ func New(cfg Config) *Server {
|
||||
pubsub: cfg.Pubsub,
|
||||
webpushDispatcher: cfg.WebpushDispatcher,
|
||||
providerAPIKeys: cfg.ProviderAPIKeys,
|
||||
debugSvc: debugSvc,
|
||||
pendingChatAcquireInterval: pendingChatAcquireInterval,
|
||||
maxChatsPerAcquire: maxChatsPerAcquire,
|
||||
inFlightChatStaleAfter: inFlightChatStaleAfter,
|
||||
@@ -2797,6 +3044,12 @@ func (p *Server) start(ctx context.Context) {
|
||||
// Recover stale chats on startup and periodically thereafter
|
||||
// to handle chats orphaned by crashed or redeployed workers.
|
||||
p.recoverStaleChats(ctx)
|
||||
if p.debugSvc != nil {
|
||||
_, err := p.debugSvc.FinalizeStale(ctx)
|
||||
if err != nil {
|
||||
p.logger.Warn(ctx, "failed to finalize stale chat debug rows", slog.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
// Single heartbeat loop for all chats on this replica.
|
||||
go p.heartbeatLoop(ctx)
|
||||
@@ -2826,6 +3079,11 @@ func (p *Server) start(ctx context.Context) {
|
||||
p.processOnce(ctx)
|
||||
case <-staleTicker.C:
|
||||
p.recoverStaleChats(ctx)
|
||||
if p.debugSvc != nil {
|
||||
if _, err := p.debugSvc.FinalizeStale(ctx); err != nil {
|
||||
p.logger.Warn(ctx, "failed to finalize stale chat debug rows", slog.Error(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -4337,6 +4595,10 @@ type runChatResult struct {
|
||||
PushSummaryModel fantasy.LanguageModel
|
||||
ProviderKeys chatprovider.ProviderAPIKeys
|
||||
PendingDynamicToolCalls []chatloop.PendingToolCall
|
||||
FallbackProvider string
|
||||
FallbackModel string
|
||||
TriggerMessageID int64
|
||||
HistoryTipMessageID int64
|
||||
}
|
||||
|
||||
func (p *Server) runChat(
|
||||
@@ -4347,11 +4609,14 @@ func (p *Server) runChat(
|
||||
) (runChatResult, error) {
|
||||
result := runChatResult{}
|
||||
var (
|
||||
model fantasy.LanguageModel
|
||||
modelConfig database.ChatModelConfig
|
||||
providerKeys chatprovider.ProviderAPIKeys
|
||||
callConfig codersdk.ChatModelCallConfig
|
||||
messages []database.ChatMessage
|
||||
model fantasy.LanguageModel
|
||||
modelConfig database.ChatModelConfig
|
||||
providerKeys chatprovider.ProviderAPIKeys
|
||||
callConfig codersdk.ChatModelCallConfig
|
||||
messages []database.ChatMessage
|
||||
debugEnabled bool
|
||||
debugProvider string
|
||||
debugModel string
|
||||
)
|
||||
|
||||
// Load MCP server configs and user tokens in parallel with
|
||||
@@ -4364,7 +4629,7 @@ func (p *Server) runChat(
|
||||
var g errgroup.Group
|
||||
g.Go(func() error {
|
||||
var err error
|
||||
model, modelConfig, providerKeys, err = p.resolveChatModel(ctx, chat)
|
||||
model, modelConfig, providerKeys, debugEnabled, debugProvider, debugModel, err = p.resolveChatModel(ctx, chat)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -4422,23 +4687,31 @@ func (p *Server) runChat(
|
||||
chainInfo := resolveChainMode(messages)
|
||||
result.PushSummaryModel = model
|
||||
result.ProviderKeys = providerKeys
|
||||
result.FallbackProvider = modelConfig.Provider
|
||||
result.FallbackModel = modelConfig.Model
|
||||
// Fire title generation asynchronously so it doesn't block the
|
||||
// chat response. It uses a detached context so it can finish
|
||||
// even after the chat processing context is canceled.
|
||||
// Snapshot the original chat model so the goroutine doesn't
|
||||
// race with the model = cuModel reassignment below.
|
||||
// Snapshot ctx before the goroutine to avoid a data race with
|
||||
// the ctx = runCtx reassignment later in the main goroutine.
|
||||
titleModel := result.PushSummaryModel
|
||||
titleCtx := context.WithoutCancel(ctx)
|
||||
p.inflight.Add(1)
|
||||
go func() {
|
||||
defer p.inflight.Done()
|
||||
p.maybeGenerateChatTitle(
|
||||
context.WithoutCancel(ctx),
|
||||
titleCtx,
|
||||
chat,
|
||||
messages,
|
||||
modelConfig.Provider,
|
||||
modelConfig.Model,
|
||||
titleModel,
|
||||
providerKeys,
|
||||
generatedTitle,
|
||||
logger,
|
||||
p.debugSvc,
|
||||
)
|
||||
}()
|
||||
|
||||
@@ -4677,6 +4950,13 @@ func (p *Server) runChat(
|
||||
var finalAssistantText string
|
||||
var pendingDynamicCalls []chatloop.PendingToolCall
|
||||
|
||||
compactionHistoryTipMessageID := int64(0)
|
||||
if len(messages) > 0 {
|
||||
compactionHistoryTipMessageID = messages[len(messages)-1].ID
|
||||
}
|
||||
|
||||
var compactionOptions *chatloop.CompactionOptions
|
||||
|
||||
persistStep := func(persistCtx context.Context, step chatloop.PersistedStep) error {
|
||||
// If the chat context has been canceled, bail out before
|
||||
// inserting any messages. We distinguish the cause so that
|
||||
@@ -4889,6 +5169,12 @@ func (p *Server) runChat(
|
||||
for _, msg := range insertedMessages {
|
||||
p.publishMessage(chat.ID, msg)
|
||||
}
|
||||
if len(insertedMessages) > 0 {
|
||||
compactionHistoryTipMessageID = insertedMessages[len(insertedMessages)-1].ID
|
||||
if compactionOptions != nil {
|
||||
compactionOptions.HistoryTipMessageID = compactionHistoryTipMessageID
|
||||
}
|
||||
}
|
||||
|
||||
// Do NOT clear the stream buffer here. Cross-replica
|
||||
// relay subscribers may still need to snapshot buffered
|
||||
@@ -4918,9 +5204,10 @@ func (p *Server) runChat(
|
||||
effectiveThreshold = override
|
||||
thresholdSource = "user_override"
|
||||
}
|
||||
compactionOptions := &chatloop.CompactionOptions{
|
||||
ThresholdPercent: effectiveThreshold,
|
||||
ContextLimit: modelConfig.ContextLimit,
|
||||
compactionOptions = &chatloop.CompactionOptions{
|
||||
ThresholdPercent: effectiveThreshold,
|
||||
ContextLimit: modelConfig.ContextLimit,
|
||||
HistoryTipMessageID: compactionHistoryTipMessageID,
|
||||
Persist: func(
|
||||
persistCtx context.Context,
|
||||
result chatloop.CompactionResult,
|
||||
@@ -4956,7 +5243,16 @@ func (p *Server) runChat(
|
||||
|
||||
if isComputerUse {
|
||||
// Override model for computer use subagent.
|
||||
cuModel, cuErr := chatprovider.ModelFromConfig(
|
||||
resolvedProvider, resolvedModel, resolveErr := chatprovider.ResolveModelWithProviderHint(
|
||||
chattool.ComputerUseModelName,
|
||||
chattool.ComputerUseModelProvider,
|
||||
)
|
||||
if resolveErr != nil {
|
||||
return result, xerrors.Errorf("resolve computer use model metadata: %w", resolveErr)
|
||||
}
|
||||
cuModel, cuDebugEnabled, cuErr := p.newDebugAwareModelFromConfig(
|
||||
ctx,
|
||||
chat,
|
||||
chattool.ComputerUseModelProvider,
|
||||
chattool.ComputerUseModelName,
|
||||
providerKeys,
|
||||
@@ -4967,6 +5263,13 @@ func (p *Server) runChat(
|
||||
return result, xerrors.Errorf("resolve computer use model: %w", cuErr)
|
||||
}
|
||||
model = cuModel
|
||||
debugEnabled = cuDebugEnabled
|
||||
debugProvider = resolvedProvider
|
||||
debugModel = resolvedModel
|
||||
}
|
||||
if debugEnabled {
|
||||
compactionOptions.DebugSvc = p.debugSvc
|
||||
compactionOptions.ChatID = chat.ID
|
||||
}
|
||||
|
||||
tools := []fantasy.AgentTool{
|
||||
@@ -5183,7 +5486,132 @@ func (p *Server) runChat(
|
||||
)
|
||||
prompt = filterPromptForChainMode(prompt, chainInfo)
|
||||
}
|
||||
err = chatloop.Run(ctx, chatloop.RunOptions{
|
||||
|
||||
var loopErr error
|
||||
triggerMessageID := int64(0)
|
||||
var triggerLabel string
|
||||
for i := len(messages) - 1; i >= 0; i-- {
|
||||
if messages[i].Role == database.ChatMessageRoleUser {
|
||||
triggerMessageID = messages[i].ID
|
||||
if parts, parseErr := chatprompt.ParseContent(messages[i]); parseErr == nil {
|
||||
triggerLabel = contentBlocksToText(parts)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
historyTipMessageID := int64(0)
|
||||
if len(messages) > 0 {
|
||||
historyTipMessageID = messages[len(messages)-1].ID
|
||||
}
|
||||
result.TriggerMessageID = triggerMessageID
|
||||
result.HistoryTipMessageID = historyTipMessageID
|
||||
if debugEnabled {
|
||||
seedSummary := chatdebug.SeedSummary(
|
||||
chatdebug.TruncateLabel(triggerLabel, chatdebug.MaxLabelLength),
|
||||
)
|
||||
rootChatID := uuid.Nil
|
||||
if chat.RootChatID.Valid {
|
||||
rootChatID = chat.RootChatID.UUID
|
||||
}
|
||||
parentChatID := uuid.Nil
|
||||
if chat.ParentChatID.Valid {
|
||||
parentChatID = chat.ParentChatID.UUID
|
||||
}
|
||||
run, createRunErr := p.debugSvc.CreateRun(ctx, chatdebug.CreateRunParams{
|
||||
ChatID: chat.ID,
|
||||
RootChatID: rootChatID,
|
||||
ParentChatID: parentChatID,
|
||||
ModelConfigID: modelConfig.ID,
|
||||
TriggerMessageID: triggerMessageID,
|
||||
HistoryTipMessageID: historyTipMessageID,
|
||||
Kind: chatdebug.KindChatTurn,
|
||||
Status: chatdebug.StatusInProgress,
|
||||
Provider: debugProvider,
|
||||
Model: debugModel,
|
||||
Summary: seedSummary,
|
||||
})
|
||||
if createRunErr != nil {
|
||||
logger.Warn(ctx, "failed to create chat debug run",
|
||||
slog.F("chat_id", chat.ID),
|
||||
slog.Error(createRunErr),
|
||||
)
|
||||
} else {
|
||||
runCtx := chatdebug.ContextWithRun(ctx, &chatdebug.RunContext{
|
||||
RunID: run.ID,
|
||||
ChatID: chat.ID,
|
||||
RootChatID: rootChatID,
|
||||
ParentChatID: parentChatID,
|
||||
ModelConfigID: modelConfig.ID,
|
||||
TriggerMessageID: triggerMessageID,
|
||||
HistoryTipMessageID: historyTipMessageID,
|
||||
Kind: chatdebug.KindChatTurn,
|
||||
Provider: debugProvider,
|
||||
Model: debugModel,
|
||||
})
|
||||
defer func() {
|
||||
panicValue := recover()
|
||||
var status chatdebug.Status
|
||||
switch {
|
||||
case panicValue != nil:
|
||||
status = chatdebug.StatusError
|
||||
case loopErr == nil:
|
||||
status = chatdebug.StatusCompleted
|
||||
case errors.Is(loopErr, chatloop.ErrInterrupted),
|
||||
errors.Is(loopErr, context.Canceled):
|
||||
status = chatdebug.StatusInterrupted
|
||||
case errors.Is(loopErr, chatloop.ErrDynamicToolCall):
|
||||
// Dynamic tool calls are a successful pause;
|
||||
// the run completed its model round-trip.
|
||||
status = chatdebug.StatusCompleted
|
||||
default:
|
||||
status = chatdebug.StatusError
|
||||
}
|
||||
|
||||
finalSummary := seedSummary
|
||||
aggCtx, aggCancel := context.WithTimeout(context.WithoutCancel(runCtx), 5*time.Second)
|
||||
defer aggCancel()
|
||||
if aggregated, aggErr := p.debugSvc.AggregateRunSummary(
|
||||
aggCtx,
|
||||
run.ID,
|
||||
seedSummary,
|
||||
); aggErr != nil {
|
||||
logger.Warn(ctx, "failed to aggregate debug run summary",
|
||||
slog.F("chat_id", chat.ID),
|
||||
slog.F("run_id", run.ID),
|
||||
slog.Error(aggErr),
|
||||
)
|
||||
} else {
|
||||
finalSummary = aggregated
|
||||
}
|
||||
|
||||
updateRunCtx, updateRunCancel := context.WithTimeout(context.WithoutCancel(runCtx), 5*time.Second)
|
||||
defer updateRunCancel()
|
||||
if _, updateRunErr := p.debugSvc.UpdateRun(
|
||||
updateRunCtx,
|
||||
chatdebug.UpdateRunParams{
|
||||
ID: run.ID,
|
||||
ChatID: chat.ID,
|
||||
Status: status,
|
||||
Summary: finalSummary,
|
||||
FinishedAt: time.Now(),
|
||||
},
|
||||
); updateRunErr != nil {
|
||||
logger.Warn(ctx, "failed to finalize chat debug run",
|
||||
slog.F("chat_id", chat.ID),
|
||||
slog.F("run_id", run.ID),
|
||||
slog.Error(updateRunErr),
|
||||
)
|
||||
}
|
||||
chatdebug.CleanupStepCounter(run.ID)
|
||||
if panicValue != nil {
|
||||
panic(panicValue)
|
||||
}
|
||||
}()
|
||||
ctx = runCtx
|
||||
}
|
||||
}
|
||||
|
||||
loopErr = chatloop.Run(ctx, chatloop.RunOptions{
|
||||
Model: model,
|
||||
Messages: prompt,
|
||||
Tools: tools, MaxSteps: maxChatSteps,
|
||||
@@ -5215,6 +5643,13 @@ func (p *Server) runChat(
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("reload chat messages: %w", err)
|
||||
}
|
||||
compactionHistoryTipMessageID = 0
|
||||
if len(reloadedMsgs) > 0 {
|
||||
compactionHistoryTipMessageID = reloadedMsgs[len(reloadedMsgs)-1].ID
|
||||
}
|
||||
if compactionOptions != nil {
|
||||
compactionOptions.HistoryTipMessageID = compactionHistoryTipMessageID
|
||||
}
|
||||
reloadedPrompt, err := chatprompt.ConvertMessagesWithFiles(reloadCtx, reloadedMsgs, p.chatFileResolver(), logger)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("convert reloaded messages: %w", err)
|
||||
@@ -5271,7 +5706,7 @@ func (p *Server) runChat(
|
||||
p.logger.Warn(ctx, "failed to persist interrupted chat step", slog.Error(err))
|
||||
},
|
||||
})
|
||||
if errors.Is(err, chatloop.ErrDynamicToolCall) {
|
||||
if errors.Is(loopErr, chatloop.ErrDynamicToolCall) {
|
||||
// The stream event is published in processChat's
|
||||
// defer after the DB status transitions to
|
||||
// requires_action, preventing a race where a fast
|
||||
@@ -5280,9 +5715,9 @@ func (p *Server) runChat(
|
||||
result.PendingDynamicToolCalls = pendingDynamicCalls
|
||||
return result, nil
|
||||
}
|
||||
if err != nil {
|
||||
classified := chaterror.Classify(err).WithProvider(model.Provider())
|
||||
return result, chaterror.WithClassification(err, classified)
|
||||
if loopErr != nil {
|
||||
classified := chaterror.Classify(loopErr).WithProvider(model.Provider())
|
||||
return result, chaterror.WithClassification(loopErr, classified)
|
||||
}
|
||||
result.FinalAssistantText = finalAssistantText
|
||||
return result, nil
|
||||
@@ -5446,10 +5881,15 @@ func (p *Server) persistChatContextSummary(
|
||||
func (p *Server) resolveChatModel(
|
||||
ctx context.Context,
|
||||
chat database.Chat,
|
||||
) (fantasy.LanguageModel, database.ChatModelConfig, chatprovider.ProviderAPIKeys, error) {
|
||||
var dbConfig database.ChatModelConfig
|
||||
var keys chatprovider.ProviderAPIKeys
|
||||
|
||||
) (
|
||||
model fantasy.LanguageModel,
|
||||
dbConfig database.ChatModelConfig,
|
||||
keys chatprovider.ProviderAPIKeys,
|
||||
debugEnabled bool,
|
||||
resolvedProvider string,
|
||||
resolvedModel string,
|
||||
err error,
|
||||
) {
|
||||
var g errgroup.Group
|
||||
g.Go(func() error {
|
||||
var err error
|
||||
@@ -5468,19 +5908,34 @@ func (p *Server) resolveChatModel(
|
||||
return nil
|
||||
})
|
||||
if err := g.Wait(); err != nil {
|
||||
return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, err
|
||||
return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, false, "", "", err
|
||||
}
|
||||
|
||||
model, err := chatprovider.ModelFromConfig(
|
||||
dbConfig.Provider, dbConfig.Model, keys, chatprovider.UserAgent(),
|
||||
resolvedProvider, resolvedModel, err = chatprovider.ResolveModelWithProviderHint(
|
||||
dbConfig.Model,
|
||||
dbConfig.Provider,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, false, "", "", xerrors.Errorf(
|
||||
"resolve model metadata: %w", err,
|
||||
)
|
||||
}
|
||||
|
||||
model, debugEnabled, err = p.newDebugAwareModelFromConfig(
|
||||
ctx,
|
||||
chat,
|
||||
dbConfig.Provider,
|
||||
dbConfig.Model,
|
||||
keys,
|
||||
chatprovider.UserAgent(),
|
||||
chatprovider.CoderHeaders(chat),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, xerrors.Errorf(
|
||||
return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, false, "", "", xerrors.Errorf(
|
||||
"create model: %w", err,
|
||||
)
|
||||
}
|
||||
return model, dbConfig, keys, nil
|
||||
return model, dbConfig, keys, debugEnabled, resolvedProvider, resolvedModel, nil
|
||||
}
|
||||
|
||||
func (p *Server) resolveUserProviderAPIKeys(
|
||||
@@ -6162,9 +6617,14 @@ func (p *Server) maybeSendPushNotification(
|
||||
pushCtx,
|
||||
chat,
|
||||
assistantText,
|
||||
runResult.FallbackProvider,
|
||||
runResult.FallbackModel,
|
||||
runResult.PushSummaryModel,
|
||||
runResult.ProviderKeys,
|
||||
logger,
|
||||
p.debugSvc,
|
||||
runResult.TriggerMessageID,
|
||||
runResult.HistoryTipMessageID,
|
||||
); summary != "" {
|
||||
pushBody = summary
|
||||
}
|
||||
|
||||
@@ -0,0 +1,64 @@
|
||||
package chatd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatdebug"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatprovider"
|
||||
)
|
||||
|
||||
func (p *Server) newDebugAwareModelFromConfig(
|
||||
ctx context.Context,
|
||||
chat database.Chat,
|
||||
providerHint string,
|
||||
modelName string,
|
||||
providerKeys chatprovider.ProviderAPIKeys,
|
||||
userAgent string,
|
||||
extraHeaders map[string]string,
|
||||
) (fantasy.LanguageModel, bool, error) {
|
||||
provider, resolvedModel, err := chatprovider.ResolveModelWithProviderHint(modelName, providerHint)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
debugEnabled := p.debugSvc != nil && p.debugSvc.IsEnabled(ctx, chat.ID, chat.OwnerID)
|
||||
|
||||
var httpClient *http.Client
|
||||
if debugEnabled {
|
||||
httpClient = &http.Client{Transport: &chatdebug.RecordingTransport{}}
|
||||
}
|
||||
|
||||
model, err := chatprovider.ModelFromConfig(
|
||||
provider,
|
||||
resolvedModel,
|
||||
providerKeys,
|
||||
userAgent,
|
||||
extraHeaders,
|
||||
httpClient,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, debugEnabled, err
|
||||
}
|
||||
if model == nil {
|
||||
return nil, debugEnabled, xerrors.Errorf(
|
||||
"create model for %s/%s returned nil",
|
||||
provider,
|
||||
resolvedModel,
|
||||
)
|
||||
}
|
||||
if !debugEnabled {
|
||||
return model, false, nil
|
||||
}
|
||||
|
||||
return chatdebug.WrapModel(model, p.debugSvc, chatdebug.RecorderOptions{
|
||||
ChatID: chat.ID,
|
||||
OwnerID: chat.OwnerID,
|
||||
Provider: provider,
|
||||
Model: resolvedModel,
|
||||
}), true, nil
|
||||
}
|
||||
@@ -33,6 +33,15 @@ import (
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
// TestWaitForActiveChatStop and TestWaitForActiveChatStop_WaitsForReplacementRun
|
||||
// were removed along with the process-local activeChats mechanism.
|
||||
// Debug cleanup is now best-effort; stale finalization handles orphaned rows.
|
||||
|
||||
// TestArchiveChatWaitsForActiveChatStop and
|
||||
// TestArchiveChatWaitsForEveryInterruptedChat were removed along with
|
||||
// the process-local activeChats mechanism. Archive cleanup is now
|
||||
// best-effort; stale finalization handles any orphaned rows.
|
||||
|
||||
func TestRegenerateChatTitle_PersistsAndBroadcasts(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -346,7 +346,7 @@ func (s *Service) CreateStep(
|
||||
}
|
||||
|
||||
return database.ChatDebugStep{}, xerrors.Errorf(
|
||||
"failed to create debug step after %d attempts (run_id=%s)",
|
||||
"chatdebug: failed to create step after %d retries (run %s)",
|
||||
maxCreateStepRetries, params.RunID,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database/dbtime"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatdebug"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chaterror"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatprompt"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatretry"
|
||||
@@ -368,7 +369,8 @@ func Run(ctx context.Context, opts RunOptions) error {
|
||||
}
|
||||
|
||||
var result stepResult
|
||||
err := chatretry.Retry(ctx, func(retryCtx context.Context) error {
|
||||
stepCtx := chatdebug.ReuseStep(ctx)
|
||||
err := chatretry.Retry(stepCtx, func(retryCtx context.Context) error {
|
||||
attempt, streamErr := guardedStream(
|
||||
retryCtx,
|
||||
opts.Model.Provider(),
|
||||
|
||||
@@ -7,8 +7,10 @@ import (
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatdebug"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
@@ -46,6 +48,9 @@ type CompactionOptions struct {
|
||||
SystemSummaryPrefix string
|
||||
Timeout time.Duration
|
||||
Persist func(context.Context, CompactionResult) error
|
||||
DebugSvc *chatdebug.Service
|
||||
ChatID uuid.UUID
|
||||
HistoryTipMessageID int64
|
||||
|
||||
// ToolCallID and ToolName identify the synthetic tool call
|
||||
// used to represent compaction in the message stream.
|
||||
@@ -269,6 +274,92 @@ func shouldCompact(contextTokens, contextLimit int64, thresholdPercent int32) (f
|
||||
return usagePercent, usagePercent >= float64(thresholdPercent)
|
||||
}
|
||||
|
||||
func startCompactionDebugRun(
|
||||
ctx context.Context,
|
||||
options CompactionOptions,
|
||||
) (context.Context, func(error)) {
|
||||
if options.DebugSvc == nil || options.ChatID == uuid.Nil {
|
||||
return ctx, func(error) {}
|
||||
}
|
||||
|
||||
parentRun, ok := chatdebug.RunFromContext(ctx)
|
||||
if !ok {
|
||||
return ctx, func(error) {}
|
||||
}
|
||||
|
||||
historyTipMessageID := options.HistoryTipMessageID
|
||||
if historyTipMessageID == 0 {
|
||||
historyTipMessageID = parentRun.HistoryTipMessageID
|
||||
}
|
||||
|
||||
run, err := options.DebugSvc.CreateRun(ctx, chatdebug.CreateRunParams{
|
||||
ChatID: options.ChatID,
|
||||
RootChatID: parentRun.RootChatID,
|
||||
ParentChatID: parentRun.ParentChatID,
|
||||
ModelConfigID: parentRun.ModelConfigID,
|
||||
TriggerMessageID: parentRun.TriggerMessageID,
|
||||
HistoryTipMessageID: historyTipMessageID,
|
||||
Kind: chatdebug.KindCompaction,
|
||||
Status: chatdebug.StatusInProgress,
|
||||
Provider: parentRun.Provider,
|
||||
Model: parentRun.Model,
|
||||
})
|
||||
if err != nil {
|
||||
// Debug instrumentation must not surface as a compaction failure.
|
||||
return ctx, func(error) {}
|
||||
}
|
||||
|
||||
compactionCtx := chatdebug.ContextWithRun(ctx, &chatdebug.RunContext{
|
||||
RunID: run.ID,
|
||||
ChatID: options.ChatID,
|
||||
RootChatID: parentRun.RootChatID,
|
||||
ParentChatID: parentRun.ParentChatID,
|
||||
ModelConfigID: parentRun.ModelConfigID,
|
||||
TriggerMessageID: parentRun.TriggerMessageID,
|
||||
HistoryTipMessageID: historyTipMessageID,
|
||||
Kind: chatdebug.KindCompaction,
|
||||
Provider: parentRun.Provider,
|
||||
Model: parentRun.Model,
|
||||
})
|
||||
|
||||
return compactionCtx, func(runErr error) {
|
||||
status := chatdebug.StatusCompleted
|
||||
if runErr != nil {
|
||||
status = chatdebug.StatusError
|
||||
if xerrors.Is(runErr, ErrInterrupted) || xerrors.Is(runErr, context.Canceled) {
|
||||
status = chatdebug.StatusInterrupted
|
||||
}
|
||||
}
|
||||
finalizeCtx, finalizeCancel := context.WithTimeout(
|
||||
context.WithoutCancel(compactionCtx),
|
||||
5*time.Second,
|
||||
)
|
||||
defer finalizeCancel()
|
||||
|
||||
finalSummary := map[string]any(nil)
|
||||
if aggregated, aggErr := options.DebugSvc.AggregateRunSummary(
|
||||
finalizeCtx,
|
||||
run.ID,
|
||||
nil,
|
||||
); aggErr == nil {
|
||||
finalSummary = aggregated
|
||||
}
|
||||
|
||||
// Debug instrumentation must not surface as a compaction failure.
|
||||
_, _ = options.DebugSvc.UpdateRun(
|
||||
finalizeCtx,
|
||||
chatdebug.UpdateRunParams{
|
||||
ID: run.ID,
|
||||
ChatID: options.ChatID,
|
||||
Status: status,
|
||||
Summary: finalSummary,
|
||||
FinishedAt: time.Now(),
|
||||
},
|
||||
)
|
||||
chatdebug.CleanupStepCounter(run.ID)
|
||||
}
|
||||
}
|
||||
|
||||
// generateCompactionSummary asks the model to summarize the
|
||||
// conversation so far. The provided messages should contain the
|
||||
// complete history (system prompt, user/assistant turns, tool
|
||||
@@ -279,7 +370,7 @@ func generateCompactionSummary(
|
||||
model fantasy.LanguageModel,
|
||||
messages []fantasy.Message,
|
||||
options CompactionOptions,
|
||||
) (string, error) {
|
||||
) (summary string, err error) {
|
||||
summaryPrompt := make([]fantasy.Message, 0, len(messages)+1)
|
||||
summaryPrompt = append(summaryPrompt, messages...)
|
||||
summaryPrompt = append(summaryPrompt, fantasy.Message{
|
||||
@@ -293,6 +384,11 @@ func generateCompactionSummary(
|
||||
summaryCtx, cancel := context.WithTimeout(ctx, options.Timeout)
|
||||
defer cancel()
|
||||
|
||||
summaryCtx, finishDebugRun := startCompactionDebugRun(summaryCtx, options)
|
||||
defer func() {
|
||||
finishDebugRun(err)
|
||||
}()
|
||||
|
||||
response, err := model.Generate(summaryCtx, fantasy.Call{
|
||||
Prompt: summaryPrompt,
|
||||
ToolChoice: &toolChoice,
|
||||
|
||||
@@ -2,17 +2,168 @@ package chatloop //nolint:testpackage // Uses internal symbols.
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/google/uuid"
|
||||
"github.com/sqlc-dev/pqtype"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbmock"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatdebug"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chattest"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestStartCompactionDebugRun_DoesNotReportDebugErrors(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
newParentContext := func(chatID uuid.UUID) context.Context {
|
||||
return chatdebug.ContextWithRun(context.Background(), &chatdebug.RunContext{
|
||||
RunID: uuid.New(),
|
||||
ChatID: chatID,
|
||||
RootChatID: uuid.New(),
|
||||
ParentChatID: uuid.New(),
|
||||
ModelConfigID: uuid.New(),
|
||||
TriggerMessageID: 41,
|
||||
HistoryTipMessageID: 42,
|
||||
Kind: chatdebug.KindChatTurn,
|
||||
Provider: "fake-provider",
|
||||
Model: "fake-model",
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("CreateRun", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
svc := chatdebug.NewService(db, testutil.Logger(t), nil)
|
||||
chatID := uuid.New()
|
||||
reportedErr := make(chan error, 1)
|
||||
|
||||
db.EXPECT().InsertChatDebugRun(
|
||||
gomock.Any(),
|
||||
gomock.AssignableToTypeOf(database.InsertChatDebugRunParams{}),
|
||||
).Return(database.ChatDebugRun{}, xerrors.New("insert compaction debug run"))
|
||||
|
||||
ctx := newParentContext(chatID)
|
||||
compactionCtx, finish := startCompactionDebugRun(ctx, CompactionOptions{
|
||||
DebugSvc: svc,
|
||||
ChatID: chatID,
|
||||
OnError: func(err error) {
|
||||
reportedErr <- err
|
||||
},
|
||||
})
|
||||
require.Same(t, ctx, compactionCtx)
|
||||
finish(nil)
|
||||
select {
|
||||
case err := <-reportedErr:
|
||||
t.Fatalf("unexpected OnError callback: %v", err)
|
||||
default:
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("FinalizeRunAggregatesSummary", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
svc := chatdebug.NewService(db, testutil.Logger(t), nil)
|
||||
chatID := uuid.New()
|
||||
runID := uuid.New()
|
||||
usageJSON, err := json.Marshal(fantasy.Usage{InputTokens: 7, OutputTokens: 3})
|
||||
require.NoError(t, err)
|
||||
attemptsJSON, err := json.Marshal([]chatdebug.Attempt{{
|
||||
Status: "completed",
|
||||
Method: "POST",
|
||||
Path: "/v1/messages",
|
||||
}})
|
||||
require.NoError(t, err)
|
||||
|
||||
db.EXPECT().InsertChatDebugRun(
|
||||
gomock.Any(),
|
||||
gomock.AssignableToTypeOf(database.InsertChatDebugRunParams{}),
|
||||
).Return(database.ChatDebugRun{ //nolint:exhaustruct // Test only needs IDs.
|
||||
ID: runID,
|
||||
ChatID: chatID,
|
||||
}, nil)
|
||||
db.EXPECT().GetChatDebugStepsByRunID(gomock.Any(), runID).Return([]database.ChatDebugStep{{
|
||||
ID: uuid.New(),
|
||||
RunID: runID,
|
||||
ChatID: chatID,
|
||||
Status: string(chatdebug.StatusCompleted),
|
||||
Usage: pqtype.NullRawMessage{RawMessage: usageJSON, Valid: true},
|
||||
Attempts: attemptsJSON,
|
||||
}}, nil)
|
||||
db.EXPECT().UpdateChatDebugRun(
|
||||
gomock.Any(),
|
||||
gomock.AssignableToTypeOf(database.UpdateChatDebugRunParams{}),
|
||||
).DoAndReturn(func(_ context.Context, params database.UpdateChatDebugRunParams) (database.ChatDebugRun, error) {
|
||||
require.Equal(t, chatID, params.ChatID)
|
||||
require.Equal(t, runID, params.ID)
|
||||
require.True(t, params.Summary.Valid)
|
||||
require.JSONEq(t, `{"endpoint_label":"POST /v1/messages","step_count":1,"total_input_tokens":7,"total_output_tokens":3}`,
|
||||
string(params.Summary.RawMessage))
|
||||
return database.ChatDebugRun{ID: runID, ChatID: chatID}, nil
|
||||
})
|
||||
|
||||
ctx := newParentContext(chatID)
|
||||
compactionCtx, finish := startCompactionDebugRun(ctx, CompactionOptions{
|
||||
DebugSvc: svc,
|
||||
ChatID: chatID,
|
||||
})
|
||||
require.NotSame(t, ctx, compactionCtx)
|
||||
finish(nil)
|
||||
})
|
||||
|
||||
t.Run("FinalizeRun", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
svc := chatdebug.NewService(db, testutil.Logger(t), nil)
|
||||
chatID := uuid.New()
|
||||
reportedErr := make(chan error, 1)
|
||||
runID := uuid.New()
|
||||
|
||||
db.EXPECT().InsertChatDebugRun(
|
||||
gomock.Any(),
|
||||
gomock.AssignableToTypeOf(database.InsertChatDebugRunParams{}),
|
||||
).Return(database.ChatDebugRun{ //nolint:exhaustruct // Test only needs IDs.
|
||||
ID: runID,
|
||||
ChatID: chatID,
|
||||
}, nil)
|
||||
db.EXPECT().GetChatDebugStepsByRunID(gomock.Any(), runID).Return(nil, xerrors.New("aggregate compaction debug run"))
|
||||
db.EXPECT().UpdateChatDebugRun(
|
||||
gomock.Any(),
|
||||
gomock.AssignableToTypeOf(database.UpdateChatDebugRunParams{}),
|
||||
).Return(database.ChatDebugRun{}, xerrors.New("finalize compaction debug run"))
|
||||
|
||||
ctx := newParentContext(chatID)
|
||||
compactionCtx, finish := startCompactionDebugRun(ctx, CompactionOptions{
|
||||
DebugSvc: svc,
|
||||
ChatID: chatID,
|
||||
OnError: func(err error) {
|
||||
reportedErr <- err
|
||||
},
|
||||
})
|
||||
require.NotSame(t, ctx, compactionCtx)
|
||||
finish(nil)
|
||||
select {
|
||||
case err := <-reportedErr:
|
||||
t.Fatalf("unexpected OnError callback: %v", err)
|
||||
default:
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRun_Compaction(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ package chatprovider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
@@ -1114,13 +1115,15 @@ func CoderHeadersFromIDs(
|
||||
// language model client using the provided provider credentials. The
|
||||
// userAgent is sent as the User-Agent header on every outgoing LLM
|
||||
// API request. extraHeaders, when non-nil, are sent as additional
|
||||
// HTTP headers on every request.
|
||||
// HTTP headers on every request. httpClient, when non-nil, is used for
|
||||
// all provider HTTP requests.
|
||||
func ModelFromConfig(
|
||||
providerHint string,
|
||||
modelName string,
|
||||
providerKeys ProviderAPIKeys,
|
||||
userAgent string,
|
||||
extraHeaders map[string]string,
|
||||
httpClient *http.Client,
|
||||
) (fantasy.LanguageModel, error) {
|
||||
provider, modelID, err := ResolveModelWithProviderHint(modelName, providerHint)
|
||||
if err != nil {
|
||||
@@ -1146,6 +1149,9 @@ func ModelFromConfig(
|
||||
if baseURL != "" {
|
||||
options = append(options, fantasyanthropic.WithBaseURL(baseURL))
|
||||
}
|
||||
if httpClient != nil {
|
||||
options = append(options, fantasyanthropic.WithHTTPClient(httpClient))
|
||||
}
|
||||
providerClient, err = fantasyanthropic.New(options...)
|
||||
case fantasyazure.Name:
|
||||
if baseURL == "" {
|
||||
@@ -1160,6 +1166,9 @@ func ModelFromConfig(
|
||||
if len(extraHeaders) > 0 {
|
||||
azureOpts = append(azureOpts, fantasyazure.WithHeaders(extraHeaders))
|
||||
}
|
||||
if httpClient != nil {
|
||||
azureOpts = append(azureOpts, fantasyazure.WithHTTPClient(httpClient))
|
||||
}
|
||||
providerClient, err = fantasyazure.New(azureOpts...)
|
||||
case fantasybedrock.Name:
|
||||
bedrockOpts := []fantasybedrock.Option{
|
||||
@@ -1169,6 +1178,9 @@ func ModelFromConfig(
|
||||
if len(extraHeaders) > 0 {
|
||||
bedrockOpts = append(bedrockOpts, fantasybedrock.WithHeaders(extraHeaders))
|
||||
}
|
||||
if httpClient != nil {
|
||||
bedrockOpts = append(bedrockOpts, fantasybedrock.WithHTTPClient(httpClient))
|
||||
}
|
||||
providerClient, err = fantasybedrock.New(bedrockOpts...)
|
||||
case fantasygoogle.Name:
|
||||
options := []fantasygoogle.Option{
|
||||
@@ -1181,6 +1193,9 @@ func ModelFromConfig(
|
||||
if baseURL != "" {
|
||||
options = append(options, fantasygoogle.WithBaseURL(baseURL))
|
||||
}
|
||||
if httpClient != nil {
|
||||
options = append(options, fantasygoogle.WithHTTPClient(httpClient))
|
||||
}
|
||||
providerClient, err = fantasygoogle.New(options...)
|
||||
case fantasyopenai.Name:
|
||||
options := []fantasyopenai.Option{
|
||||
@@ -1194,6 +1209,9 @@ func ModelFromConfig(
|
||||
if baseURL != "" {
|
||||
options = append(options, fantasyopenai.WithBaseURL(baseURL))
|
||||
}
|
||||
if httpClient != nil {
|
||||
options = append(options, fantasyopenai.WithHTTPClient(httpClient))
|
||||
}
|
||||
providerClient, err = fantasyopenai.New(options...)
|
||||
case fantasyopenaicompat.Name:
|
||||
options := []fantasyopenaicompat.Option{
|
||||
@@ -1206,6 +1224,9 @@ func ModelFromConfig(
|
||||
if baseURL != "" {
|
||||
options = append(options, fantasyopenaicompat.WithBaseURL(baseURL))
|
||||
}
|
||||
if httpClient != nil {
|
||||
options = append(options, fantasyopenaicompat.WithHTTPClient(httpClient))
|
||||
}
|
||||
providerClient, err = fantasyopenaicompat.New(options...)
|
||||
case fantasyopenrouter.Name:
|
||||
routerOpts := []fantasyopenrouter.Option{
|
||||
@@ -1215,6 +1236,9 @@ func ModelFromConfig(
|
||||
if len(extraHeaders) > 0 {
|
||||
routerOpts = append(routerOpts, fantasyopenrouter.WithHeaders(extraHeaders))
|
||||
}
|
||||
if httpClient != nil {
|
||||
routerOpts = append(routerOpts, fantasyopenrouter.WithHTTPClient(httpClient))
|
||||
}
|
||||
providerClient, err = fantasyopenrouter.New(routerOpts...)
|
||||
case fantasyvercel.Name:
|
||||
options := []fantasyvercel.Option{
|
||||
@@ -1227,6 +1251,9 @@ func ModelFromConfig(
|
||||
if baseURL != "" {
|
||||
options = append(options, fantasyvercel.WithBaseURL(baseURL))
|
||||
}
|
||||
if httpClient != nil {
|
||||
options = append(options, fantasyvercel.WithHTTPClient(httpClient))
|
||||
}
|
||||
providerClient, err = fantasyvercel.New(options...)
|
||||
default:
|
||||
return nil, xerrors.Errorf("unsupported model provider %q", provider)
|
||||
|
||||
@@ -181,6 +181,12 @@ func TestResolveUserProviderKeys(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
type roundTripperFunc func(*http.Request) (*http.Response, error)
|
||||
|
||||
func (fn roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return fn(req)
|
||||
}
|
||||
|
||||
func TestReasoningEffortFromChat(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -777,7 +783,7 @@ func TestModelFromConfig_ExtraHeaders(t *testing.T) {
|
||||
BaseURLByProvider: map[string]string{"openai": serverURL},
|
||||
}
|
||||
|
||||
model, err := chatprovider.ModelFromConfig("openai", "gpt-4", keys, chatprovider.UserAgent(), headers)
|
||||
model, err := chatprovider.ModelFromConfig("openai", "gpt-4", keys, chatprovider.UserAgent(), headers, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = model.Generate(ctx, fantasy.Call{
|
||||
@@ -808,7 +814,7 @@ func TestModelFromConfig_ExtraHeaders(t *testing.T) {
|
||||
BaseURLByProvider: map[string]string{"anthropic": serverURL},
|
||||
}
|
||||
|
||||
model, err := chatprovider.ModelFromConfig("anthropic", "claude-sonnet-4-20250514", keys, chatprovider.UserAgent(), headers)
|
||||
model, err := chatprovider.ModelFromConfig("anthropic", "claude-sonnet-4-20250514", keys, chatprovider.UserAgent(), headers, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = model.Generate(ctx, fantasy.Call{
|
||||
@@ -844,7 +850,7 @@ func TestModelFromConfig_NilExtraHeaders(t *testing.T) {
|
||||
BaseURLByProvider: map[string]string{"openai": serverURL},
|
||||
}
|
||||
|
||||
model, err := chatprovider.ModelFromConfig("openai", "gpt-4", keys, chatprovider.UserAgent(), nil)
|
||||
model, err := chatprovider.ModelFromConfig("openai", "gpt-4", keys, chatprovider.UserAgent(), nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = model.Generate(ctx, fantasy.Call{
|
||||
@@ -859,6 +865,48 @@ func TestModelFromConfig_NilExtraHeaders(t *testing.T) {
|
||||
_ = testutil.TryReceive(ctx, t, called)
|
||||
}
|
||||
|
||||
func TestModelFromConfig_HTTPClient(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
called := make(chan struct{})
|
||||
serverURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
||||
assert.Equal(t, "true", req.Header.Get("X-Test-Transport"))
|
||||
close(called)
|
||||
return chattest.OpenAINonStreamingResponse("hello")
|
||||
})
|
||||
|
||||
keys := chatprovider.ProviderAPIKeys{
|
||||
ByProvider: map[string]string{"openai": "test-key"},
|
||||
BaseURLByProvider: map[string]string{"openai": serverURL},
|
||||
}
|
||||
client := &http.Client{Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||
cloned := req.Clone(req.Context())
|
||||
cloned.Header = req.Header.Clone()
|
||||
cloned.Header.Set("X-Test-Transport", "true")
|
||||
return http.DefaultTransport.RoundTrip(cloned)
|
||||
})}
|
||||
|
||||
model, err := chatprovider.ModelFromConfig(
|
||||
"openai",
|
||||
"gpt-4",
|
||||
keys,
|
||||
chatprovider.UserAgent(),
|
||||
nil,
|
||||
client,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = model.Generate(ctx, fantasy.Call{
|
||||
Prompt: []fantasy.Message{{
|
||||
Role: fantasy.MessageRoleUser,
|
||||
Content: []fantasy.MessagePart{fantasy.TextPart{Text: "hello"}},
|
||||
}},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_ = testutil.TryReceive(ctx, t, called)
|
||||
}
|
||||
|
||||
func TestMergeMissingProviderOptions_OpenRouterNested(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -48,7 +48,7 @@ func TestModelFromConfig_UserAgent(t *testing.T) {
|
||||
BaseURLByProvider: map[string]string{"openai": serverURL},
|
||||
}
|
||||
|
||||
model, err := chatprovider.ModelFromConfig("openai", "gpt-4", keys, expectedUA, nil)
|
||||
model, err := chatprovider.ModelFromConfig("openai", "gpt-4", keys, expectedUA, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Make a real call so Fantasy sends an HTTP request to the
|
||||
|
||||
+317
-11
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -21,6 +22,7 @@ import (
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatdebug"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatprompt"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatprovider"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatretry"
|
||||
@@ -105,35 +107,173 @@ func (p *Server) maybeGenerateChatTitle(
|
||||
ctx context.Context,
|
||||
chat database.Chat,
|
||||
messages []database.ChatMessage,
|
||||
fallbackProvider string,
|
||||
fallbackModelName string,
|
||||
fallbackModel fantasy.LanguageModel,
|
||||
keys chatprovider.ProviderAPIKeys,
|
||||
generatedTitle *generatedChatTitle,
|
||||
logger slog.Logger,
|
||||
debugSvc *chatdebug.Service,
|
||||
) {
|
||||
input, ok := titleInput(chat, messages)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
debugEnabled := debugSvc != nil && debugSvc.IsEnabled(ctx, chat.ID, chat.OwnerID)
|
||||
|
||||
titleCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
type candidateDescriptor struct {
|
||||
provider string
|
||||
model string
|
||||
lm fantasy.LanguageModel
|
||||
}
|
||||
|
||||
// Build candidate list: preferred lightweight models first,
|
||||
// then the user's chat model as last resort.
|
||||
candidates := make([]fantasy.LanguageModel, 0, len(preferredTitleModels)+1)
|
||||
candidates := make([]candidateDescriptor, 0, len(preferredTitleModels)+1)
|
||||
for _, c := range preferredTitleModels {
|
||||
m, err := chatprovider.ModelFromConfig(
|
||||
c.provider, c.model, keys, chatprovider.UserAgent(),
|
||||
chatprovider.CoderHeaders(chat),
|
||||
nil,
|
||||
)
|
||||
if err == nil {
|
||||
candidates = append(candidates, m)
|
||||
candidates = append(candidates, candidateDescriptor{
|
||||
provider: c.provider,
|
||||
model: c.model,
|
||||
lm: m,
|
||||
})
|
||||
}
|
||||
}
|
||||
candidates = append(candidates, fallbackModel)
|
||||
candidates = append(candidates, candidateDescriptor{
|
||||
provider: fallbackProvider,
|
||||
model: fallbackModelName,
|
||||
lm: fallbackModel,
|
||||
})
|
||||
|
||||
var historyTipMessageID int64
|
||||
if len(messages) > 0 {
|
||||
historyTipMessageID = messages[len(messages)-1].ID
|
||||
}
|
||||
|
||||
var triggerMessageID int64
|
||||
for _, message := range messages {
|
||||
if message.Role == database.ChatMessageRoleUser {
|
||||
triggerMessageID = message.ID
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
seedSummary := chatdebug.SeedSummary(
|
||||
chatdebug.TruncateLabel(input, chatdebug.MaxLabelLength),
|
||||
)
|
||||
|
||||
var lastErr error
|
||||
for _, model := range candidates {
|
||||
title, err := generateTitle(titleCtx, model, input)
|
||||
for _, candidate := range candidates {
|
||||
candidateModel := candidate.lm
|
||||
candidateCtx := titleCtx
|
||||
var debugRun *database.ChatDebugRun
|
||||
if debugEnabled {
|
||||
run, err := debugSvc.CreateRun(titleCtx, chatdebug.CreateRunParams{
|
||||
ChatID: chat.ID,
|
||||
TriggerMessageID: triggerMessageID,
|
||||
HistoryTipMessageID: historyTipMessageID,
|
||||
Kind: chatdebug.KindTitleGeneration,
|
||||
Status: chatdebug.StatusInProgress,
|
||||
Provider: candidate.provider,
|
||||
Model: candidate.model,
|
||||
Summary: seedSummary,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Warn(ctx, "failed to create title debug run",
|
||||
slog.F("chat_id", chat.ID),
|
||||
slog.F("provider", candidate.provider),
|
||||
slog.F("model", candidate.model),
|
||||
slog.Error(err),
|
||||
)
|
||||
} else {
|
||||
debugRun = &run
|
||||
candidateCtx = chatdebug.ContextWithRun(
|
||||
candidateCtx,
|
||||
&chatdebug.RunContext{
|
||||
RunID: run.ID,
|
||||
ChatID: chat.ID,
|
||||
TriggerMessageID: triggerMessageID,
|
||||
HistoryTipMessageID: historyTipMessageID,
|
||||
Kind: chatdebug.KindTitleGeneration,
|
||||
Provider: candidate.provider,
|
||||
Model: candidate.model,
|
||||
},
|
||||
)
|
||||
debugModel, err := newQuickgenDebugModel(
|
||||
chat,
|
||||
keys,
|
||||
debugSvc,
|
||||
candidate.provider,
|
||||
candidate.model,
|
||||
)
|
||||
if err != nil {
|
||||
logger.Warn(ctx, "failed to build title debug model",
|
||||
slog.F("chat_id", chat.ID),
|
||||
slog.F("provider", candidate.provider),
|
||||
slog.F("model", candidate.model),
|
||||
slog.Error(err),
|
||||
)
|
||||
} else {
|
||||
candidateModel = debugModel
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
title, err := generateTitle(candidateCtx, candidateModel, input)
|
||||
if debugRun != nil {
|
||||
status := chatdebug.StatusCompleted
|
||||
switch {
|
||||
case err == nil:
|
||||
// keep completed
|
||||
case errors.Is(err, context.Canceled):
|
||||
status = chatdebug.StatusInterrupted
|
||||
default:
|
||||
status = chatdebug.StatusError
|
||||
}
|
||||
finalizeCtx, finalizeCancel := context.WithTimeout(
|
||||
context.WithoutCancel(ctx), 10*time.Second,
|
||||
)
|
||||
finalSummary := seedSummary
|
||||
if aggregated, aggErr := debugSvc.AggregateRunSummary(
|
||||
finalizeCtx,
|
||||
debugRun.ID,
|
||||
seedSummary,
|
||||
); aggErr != nil {
|
||||
logger.Warn(ctx, "failed to aggregate debug run summary",
|
||||
slog.F("chat_id", chat.ID),
|
||||
slog.F("run_id", debugRun.ID),
|
||||
slog.Error(aggErr),
|
||||
)
|
||||
} else {
|
||||
finalSummary = aggregated
|
||||
}
|
||||
if _, updateErr := debugSvc.UpdateRun(
|
||||
finalizeCtx,
|
||||
chatdebug.UpdateRunParams{
|
||||
ID: debugRun.ID,
|
||||
ChatID: chat.ID,
|
||||
Status: status,
|
||||
Summary: finalSummary,
|
||||
FinishedAt: time.Now(),
|
||||
},
|
||||
); updateErr != nil {
|
||||
logger.Warn(ctx, "failed to finalize title debug run",
|
||||
slog.F("chat_id", chat.ID),
|
||||
slog.F("run_id", debugRun.ID),
|
||||
slog.Error(updateErr),
|
||||
)
|
||||
}
|
||||
chatdebug.CleanupStepCounter(debugRun.ID)
|
||||
finalizeCancel()
|
||||
}
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
logger.Debug(ctx, "title model candidate failed",
|
||||
@@ -171,6 +311,41 @@ func (p *Server) maybeGenerateChatTitle(
|
||||
}
|
||||
}
|
||||
|
||||
func newQuickgenDebugModel(
|
||||
chat database.Chat,
|
||||
keys chatprovider.ProviderAPIKeys,
|
||||
debugSvc *chatdebug.Service,
|
||||
provider string,
|
||||
model string,
|
||||
) (fantasy.LanguageModel, error) {
|
||||
httpClient := &http.Client{Transport: &chatdebug.RecordingTransport{}}
|
||||
debugModel, err := chatprovider.ModelFromConfig(
|
||||
provider,
|
||||
model,
|
||||
keys,
|
||||
chatprovider.UserAgent(),
|
||||
chatprovider.CoderHeaders(chat),
|
||||
httpClient,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if debugModel == nil {
|
||||
return nil, xerrors.Errorf(
|
||||
"create model for %s/%s returned nil",
|
||||
provider,
|
||||
model,
|
||||
)
|
||||
}
|
||||
|
||||
return chatdebug.WrapModel(debugModel, debugSvc, chatdebug.RecorderOptions{
|
||||
ChatID: chat.ID,
|
||||
OwnerID: chat.OwnerID,
|
||||
Provider: provider,
|
||||
Model: model,
|
||||
}), nil
|
||||
}
|
||||
|
||||
// generateTitle calls the model with a title-generation system prompt
|
||||
// and returns the normalized result. It retries transient LLM errors
|
||||
// (rate limits, overloaded, etc.) with exponential backoff.
|
||||
@@ -571,30 +746,160 @@ func generatePushSummary(
|
||||
ctx context.Context,
|
||||
chat database.Chat,
|
||||
assistantText string,
|
||||
fallbackProvider string,
|
||||
fallbackModelName string,
|
||||
fallbackModel fantasy.LanguageModel,
|
||||
keys chatprovider.ProviderAPIKeys,
|
||||
logger slog.Logger,
|
||||
debugSvc *chatdebug.Service,
|
||||
triggerMessageID int64,
|
||||
historyTipMessageID int64,
|
||||
) string {
|
||||
debugEnabled := debugSvc != nil && debugSvc.IsEnabled(ctx, chat.ID, chat.OwnerID)
|
||||
|
||||
summaryCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
assistantText = truncateRunes(assistantText, maxConversationContextRunes)
|
||||
input := "Chat title: " + chat.Title + "\n\nAgent's last message:\n" + assistantText
|
||||
|
||||
candidates := make([]fantasy.LanguageModel, 0, len(preferredTitleModels)+1)
|
||||
type candidateDescriptor struct {
|
||||
provider string
|
||||
model string
|
||||
lm fantasy.LanguageModel
|
||||
}
|
||||
|
||||
candidates := make([]candidateDescriptor, 0, len(preferredTitleModels)+1)
|
||||
for _, c := range preferredTitleModels {
|
||||
m, err := chatprovider.ModelFromConfig(
|
||||
c.provider, c.model, keys, chatprovider.UserAgent(),
|
||||
chatprovider.CoderHeaders(chat),
|
||||
nil,
|
||||
)
|
||||
if err == nil {
|
||||
candidates = append(candidates, m)
|
||||
candidates = append(candidates, candidateDescriptor{
|
||||
provider: c.provider,
|
||||
model: c.model,
|
||||
lm: m,
|
||||
})
|
||||
}
|
||||
}
|
||||
candidates = append(candidates, fallbackModel)
|
||||
candidates = append(candidates, candidateDescriptor{
|
||||
provider: fallbackProvider,
|
||||
model: fallbackModelName,
|
||||
lm: fallbackModel,
|
||||
})
|
||||
|
||||
for _, model := range candidates {
|
||||
summary, err := generateShortText(summaryCtx, model, pushSummaryPrompt, input)
|
||||
pushSeedSummary := chatdebug.SeedSummary("Push summary")
|
||||
|
||||
for _, candidate := range candidates {
|
||||
candidateModel := candidate.lm
|
||||
candidateCtx := summaryCtx
|
||||
var debugRun *database.ChatDebugRun
|
||||
if debugEnabled {
|
||||
run, err := debugSvc.CreateRun(summaryCtx, chatdebug.CreateRunParams{
|
||||
ChatID: chat.ID,
|
||||
TriggerMessageID: triggerMessageID,
|
||||
HistoryTipMessageID: historyTipMessageID,
|
||||
Kind: chatdebug.KindQuickgen,
|
||||
Status: chatdebug.StatusInProgress,
|
||||
Provider: candidate.provider,
|
||||
Model: candidate.model,
|
||||
Summary: pushSeedSummary,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Warn(ctx, "failed to create quickgen debug run",
|
||||
slog.F("chat_id", chat.ID),
|
||||
slog.F("provider", candidate.provider),
|
||||
slog.F("model", candidate.model),
|
||||
slog.Error(err),
|
||||
)
|
||||
} else {
|
||||
debugRun = &run
|
||||
candidateCtx = chatdebug.ContextWithRun(
|
||||
candidateCtx,
|
||||
&chatdebug.RunContext{
|
||||
RunID: run.ID,
|
||||
ChatID: chat.ID,
|
||||
TriggerMessageID: triggerMessageID,
|
||||
HistoryTipMessageID: historyTipMessageID,
|
||||
Kind: chatdebug.KindQuickgen,
|
||||
Provider: candidate.provider,
|
||||
Model: candidate.model,
|
||||
},
|
||||
)
|
||||
debugModel, err := newQuickgenDebugModel(
|
||||
chat,
|
||||
keys,
|
||||
debugSvc,
|
||||
candidate.provider,
|
||||
candidate.model,
|
||||
)
|
||||
if err != nil {
|
||||
logger.Warn(ctx, "failed to build quickgen debug model",
|
||||
slog.F("chat_id", chat.ID),
|
||||
slog.F("provider", candidate.provider),
|
||||
slog.F("model", candidate.model),
|
||||
slog.Error(err),
|
||||
)
|
||||
} else {
|
||||
candidateModel = debugModel
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
summary, err := generateShortText(
|
||||
candidateCtx,
|
||||
candidateModel,
|
||||
pushSummaryPrompt,
|
||||
input,
|
||||
)
|
||||
if debugRun != nil {
|
||||
status := chatdebug.StatusCompleted
|
||||
switch {
|
||||
case err == nil:
|
||||
// keep completed
|
||||
case errors.Is(err, context.Canceled):
|
||||
status = chatdebug.StatusInterrupted
|
||||
default:
|
||||
status = chatdebug.StatusError
|
||||
}
|
||||
finalizeCtx, finalizeCancel := context.WithTimeout(
|
||||
context.WithoutCancel(ctx), 10*time.Second,
|
||||
)
|
||||
finalSummary := pushSeedSummary
|
||||
if aggregated, aggErr := debugSvc.AggregateRunSummary(
|
||||
finalizeCtx,
|
||||
debugRun.ID,
|
||||
pushSeedSummary,
|
||||
); aggErr != nil {
|
||||
logger.Warn(ctx, "failed to aggregate debug run summary",
|
||||
slog.F("chat_id", chat.ID),
|
||||
slog.F("run_id", debugRun.ID),
|
||||
slog.Error(aggErr),
|
||||
)
|
||||
} else {
|
||||
finalSummary = aggregated
|
||||
}
|
||||
if _, updateErr := debugSvc.UpdateRun(
|
||||
finalizeCtx,
|
||||
chatdebug.UpdateRunParams{
|
||||
ID: debugRun.ID,
|
||||
ChatID: chat.ID,
|
||||
Status: status,
|
||||
Summary: finalSummary,
|
||||
FinishedAt: time.Now(),
|
||||
},
|
||||
); updateErr != nil {
|
||||
logger.Warn(ctx, "failed to finalize quickgen debug run",
|
||||
slog.F("chat_id", chat.ID),
|
||||
slog.F("run_id", debugRun.ID),
|
||||
slog.Error(updateErr),
|
||||
)
|
||||
}
|
||||
chatdebug.CleanupStepCounter(debugRun.ID)
|
||||
finalizeCancel()
|
||||
}
|
||||
if err != nil {
|
||||
logger.Debug(ctx, "push summary model candidate failed",
|
||||
slog.Error(err),
|
||||
@@ -610,7 +915,8 @@ func generatePushSummary(
|
||||
|
||||
// generateShortText calls a model with a system prompt and user
|
||||
// input, returning a cleaned-up short text response. It reuses the
|
||||
// same retry logic as title generation.
|
||||
// same retry logic as title generation. Retries can therefore
|
||||
// produce multiple debug steps for a single quickgen run.
|
||||
func generateShortText(
|
||||
ctx context.Context,
|
||||
model fantasy.LanguageModel,
|
||||
|
||||
Reference in New Issue
Block a user