Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 72960aeb77 |
@@ -18,35 +18,35 @@ The 5.x era resolves years of module system ambiguity and cleans house on legacy
|
||||
|
||||
The left column reflects patterns still common before TypeScript 5.x. Write the right column instead. The "Since" column tells you the minimum TypeScript version required.
|
||||
|
||||
| Old pattern | Modern replacement | Since |
|
||||
| ---------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------- | ------ |
|
||||
| `--experimentalDecorators` + legacy decorator signatures | Standard decorators (TC39): `function dec(target, context: ClassMethodDecoratorContext)` — no flag needed | 5.0 |
|
||||
| Requiring callers to add `as const` at call sites | `<const T extends HasNames>(arg: T)` — `const` modifier on type parameter | 5.0 |
|
||||
| `--importsNotUsedAsValues` + `--preserveValueImports` | `--verbatimModuleSyntax` | 5.0 |
|
||||
| `import { Foo } from "..."` when `Foo` is only used as a type | `import { type Foo } from "..."` or `import type { Foo } from "..."` | 5.0 |
|
||||
| `"extends": "@tsconfig/strictest/tsconfig.json"` chain | `"extends": ["@tsconfig/strictest/tsconfig.json", "./tsconfig.base.json"]` (array form) | 5.0 |
|
||||
| `try { ... } finally { resource.close(); resource.delete(); }` | `using resource = acquireResource()` — calls `[Symbol.dispose]()` automatically | 5.2 |
|
||||
| `try { ... } finally { await resource.close() }` | `await using resource = acquireAsyncResource()` | 5.2 |
|
||||
| Ad-hoc cleanup with multiple `try/finally` blocks | `using cleanup = new DisposableStack(); cleanup.defer(() => ...)` | 5.2 |
|
||||
| `import data from "./data.json" assert { type: "json" }` | `import data from "./data.json" with { type: "json" }` | 5.3 |
|
||||
| `.filter(Boolean)` or `.filter(x => !!x)` to remove nulls | `.filter(x => x !== undefined)` or `.filter(x => x !== null)` (infers type predicate) | 5.5 |
|
||||
| Extra phantom type param to block inference bleed: `<C extends string, D extends C>` | `NoInfer<C>` on the parameter you don't want to drive inference | 5.4 |
|
||||
| `/** @typedef {import("./types").Foo} Foo */` in JS files | `/** @import { Foo } from "./types" */` (JSDoc `@import` tag) | 5.5 |
|
||||
| `myArray.reverse()` mutating in place | `myArray.toReversed()` (returns new array) | 5.2 |
|
||||
| `myArray.sort(cmp)` mutating in place | `myArray.toSorted(cmp)` (returns new array) | 5.2 |
|
||||
| `const copy = [...arr]; copy[i] = v` | `arr.with(i, v)` (returns new array) | 5.2 |
|
||||
| Manual `has`/`get`/`set` pattern on `Map` | `map.getOrInsert(key, defaultValue)` or `getOrInsertComputed(key, fn)` | 6.0 RC |
|
||||
| `new RegExp(str.replace(/[.\*+?^${}()\[\]\\]/g, '\\$&'))` | `new RegExp(RegExp.escape(str))` | 6.0 RC |
|
||||
| `--moduleResolution node` (node10) | `--moduleResolution nodenext` (Node.js) or `--moduleResolution bundler` (bundlers/Bun) | 6.0 RC |
|
||||
| `"baseUrl": "./src"` + `"@app/*": ["app/*"]` in paths | Remove `baseUrl`; use `"@app/*": ["./src/app/*"]` in paths directly | 6.0 RC |
|
||||
| `module Foo { export const x = 1; }` | `namespace Foo { export const x = 1; }` | 6.0 RC |
|
||||
| `export * from "..."` when all re-exported members are types | `export type * from "..."` (or `export type * as ns from "..."`) | 5.0 |
|
||||
| `function f(): undefined { return undefined; }` — explicit return required in `: undefined`-returning function | Remove the `return` entirely; `undefined`-returning functions no longer require any return statement | 5.1 |
|
||||
| Manual type predicate annotation on a simple arrow: `(x: T \| undefined): x is T => x !== undefined` | Remove the annotation; TypeScript infers `x is T` from `!== null/undefined` and `instanceof` checks automatically | 5.5 |
|
||||
| `const val = obj[key]; if (typeof val === "string") { use(val); }` — extract to const to narrow indexed access | `if (typeof obj[key] === "string") { obj[key].toUpperCase(); }` directly — both `obj` and `key` must be effectively constant | 5.5 |
|
||||
| Copy narrowed `let`/param to a `const`, or restructure code to escape stale closure narrowing after reassignment | Remove the copy; narrowing survives into closures created after the last assignment to the variable | 5.4 |
|
||||
| `(arr as string[]).filter(...)` or restructure to avoid "not callable" errors on `string[] \| number[]` | Call `.filter`, `.find`, `.some`, `.every`, `.reduce` directly on union-of-array types | 5.2 |
|
||||
| `if`/`else` chain used to work around lack of narrowing inside a `switch (true)` body | `switch (true)` — each `case` condition now narrows the tested variable in its clause | 5.3 |
|
||||
| Old pattern | Modern replacement | Since |
|
||||
| ---------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------- | -------------------------------- | ------ |
|
||||
| `--experimentalDecorators` + legacy decorator signatures | Standard decorators (TC39): `function dec(target, context: ClassMethodDecoratorContext)` — no flag needed | 5.0 |
|
||||
| Requiring callers to add `as const` at call sites | `<const T extends HasNames>(arg: T)` — `const` modifier on type parameter | 5.0 |
|
||||
| `--importsNotUsedAsValues` + `--preserveValueImports` | `--verbatimModuleSyntax` | 5.0 |
|
||||
| `import { Foo } from "..."` when `Foo` is only used as a type | `import { type Foo } from "..."` or `import type { Foo } from "..."` | 5.0 |
|
||||
| `"extends": "@tsconfig/strictest/tsconfig.json"` chain | `"extends": ["@tsconfig/strictest/tsconfig.json", "./tsconfig.base.json"]` (array form) | 5.0 |
|
||||
| `try { ... } finally { resource.close(); resource.delete(); }` | `using resource = acquireResource()` — calls `[Symbol.dispose]()` automatically | 5.2 |
|
||||
| `try { ... } finally { await resource.close() }` | `await using resource = acquireAsyncResource()` | 5.2 |
|
||||
| Ad-hoc cleanup with multiple `try/finally` blocks | `using cleanup = new DisposableStack(); cleanup.defer(() => ...)` | 5.2 |
|
||||
| `import data from "./data.json" assert { type: "json" }` | `import data from "./data.json" with { type: "json" }` | 5.3 |
|
||||
| `.filter(Boolean)` or `.filter(x => !!x)` to remove nulls | `.filter(x => x !== undefined)` or `.filter(x => x !== null)` (infers type predicate) | 5.5 |
|
||||
| Extra phantom type param to block inference bleed: `<C extends string, D extends C>` | `NoInfer<C>` on the parameter you don't want to drive inference | 5.4 |
|
||||
| `/** @typedef {import("./types").Foo} Foo */` in JS files | `/** @import { Foo } from "./types" */` (JSDoc `@import` tag) | 5.5 |
|
||||
| `myArray.reverse()` mutating in place | `myArray.toReversed()` (returns new array) | 5.2 |
|
||||
| `myArray.sort(cmp)` mutating in place | `myArray.toSorted(cmp)` (returns new array) | 5.2 |
|
||||
| `const copy = [...arr]; copy[i] = v` | `arr.with(i, v)` (returns new array) | 5.2 |
|
||||
| Manual `has`/`get`/`set` pattern on `Map` | `map.getOrInsert(key, defaultValue)` or `getOrInsertComputed(key, fn)` | 6.0 RC |
|
||||
| `new RegExp(str.replace(/[.\*+?^${}() | [\]\\]/g, '\\$&'))` | `new RegExp(RegExp.escape(str))` | 6.0 RC |
|
||||
| `--moduleResolution node` (node10) | `--moduleResolution nodenext` (Node.js) or `--moduleResolution bundler` (bundlers/Bun) | 6.0 RC |
|
||||
| `"baseUrl": "./src"` + `"@app/*": ["app/*"]` in paths | Remove `baseUrl`; use `"@app/*": ["./src/app/*"]` in paths directly | 6.0 RC |
|
||||
| `module Foo { export const x = 1; }` | `namespace Foo { export const x = 1; }` | 6.0 RC |
|
||||
| `export * from "..."` when all re-exported members are types | `export type * from "..."` (or `export type * as ns from "..."`) | 5.0 |
|
||||
| `function f(): undefined { return undefined; }` — explicit return required in `: undefined`-returning function | Remove the `return` entirely; `undefined`-returning functions no longer require any return statement | 5.1 |
|
||||
| Manual type predicate annotation on a simple arrow: `(x: T \| undefined): x is T => x !== undefined` | Remove the annotation; TypeScript infers `x is T` from `!== null/undefined` and `instanceof` checks automatically | 5.5 |
|
||||
| `const val = obj[key]; if (typeof val === "string") { use(val); }` — extract to const to narrow indexed access | `if (typeof obj[key] === "string") { obj[key].toUpperCase(); }` directly — both `obj` and `key` must be effectively constant | 5.5 |
|
||||
| Copy narrowed `let`/param to a `const`, or restructure code to escape stale closure narrowing after reassignment | Remove the copy; narrowing survives into closures created after the last assignment to the variable | 5.4 |
|
||||
| `(arr as string[]).filter(...)` or restructure to avoid "not callable" errors on `string[] \| number[]` | Call `.filter`, `.find`, `.some`, `.every`, `.reduce` directly on union-of-array types | 5.2 |
|
||||
| `if`/`else` chain used to work around lack of narrowing inside a `switch (true)` body | `switch (true)` — each `case` condition now narrows the tested variable in its clause | 5.3 |
|
||||
|
||||
## New capabilities
|
||||
|
||||
|
||||
@@ -91,6 +91,12 @@ updates:
|
||||
emotion:
|
||||
patterns:
|
||||
- "@emotion*"
|
||||
exclude-patterns:
|
||||
- "jest-runner-eslint"
|
||||
jest:
|
||||
patterns:
|
||||
- "jest"
|
||||
- "@types/jest"
|
||||
vite:
|
||||
patterns:
|
||||
- "vite*"
|
||||
|
||||
+1036
-1139
File diff suppressed because it is too large
Load Diff
@@ -98,21 +98,6 @@ message Manifest {
|
||||
repeated WorkspaceApp apps = 11;
|
||||
repeated WorkspaceAgentMetadata.Description metadata = 12;
|
||||
repeated WorkspaceAgentDevcontainer devcontainers = 17;
|
||||
repeated WorkspaceSecret secrets = 19;
|
||||
}
|
||||
|
||||
// WorkspaceSecret is a secret included in the agent manifest
|
||||
// for injection into a workspace.
|
||||
message WorkspaceSecret {
|
||||
// Environment variable name to inject (e.g. "GITHUB_TOKEN").
|
||||
// Empty string means this secret is not injected as an env var.
|
||||
string env_name = 1;
|
||||
// File path to write the secret value to (e.g.
|
||||
// "~/.aws/credentials"). Empty string means this secret is not
|
||||
// written to a file.
|
||||
string file_path = 2;
|
||||
// The decrypted secret value.
|
||||
bytes value = 3;
|
||||
}
|
||||
|
||||
message WorkspaceAgentDevcontainer {
|
||||
|
||||
@@ -812,18 +812,12 @@ func TestPortableDesktop_IdleTimeout_StopsRecordings(t *testing.T) {
|
||||
stopTrap := clk.Trap().NewTimer("agentdesktop", "stop_timeout")
|
||||
|
||||
// Advance past idle timeout to trigger the stop-all.
|
||||
clk.Advance(idleTimeout).MustWait(ctx)
|
||||
clk.Advance(idleTimeout)
|
||||
|
||||
// Wait for the stop timer to be created, then release it.
|
||||
stopTrap.MustWait(ctx).MustRelease(ctx)
|
||||
stopTrap.Close()
|
||||
|
||||
// Advance past the 15s stop timeout so the process is
|
||||
// forcibly killed. Without this the test depends on the real
|
||||
// shell handling SIGINT promptly, which is unreliable on
|
||||
// macOS CI runners (the flake in #1461).
|
||||
clk.Advance(15 * time.Second).MustWait(ctx)
|
||||
|
||||
// The recording process should now be stopped.
|
||||
require.Eventually(t, func() bool {
|
||||
pd.mu.Lock()
|
||||
@@ -945,17 +939,11 @@ func TestPortableDesktop_IdleTimeout_MultipleRecordings(t *testing.T) {
|
||||
stopTrap := clk.Trap().NewTimer("agentdesktop", "stop_timeout")
|
||||
|
||||
// Advance past idle timeout.
|
||||
clk.Advance(idleTimeout).MustWait(ctx)
|
||||
clk.Advance(idleTimeout)
|
||||
|
||||
// Each idle monitor goroutine serializes on p.mu, so the
|
||||
// second stop timer is only created after the first stop
|
||||
// completes. Advance past the 15s stop timeout after each
|
||||
// release so the process is forcibly killed instead of
|
||||
// depending on SIGINT (unreliable on macOS — see #1461).
|
||||
// Wait for both stop timers.
|
||||
stopTrap.MustWait(ctx).MustRelease(ctx)
|
||||
clk.Advance(15 * time.Second).MustWait(ctx)
|
||||
stopTrap.MustWait(ctx).MustRelease(ctx)
|
||||
clk.Advance(15 * time.Second).MustWait(ctx)
|
||||
stopTrap.Close()
|
||||
|
||||
// Both recordings should be stopped.
|
||||
|
||||
-7
@@ -211,13 +211,6 @@ AI BRIDGE PROXY OPTIONS:
|
||||
certificates not trusted by the system. If not provided, the system
|
||||
certificate pool is used.
|
||||
|
||||
CHAT OPTIONS:
|
||||
Configure the background chat processing daemon.
|
||||
|
||||
--chat-debug-logging-enabled bool, $CODER_CHAT_DEBUG_LOGGING_ENABLED (default: false)
|
||||
Force chat debug logging on for every chat, bypassing the runtime
|
||||
admin and user opt-in settings.
|
||||
|
||||
CLIENT OPTIONS:
|
||||
These options change the behavior of how clients interact with the Coder.
|
||||
Clients include the Coder CLI, Coder Desktop, IDE extensions, and the web UI.
|
||||
|
||||
-4
@@ -757,10 +757,6 @@ chat:
|
||||
# How many pending chats a worker should acquire per polling cycle.
|
||||
# (default: 10, type: int)
|
||||
acquireBatchSize: 10
|
||||
# Force chat debug logging on for every chat, bypassing the runtime admin and user
|
||||
# opt-in settings.
|
||||
# (default: false, type: bool)
|
||||
debugLoggingEnabled: false
|
||||
aibridge:
|
||||
# Whether to start an in-memory aibridged instance.
|
||||
# (default: false, type: bool)
|
||||
|
||||
Generated
-3
@@ -14691,9 +14691,6 @@ const docTemplate = `{
|
||||
"properties": {
|
||||
"acquire_batch_size": {
|
||||
"type": "integer"
|
||||
},
|
||||
"debug_logging_enabled": {
|
||||
"type": "boolean"
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
Generated
-3
@@ -13204,9 +13204,6 @@
|
||||
"properties": {
|
||||
"acquire_batch_size": {
|
||||
"type": "integer"
|
||||
},
|
||||
"debug_logging_enabled": {
|
||||
"type": "boolean"
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
@@ -784,7 +784,6 @@ func New(options *Options) *API {
|
||||
SubscribeFn: options.ChatSubscribeFn,
|
||||
MaxChatsPerAcquire: int32(maxChatsPerAcquire), //nolint:gosec // maxChatsPerAcquire is clamped to int32 range above.
|
||||
ProviderAPIKeys: ChatProviderAPIKeysFromDeploymentValues(options.DeploymentValues),
|
||||
AlwaysEnableDebugLogs: options.DeploymentValues.AI.Chat.DebugLoggingEnabled.Value(),
|
||||
AgentConn: api.agentProvider.AgentConn,
|
||||
AgentInactiveDisconnectTimeout: api.AgentInactiveDisconnectTimeout,
|
||||
InstructionLookupTimeout: options.ChatdInstructionLookupTimeout,
|
||||
@@ -1183,10 +1182,6 @@ func New(options *Options) *API {
|
||||
r.Put("/system-prompt", api.putChatSystemPrompt)
|
||||
r.Get("/desktop-enabled", api.getChatDesktopEnabled)
|
||||
r.Put("/desktop-enabled", api.putChatDesktopEnabled)
|
||||
r.Get("/debug-logging", api.getChatDebugLogging)
|
||||
r.Put("/debug-logging", api.putChatDebugLogging)
|
||||
r.Get("/user-debug-logging", api.getUserChatDebugLogging)
|
||||
r.Put("/user-debug-logging", api.putUserChatDebugLogging)
|
||||
r.Get("/user-prompt", api.getUserChatCustomPrompt)
|
||||
r.Put("/user-prompt", api.putUserChatCustomPrompt)
|
||||
r.Get("/user-compaction-thresholds", api.getUserChatCompactionThresholds)
|
||||
@@ -1257,10 +1252,6 @@ func New(options *Options) *API {
|
||||
r.Delete("/", api.deleteChatQueuedMessage)
|
||||
r.Post("/promote", api.promoteChatQueuedMessage)
|
||||
})
|
||||
r.Route("/debug", func(r chi.Router) {
|
||||
r.Get("/runs", api.getChatDebugRuns)
|
||||
r.Get("/runs/{debugRun}", api.getChatDebugRun)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
|
||||
@@ -1525,14 +1525,6 @@ func chatMessageParts(m database.ChatMessage) ([]codersdk.ChatMessagePart, error
|
||||
return parts, nil
|
||||
}
|
||||
|
||||
func nullUUIDPtr(v uuid.NullUUID) *uuid.UUID {
|
||||
if !v.Valid {
|
||||
return nil
|
||||
}
|
||||
value := v.UUID
|
||||
return &value
|
||||
}
|
||||
|
||||
func nullInt64Ptr(v sql.NullInt64) *int64 {
|
||||
if !v.Valid {
|
||||
return nil
|
||||
@@ -1541,22 +1533,6 @@ func nullInt64Ptr(v sql.NullInt64) *int64 {
|
||||
return &value
|
||||
}
|
||||
|
||||
func nullStringPtr(v sql.NullString) *string {
|
||||
if !v.Valid {
|
||||
return nil
|
||||
}
|
||||
value := v.String
|
||||
return &value
|
||||
}
|
||||
|
||||
func nullTimePtr(v sql.NullTime) *time.Time {
|
||||
if !v.Valid {
|
||||
return nil
|
||||
}
|
||||
value := v.Time
|
||||
return &value
|
||||
}
|
||||
|
||||
// Chat converts a database.Chat to a codersdk.Chat. It coalesces
|
||||
// nil slices and maps to empty values for JSON serialization and
|
||||
// derives RootChatID from the parent chain when not explicitly set.
|
||||
@@ -1643,115 +1619,6 @@ func Chat(c database.Chat, diffStatus *database.ChatDiffStatus, files []database
|
||||
return chat
|
||||
}
|
||||
|
||||
func chatDebugAttempts(raw json.RawMessage) []map[string]any {
|
||||
if len(raw) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var attempts []map[string]any
|
||||
if err := json.Unmarshal(raw, &attempts); err != nil {
|
||||
return []map[string]any{{
|
||||
"error": "malformed attempts payload",
|
||||
"raw": string(raw),
|
||||
}}
|
||||
}
|
||||
return attempts
|
||||
}
|
||||
|
||||
// rawJSONObject deserializes a JSON object payload for debug display.
|
||||
// If the payload is malformed, it returns a map with "error" and "raw"
|
||||
// keys preserving the original content for diagnostics. Callers that
|
||||
// consume the result programmatically should check for the "error" key.
|
||||
func rawJSONObject(raw json.RawMessage) map[string]any {
|
||||
if len(raw) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var object map[string]any
|
||||
if err := json.Unmarshal(raw, &object); err != nil {
|
||||
return map[string]any{
|
||||
"error": "malformed debug payload",
|
||||
"raw": string(raw),
|
||||
}
|
||||
}
|
||||
return object
|
||||
}
|
||||
|
||||
func nullRawJSONObject(raw pqtype.NullRawMessage) map[string]any {
|
||||
if !raw.Valid {
|
||||
return nil
|
||||
}
|
||||
return rawJSONObject(raw.RawMessage)
|
||||
}
|
||||
|
||||
// ChatDebugRunSummary converts a database.ChatDebugRun to a
|
||||
// codersdk.ChatDebugRunSummary.
|
||||
func ChatDebugRunSummary(r database.ChatDebugRun) codersdk.ChatDebugRunSummary {
|
||||
return codersdk.ChatDebugRunSummary{
|
||||
ID: r.ID,
|
||||
ChatID: r.ChatID,
|
||||
Kind: codersdk.ChatDebugRunKind(r.Kind),
|
||||
Status: codersdk.ChatDebugStatus(r.Status),
|
||||
Provider: nullStringPtr(r.Provider),
|
||||
Model: nullStringPtr(r.Model),
|
||||
Summary: rawJSONObject(r.Summary),
|
||||
StartedAt: r.StartedAt,
|
||||
UpdatedAt: r.UpdatedAt,
|
||||
FinishedAt: nullTimePtr(r.FinishedAt),
|
||||
}
|
||||
}
|
||||
|
||||
// ChatDebugStep converts a database.ChatDebugStep to a
|
||||
// codersdk.ChatDebugStep.
|
||||
func ChatDebugStep(s database.ChatDebugStep) codersdk.ChatDebugStep {
|
||||
return codersdk.ChatDebugStep{
|
||||
ID: s.ID,
|
||||
RunID: s.RunID,
|
||||
ChatID: s.ChatID,
|
||||
StepNumber: s.StepNumber,
|
||||
Operation: codersdk.ChatDebugStepOperation(s.Operation),
|
||||
Status: codersdk.ChatDebugStatus(s.Status),
|
||||
HistoryTipMessageID: nullInt64Ptr(s.HistoryTipMessageID),
|
||||
AssistantMessageID: nullInt64Ptr(s.AssistantMessageID),
|
||||
NormalizedRequest: rawJSONObject(s.NormalizedRequest),
|
||||
NormalizedResponse: nullRawJSONObject(s.NormalizedResponse),
|
||||
Usage: nullRawJSONObject(s.Usage),
|
||||
Attempts: chatDebugAttempts(s.Attempts),
|
||||
Error: nullRawJSONObject(s.Error),
|
||||
Metadata: rawJSONObject(s.Metadata),
|
||||
StartedAt: s.StartedAt,
|
||||
UpdatedAt: s.UpdatedAt,
|
||||
FinishedAt: nullTimePtr(s.FinishedAt),
|
||||
}
|
||||
}
|
||||
|
||||
// ChatDebugRunDetail converts a database.ChatDebugRun and its steps
|
||||
// to a codersdk.ChatDebugRun.
|
||||
func ChatDebugRunDetail(r database.ChatDebugRun, steps []database.ChatDebugStep) codersdk.ChatDebugRun {
|
||||
sdkSteps := make([]codersdk.ChatDebugStep, 0, len(steps))
|
||||
for _, s := range steps {
|
||||
sdkSteps = append(sdkSteps, ChatDebugStep(s))
|
||||
}
|
||||
return codersdk.ChatDebugRun{
|
||||
ID: r.ID,
|
||||
ChatID: r.ChatID,
|
||||
RootChatID: nullUUIDPtr(r.RootChatID),
|
||||
ParentChatID: nullUUIDPtr(r.ParentChatID),
|
||||
ModelConfigID: nullUUIDPtr(r.ModelConfigID),
|
||||
TriggerMessageID: nullInt64Ptr(r.TriggerMessageID),
|
||||
HistoryTipMessageID: nullInt64Ptr(r.HistoryTipMessageID),
|
||||
Kind: codersdk.ChatDebugRunKind(r.Kind),
|
||||
Status: codersdk.ChatDebugStatus(r.Status),
|
||||
Provider: nullStringPtr(r.Provider),
|
||||
Model: nullStringPtr(r.Model),
|
||||
Summary: rawJSONObject(r.Summary),
|
||||
StartedAt: r.StartedAt,
|
||||
UpdatedAt: r.UpdatedAt,
|
||||
FinishedAt: nullTimePtr(r.FinishedAt),
|
||||
Steps: sdkSteps,
|
||||
}
|
||||
}
|
||||
|
||||
// ChatRows converts a slice of database.GetChatsRow (which embeds
|
||||
// Chat plus HasUnread) to codersdk.Chat, looking up diff statuses
|
||||
// from the provided map. When diffStatusesByChatID is non-nil,
|
||||
|
||||
@@ -210,231 +210,6 @@ func TestTemplateVersionParameter_BadDescription(t *testing.T) {
|
||||
req.NotEmpty(sdk.DescriptionPlaintext, "broke the markdown parser with %v", desc)
|
||||
}
|
||||
|
||||
func TestChatDebugRunSummary(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
startedAt := time.Now().UTC().Round(time.Second)
|
||||
finishedAt := startedAt.Add(5 * time.Second)
|
||||
|
||||
run := database.ChatDebugRun{
|
||||
ID: uuid.New(),
|
||||
ChatID: uuid.New(),
|
||||
Kind: "chat_turn",
|
||||
Status: "completed",
|
||||
Provider: sql.NullString{String: "openai", Valid: true},
|
||||
Model: sql.NullString{String: "gpt-4o", Valid: true},
|
||||
Summary: json.RawMessage(`{"step_count":3,"has_error":false}`),
|
||||
StartedAt: startedAt,
|
||||
UpdatedAt: finishedAt,
|
||||
FinishedAt: sql.NullTime{Time: finishedAt, Valid: true},
|
||||
}
|
||||
|
||||
sdk := db2sdk.ChatDebugRunSummary(run)
|
||||
|
||||
require.Equal(t, run.ID, sdk.ID)
|
||||
require.Equal(t, run.ChatID, sdk.ChatID)
|
||||
require.Equal(t, codersdk.ChatDebugRunKindChatTurn, sdk.Kind)
|
||||
require.Equal(t, codersdk.ChatDebugStatusCompleted, sdk.Status)
|
||||
require.NotNil(t, sdk.Provider)
|
||||
require.Equal(t, "openai", *sdk.Provider)
|
||||
require.NotNil(t, sdk.Model)
|
||||
require.Equal(t, "gpt-4o", *sdk.Model)
|
||||
require.Equal(t, map[string]any{"step_count": float64(3), "has_error": false}, sdk.Summary)
|
||||
require.Equal(t, startedAt, sdk.StartedAt)
|
||||
require.Equal(t, finishedAt, sdk.UpdatedAt)
|
||||
require.NotNil(t, sdk.FinishedAt)
|
||||
require.Equal(t, finishedAt, *sdk.FinishedAt)
|
||||
}
|
||||
|
||||
func TestChatDebugRunSummary_NullableFieldsNil(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
run := database.ChatDebugRun{
|
||||
ID: uuid.New(),
|
||||
ChatID: uuid.New(),
|
||||
Kind: "title_generation",
|
||||
Status: "in_progress",
|
||||
Summary: json.RawMessage(`{}`),
|
||||
StartedAt: time.Now().UTC(),
|
||||
UpdatedAt: time.Now().UTC(),
|
||||
}
|
||||
|
||||
sdk := db2sdk.ChatDebugRunSummary(run)
|
||||
|
||||
require.Nil(t, sdk.Provider, "NULL Provider should map to nil")
|
||||
require.Nil(t, sdk.Model, "NULL Model should map to nil")
|
||||
require.Nil(t, sdk.FinishedAt, "NULL FinishedAt should map to nil")
|
||||
}
|
||||
|
||||
func TestChatDebugStep(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
startedAt := time.Now().UTC().Round(time.Second)
|
||||
finishedAt := startedAt.Add(2 * time.Second)
|
||||
attempts := json.RawMessage(`[
|
||||
{
|
||||
"attempt_number": 1,
|
||||
"status": "completed",
|
||||
"raw_request": {"url": "https://example.com"},
|
||||
"raw_response": {"status": "200"},
|
||||
"duration_ms": 123,
|
||||
"started_at": "2026-03-01T10:00:01Z",
|
||||
"finished_at": "2026-03-01T10:00:02Z"
|
||||
}
|
||||
]`)
|
||||
step := database.ChatDebugStep{
|
||||
ID: uuid.New(),
|
||||
RunID: uuid.New(),
|
||||
ChatID: uuid.New(),
|
||||
StepNumber: 1,
|
||||
Operation: "stream",
|
||||
Status: "completed",
|
||||
NormalizedRequest: json.RawMessage(`{"messages":[]}`),
|
||||
Attempts: attempts,
|
||||
Metadata: json.RawMessage(`{"provider":"openai"}`),
|
||||
StartedAt: startedAt,
|
||||
UpdatedAt: finishedAt,
|
||||
FinishedAt: sql.NullTime{Time: finishedAt, Valid: true},
|
||||
}
|
||||
|
||||
sdk := db2sdk.ChatDebugStep(step)
|
||||
|
||||
// Verify all scalar fields are mapped correctly.
|
||||
require.Equal(t, step.ID, sdk.ID)
|
||||
require.Equal(t, step.RunID, sdk.RunID)
|
||||
require.Equal(t, step.ChatID, sdk.ChatID)
|
||||
require.Equal(t, step.StepNumber, sdk.StepNumber)
|
||||
require.Equal(t, codersdk.ChatDebugStepOperationStream, sdk.Operation)
|
||||
require.Equal(t, codersdk.ChatDebugStatusCompleted, sdk.Status)
|
||||
require.Equal(t, startedAt, sdk.StartedAt)
|
||||
require.Equal(t, finishedAt, sdk.UpdatedAt)
|
||||
require.Equal(t, &finishedAt, sdk.FinishedAt)
|
||||
|
||||
// Verify JSON object fields are deserialized.
|
||||
require.NotNil(t, sdk.NormalizedRequest)
|
||||
require.Equal(t, map[string]any{"messages": []any{}}, sdk.NormalizedRequest)
|
||||
require.NotNil(t, sdk.Metadata)
|
||||
require.Equal(t, map[string]any{"provider": "openai"}, sdk.Metadata)
|
||||
|
||||
// Verify nullable fields are nil when the DB row has NULL values.
|
||||
require.Nil(t, sdk.HistoryTipMessageID, "NULL HistoryTipMessageID should map to nil")
|
||||
require.Nil(t, sdk.AssistantMessageID, "NULL AssistantMessageID should map to nil")
|
||||
require.Nil(t, sdk.NormalizedResponse, "NULL NormalizedResponse should map to nil")
|
||||
require.Nil(t, sdk.Usage, "NULL Usage should map to nil")
|
||||
require.Nil(t, sdk.Error, "NULL Error should map to nil")
|
||||
|
||||
// Verify attempts are preserved with all fields.
|
||||
require.Len(t, sdk.Attempts, 1)
|
||||
require.Equal(t, float64(1), sdk.Attempts[0]["attempt_number"])
|
||||
require.Equal(t, "completed", sdk.Attempts[0]["status"])
|
||||
require.Equal(t, float64(123), sdk.Attempts[0]["duration_ms"])
|
||||
require.Equal(t, map[string]any{"url": "https://example.com"}, sdk.Attempts[0]["raw_request"])
|
||||
require.Equal(t, map[string]any{"status": "200"}, sdk.Attempts[0]["raw_response"])
|
||||
}
|
||||
|
||||
func TestChatDebugStep_NullableFieldsPopulated(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tipID := int64(42)
|
||||
asstID := int64(99)
|
||||
step := database.ChatDebugStep{
|
||||
ID: uuid.New(),
|
||||
RunID: uuid.New(),
|
||||
ChatID: uuid.New(),
|
||||
StepNumber: 2,
|
||||
Operation: "generate",
|
||||
Status: "completed",
|
||||
HistoryTipMessageID: sql.NullInt64{Int64: tipID, Valid: true},
|
||||
AssistantMessageID: sql.NullInt64{Int64: asstID, Valid: true},
|
||||
NormalizedRequest: json.RawMessage(`{}`),
|
||||
NormalizedResponse: pqtype.NullRawMessage{RawMessage: json.RawMessage(`{"text":"hi"}`), Valid: true},
|
||||
Usage: pqtype.NullRawMessage{RawMessage: json.RawMessage(`{"tokens":10}`), Valid: true},
|
||||
Error: pqtype.NullRawMessage{RawMessage: json.RawMessage(`{"code":"rate_limit"}`), Valid: true},
|
||||
Attempts: json.RawMessage(`[]`),
|
||||
Metadata: json.RawMessage(`{}`),
|
||||
StartedAt: time.Now().UTC(),
|
||||
UpdatedAt: time.Now().UTC(),
|
||||
}
|
||||
|
||||
sdk := db2sdk.ChatDebugStep(step)
|
||||
|
||||
require.NotNil(t, sdk.HistoryTipMessageID)
|
||||
require.Equal(t, tipID, *sdk.HistoryTipMessageID)
|
||||
require.NotNil(t, sdk.AssistantMessageID)
|
||||
require.Equal(t, asstID, *sdk.AssistantMessageID)
|
||||
require.NotNil(t, sdk.NormalizedResponse)
|
||||
require.Equal(t, map[string]any{"text": "hi"}, sdk.NormalizedResponse)
|
||||
require.NotNil(t, sdk.Usage)
|
||||
require.Equal(t, map[string]any{"tokens": float64(10)}, sdk.Usage)
|
||||
require.NotNil(t, sdk.Error)
|
||||
require.Equal(t, map[string]any{"code": "rate_limit"}, sdk.Error)
|
||||
}
|
||||
|
||||
func TestChatDebugStep_PreservesMalformedAttempts(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
step := database.ChatDebugStep{
|
||||
ID: uuid.New(),
|
||||
RunID: uuid.New(),
|
||||
ChatID: uuid.New(),
|
||||
StepNumber: 1,
|
||||
Operation: "stream",
|
||||
Status: "completed",
|
||||
NormalizedRequest: json.RawMessage(`{"messages":[]}`),
|
||||
Attempts: json.RawMessage(`{"bad":true}`),
|
||||
Metadata: json.RawMessage(`{"provider":"openai"}`),
|
||||
StartedAt: time.Now().UTC(),
|
||||
UpdatedAt: time.Now().UTC(),
|
||||
}
|
||||
|
||||
sdk := db2sdk.ChatDebugStep(step)
|
||||
require.Len(t, sdk.Attempts, 1)
|
||||
require.Equal(t, "malformed attempts payload", sdk.Attempts[0]["error"])
|
||||
require.Equal(t, `{"bad":true}`, sdk.Attempts[0]["raw"])
|
||||
}
|
||||
|
||||
func TestChatDebugRunSummary_PreservesMalformedSummary(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
run := database.ChatDebugRun{
|
||||
ID: uuid.New(),
|
||||
ChatID: uuid.New(),
|
||||
Kind: "chat_turn",
|
||||
Status: "completed",
|
||||
Summary: json.RawMessage(`not-an-object`),
|
||||
StartedAt: time.Now().UTC(),
|
||||
UpdatedAt: time.Now().UTC(),
|
||||
}
|
||||
|
||||
sdk := db2sdk.ChatDebugRunSummary(run)
|
||||
require.Equal(t, "malformed debug payload", sdk.Summary["error"])
|
||||
require.Equal(t, "not-an-object", sdk.Summary["raw"])
|
||||
}
|
||||
|
||||
func TestChatDebugStep_PreservesMalformedRequest(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
step := database.ChatDebugStep{
|
||||
ID: uuid.New(),
|
||||
RunID: uuid.New(),
|
||||
ChatID: uuid.New(),
|
||||
StepNumber: 1,
|
||||
Operation: "stream",
|
||||
Status: "completed",
|
||||
NormalizedRequest: json.RawMessage(`[1,2,3]`),
|
||||
Attempts: json.RawMessage(`[]`),
|
||||
Metadata: json.RawMessage(`"just-a-string"`),
|
||||
StartedAt: time.Now().UTC(),
|
||||
UpdatedAt: time.Now().UTC(),
|
||||
}
|
||||
|
||||
sdk := db2sdk.ChatDebugStep(step)
|
||||
require.Equal(t, "malformed debug payload", sdk.NormalizedRequest["error"])
|
||||
require.Equal(t, "[1,2,3]", sdk.NormalizedRequest["raw"])
|
||||
require.Equal(t, "malformed debug payload", sdk.Metadata["error"])
|
||||
require.Equal(t, `"just-a-string"`, sdk.Metadata["raw"])
|
||||
}
|
||||
|
||||
func TestAIBridgeInterception(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -1860,28 +1860,6 @@ func (q *querier) DeleteApplicationConnectAPIKeysByUserID(ctx context.Context, u
|
||||
return q.db.DeleteApplicationConnectAPIKeysByUserID(ctx, userID)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteChatDebugDataAfterMessageID(ctx context.Context, arg database.DeleteChatDebugDataAfterMessageIDParams) (int64, error) {
|
||||
chat, err := q.db.GetChatByID(ctx, arg.ChatID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return q.db.DeleteChatDebugDataAfterMessageID(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteChatDebugDataByChatID(ctx context.Context, chatID uuid.UUID) (int64, error) {
|
||||
chat, err := q.db.GetChatByID(ctx, chatID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return q.db.DeleteChatDebugDataByChatID(ctx, chatID)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteChatModelConfigByID(ctx context.Context, id uuid.UUID) error {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return err
|
||||
@@ -2369,14 +2347,6 @@ func (q *querier) FetchVolumesResourceMonitorsUpdatedAfter(ctx context.Context,
|
||||
return q.db.FetchVolumesResourceMonitorsUpdatedAfter(ctx, updatedAt)
|
||||
}
|
||||
|
||||
func (q *querier) FinalizeStaleChatDebugRows(ctx context.Context, updatedBefore time.Time) (database.FinalizeStaleChatDebugRowsRow, error) {
|
||||
// Background sweep operates across all chats.
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceChat); err != nil {
|
||||
return database.FinalizeStaleChatDebugRowsRow{}, err
|
||||
}
|
||||
return q.db.FinalizeStaleChatDebugRows(ctx, updatedBefore)
|
||||
}
|
||||
|
||||
func (q *querier) FindMatchingPresetID(ctx context.Context, arg database.FindMatchingPresetIDParams) (uuid.UUID, error) {
|
||||
_, err := q.GetTemplateVersionByID(ctx, arg.TemplateVersionID)
|
||||
if err != nil {
|
||||
@@ -2585,59 +2555,6 @@ func (q *querier) GetChatCostSummary(ctx context.Context, arg database.GetChatCo
|
||||
return q.db.GetChatCostSummary(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatDebugLoggingAllowUsers(ctx context.Context) (bool, error) {
|
||||
// The allow-users flag is a deployment-wide setting read by any
|
||||
// authenticated chat user. We only require that an explicit actor
|
||||
// is present in the context so unauthenticated calls fail closed.
|
||||
if _, ok := ActorFromContext(ctx); !ok {
|
||||
return false, ErrNoActor
|
||||
}
|
||||
return q.db.GetChatDebugLoggingAllowUsers(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatDebugRunByID(ctx context.Context, id uuid.UUID) (database.ChatDebugRun, error) {
|
||||
run, err := q.db.GetChatDebugRunByID(ctx, id)
|
||||
if err != nil {
|
||||
return database.ChatDebugRun{}, err
|
||||
}
|
||||
// Authorize via the owning chat.
|
||||
chat, err := q.db.GetChatByID(ctx, run.ChatID)
|
||||
if err != nil {
|
||||
return database.ChatDebugRun{}, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, chat); err != nil {
|
||||
return database.ChatDebugRun{}, err
|
||||
}
|
||||
return run, nil
|
||||
}
|
||||
|
||||
func (q *querier) GetChatDebugRunsByChatID(ctx context.Context, arg database.GetChatDebugRunsByChatIDParams) ([]database.ChatDebugRun, error) {
|
||||
chat, err := q.db.GetChatByID(ctx, arg.ChatID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, chat); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetChatDebugRunsByChatID(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatDebugStepsByRunID(ctx context.Context, runID uuid.UUID) ([]database.ChatDebugStep, error) {
|
||||
run, err := q.db.GetChatDebugRunByID(ctx, runID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Authorize via the owning chat.
|
||||
chat, err := q.db.GetChatByID(ctx, run.ChatID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, chat); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetChatDebugStepsByRunID(ctx, runID)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatDesktopEnabled(ctx context.Context) (bool, error) {
|
||||
// The desktop-enabled flag is a deployment-wide setting read by any
|
||||
// authenticated chat user and by chatd when deciding whether to expose
|
||||
@@ -4186,17 +4103,6 @@ func (q *querier) GetUserChatCustomPrompt(ctx context.Context, userID uuid.UUID)
|
||||
return q.db.GetUserChatCustomPrompt(ctx, userID)
|
||||
}
|
||||
|
||||
func (q *querier) GetUserChatDebugLoggingEnabled(ctx context.Context, userID uuid.UUID) (bool, error) {
|
||||
u, err := q.db.GetUserByID(ctx, userID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionReadPersonal, u); err != nil {
|
||||
return false, err
|
||||
}
|
||||
return q.db.GetUserChatDebugLoggingEnabled(ctx, userID)
|
||||
}
|
||||
|
||||
func (q *querier) GetUserChatProviderKeys(ctx context.Context, userID uuid.UUID) ([]database.UserChatProviderKey, error) {
|
||||
u, err := q.db.GetUserByID(ctx, userID)
|
||||
if err != nil {
|
||||
@@ -4943,33 +4849,6 @@ func (q *querier) InsertChat(ctx context.Context, arg database.InsertChatParams)
|
||||
return insert(q.log, q.auth, rbac.ResourceChat.WithOwner(arg.OwnerID.String()), q.db.InsertChat)(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) InsertChatDebugRun(ctx context.Context, arg database.InsertChatDebugRunParams) (database.ChatDebugRun, error) {
|
||||
chat, err := q.db.GetChatByID(ctx, arg.ChatID)
|
||||
if err != nil {
|
||||
return database.ChatDebugRun{}, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return database.ChatDebugRun{}, err
|
||||
}
|
||||
return q.db.InsertChatDebugRun(ctx, arg)
|
||||
}
|
||||
|
||||
// InsertChatDebugStep creates a new step in a debug run. The underlying
|
||||
// SQL uses INSERT ... SELECT ... FROM chat_debug_runs to enforce that the
|
||||
// run exists and belongs to the specified chat. If the run_id is invalid
|
||||
// or the chat_id doesn't match, the INSERT produces 0 rows and SQLC
|
||||
// returns sql.ErrNoRows.
|
||||
func (q *querier) InsertChatDebugStep(ctx context.Context, arg database.InsertChatDebugStepParams) (database.ChatDebugStep, error) {
|
||||
chat, err := q.db.GetChatByID(ctx, arg.ChatID)
|
||||
if err != nil {
|
||||
return database.ChatDebugStep{}, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return database.ChatDebugStep{}, err
|
||||
}
|
||||
return q.db.InsertChatDebugStep(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) InsertChatFile(ctx context.Context, arg database.InsertChatFileParams) (database.InsertChatFileRow, error) {
|
||||
// Authorize create on chat resource scoped to the owner and org.
|
||||
return insert(q.log, q.auth, rbac.ResourceChat.WithOwner(arg.OwnerID.String()).InOrg(arg.OrganizationID), q.db.InsertChatFile)(ctx, arg)
|
||||
@@ -5968,28 +5847,6 @@ func (q *querier) UpdateChatByID(ctx context.Context, arg database.UpdateChatByI
|
||||
return q.db.UpdateChatByID(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateChatDebugRun(ctx context.Context, arg database.UpdateChatDebugRunParams) (database.ChatDebugRun, error) {
|
||||
chat, err := q.db.GetChatByID(ctx, arg.ChatID)
|
||||
if err != nil {
|
||||
return database.ChatDebugRun{}, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return database.ChatDebugRun{}, err
|
||||
}
|
||||
return q.db.UpdateChatDebugRun(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateChatDebugStep(ctx context.Context, arg database.UpdateChatDebugStepParams) (database.ChatDebugStep, error) {
|
||||
chat, err := q.db.GetChatByID(ctx, arg.ChatID)
|
||||
if err != nil {
|
||||
return database.ChatDebugStep{}, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return database.ChatDebugStep{}, err
|
||||
}
|
||||
return q.db.UpdateChatDebugStep(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateChatHeartbeats(ctx context.Context, arg database.UpdateChatHeartbeatsParams) ([]uuid.UUID, error) {
|
||||
// The batch heartbeat is a system-level operation filtered by
|
||||
// worker_id. Authorization is enforced by the AsChatd context
|
||||
@@ -7222,13 +7079,6 @@ func (q *querier) UpsertBoundaryUsageStats(ctx context.Context, arg database.Ups
|
||||
return q.db.UpsertBoundaryUsageStats(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpsertChatDebugLoggingAllowUsers(ctx context.Context, allowUsers bool) error {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.UpsertChatDebugLoggingAllowUsers(ctx, allowUsers)
|
||||
}
|
||||
|
||||
func (q *querier) UpsertChatDesktopEnabled(ctx context.Context, enableDesktop bool) error {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return err
|
||||
@@ -7459,17 +7309,6 @@ func (q *querier) UpsertTemplateUsageStats(ctx context.Context) error {
|
||||
return q.db.UpsertTemplateUsageStats(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) UpsertUserChatDebugLoggingEnabled(ctx context.Context, arg database.UpsertUserChatDebugLoggingEnabledParams) error {
|
||||
u, err := q.db.GetUserByID(ctx, arg.UserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdatePersonal, u); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.UpsertUserChatDebugLoggingEnabled(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpsertUserChatProviderKey(ctx context.Context, arg database.UpsertUserChatProviderKeyParams) (database.UserChatProviderKey, error) {
|
||||
u, err := q.db.GetUserByID(ctx, arg.UserID)
|
||||
if err != nil {
|
||||
|
||||
@@ -461,89 +461,6 @@ func (s *MethodTestSuite) TestChats() {
|
||||
dbm.EXPECT().DeleteChatQueuedMessage(gomock.Any(), args).Return(nil).AnyTimes()
|
||||
check.Args(args).Asserts(chat, policy.ActionUpdate).Returns()
|
||||
}))
|
||||
s.Run("DeleteChatDebugDataAfterMessageID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
arg := database.DeleteChatDebugDataAfterMessageIDParams{ChatID: chat.ID, MessageID: 123}
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().DeleteChatDebugDataAfterMessageID(gomock.Any(), arg).Return(int64(1), nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(int64(1))
|
||||
}))
|
||||
s.Run("DeleteChatDebugDataByChatID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().DeleteChatDebugDataByChatID(gomock.Any(), chat.ID).Return(int64(1), nil).AnyTimes()
|
||||
check.Args(chat.ID).Asserts(chat, policy.ActionUpdate).Returns(int64(1))
|
||||
}))
|
||||
s.Run("FinalizeStaleChatDebugRows", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
updatedBefore := dbtime.Now()
|
||||
row := database.FinalizeStaleChatDebugRowsRow{RunsFinalized: 1, StepsFinalized: 2}
|
||||
dbm.EXPECT().FinalizeStaleChatDebugRows(gomock.Any(), updatedBefore).Return(row, nil).AnyTimes()
|
||||
check.Args(updatedBefore).Asserts(rbac.ResourceChat, policy.ActionUpdate).Returns(row)
|
||||
}))
|
||||
s.Run("GetChatDebugLoggingAllowUsers", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
dbm.EXPECT().GetChatDebugLoggingAllowUsers(gomock.Any()).Return(true, nil).AnyTimes()
|
||||
check.Args().Asserts().Returns(true)
|
||||
}))
|
||||
s.Run("GetChatDebugRunByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
run := database.ChatDebugRun{ID: uuid.New(), ChatID: chat.ID}
|
||||
dbm.EXPECT().GetChatDebugRunByID(gomock.Any(), run.ID).Return(run, nil).AnyTimes()
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
check.Args(run.ID).Asserts(chat, policy.ActionRead).Returns(run)
|
||||
}))
|
||||
s.Run("GetChatDebugRunsByChatID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
runs := []database.ChatDebugRun{{ID: uuid.New(), ChatID: chat.ID}}
|
||||
arg := database.GetChatDebugRunsByChatIDParams{ChatID: chat.ID, LimitVal: 100}
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().GetChatDebugRunsByChatID(gomock.Any(), arg).Return(runs, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionRead).Returns(runs)
|
||||
}))
|
||||
s.Run("GetChatDebugStepsByRunID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
run := database.ChatDebugRun{ID: uuid.New(), ChatID: chat.ID}
|
||||
steps := []database.ChatDebugStep{{ID: uuid.New(), RunID: run.ID, ChatID: chat.ID}}
|
||||
dbm.EXPECT().GetChatDebugRunByID(gomock.Any(), run.ID).Return(run, nil).AnyTimes()
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().GetChatDebugStepsByRunID(gomock.Any(), run.ID).Return(steps, nil).AnyTimes()
|
||||
check.Args(run.ID).Asserts(chat, policy.ActionRead).Returns(steps)
|
||||
}))
|
||||
s.Run("InsertChatDebugRun", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
arg := database.InsertChatDebugRunParams{ChatID: chat.ID, Kind: "chat_turn", Status: "in_progress"}
|
||||
run := database.ChatDebugRun{ID: uuid.New(), ChatID: chat.ID}
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().InsertChatDebugRun(gomock.Any(), arg).Return(run, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(run)
|
||||
}))
|
||||
s.Run("InsertChatDebugStep", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
arg := database.InsertChatDebugStepParams{RunID: uuid.New(), ChatID: chat.ID, StepNumber: 1, Operation: "stream", Status: "in_progress"}
|
||||
step := database.ChatDebugStep{ID: uuid.New(), RunID: arg.RunID, ChatID: chat.ID}
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().InsertChatDebugStep(gomock.Any(), arg).Return(step, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(step)
|
||||
}))
|
||||
s.Run("UpdateChatDebugRun", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
arg := database.UpdateChatDebugRunParams{ID: uuid.New(), ChatID: chat.ID}
|
||||
run := database.ChatDebugRun{ID: arg.ID, ChatID: chat.ID}
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().UpdateChatDebugRun(gomock.Any(), arg).Return(run, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(run)
|
||||
}))
|
||||
s.Run("UpdateChatDebugStep", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
arg := database.UpdateChatDebugStepParams{ID: uuid.New(), ChatID: chat.ID}
|
||||
step := database.ChatDebugStep{ID: arg.ID, ChatID: chat.ID}
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().UpdateChatDebugStep(gomock.Any(), arg).Return(step, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(step)
|
||||
}))
|
||||
s.Run("UpsertChatDebugLoggingAllowUsers", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
dbm.EXPECT().UpsertChatDebugLoggingAllowUsers(gomock.Any(), true).Return(nil).AnyTimes()
|
||||
check.Args(true).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
|
||||
}))
|
||||
s.Run("GetChatByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
@@ -2577,19 +2494,6 @@ func (s *MethodTestSuite) TestUser() {
|
||||
dbm.EXPECT().UpsertUserChatProviderKey(gomock.Any(), arg).Return(key, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(u, policy.ActionUpdatePersonal).Returns(key)
|
||||
}))
|
||||
s.Run("GetUserChatDebugLoggingEnabled", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
u := testutil.Fake(s.T(), faker, database.User{})
|
||||
dbm.EXPECT().GetUserByID(gomock.Any(), u.ID).Return(u, nil).AnyTimes()
|
||||
dbm.EXPECT().GetUserChatDebugLoggingEnabled(gomock.Any(), u.ID).Return(true, nil).AnyTimes()
|
||||
check.Args(u.ID).Asserts(u, policy.ActionReadPersonal).Returns(true)
|
||||
}))
|
||||
s.Run("UpsertUserChatDebugLoggingEnabled", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
u := testutil.Fake(s.T(), faker, database.User{})
|
||||
arg := database.UpsertUserChatDebugLoggingEnabledParams{UserID: u.ID, DebugLoggingEnabled: true}
|
||||
dbm.EXPECT().GetUserByID(gomock.Any(), u.ID).Return(u, nil).AnyTimes()
|
||||
dbm.EXPECT().UpsertUserChatDebugLoggingEnabled(gomock.Any(), arg).Return(nil).AnyTimes()
|
||||
check.Args(arg).Asserts(u, policy.ActionUpdatePersonal)
|
||||
}))
|
||||
s.Run("UpdateUserChatCustomPrompt", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
u := testutil.Fake(s.T(), faker, database.User{})
|
||||
uc := database.UserConfig{UserID: u.ID, Key: "chat_custom_prompt", Value: "my custom prompt"}
|
||||
|
||||
@@ -416,22 +416,6 @@ func (m queryMetricsStore) DeleteApplicationConnectAPIKeysByUserID(ctx context.C
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteChatDebugDataAfterMessageID(ctx context.Context, arg database.DeleteChatDebugDataAfterMessageIDParams) (int64, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.DeleteChatDebugDataAfterMessageID(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("DeleteChatDebugDataAfterMessageID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteChatDebugDataAfterMessageID").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteChatDebugDataByChatID(ctx context.Context, chatID uuid.UUID) (int64, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.DeleteChatDebugDataByChatID(ctx, chatID)
|
||||
m.queryLatencies.WithLabelValues("DeleteChatDebugDataByChatID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteChatDebugDataByChatID").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteChatModelConfigByID(ctx context.Context, id uuid.UUID) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.DeleteChatModelConfigByID(ctx, id)
|
||||
@@ -888,14 +872,6 @@ func (m queryMetricsStore) FetchVolumesResourceMonitorsUpdatedAfter(ctx context.
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) FinalizeStaleChatDebugRows(ctx context.Context, updatedBefore time.Time) (database.FinalizeStaleChatDebugRowsRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.FinalizeStaleChatDebugRows(ctx, updatedBefore)
|
||||
m.queryLatencies.WithLabelValues("FinalizeStaleChatDebugRows").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "FinalizeStaleChatDebugRows").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) FindMatchingPresetID(ctx context.Context, arg database.FindMatchingPresetIDParams) (uuid.UUID, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.FindMatchingPresetID(ctx, arg)
|
||||
@@ -1152,38 +1128,6 @@ func (m queryMetricsStore) GetChatCostSummary(ctx context.Context, arg database.
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatDebugLoggingAllowUsers(ctx context.Context) (bool, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatDebugLoggingAllowUsers(ctx)
|
||||
m.queryLatencies.WithLabelValues("GetChatDebugLoggingAllowUsers").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatDebugLoggingAllowUsers").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatDebugRunByID(ctx context.Context, id uuid.UUID) (database.ChatDebugRun, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatDebugRunByID(ctx, id)
|
||||
m.queryLatencies.WithLabelValues("GetChatDebugRunByID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatDebugRunByID").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatDebugRunsByChatID(ctx context.Context, chatID database.GetChatDebugRunsByChatIDParams) ([]database.ChatDebugRun, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatDebugRunsByChatID(ctx, chatID)
|
||||
m.queryLatencies.WithLabelValues("GetChatDebugRunsByChatID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatDebugRunsByChatID").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatDebugStepsByRunID(ctx context.Context, runID uuid.UUID) ([]database.ChatDebugStep, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatDebugStepsByRunID(ctx, runID)
|
||||
m.queryLatencies.WithLabelValues("GetChatDebugStepsByRunID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatDebugStepsByRunID").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatDesktopEnabled(ctx context.Context) (bool, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatDesktopEnabled(ctx)
|
||||
@@ -2672,14 +2616,6 @@ func (m queryMetricsStore) GetUserChatCustomPrompt(ctx context.Context, userID u
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetUserChatDebugLoggingEnabled(ctx context.Context, userID uuid.UUID) (bool, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetUserChatDebugLoggingEnabled(ctx, userID)
|
||||
m.queryLatencies.WithLabelValues("GetUserChatDebugLoggingEnabled").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetUserChatDebugLoggingEnabled").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetUserChatProviderKeys(ctx context.Context, userID uuid.UUID) ([]database.UserChatProviderKey, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetUserChatProviderKeys(ctx, userID)
|
||||
@@ -3376,22 +3312,6 @@ func (m queryMetricsStore) InsertChat(ctx context.Context, arg database.InsertCh
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) InsertChatDebugRun(ctx context.Context, arg database.InsertChatDebugRunParams) (database.ChatDebugRun, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.InsertChatDebugRun(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("InsertChatDebugRun").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "InsertChatDebugRun").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) InsertChatDebugStep(ctx context.Context, arg database.InsertChatDebugStepParams) (database.ChatDebugStep, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.InsertChatDebugStep(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("InsertChatDebugStep").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "InsertChatDebugStep").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) InsertChatFile(ctx context.Context, arg database.InsertChatFileParams) (database.InsertChatFileRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.InsertChatFile(ctx, arg)
|
||||
@@ -4288,22 +4208,6 @@ func (m queryMetricsStore) UpdateChatByID(ctx context.Context, arg database.Upda
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateChatDebugRun(ctx context.Context, arg database.UpdateChatDebugRunParams) (database.ChatDebugRun, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateChatDebugRun(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpdateChatDebugRun").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatDebugRun").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateChatDebugStep(ctx context.Context, arg database.UpdateChatDebugStepParams) (database.ChatDebugStep, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateChatDebugStep(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpdateChatDebugStep").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatDebugStep").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateChatHeartbeats(ctx context.Context, arg database.UpdateChatHeartbeatsParams) ([]uuid.UUID, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateChatHeartbeats(ctx, arg)
|
||||
@@ -5144,14 +5048,6 @@ func (m queryMetricsStore) UpsertBoundaryUsageStats(ctx context.Context, arg dat
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpsertChatDebugLoggingAllowUsers(ctx context.Context, allowUsers bool) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.UpsertChatDebugLoggingAllowUsers(ctx, allowUsers)
|
||||
m.queryLatencies.WithLabelValues("UpsertChatDebugLoggingAllowUsers").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertChatDebugLoggingAllowUsers").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpsertChatDesktopEnabled(ctx context.Context, enableDesktop bool) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.UpsertChatDesktopEnabled(ctx, enableDesktop)
|
||||
@@ -5384,14 +5280,6 @@ func (m queryMetricsStore) UpsertTemplateUsageStats(ctx context.Context) error {
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpsertUserChatDebugLoggingEnabled(ctx context.Context, arg database.UpsertUserChatDebugLoggingEnabledParams) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.UpsertUserChatDebugLoggingEnabled(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpsertUserChatDebugLoggingEnabled").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertUserChatDebugLoggingEnabled").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpsertUserChatProviderKey(ctx context.Context, arg database.UpsertUserChatProviderKeyParams) (database.UserChatProviderKey, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpsertUserChatProviderKey(ctx, arg)
|
||||
|
||||
@@ -671,36 +671,6 @@ func (mr *MockStoreMockRecorder) DeleteApplicationConnectAPIKeysByUserID(ctx, us
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteApplicationConnectAPIKeysByUserID", reflect.TypeOf((*MockStore)(nil).DeleteApplicationConnectAPIKeysByUserID), ctx, userID)
|
||||
}
|
||||
|
||||
// DeleteChatDebugDataAfterMessageID mocks base method.
|
||||
func (m *MockStore) DeleteChatDebugDataAfterMessageID(ctx context.Context, arg database.DeleteChatDebugDataAfterMessageIDParams) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteChatDebugDataAfterMessageID", ctx, arg)
|
||||
ret0, _ := ret[0].(int64)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// DeleteChatDebugDataAfterMessageID indicates an expected call of DeleteChatDebugDataAfterMessageID.
|
||||
func (mr *MockStoreMockRecorder) DeleteChatDebugDataAfterMessageID(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteChatDebugDataAfterMessageID", reflect.TypeOf((*MockStore)(nil).DeleteChatDebugDataAfterMessageID), ctx, arg)
|
||||
}
|
||||
|
||||
// DeleteChatDebugDataByChatID mocks base method.
|
||||
func (m *MockStore) DeleteChatDebugDataByChatID(ctx context.Context, chatID uuid.UUID) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteChatDebugDataByChatID", ctx, chatID)
|
||||
ret0, _ := ret[0].(int64)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// DeleteChatDebugDataByChatID indicates an expected call of DeleteChatDebugDataByChatID.
|
||||
func (mr *MockStoreMockRecorder) DeleteChatDebugDataByChatID(ctx, chatID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteChatDebugDataByChatID", reflect.TypeOf((*MockStore)(nil).DeleteChatDebugDataByChatID), ctx, chatID)
|
||||
}
|
||||
|
||||
// DeleteChatModelConfigByID mocks base method.
|
||||
func (m *MockStore) DeleteChatModelConfigByID(ctx context.Context, id uuid.UUID) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -1517,21 +1487,6 @@ func (mr *MockStoreMockRecorder) FetchVolumesResourceMonitorsUpdatedAfter(ctx, u
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FetchVolumesResourceMonitorsUpdatedAfter", reflect.TypeOf((*MockStore)(nil).FetchVolumesResourceMonitorsUpdatedAfter), ctx, updatedAt)
|
||||
}
|
||||
|
||||
// FinalizeStaleChatDebugRows mocks base method.
|
||||
func (m *MockStore) FinalizeStaleChatDebugRows(ctx context.Context, updatedBefore time.Time) (database.FinalizeStaleChatDebugRowsRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "FinalizeStaleChatDebugRows", ctx, updatedBefore)
|
||||
ret0, _ := ret[0].(database.FinalizeStaleChatDebugRowsRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// FinalizeStaleChatDebugRows indicates an expected call of FinalizeStaleChatDebugRows.
|
||||
func (mr *MockStoreMockRecorder) FinalizeStaleChatDebugRows(ctx, updatedBefore any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FinalizeStaleChatDebugRows", reflect.TypeOf((*MockStore)(nil).FinalizeStaleChatDebugRows), ctx, updatedBefore)
|
||||
}
|
||||
|
||||
// FindMatchingPresetID mocks base method.
|
||||
func (m *MockStore) FindMatchingPresetID(ctx context.Context, arg database.FindMatchingPresetIDParams) (uuid.UUID, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -2117,66 +2072,6 @@ func (mr *MockStoreMockRecorder) GetChatCostSummary(ctx, arg any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatCostSummary", reflect.TypeOf((*MockStore)(nil).GetChatCostSummary), ctx, arg)
|
||||
}
|
||||
|
||||
// GetChatDebugLoggingAllowUsers mocks base method.
|
||||
func (m *MockStore) GetChatDebugLoggingAllowUsers(ctx context.Context) (bool, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatDebugLoggingAllowUsers", ctx)
|
||||
ret0, _ := ret[0].(bool)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatDebugLoggingAllowUsers indicates an expected call of GetChatDebugLoggingAllowUsers.
|
||||
func (mr *MockStoreMockRecorder) GetChatDebugLoggingAllowUsers(ctx any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatDebugLoggingAllowUsers", reflect.TypeOf((*MockStore)(nil).GetChatDebugLoggingAllowUsers), ctx)
|
||||
}
|
||||
|
||||
// GetChatDebugRunByID mocks base method.
|
||||
func (m *MockStore) GetChatDebugRunByID(ctx context.Context, id uuid.UUID) (database.ChatDebugRun, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatDebugRunByID", ctx, id)
|
||||
ret0, _ := ret[0].(database.ChatDebugRun)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatDebugRunByID indicates an expected call of GetChatDebugRunByID.
|
||||
func (mr *MockStoreMockRecorder) GetChatDebugRunByID(ctx, id any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatDebugRunByID", reflect.TypeOf((*MockStore)(nil).GetChatDebugRunByID), ctx, id)
|
||||
}
|
||||
|
||||
// GetChatDebugRunsByChatID mocks base method.
|
||||
func (m *MockStore) GetChatDebugRunsByChatID(ctx context.Context, arg database.GetChatDebugRunsByChatIDParams) ([]database.ChatDebugRun, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatDebugRunsByChatID", ctx, arg)
|
||||
ret0, _ := ret[0].([]database.ChatDebugRun)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatDebugRunsByChatID indicates an expected call of GetChatDebugRunsByChatID.
|
||||
func (mr *MockStoreMockRecorder) GetChatDebugRunsByChatID(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatDebugRunsByChatID", reflect.TypeOf((*MockStore)(nil).GetChatDebugRunsByChatID), ctx, arg)
|
||||
}
|
||||
|
||||
// GetChatDebugStepsByRunID mocks base method.
|
||||
func (m *MockStore) GetChatDebugStepsByRunID(ctx context.Context, runID uuid.UUID) ([]database.ChatDebugStep, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatDebugStepsByRunID", ctx, runID)
|
||||
ret0, _ := ret[0].([]database.ChatDebugStep)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatDebugStepsByRunID indicates an expected call of GetChatDebugStepsByRunID.
|
||||
func (mr *MockStoreMockRecorder) GetChatDebugStepsByRunID(ctx, runID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatDebugStepsByRunID", reflect.TypeOf((*MockStore)(nil).GetChatDebugStepsByRunID), ctx, runID)
|
||||
}
|
||||
|
||||
// GetChatDesktopEnabled mocks base method.
|
||||
func (m *MockStore) GetChatDesktopEnabled(ctx context.Context) (bool, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -4997,21 +4892,6 @@ func (mr *MockStoreMockRecorder) GetUserChatCustomPrompt(ctx, userID any) *gomoc
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserChatCustomPrompt", reflect.TypeOf((*MockStore)(nil).GetUserChatCustomPrompt), ctx, userID)
|
||||
}
|
||||
|
||||
// GetUserChatDebugLoggingEnabled mocks base method.
|
||||
func (m *MockStore) GetUserChatDebugLoggingEnabled(ctx context.Context, userID uuid.UUID) (bool, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetUserChatDebugLoggingEnabled", ctx, userID)
|
||||
ret0, _ := ret[0].(bool)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetUserChatDebugLoggingEnabled indicates an expected call of GetUserChatDebugLoggingEnabled.
|
||||
func (mr *MockStoreMockRecorder) GetUserChatDebugLoggingEnabled(ctx, userID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserChatDebugLoggingEnabled", reflect.TypeOf((*MockStore)(nil).GetUserChatDebugLoggingEnabled), ctx, userID)
|
||||
}
|
||||
|
||||
// GetUserChatProviderKeys mocks base method.
|
||||
func (m *MockStore) GetUserChatProviderKeys(ctx context.Context, userID uuid.UUID) ([]database.UserChatProviderKey, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -6331,36 +6211,6 @@ func (mr *MockStoreMockRecorder) InsertChat(ctx, arg any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertChat", reflect.TypeOf((*MockStore)(nil).InsertChat), ctx, arg)
|
||||
}
|
||||
|
||||
// InsertChatDebugRun mocks base method.
|
||||
func (m *MockStore) InsertChatDebugRun(ctx context.Context, arg database.InsertChatDebugRunParams) (database.ChatDebugRun, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "InsertChatDebugRun", ctx, arg)
|
||||
ret0, _ := ret[0].(database.ChatDebugRun)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// InsertChatDebugRun indicates an expected call of InsertChatDebugRun.
|
||||
func (mr *MockStoreMockRecorder) InsertChatDebugRun(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertChatDebugRun", reflect.TypeOf((*MockStore)(nil).InsertChatDebugRun), ctx, arg)
|
||||
}
|
||||
|
||||
// InsertChatDebugStep mocks base method.
|
||||
func (m *MockStore) InsertChatDebugStep(ctx context.Context, arg database.InsertChatDebugStepParams) (database.ChatDebugStep, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "InsertChatDebugStep", ctx, arg)
|
||||
ret0, _ := ret[0].(database.ChatDebugStep)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// InsertChatDebugStep indicates an expected call of InsertChatDebugStep.
|
||||
func (mr *MockStoreMockRecorder) InsertChatDebugStep(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertChatDebugStep", reflect.TypeOf((*MockStore)(nil).InsertChatDebugStep), ctx, arg)
|
||||
}
|
||||
|
||||
// InsertChatFile mocks base method.
|
||||
func (m *MockStore) InsertChatFile(ctx context.Context, arg database.InsertChatFileParams) (database.InsertChatFileRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -8119,36 +7969,6 @@ func (mr *MockStoreMockRecorder) UpdateChatByID(ctx, arg any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatByID", reflect.TypeOf((*MockStore)(nil).UpdateChatByID), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateChatDebugRun mocks base method.
|
||||
func (m *MockStore) UpdateChatDebugRun(ctx context.Context, arg database.UpdateChatDebugRunParams) (database.ChatDebugRun, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateChatDebugRun", ctx, arg)
|
||||
ret0, _ := ret[0].(database.ChatDebugRun)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpdateChatDebugRun indicates an expected call of UpdateChatDebugRun.
|
||||
func (mr *MockStoreMockRecorder) UpdateChatDebugRun(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatDebugRun", reflect.TypeOf((*MockStore)(nil).UpdateChatDebugRun), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateChatDebugStep mocks base method.
|
||||
func (m *MockStore) UpdateChatDebugStep(ctx context.Context, arg database.UpdateChatDebugStepParams) (database.ChatDebugStep, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateChatDebugStep", ctx, arg)
|
||||
ret0, _ := ret[0].(database.ChatDebugStep)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpdateChatDebugStep indicates an expected call of UpdateChatDebugStep.
|
||||
func (mr *MockStoreMockRecorder) UpdateChatDebugStep(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatDebugStep", reflect.TypeOf((*MockStore)(nil).UpdateChatDebugStep), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateChatHeartbeats mocks base method.
|
||||
func (m *MockStore) UpdateChatHeartbeats(ctx context.Context, arg database.UpdateChatHeartbeatsParams) ([]uuid.UUID, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -9669,20 +9489,6 @@ func (mr *MockStoreMockRecorder) UpsertBoundaryUsageStats(ctx, arg any) *gomock.
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertBoundaryUsageStats", reflect.TypeOf((*MockStore)(nil).UpsertBoundaryUsageStats), ctx, arg)
|
||||
}
|
||||
|
||||
// UpsertChatDebugLoggingAllowUsers mocks base method.
|
||||
func (m *MockStore) UpsertChatDebugLoggingAllowUsers(ctx context.Context, allowUsers bool) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpsertChatDebugLoggingAllowUsers", ctx, allowUsers)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// UpsertChatDebugLoggingAllowUsers indicates an expected call of UpsertChatDebugLoggingAllowUsers.
|
||||
func (mr *MockStoreMockRecorder) UpsertChatDebugLoggingAllowUsers(ctx, allowUsers any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatDebugLoggingAllowUsers", reflect.TypeOf((*MockStore)(nil).UpsertChatDebugLoggingAllowUsers), ctx, allowUsers)
|
||||
}
|
||||
|
||||
// UpsertChatDesktopEnabled mocks base method.
|
||||
func (m *MockStore) UpsertChatDesktopEnabled(ctx context.Context, enableDesktop bool) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -10100,20 +9906,6 @@ func (mr *MockStoreMockRecorder) UpsertTemplateUsageStats(ctx any) *gomock.Call
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertTemplateUsageStats", reflect.TypeOf((*MockStore)(nil).UpsertTemplateUsageStats), ctx)
|
||||
}
|
||||
|
||||
// UpsertUserChatDebugLoggingEnabled mocks base method.
|
||||
func (m *MockStore) UpsertUserChatDebugLoggingEnabled(ctx context.Context, arg database.UpsertUserChatDebugLoggingEnabledParams) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpsertUserChatDebugLoggingEnabled", ctx, arg)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// UpsertUserChatDebugLoggingEnabled indicates an expected call of UpsertUserChatDebugLoggingEnabled.
|
||||
func (mr *MockStoreMockRecorder) UpsertUserChatDebugLoggingEnabled(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertUserChatDebugLoggingEnabled", reflect.TypeOf((*MockStore)(nil).UpsertUserChatDebugLoggingEnabled), ctx, arg)
|
||||
}
|
||||
|
||||
// UpsertUserChatProviderKey mocks base method.
|
||||
func (m *MockStore) UpsertUserChatProviderKey(ctx context.Context, arg database.UpsertUserChatProviderKeyParams) (database.UserChatProviderKey, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
Generated
-67
@@ -1255,44 +1255,6 @@ COMMENT ON COLUMN boundary_usage_stats.window_start IS 'Start of the time window
|
||||
|
||||
COMMENT ON COLUMN boundary_usage_stats.updated_at IS 'Timestamp of the last update to this row.';
|
||||
|
||||
CREATE TABLE chat_debug_runs (
|
||||
id uuid DEFAULT gen_random_uuid() NOT NULL,
|
||||
chat_id uuid NOT NULL,
|
||||
root_chat_id uuid,
|
||||
parent_chat_id uuid,
|
||||
model_config_id uuid,
|
||||
trigger_message_id bigint,
|
||||
history_tip_message_id bigint,
|
||||
kind text NOT NULL,
|
||||
status text NOT NULL,
|
||||
provider text,
|
||||
model text,
|
||||
summary jsonb DEFAULT '{}'::jsonb NOT NULL,
|
||||
started_at timestamp with time zone DEFAULT now() NOT NULL,
|
||||
updated_at timestamp with time zone DEFAULT now() NOT NULL,
|
||||
finished_at timestamp with time zone
|
||||
);
|
||||
|
||||
CREATE TABLE chat_debug_steps (
|
||||
id uuid DEFAULT gen_random_uuid() NOT NULL,
|
||||
run_id uuid NOT NULL,
|
||||
chat_id uuid NOT NULL,
|
||||
step_number integer NOT NULL,
|
||||
operation text NOT NULL,
|
||||
status text NOT NULL,
|
||||
history_tip_message_id bigint,
|
||||
assistant_message_id bigint,
|
||||
normalized_request jsonb NOT NULL,
|
||||
normalized_response jsonb,
|
||||
usage jsonb,
|
||||
attempts jsonb DEFAULT '[]'::jsonb NOT NULL,
|
||||
error jsonb,
|
||||
metadata jsonb DEFAULT '{}'::jsonb NOT NULL,
|
||||
started_at timestamp with time zone DEFAULT now() NOT NULL,
|
||||
updated_at timestamp with time zone DEFAULT now() NOT NULL,
|
||||
finished_at timestamp with time zone
|
||||
);
|
||||
|
||||
CREATE TABLE chat_diff_statuses (
|
||||
chat_id uuid NOT NULL,
|
||||
url text,
|
||||
@@ -3397,12 +3359,6 @@ ALTER TABLE ONLY audit_logs
|
||||
ALTER TABLE ONLY boundary_usage_stats
|
||||
ADD CONSTRAINT boundary_usage_stats_pkey PRIMARY KEY (replica_id);
|
||||
|
||||
ALTER TABLE ONLY chat_debug_runs
|
||||
ADD CONSTRAINT chat_debug_runs_pkey PRIMARY KEY (id);
|
||||
|
||||
ALTER TABLE ONLY chat_debug_steps
|
||||
ADD CONSTRAINT chat_debug_steps_pkey PRIMARY KEY (id);
|
||||
|
||||
ALTER TABLE ONLY chat_diff_statuses
|
||||
ADD CONSTRAINT chat_diff_statuses_pkey PRIMARY KEY (chat_id);
|
||||
|
||||
@@ -3797,20 +3753,6 @@ CREATE INDEX idx_audit_log_user_id ON audit_logs USING btree (user_id);
|
||||
|
||||
CREATE INDEX idx_audit_logs_time_desc ON audit_logs USING btree ("time" DESC);
|
||||
|
||||
CREATE INDEX idx_chat_debug_runs_chat_started ON chat_debug_runs USING btree (chat_id, started_at DESC);
|
||||
|
||||
CREATE UNIQUE INDEX idx_chat_debug_runs_id_chat ON chat_debug_runs USING btree (id, chat_id);
|
||||
|
||||
CREATE INDEX idx_chat_debug_runs_stale ON chat_debug_runs USING btree (updated_at) WHERE (finished_at IS NULL);
|
||||
|
||||
CREATE INDEX idx_chat_debug_steps_chat_assistant_msg ON chat_debug_steps USING btree (chat_id, assistant_message_id) WHERE (assistant_message_id IS NOT NULL);
|
||||
|
||||
CREATE INDEX idx_chat_debug_steps_chat_tip ON chat_debug_steps USING btree (chat_id, history_tip_message_id);
|
||||
|
||||
CREATE UNIQUE INDEX idx_chat_debug_steps_run_step ON chat_debug_steps USING btree (run_id, step_number);
|
||||
|
||||
CREATE INDEX idx_chat_debug_steps_stale ON chat_debug_steps USING btree (updated_at) WHERE (finished_at IS NULL);
|
||||
|
||||
CREATE INDEX idx_chat_diff_statuses_stale_at ON chat_diff_statuses USING btree (stale_at);
|
||||
|
||||
CREATE INDEX idx_chat_file_links_chat_id ON chat_file_links USING btree (chat_id);
|
||||
@@ -4114,12 +4056,6 @@ ALTER TABLE ONLY aibridge_interceptions
|
||||
ALTER TABLE ONLY api_keys
|
||||
ADD CONSTRAINT api_keys_user_id_uuid_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
|
||||
|
||||
ALTER TABLE ONLY chat_debug_runs
|
||||
ADD CONSTRAINT chat_debug_runs_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
|
||||
|
||||
ALTER TABLE ONLY chat_debug_steps
|
||||
ADD CONSTRAINT chat_debug_steps_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
|
||||
|
||||
ALTER TABLE ONLY chat_diff_statuses
|
||||
ADD CONSTRAINT chat_diff_statuses_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
|
||||
|
||||
@@ -4192,9 +4128,6 @@ ALTER TABLE ONLY connection_logs
|
||||
ALTER TABLE ONLY crypto_keys
|
||||
ADD CONSTRAINT crypto_keys_secret_key_id_fkey FOREIGN KEY (secret_key_id) REFERENCES dbcrypt_keys(active_key_digest);
|
||||
|
||||
ALTER TABLE ONLY chat_debug_steps
|
||||
ADD CONSTRAINT fk_chat_debug_steps_run_chat FOREIGN KEY (run_id, chat_id) REFERENCES chat_debug_runs(id, chat_id) ON DELETE CASCADE;
|
||||
|
||||
ALTER TABLE ONLY oauth2_provider_app_tokens
|
||||
ADD CONSTRAINT fk_oauth2_provider_app_tokens_user_id FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
|
||||
|
||||
|
||||
@@ -9,8 +9,6 @@ const (
|
||||
ForeignKeyAiSeatStateUserID ForeignKeyConstraint = "ai_seat_state_user_id_fkey" // ALTER TABLE ONLY ai_seat_state ADD CONSTRAINT ai_seat_state_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
|
||||
ForeignKeyAibridgeInterceptionsInitiatorID ForeignKeyConstraint = "aibridge_interceptions_initiator_id_fkey" // ALTER TABLE ONLY aibridge_interceptions ADD CONSTRAINT aibridge_interceptions_initiator_id_fkey FOREIGN KEY (initiator_id) REFERENCES users(id);
|
||||
ForeignKeyAPIKeysUserIDUUID ForeignKeyConstraint = "api_keys_user_id_uuid_fkey" // ALTER TABLE ONLY api_keys ADD CONSTRAINT api_keys_user_id_uuid_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
|
||||
ForeignKeyChatDebugRunsChatID ForeignKeyConstraint = "chat_debug_runs_chat_id_fkey" // ALTER TABLE ONLY chat_debug_runs ADD CONSTRAINT chat_debug_runs_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
|
||||
ForeignKeyChatDebugStepsChatID ForeignKeyConstraint = "chat_debug_steps_chat_id_fkey" // ALTER TABLE ONLY chat_debug_steps ADD CONSTRAINT chat_debug_steps_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
|
||||
ForeignKeyChatDiffStatusesChatID ForeignKeyConstraint = "chat_diff_statuses_chat_id_fkey" // ALTER TABLE ONLY chat_diff_statuses ADD CONSTRAINT chat_diff_statuses_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
|
||||
ForeignKeyChatFileLinksChatID ForeignKeyConstraint = "chat_file_links_chat_id_fkey" // ALTER TABLE ONLY chat_file_links ADD CONSTRAINT chat_file_links_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
|
||||
ForeignKeyChatFileLinksFileID ForeignKeyConstraint = "chat_file_links_file_id_fkey" // ALTER TABLE ONLY chat_file_links ADD CONSTRAINT chat_file_links_file_id_fkey FOREIGN KEY (file_id) REFERENCES chat_files(id) ON DELETE CASCADE;
|
||||
@@ -35,7 +33,6 @@ const (
|
||||
ForeignKeyConnectionLogsWorkspaceID ForeignKeyConstraint = "connection_logs_workspace_id_fkey" // ALTER TABLE ONLY connection_logs ADD CONSTRAINT connection_logs_workspace_id_fkey FOREIGN KEY (workspace_id) REFERENCES workspaces(id) ON DELETE CASCADE;
|
||||
ForeignKeyConnectionLogsWorkspaceOwnerID ForeignKeyConstraint = "connection_logs_workspace_owner_id_fkey" // ALTER TABLE ONLY connection_logs ADD CONSTRAINT connection_logs_workspace_owner_id_fkey FOREIGN KEY (workspace_owner_id) REFERENCES users(id) ON DELETE CASCADE;
|
||||
ForeignKeyCryptoKeysSecretKeyID ForeignKeyConstraint = "crypto_keys_secret_key_id_fkey" // ALTER TABLE ONLY crypto_keys ADD CONSTRAINT crypto_keys_secret_key_id_fkey FOREIGN KEY (secret_key_id) REFERENCES dbcrypt_keys(active_key_digest);
|
||||
ForeignKeyFkChatDebugStepsRunChat ForeignKeyConstraint = "fk_chat_debug_steps_run_chat" // ALTER TABLE ONLY chat_debug_steps ADD CONSTRAINT fk_chat_debug_steps_run_chat FOREIGN KEY (run_id, chat_id) REFERENCES chat_debug_runs(id, chat_id) ON DELETE CASCADE;
|
||||
ForeignKeyFkOauth2ProviderAppTokensUserID ForeignKeyConstraint = "fk_oauth2_provider_app_tokens_user_id" // ALTER TABLE ONLY oauth2_provider_app_tokens ADD CONSTRAINT fk_oauth2_provider_app_tokens_user_id FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
|
||||
ForeignKeyGitAuthLinksOauthAccessTokenKeyID ForeignKeyConstraint = "git_auth_links_oauth_access_token_key_id_fkey" // ALTER TABLE ONLY external_auth_links ADD CONSTRAINT git_auth_links_oauth_access_token_key_id_fkey FOREIGN KEY (oauth_access_token_key_id) REFERENCES dbcrypt_keys(active_key_digest);
|
||||
ForeignKeyGitAuthLinksOauthRefreshTokenKeyID ForeignKeyConstraint = "git_auth_links_oauth_refresh_token_key_id_fkey" // ALTER TABLE ONLY external_auth_links ADD CONSTRAINT git_auth_links_oauth_refresh_token_key_id_fkey FOREIGN KEY (oauth_refresh_token_key_id) REFERENCES dbcrypt_keys(active_key_digest);
|
||||
|
||||
@@ -1,2 +0,0 @@
|
||||
DROP TABLE IF EXISTS chat_debug_steps;
|
||||
DROP TABLE IF EXISTS chat_debug_runs;
|
||||
@@ -1,59 +0,0 @@
|
||||
CREATE TABLE chat_debug_runs (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
chat_id UUID NOT NULL REFERENCES chats(id) ON DELETE CASCADE,
|
||||
-- root_chat_id and parent_chat_id are intentionally NOT
|
||||
-- foreign-keyed to chats(id). They are snapshot values that
|
||||
-- record the subchat hierarchy at run time. The referenced
|
||||
-- chat may be archived or deleted independently, and we want
|
||||
-- to preserve the historical lineage in debug rows rather
|
||||
-- than cascade-delete them.
|
||||
root_chat_id UUID,
|
||||
parent_chat_id UUID,
|
||||
model_config_id UUID,
|
||||
trigger_message_id BIGINT,
|
||||
history_tip_message_id BIGINT,
|
||||
kind TEXT NOT NULL,
|
||||
status TEXT NOT NULL,
|
||||
provider TEXT,
|
||||
model TEXT,
|
||||
summary JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
started_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
finished_at TIMESTAMPTZ
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX idx_chat_debug_runs_id_chat ON chat_debug_runs(id, chat_id);
|
||||
CREATE INDEX idx_chat_debug_runs_chat_started ON chat_debug_runs(chat_id, started_at DESC);
|
||||
|
||||
CREATE TABLE chat_debug_steps (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
run_id UUID NOT NULL,
|
||||
chat_id UUID NOT NULL REFERENCES chats(id) ON DELETE CASCADE,
|
||||
step_number INT NOT NULL,
|
||||
operation TEXT NOT NULL,
|
||||
status TEXT NOT NULL,
|
||||
history_tip_message_id BIGINT,
|
||||
assistant_message_id BIGINT,
|
||||
normalized_request JSONB NOT NULL,
|
||||
normalized_response JSONB,
|
||||
usage JSONB,
|
||||
attempts JSONB NOT NULL DEFAULT '[]'::jsonb,
|
||||
error JSONB,
|
||||
metadata JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
started_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
finished_at TIMESTAMPTZ,
|
||||
CONSTRAINT fk_chat_debug_steps_run_chat
|
||||
FOREIGN KEY (run_id, chat_id)
|
||||
REFERENCES chat_debug_runs(id, chat_id)
|
||||
ON DELETE CASCADE
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX idx_chat_debug_steps_run_step ON chat_debug_steps(run_id, step_number);
|
||||
CREATE INDEX idx_chat_debug_steps_chat_tip ON chat_debug_steps(chat_id, history_tip_message_id);
|
||||
-- Supports DeleteChatDebugDataAfterMessageID assistant_message_id branch.
|
||||
CREATE INDEX idx_chat_debug_steps_chat_assistant_msg ON chat_debug_steps(chat_id, assistant_message_id) WHERE assistant_message_id IS NOT NULL;
|
||||
|
||||
-- Supports FinalizeStaleChatDebugRows worker query.
|
||||
CREATE INDEX idx_chat_debug_runs_stale ON chat_debug_runs(updated_at) WHERE finished_at IS NULL;
|
||||
CREATE INDEX idx_chat_debug_steps_stale ON chat_debug_steps(updated_at) WHERE finished_at IS NULL;
|
||||
-65
@@ -1,65 +0,0 @@
|
||||
INSERT INTO chat_debug_runs (
|
||||
id,
|
||||
chat_id,
|
||||
model_config_id,
|
||||
history_tip_message_id,
|
||||
kind,
|
||||
status,
|
||||
provider,
|
||||
model,
|
||||
summary,
|
||||
started_at,
|
||||
updated_at,
|
||||
finished_at
|
||||
) VALUES (
|
||||
'c98518f8-9fb3-458b-a642-57552af1db63',
|
||||
'72c0438a-18eb-4688-ab80-e4c6a126ef96',
|
||||
'9af5f8d5-6a57-4505-8a69-3d6c787b95fd',
|
||||
(SELECT MAX(id) FROM chat_messages WHERE chat_id = '72c0438a-18eb-4688-ab80-e4c6a126ef96'),
|
||||
'chat_turn',
|
||||
'completed',
|
||||
'openai',
|
||||
'gpt-5.2',
|
||||
'{"step_count":1,"has_error":false}'::jsonb,
|
||||
'2024-01-01 00:00:00+00',
|
||||
'2024-01-01 00:00:01+00',
|
||||
'2024-01-01 00:00:01+00'
|
||||
);
|
||||
|
||||
INSERT INTO chat_debug_steps (
|
||||
id,
|
||||
run_id,
|
||||
chat_id,
|
||||
step_number,
|
||||
operation,
|
||||
status,
|
||||
history_tip_message_id,
|
||||
assistant_message_id,
|
||||
normalized_request,
|
||||
normalized_response,
|
||||
usage,
|
||||
attempts,
|
||||
error,
|
||||
metadata,
|
||||
started_at,
|
||||
updated_at,
|
||||
finished_at
|
||||
) VALUES (
|
||||
'59471c60-7851-4fa6-bf05-e21dd939721f',
|
||||
'c98518f8-9fb3-458b-a642-57552af1db63',
|
||||
'72c0438a-18eb-4688-ab80-e4c6a126ef96',
|
||||
1,
|
||||
'stream',
|
||||
'completed',
|
||||
(SELECT MAX(id) FROM chat_messages WHERE chat_id = '72c0438a-18eb-4688-ab80-e4c6a126ef96'),
|
||||
(SELECT MAX(id) FROM chat_messages WHERE chat_id = '72c0438a-18eb-4688-ab80-e4c6a126ef96'),
|
||||
'{"messages":[]}'::jsonb,
|
||||
'{"finish_reason":"stop"}'::jsonb,
|
||||
'{"input_tokens":1,"output_tokens":1}'::jsonb,
|
||||
'[]'::jsonb,
|
||||
NULL,
|
||||
'{"provider":"openai"}'::jsonb,
|
||||
'2024-01-01 00:00:00+00',
|
||||
'2024-01-01 00:00:01+00',
|
||||
'2024-01-01 00:00:01+00'
|
||||
);
|
||||
@@ -4248,44 +4248,6 @@ type Chat struct {
|
||||
DynamicTools pqtype.NullRawMessage `db:"dynamic_tools" json:"dynamic_tools"`
|
||||
}
|
||||
|
||||
type ChatDebugRun struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
ChatID uuid.UUID `db:"chat_id" json:"chat_id"`
|
||||
RootChatID uuid.NullUUID `db:"root_chat_id" json:"root_chat_id"`
|
||||
ParentChatID uuid.NullUUID `db:"parent_chat_id" json:"parent_chat_id"`
|
||||
ModelConfigID uuid.NullUUID `db:"model_config_id" json:"model_config_id"`
|
||||
TriggerMessageID sql.NullInt64 `db:"trigger_message_id" json:"trigger_message_id"`
|
||||
HistoryTipMessageID sql.NullInt64 `db:"history_tip_message_id" json:"history_tip_message_id"`
|
||||
Kind string `db:"kind" json:"kind"`
|
||||
Status string `db:"status" json:"status"`
|
||||
Provider sql.NullString `db:"provider" json:"provider"`
|
||||
Model sql.NullString `db:"model" json:"model"`
|
||||
Summary json.RawMessage `db:"summary" json:"summary"`
|
||||
StartedAt time.Time `db:"started_at" json:"started_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
FinishedAt sql.NullTime `db:"finished_at" json:"finished_at"`
|
||||
}
|
||||
|
||||
type ChatDebugStep struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
RunID uuid.UUID `db:"run_id" json:"run_id"`
|
||||
ChatID uuid.UUID `db:"chat_id" json:"chat_id"`
|
||||
StepNumber int32 `db:"step_number" json:"step_number"`
|
||||
Operation string `db:"operation" json:"operation"`
|
||||
Status string `db:"status" json:"status"`
|
||||
HistoryTipMessageID sql.NullInt64 `db:"history_tip_message_id" json:"history_tip_message_id"`
|
||||
AssistantMessageID sql.NullInt64 `db:"assistant_message_id" json:"assistant_message_id"`
|
||||
NormalizedRequest json.RawMessage `db:"normalized_request" json:"normalized_request"`
|
||||
NormalizedResponse pqtype.NullRawMessage `db:"normalized_response" json:"normalized_response"`
|
||||
Usage pqtype.NullRawMessage `db:"usage" json:"usage"`
|
||||
Attempts json.RawMessage `db:"attempts" json:"attempts"`
|
||||
Error pqtype.NullRawMessage `db:"error" json:"error"`
|
||||
Metadata json.RawMessage `db:"metadata" json:"metadata"`
|
||||
StartedAt time.Time `db:"started_at" json:"started_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
FinishedAt sql.NullTime `db:"finished_at" json:"finished_at"`
|
||||
}
|
||||
|
||||
type ChatDiffStatus struct {
|
||||
ChatID uuid.UUID `db:"chat_id" json:"chat_id"`
|
||||
Url sql.NullString `db:"url" json:"url"`
|
||||
|
||||
@@ -102,8 +102,6 @@ type sqlcQuerier interface {
|
||||
// be recreated.
|
||||
DeleteAllWebpushSubscriptions(ctx context.Context) error
|
||||
DeleteApplicationConnectAPIKeysByUserID(ctx context.Context, userID uuid.UUID) error
|
||||
DeleteChatDebugDataAfterMessageID(ctx context.Context, arg DeleteChatDebugDataAfterMessageIDParams) (int64, error)
|
||||
DeleteChatDebugDataByChatID(ctx context.Context, chatID uuid.UUID) (int64, error)
|
||||
DeleteChatModelConfigByID(ctx context.Context, id uuid.UUID) error
|
||||
DeleteChatProviderByID(ctx context.Context, id uuid.UUID) error
|
||||
DeleteChatQueuedMessage(ctx context.Context, arg DeleteChatQueuedMessageParams) error
|
||||
@@ -196,16 +194,6 @@ type sqlcQuerier interface {
|
||||
FetchNewMessageMetadata(ctx context.Context, arg FetchNewMessageMetadataParams) (FetchNewMessageMetadataRow, error)
|
||||
FetchVolumesResourceMonitorsByAgentID(ctx context.Context, agentID uuid.UUID) ([]WorkspaceAgentVolumeResourceMonitor, error)
|
||||
FetchVolumesResourceMonitorsUpdatedAfter(ctx context.Context, updatedAt time.Time) ([]WorkspaceAgentVolumeResourceMonitor, error)
|
||||
// Marks orphaned in-progress rows as interrupted so they do not stay
|
||||
// in a non-terminal state forever. The NOT IN list must match the
|
||||
// terminal statuses defined by ChatDebugStatus in codersdk/chats.go.
|
||||
//
|
||||
// The steps CTE also catches steps whose parent run was just finalized
|
||||
// (via run_id IN), because PostgreSQL data-modifying CTEs share the
|
||||
// same snapshot and cannot see each other's row updates. Without this,
|
||||
// a step with a recent updated_at would survive its run's finalization
|
||||
// and remain in 'in_progress' state permanently.
|
||||
FinalizeStaleChatDebugRows(ctx context.Context, updatedBefore time.Time) (FinalizeStaleChatDebugRowsRow, error)
|
||||
// FindMatchingPresetID finds a preset ID that is the largest exact subset of the provided parameters.
|
||||
// It returns the preset ID if a match is found, or NULL if no match is found.
|
||||
// The query finds presets where all preset parameters are present in the provided parameters,
|
||||
@@ -270,15 +258,6 @@ type sqlcQuerier interface {
|
||||
// Aggregate cost summary for a single user within a date range.
|
||||
// Only counts assistant-role messages.
|
||||
GetChatCostSummary(ctx context.Context, arg GetChatCostSummaryParams) (GetChatCostSummaryRow, error)
|
||||
// GetChatDebugLoggingAllowUsers returns the runtime admin setting that
|
||||
// allows users to opt into chat debug logging when the deployment does
|
||||
// not already force debug logging on globally.
|
||||
GetChatDebugLoggingAllowUsers(ctx context.Context) (bool, error)
|
||||
GetChatDebugRunByID(ctx context.Context, id uuid.UUID) (ChatDebugRun, error)
|
||||
// Returns the most recent debug runs for a chat, ordered newest-first.
|
||||
// Callers must supply an explicit limit to avoid unbounded result sets.
|
||||
GetChatDebugRunsByChatID(ctx context.Context, arg GetChatDebugRunsByChatIDParams) ([]ChatDebugRun, error)
|
||||
GetChatDebugStepsByRunID(ctx context.Context, runID uuid.UUID) ([]ChatDebugStep, error)
|
||||
GetChatDesktopEnabled(ctx context.Context) (bool, error)
|
||||
GetChatDiffStatusByChatID(ctx context.Context, chatID uuid.UUID) (ChatDiffStatus, error)
|
||||
GetChatDiffStatusesByChatIDs(ctx context.Context, chatIds []uuid.UUID) ([]ChatDiffStatus, error)
|
||||
@@ -640,7 +619,6 @@ type sqlcQuerier interface {
|
||||
GetUserByID(ctx context.Context, id uuid.UUID) (User, error)
|
||||
GetUserChatCompactionThreshold(ctx context.Context, arg GetUserChatCompactionThresholdParams) (string, error)
|
||||
GetUserChatCustomPrompt(ctx context.Context, userID uuid.UUID) (string, error)
|
||||
GetUserChatDebugLoggingEnabled(ctx context.Context, userID uuid.UUID) (bool, error)
|
||||
GetUserChatProviderKeys(ctx context.Context, userID uuid.UUID) ([]UserChatProviderKey, error)
|
||||
GetUserChatSpendInPeriod(ctx context.Context, arg GetUserChatSpendInPeriodParams) (int64, error)
|
||||
GetUserCount(ctx context.Context, includeSystem bool) (int64, error)
|
||||
@@ -760,8 +738,6 @@ type sqlcQuerier interface {
|
||||
InsertAllUsersGroup(ctx context.Context, organizationID uuid.UUID) (Group, error)
|
||||
InsertAuditLog(ctx context.Context, arg InsertAuditLogParams) (AuditLog, error)
|
||||
InsertChat(ctx context.Context, arg InsertChatParams) (Chat, error)
|
||||
InsertChatDebugRun(ctx context.Context, arg InsertChatDebugRunParams) (ChatDebugRun, error)
|
||||
InsertChatDebugStep(ctx context.Context, arg InsertChatDebugStepParams) (ChatDebugStep, error)
|
||||
InsertChatFile(ctx context.Context, arg InsertChatFileParams) (InsertChatFileRow, error)
|
||||
InsertChatMessages(ctx context.Context, arg InsertChatMessagesParams) ([]ChatMessage, error)
|
||||
InsertChatModelConfig(ctx context.Context, arg InsertChatModelConfigParams) (ChatModelConfig, error)
|
||||
@@ -940,16 +916,6 @@ type sqlcQuerier interface {
|
||||
UpdateAPIKeyByID(ctx context.Context, arg UpdateAPIKeyByIDParams) error
|
||||
UpdateChatBuildAgentBinding(ctx context.Context, arg UpdateChatBuildAgentBindingParams) (Chat, error)
|
||||
UpdateChatByID(ctx context.Context, arg UpdateChatByIDParams) (Chat, error)
|
||||
// Uses COALESCE so that passing NULL from Go means "keep the
|
||||
// existing value." This is intentional: debug rows follow a
|
||||
// write-once-finalize pattern where fields are set at creation
|
||||
// or finalization and never cleared back to NULL.
|
||||
UpdateChatDebugRun(ctx context.Context, arg UpdateChatDebugRunParams) (ChatDebugRun, error)
|
||||
// Uses COALESCE so that passing NULL from Go means "keep the
|
||||
// existing value." This is intentional: debug rows follow a
|
||||
// write-once-finalize pattern where fields are set at creation
|
||||
// or finalization and never cleared back to NULL.
|
||||
UpdateChatDebugStep(ctx context.Context, arg UpdateChatDebugStepParams) (ChatDebugStep, error)
|
||||
// Bumps the heartbeat timestamp for the given set of chat IDs,
|
||||
// provided they are still running and owned by the specified
|
||||
// worker. Returns the IDs that were actually updated so the
|
||||
@@ -1078,9 +1044,6 @@ type sqlcQuerier interface {
|
||||
// cumulative values for unique counts (accurate period totals). Request counts
|
||||
// are always deltas, accumulated in DB. Returns true if insert, false if update.
|
||||
UpsertBoundaryUsageStats(ctx context.Context, arg UpsertBoundaryUsageStatsParams) (bool, error)
|
||||
// UpsertChatDebugLoggingAllowUsers updates the runtime admin setting that
|
||||
// allows users to opt into chat debug logging.
|
||||
UpsertChatDebugLoggingAllowUsers(ctx context.Context, allowUsers bool) error
|
||||
UpsertChatDesktopEnabled(ctx context.Context, enableDesktop bool) error
|
||||
UpsertChatDiffStatus(ctx context.Context, arg UpsertChatDiffStatusParams) (ChatDiffStatus, error)
|
||||
UpsertChatDiffStatusReference(ctx context.Context, arg UpsertChatDiffStatusReferenceParams) (ChatDiffStatus, error)
|
||||
@@ -1118,7 +1081,6 @@ type sqlcQuerier interface {
|
||||
// used to store the data, and the minutes are summed for each user and template
|
||||
// combination. The result is stored in the template_usage_stats table.
|
||||
UpsertTemplateUsageStats(ctx context.Context) error
|
||||
UpsertUserChatDebugLoggingEnabled(ctx context.Context, arg UpsertUserChatDebugLoggingEnabledParams) error
|
||||
UpsertUserChatProviderKey(ctx context.Context, arg UpsertUserChatProviderKeyParams) (UserChatProviderKey, error)
|
||||
UpsertWebpushVAPIDKeys(ctx context.Context, arg UpsertWebpushVAPIDKeysParams) error
|
||||
UpsertWorkspaceAgentPortShare(ctx context.Context, arg UpsertWorkspaceAgentPortShareParams) (WorkspaceAgentPortShare, error)
|
||||
|
||||
@@ -11218,951 +11218,6 @@ func TestChatLabels(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestDeleteChatDebugDataAfterMessageIDIncludesTriggeredRuns(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store, _ := dbtestutil.NewDB(t)
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
dbgen.Organization(t, store, database.Organization{})
|
||||
user := dbgen.User(t, store, database.User{})
|
||||
|
||||
providerName := "openai"
|
||||
modelName := "debug-model-" + uuid.NewString()
|
||||
|
||||
_, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
||||
Provider: providerName,
|
||||
DisplayName: "Debug Provider",
|
||||
APIKey: "test-key",
|
||||
Enabled: true,
|
||||
CentralApiKeyEnabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
modelCfg, err := store.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
||||
Provider: providerName,
|
||||
Model: modelName,
|
||||
DisplayName: "Debug Model",
|
||||
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
||||
UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
||||
Enabled: true,
|
||||
IsDefault: true,
|
||||
ContextLimit: 128000,
|
||||
CompressionThreshold: 80,
|
||||
Options: json.RawMessage(`{}`),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
chat, err := store.InsertChat(ctx, database.InsertChatParams{
|
||||
Status: database.ChatStatusWaiting,
|
||||
OwnerID: user.ID,
|
||||
LastModelConfigID: modelCfg.ID,
|
||||
Title: "chat-debug-rollback-" + uuid.NewString(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
const cutoff int64 = 50
|
||||
|
||||
affectedRun, err := store.InsertChatDebugRun(ctx, database.InsertChatDebugRunParams{
|
||||
ChatID: chat.ID,
|
||||
ModelConfigID: uuid.NullUUID{UUID: modelCfg.ID, Valid: true},
|
||||
TriggerMessageID: sql.NullInt64{Int64: cutoff + 10, Valid: true},
|
||||
HistoryTipMessageID: sql.NullInt64{Int64: cutoff - 5, Valid: true},
|
||||
Kind: "chat_turn",
|
||||
Status: "in_progress",
|
||||
Provider: sql.NullString{String: providerName, Valid: true},
|
||||
Model: sql.NullString{String: modelName, Valid: true},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = store.InsertChatDebugStep(ctx, database.InsertChatDebugStepParams{
|
||||
RunID: affectedRun.ID,
|
||||
ChatID: chat.ID,
|
||||
StepNumber: 1,
|
||||
Operation: "stream",
|
||||
Status: "in_progress",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
affectedByStepHistoryTipRun, err := store.InsertChatDebugRun(ctx, database.InsertChatDebugRunParams{
|
||||
ChatID: chat.ID,
|
||||
ModelConfigID: uuid.NullUUID{UUID: modelCfg.ID, Valid: true},
|
||||
TriggerMessageID: sql.NullInt64{Int64: cutoff - 1, Valid: true},
|
||||
HistoryTipMessageID: sql.NullInt64{Int64: cutoff - 1, Valid: true},
|
||||
Kind: "chat_turn",
|
||||
Status: "in_progress",
|
||||
Provider: sql.NullString{String: providerName, Valid: true},
|
||||
Model: sql.NullString{String: modelName, Valid: true},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = store.InsertChatDebugStep(ctx, database.InsertChatDebugStepParams{
|
||||
RunID: affectedByStepHistoryTipRun.ID,
|
||||
ChatID: chat.ID,
|
||||
StepNumber: 1,
|
||||
Operation: "stream",
|
||||
Status: "interrupted",
|
||||
HistoryTipMessageID: sql.NullInt64{Int64: cutoff + 7, Valid: true},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// affectedByStepAssistantMsgRun: run-level fields are at/below
|
||||
// the cutoff, but its step has assistant_message_id above the
|
||||
// cutoff. This exercises the step.assistant_message_id > cutoff
|
||||
// branch of the UNION independently of history_tip_message_id.
|
||||
affectedByStepAssistantMsgRun, err := store.InsertChatDebugRun(ctx, database.InsertChatDebugRunParams{
|
||||
ChatID: chat.ID,
|
||||
ModelConfigID: uuid.NullUUID{UUID: modelCfg.ID, Valid: true},
|
||||
TriggerMessageID: sql.NullInt64{Int64: cutoff - 2, Valid: true},
|
||||
HistoryTipMessageID: sql.NullInt64{Int64: cutoff - 2, Valid: true},
|
||||
Kind: "chat_turn",
|
||||
Status: "in_progress",
|
||||
Provider: sql.NullString{String: providerName, Valid: true},
|
||||
Model: sql.NullString{String: modelName, Valid: true},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = store.InsertChatDebugStep(ctx, database.InsertChatDebugStepParams{
|
||||
RunID: affectedByStepAssistantMsgRun.ID,
|
||||
ChatID: chat.ID,
|
||||
StepNumber: 1,
|
||||
Operation: "stream",
|
||||
Status: "completed",
|
||||
AssistantMessageID: sql.NullInt64{Int64: cutoff + 3, Valid: true},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
unaffectedRun, err := store.InsertChatDebugRun(ctx, database.InsertChatDebugRunParams{
|
||||
ChatID: chat.ID,
|
||||
ModelConfigID: uuid.NullUUID{UUID: modelCfg.ID, Valid: true},
|
||||
TriggerMessageID: sql.NullInt64{Int64: cutoff, Valid: true},
|
||||
HistoryTipMessageID: sql.NullInt64{Int64: cutoff, Valid: true},
|
||||
Kind: "chat_turn",
|
||||
Status: "in_progress",
|
||||
Provider: sql.NullString{String: providerName, Valid: true},
|
||||
Model: sql.NullString{String: modelName, Valid: true},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
unaffectedStep, err := store.InsertChatDebugStep(ctx, database.InsertChatDebugStepParams{
|
||||
RunID: unaffectedRun.ID,
|
||||
ChatID: chat.ID,
|
||||
StepNumber: 1,
|
||||
Operation: "stream",
|
||||
Status: "in_progress",
|
||||
AssistantMessageID: sql.NullInt64{Int64: cutoff, Valid: true},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
deletedRows, err := store.DeleteChatDebugDataAfterMessageID(ctx, database.DeleteChatDebugDataAfterMessageIDParams{
|
||||
ChatID: chat.ID,
|
||||
MessageID: cutoff,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 3, deletedRows)
|
||||
|
||||
_, err = store.GetChatDebugRunByID(ctx, affectedRun.ID)
|
||||
require.ErrorIs(t, err, sql.ErrNoRows)
|
||||
|
||||
affectedSteps, err := store.GetChatDebugStepsByRunID(ctx, affectedRun.ID)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, affectedSteps)
|
||||
|
||||
_, err = store.GetChatDebugRunByID(ctx, affectedByStepHistoryTipRun.ID)
|
||||
require.ErrorIs(t, err, sql.ErrNoRows)
|
||||
|
||||
affectedByStepHistoryTipSteps, err := store.GetChatDebugStepsByRunID(ctx, affectedByStepHistoryTipRun.ID)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, affectedByStepHistoryTipSteps)
|
||||
|
||||
// Verify the run caught by step-level assistant_message_id is
|
||||
// also deleted. This would survive if the
|
||||
// step.assistant_message_id > @message_id clause were removed.
|
||||
_, err = store.GetChatDebugRunByID(ctx, affectedByStepAssistantMsgRun.ID)
|
||||
require.ErrorIs(t, err, sql.ErrNoRows)
|
||||
|
||||
affectedByStepAssistantMsgSteps, err := store.GetChatDebugStepsByRunID(ctx, affectedByStepAssistantMsgRun.ID)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, affectedByStepAssistantMsgSteps)
|
||||
|
||||
remainingRuns, err := store.GetChatDebugRunsByChatID(ctx, database.GetChatDebugRunsByChatIDParams{
|
||||
ChatID: chat.ID,
|
||||
LimitVal: 100,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, remainingRuns, 1)
|
||||
require.Equal(t, unaffectedRun.ID, remainingRuns[0].ID)
|
||||
|
||||
remainingRun, err := store.GetChatDebugRunByID(ctx, unaffectedRun.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, unaffectedRun.ID, remainingRun.ID)
|
||||
|
||||
remainingSteps, err := store.GetChatDebugStepsByRunID(ctx, unaffectedRun.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, remainingSteps, 1)
|
||||
require.Equal(t, unaffectedStep.ID, remainingSteps[0].ID)
|
||||
}
|
||||
|
||||
func TestFinalizeStaleChatDebugRows(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store, _ := dbtestutil.NewDB(t)
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
dbgen.Organization(t, store, database.Organization{})
|
||||
user := dbgen.User(t, store, database.User{})
|
||||
|
||||
providerName := "openai"
|
||||
modelName := "debug-model-finalize-" + uuid.NewString()
|
||||
|
||||
_, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
||||
Provider: providerName,
|
||||
DisplayName: "Debug Provider",
|
||||
APIKey: "test-key",
|
||||
Enabled: true,
|
||||
CentralApiKeyEnabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
modelCfg, err := store.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
||||
Provider: providerName,
|
||||
Model: modelName,
|
||||
DisplayName: "Debug Model",
|
||||
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
||||
UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
||||
Enabled: true,
|
||||
IsDefault: true,
|
||||
ContextLimit: 128000,
|
||||
CompressionThreshold: 80,
|
||||
Options: json.RawMessage(`{}`),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
chat, err := store.InsertChat(ctx, database.InsertChatParams{
|
||||
Status: database.ChatStatusWaiting,
|
||||
OwnerID: user.ID,
|
||||
LastModelConfigID: modelCfg.ID,
|
||||
Title: "chat-finalize-" + uuid.NewString(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// staleTime is well before the threshold so rows stamped with it
|
||||
// are considered stale. The threshold sits between staleTime and
|
||||
// NOW(), letting us create rows that are stale-by-age and rows
|
||||
// that are fresh-by-age in the same test.
|
||||
staleTime := time.Now().Add(-2 * time.Hour)
|
||||
staleThreshold := time.Now().Add(-1 * time.Hour)
|
||||
|
||||
// --- staleRun: in_progress run with no finished_at --- should be
|
||||
// finalized.
|
||||
staleRun, err := store.InsertChatDebugRun(ctx, database.InsertChatDebugRunParams{
|
||||
ChatID: chat.ID,
|
||||
ModelConfigID: uuid.NullUUID{UUID: modelCfg.ID, Valid: true},
|
||||
TriggerMessageID: sql.NullInt64{Int64: 1, Valid: true},
|
||||
HistoryTipMessageID: sql.NullInt64{Int64: 1, Valid: true},
|
||||
Kind: "chat_turn",
|
||||
Status: "in_progress",
|
||||
Provider: sql.NullString{String: providerName, Valid: true},
|
||||
Model: sql.NullString{String: modelName, Valid: true},
|
||||
UpdatedAt: sql.NullTime{Time: staleTime, Valid: true},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// staleStep: in_progress step attached to staleRun.
|
||||
staleStep, err := store.InsertChatDebugStep(ctx, database.InsertChatDebugStepParams{
|
||||
RunID: staleRun.ID,
|
||||
ChatID: chat.ID,
|
||||
StepNumber: 1,
|
||||
Operation: "stream",
|
||||
Status: "in_progress",
|
||||
UpdatedAt: sql.NullTime{Time: staleTime, Valid: true},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// --- orphanStep: in_progress step whose run is already completed ---
|
||||
// its own updated_at is old, so it should be finalized directly.
|
||||
completedRun, err := store.InsertChatDebugRun(ctx, database.InsertChatDebugRunParams{
|
||||
ChatID: chat.ID,
|
||||
ModelConfigID: uuid.NullUUID{UUID: modelCfg.ID, Valid: true},
|
||||
TriggerMessageID: sql.NullInt64{Int64: 2, Valid: true},
|
||||
HistoryTipMessageID: sql.NullInt64{Int64: 2, Valid: true},
|
||||
Kind: "chat_turn",
|
||||
Status: "completed",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Mark the run as completed with a finished_at timestamp.
|
||||
_, err = store.UpdateChatDebugRun(ctx, database.UpdateChatDebugRunParams{
|
||||
ID: completedRun.ID,
|
||||
ChatID: completedRun.ChatID,
|
||||
Status: sql.NullString{String: "completed", Valid: true},
|
||||
FinishedAt: sql.NullTime{
|
||||
Time: time.Now(),
|
||||
Valid: true,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
orphanStep, err := store.InsertChatDebugStep(ctx, database.InsertChatDebugStepParams{
|
||||
RunID: completedRun.ID,
|
||||
ChatID: chat.ID,
|
||||
StepNumber: 1,
|
||||
Operation: "stream",
|
||||
Status: "in_progress",
|
||||
UpdatedAt: sql.NullTime{Time: staleTime, Valid: true},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// --- cascadeRun: stale in_progress run with a FRESH step ---
|
||||
// The run's updated_at is old so the run itself is finalized by
|
||||
// age. The step's updated_at is recent (default NOW()), so it is
|
||||
// NOT caught by the age predicate. It must be finalized solely
|
||||
// via the cascade CTE clause: run_id IN (SELECT id FROM
|
||||
// finalized_runs). Removing that clause would leave this step
|
||||
// stuck in 'in_progress'.
|
||||
cascadeRun, err := store.InsertChatDebugRun(ctx, database.InsertChatDebugRunParams{
|
||||
ChatID: chat.ID,
|
||||
ModelConfigID: uuid.NullUUID{UUID: modelCfg.ID, Valid: true},
|
||||
TriggerMessageID: sql.NullInt64{Int64: 10, Valid: true},
|
||||
HistoryTipMessageID: sql.NullInt64{Int64: 10, Valid: true},
|
||||
Kind: "chat_turn",
|
||||
Status: "in_progress",
|
||||
Provider: sql.NullString{String: providerName, Valid: true},
|
||||
Model: sql.NullString{String: modelName, Valid: true},
|
||||
UpdatedAt: sql.NullTime{Time: staleTime, Valid: true},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// cascadeStep: recent updated_at (default NOW()), so only the
|
||||
// cascade path can finalize it.
|
||||
cascadeStep, err := store.InsertChatDebugStep(ctx, database.InsertChatDebugStepParams{
|
||||
RunID: cascadeRun.ID,
|
||||
ChatID: chat.ID,
|
||||
StepNumber: 1,
|
||||
Operation: "stream",
|
||||
Status: "in_progress",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// --- alreadyDone: completed run/step --- should NOT be touched.
|
||||
doneRun, err := store.InsertChatDebugRun(ctx, database.InsertChatDebugRunParams{
|
||||
ChatID: chat.ID,
|
||||
ModelConfigID: uuid.NullUUID{UUID: modelCfg.ID, Valid: true},
|
||||
TriggerMessageID: sql.NullInt64{Int64: 3, Valid: true},
|
||||
HistoryTipMessageID: sql.NullInt64{Int64: 3, Valid: true},
|
||||
Kind: "chat_turn",
|
||||
Status: "completed",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = store.UpdateChatDebugRun(ctx, database.UpdateChatDebugRunParams{
|
||||
ID: doneRun.ID,
|
||||
ChatID: doneRun.ChatID,
|
||||
Status: sql.NullString{String: "completed", Valid: true},
|
||||
FinishedAt: sql.NullTime{
|
||||
Time: time.Now(),
|
||||
Valid: true,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
doneStep, err := store.InsertChatDebugStep(ctx, database.InsertChatDebugStepParams{
|
||||
RunID: doneRun.ID,
|
||||
ChatID: chat.ID,
|
||||
StepNumber: 1,
|
||||
Operation: "stream",
|
||||
Status: "completed",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = store.UpdateChatDebugStep(ctx, database.UpdateChatDebugStepParams{
|
||||
ID: doneStep.ID,
|
||||
ChatID: chat.ID,
|
||||
Status: sql.NullString{String: "completed", Valid: true},
|
||||
FinishedAt: sql.NullTime{
|
||||
Time: time.Now(),
|
||||
Valid: true,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// --- errorRun: error run/step --- should NOT be touched either,
|
||||
// exercising the 'error' branch of the NOT IN clause.
|
||||
errorRun, err := store.InsertChatDebugRun(ctx, database.InsertChatDebugRunParams{
|
||||
ChatID: chat.ID,
|
||||
ModelConfigID: uuid.NullUUID{UUID: modelCfg.ID, Valid: true},
|
||||
TriggerMessageID: sql.NullInt64{Int64: 4, Valid: true},
|
||||
HistoryTipMessageID: sql.NullInt64{Int64: 4, Valid: true},
|
||||
Kind: "chat_turn",
|
||||
Status: "error",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = store.UpdateChatDebugRun(ctx, database.UpdateChatDebugRunParams{
|
||||
ID: errorRun.ID,
|
||||
ChatID: errorRun.ChatID,
|
||||
Status: sql.NullString{String: "error", Valid: true},
|
||||
FinishedAt: sql.NullTime{
|
||||
Time: time.Now(),
|
||||
Valid: true,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
errorStep, err := store.InsertChatDebugStep(ctx, database.InsertChatDebugStepParams{
|
||||
RunID: errorRun.ID,
|
||||
ChatID: chat.ID,
|
||||
StepNumber: 1,
|
||||
Operation: "stream",
|
||||
Status: "error",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = store.UpdateChatDebugStep(ctx, database.UpdateChatDebugStepParams{
|
||||
ID: errorStep.ID,
|
||||
ChatID: chat.ID,
|
||||
Status: sql.NullString{String: "error", Valid: true},
|
||||
FinishedAt: sql.NullTime{
|
||||
Time: time.Now(),
|
||||
Valid: true,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// --- Execute the finalization sweep. ---
|
||||
result, err := store.FinalizeStaleChatDebugRows(ctx, staleThreshold)
|
||||
require.NoError(t, err)
|
||||
|
||||
// staleRun + cascadeRun were finalized; completedRun and doneRun
|
||||
// were already terminal so only 2 runs are expected.
|
||||
assert.EqualValues(t, 2, result.RunsFinalized,
|
||||
"stale + cascade in_progress runs should be finalized")
|
||||
// staleStep (age), orphanStep (age), cascadeStep (cascade only)
|
||||
// should all be finalized.
|
||||
assert.EqualValues(t, 3, result.StepsFinalized,
|
||||
"stale step + orphan step + cascade step should all be finalized")
|
||||
|
||||
// Verify the stale run was set to interrupted.
|
||||
updatedStaleRun, err := store.GetChatDebugRunByID(ctx, staleRun.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "interrupted", updatedStaleRun.Status)
|
||||
assert.True(t, updatedStaleRun.FinishedAt.Valid,
|
||||
"finalized run should have a finished_at timestamp")
|
||||
|
||||
// Verify the stale step was set to interrupted.
|
||||
staleSteps, err := store.GetChatDebugStepsByRunID(ctx, staleRun.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, staleSteps, 1)
|
||||
assert.Equal(t, staleStep.ID, staleSteps[0].ID)
|
||||
assert.Equal(t, "interrupted", staleSteps[0].Status)
|
||||
assert.True(t, staleSteps[0].FinishedAt.Valid,
|
||||
"finalized step should have a finished_at timestamp")
|
||||
|
||||
// Verify the orphan step was also finalized.
|
||||
orphanSteps, err := store.GetChatDebugStepsByRunID(ctx, completedRun.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, orphanSteps, 1)
|
||||
assert.Equal(t, orphanStep.ID, orphanSteps[0].ID)
|
||||
assert.Equal(t, "interrupted", orphanSteps[0].Status)
|
||||
|
||||
// Verify the cascade run was finalized.
|
||||
updatedCascadeRun, err := store.GetChatDebugRunByID(ctx, cascadeRun.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "interrupted", updatedCascadeRun.Status)
|
||||
assert.True(t, updatedCascadeRun.FinishedAt.Valid,
|
||||
"cascade run should have a finished_at timestamp")
|
||||
|
||||
// Verify the cascade step was finalized despite its recent
|
||||
// updated_at, proving the cascade CTE clause is required.
|
||||
cascadeSteps, err := store.GetChatDebugStepsByRunID(ctx, cascadeRun.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, cascadeSteps, 1)
|
||||
assert.Equal(t, cascadeStep.ID, cascadeSteps[0].ID)
|
||||
assert.Equal(t, "interrupted", cascadeSteps[0].Status,
|
||||
"fresh step should be finalized via cascade, not age")
|
||||
assert.True(t, cascadeSteps[0].FinishedAt.Valid,
|
||||
"cascade step should have a finished_at timestamp")
|
||||
|
||||
// Verify the completed run/step are untouched.
|
||||
unchangedRun, err := store.GetChatDebugRunByID(ctx, doneRun.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "completed", unchangedRun.Status)
|
||||
|
||||
doneSteps, err := store.GetChatDebugStepsByRunID(ctx, doneRun.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, doneSteps, 1)
|
||||
assert.Equal(t, "completed", doneSteps[0].Status)
|
||||
|
||||
// Verify the error run/step are untouched.
|
||||
unchangedErrorRun, err := store.GetChatDebugRunByID(ctx, errorRun.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "error", unchangedErrorRun.Status)
|
||||
|
||||
errorSteps, err := store.GetChatDebugStepsByRunID(ctx, errorRun.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, errorSteps, 1)
|
||||
assert.Equal(t, "error", errorSteps[0].Status)
|
||||
|
||||
// A second sweep should be a no-op.
|
||||
result2, err := store.FinalizeStaleChatDebugRows(ctx, staleThreshold)
|
||||
require.NoError(t, err)
|
||||
assert.EqualValues(t, 0, result2.RunsFinalized,
|
||||
"second sweep should find nothing to finalize")
|
||||
assert.EqualValues(t, 0, result2.StepsFinalized,
|
||||
"second sweep should find nothing to finalize")
|
||||
}
|
||||
|
||||
func TestChatDebugSQLGuards(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store, _ := dbtestutil.NewDB(t)
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
dbgen.Organization(t, store, database.Organization{})
|
||||
user := dbgen.User(t, store, database.User{})
|
||||
|
||||
providerName := "openai"
|
||||
modelName := "debug-model-guards-" + uuid.NewString()
|
||||
|
||||
_, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
||||
Provider: providerName,
|
||||
DisplayName: "Debug Provider",
|
||||
APIKey: "test-key",
|
||||
Enabled: true,
|
||||
CentralApiKeyEnabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
modelCfg, err := store.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
||||
Provider: providerName,
|
||||
Model: modelName,
|
||||
DisplayName: "Debug Model",
|
||||
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
||||
UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
||||
Enabled: true,
|
||||
IsDefault: true,
|
||||
ContextLimit: 128000,
|
||||
CompressionThreshold: 80,
|
||||
Options: json.RawMessage(`{}`),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
chatA, err := store.InsertChat(ctx, database.InsertChatParams{
|
||||
Status: database.ChatStatusWaiting,
|
||||
OwnerID: user.ID,
|
||||
LastModelConfigID: modelCfg.ID,
|
||||
Title: "chat-guard-A-" + uuid.NewString(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
chatB, err := store.InsertChat(ctx, database.InsertChatParams{
|
||||
Status: database.ChatStatusWaiting,
|
||||
OwnerID: user.ID,
|
||||
LastModelConfigID: modelCfg.ID,
|
||||
Title: "chat-guard-B-" + uuid.NewString(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
runA, err := store.InsertChatDebugRun(ctx, database.InsertChatDebugRunParams{
|
||||
ChatID: chatA.ID,
|
||||
ModelConfigID: uuid.NullUUID{UUID: modelCfg.ID, Valid: true},
|
||||
TriggerMessageID: sql.NullInt64{Int64: 1, Valid: true},
|
||||
HistoryTipMessageID: sql.NullInt64{Int64: 1, Valid: true},
|
||||
Kind: "chat_turn",
|
||||
Status: "in_progress",
|
||||
Provider: sql.NullString{String: providerName, Valid: true},
|
||||
Model: sql.NullString{String: modelName, Valid: true},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
stepA, err := store.InsertChatDebugStep(ctx, database.InsertChatDebugStepParams{
|
||||
RunID: runA.ID,
|
||||
ChatID: chatA.ID,
|
||||
StepNumber: 1,
|
||||
Operation: "stream",
|
||||
Status: "in_progress",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// InsertChatDebugStep: valid run_id but chat_id belongs to a
|
||||
// different chat. The INSERT...SELECT guard should produce zero
|
||||
// rows, surfacing as sql.ErrNoRows.
|
||||
t.Run("InsertChatDebugStep_MismatchedChatID", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
_, err := store.InsertChatDebugStep(ctx, database.InsertChatDebugStepParams{
|
||||
RunID: runA.ID,
|
||||
ChatID: chatB.ID, // wrong chat
|
||||
StepNumber: 2,
|
||||
Operation: "stream",
|
||||
Status: "in_progress",
|
||||
})
|
||||
require.ErrorIs(t, err, sql.ErrNoRows,
|
||||
"InsertChatDebugStep should fail when chat_id does not match the run's chat_id")
|
||||
})
|
||||
|
||||
// UpdateChatDebugRun: valid run ID but wrong chat_id.
|
||||
t.Run("UpdateChatDebugRun_MismatchedChatID", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
_, err := store.UpdateChatDebugRun(ctx, database.UpdateChatDebugRunParams{
|
||||
ID: runA.ID,
|
||||
ChatID: chatB.ID, // wrong chat
|
||||
Status: sql.NullString{String: "completed", Valid: true},
|
||||
FinishedAt: sql.NullTime{
|
||||
Time: time.Now(),
|
||||
Valid: true,
|
||||
},
|
||||
})
|
||||
require.ErrorIs(t, err, sql.ErrNoRows,
|
||||
"UpdateChatDebugRun should fail when chat_id does not match")
|
||||
})
|
||||
|
||||
// UpdateChatDebugStep: valid step ID but wrong chat_id.
|
||||
t.Run("UpdateChatDebugStep_MismatchedChatID", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
_, err := store.UpdateChatDebugStep(ctx, database.UpdateChatDebugStepParams{
|
||||
ID: stepA.ID,
|
||||
ChatID: chatB.ID, // wrong chat
|
||||
Status: sql.NullString{String: "completed", Valid: true},
|
||||
FinishedAt: sql.NullTime{
|
||||
Time: time.Now(),
|
||||
Valid: true,
|
||||
},
|
||||
})
|
||||
require.ErrorIs(t, err, sql.ErrNoRows,
|
||||
"UpdateChatDebugStep should fail when chat_id does not match")
|
||||
})
|
||||
}
|
||||
|
||||
// TestChatDebugRunCOALESCEPreservation verifies that the COALESCE
|
||||
// pattern in UpdateChatDebugRun preserves every field that was not
|
||||
// explicitly supplied in the update. If COALESCE were removed from
|
||||
// any column, the corresponding field would silently null out.
|
||||
func TestChatDebugRunCOALESCEPreservation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store, _ := dbtestutil.NewDB(t)
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
dbgen.Organization(t, store, database.Organization{})
|
||||
user := dbgen.User(t, store, database.User{})
|
||||
|
||||
providerName := "openai"
|
||||
modelName := "debug-model-coalesce-" + uuid.NewString()
|
||||
|
||||
_, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
||||
Provider: providerName,
|
||||
DisplayName: "Debug Provider",
|
||||
APIKey: "test-key",
|
||||
Enabled: true,
|
||||
CentralApiKeyEnabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
modelCfg, err := store.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
||||
Provider: providerName,
|
||||
Model: modelName,
|
||||
DisplayName: "Debug Model",
|
||||
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
||||
UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
||||
Enabled: true,
|
||||
IsDefault: true,
|
||||
ContextLimit: 128000,
|
||||
CompressionThreshold: 80,
|
||||
Options: json.RawMessage(`{}`),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
chat, err := store.InsertChat(ctx, database.InsertChatParams{
|
||||
Status: database.ChatStatusWaiting,
|
||||
OwnerID: user.ID,
|
||||
LastModelConfigID: modelCfg.ID,
|
||||
Title: "chat-debug-coalesce-" + uuid.NewString(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
rootChatID := uuid.New()
|
||||
parentChatID := uuid.New()
|
||||
|
||||
// Insert a fully-populated run so every nullable field has a value.
|
||||
original, err := store.InsertChatDebugRun(ctx, database.InsertChatDebugRunParams{
|
||||
ChatID: chat.ID,
|
||||
RootChatID: uuid.NullUUID{UUID: rootChatID, Valid: true},
|
||||
ParentChatID: uuid.NullUUID{UUID: parentChatID, Valid: true},
|
||||
ModelConfigID: uuid.NullUUID{UUID: modelCfg.ID, Valid: true},
|
||||
TriggerMessageID: sql.NullInt64{Int64: 42, Valid: true},
|
||||
HistoryTipMessageID: sql.NullInt64{Int64: 41, Valid: true},
|
||||
Kind: "chat_turn",
|
||||
Status: "in_progress",
|
||||
Provider: sql.NullString{String: providerName, Valid: true},
|
||||
Model: sql.NullString{String: modelName, Valid: true},
|
||||
Summary: pqtype.NullRawMessage{RawMessage: json.RawMessage(`{"key":"val"}`), Valid: true},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Update only Status and FinishedAt. Every other nullable param
|
||||
// is left as its Go zero value (Valid: false → SQL NULL), which
|
||||
// the COALESCE pattern should interpret as "keep existing."
|
||||
now := time.Now()
|
||||
updated, err := store.UpdateChatDebugRun(ctx, database.UpdateChatDebugRunParams{
|
||||
ID: original.ID,
|
||||
ChatID: chat.ID,
|
||||
Status: sql.NullString{String: "completed", Valid: true},
|
||||
FinishedAt: sql.NullTime{
|
||||
Time: now,
|
||||
Valid: true,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Status and FinishedAt should be updated.
|
||||
require.Equal(t, "completed", updated.Status)
|
||||
require.True(t, updated.FinishedAt.Valid)
|
||||
|
||||
// UpdatedAt should advance (set to NOW() unconditionally).
|
||||
require.True(t, updated.UpdatedAt.After(original.UpdatedAt) ||
|
||||
updated.UpdatedAt.Equal(original.UpdatedAt))
|
||||
|
||||
// Every field not in the update call must be preserved exactly.
|
||||
require.Equal(t, original.RootChatID, updated.RootChatID,
|
||||
"RootChatID should survive a partial update")
|
||||
require.Equal(t, original.ParentChatID, updated.ParentChatID,
|
||||
"ParentChatID should survive a partial update")
|
||||
require.Equal(t, original.ModelConfigID, updated.ModelConfigID,
|
||||
"ModelConfigID should survive a partial update")
|
||||
require.Equal(t, original.TriggerMessageID, updated.TriggerMessageID,
|
||||
"TriggerMessageID should survive a partial update")
|
||||
require.Equal(t, original.HistoryTipMessageID, updated.HistoryTipMessageID,
|
||||
"HistoryTipMessageID should survive a partial update")
|
||||
require.Equal(t, original.Provider, updated.Provider,
|
||||
"Provider should survive a partial update")
|
||||
require.Equal(t, original.Model, updated.Model,
|
||||
"Model should survive a partial update")
|
||||
require.JSONEq(t, string(original.Summary), string(updated.Summary),
|
||||
"Summary should survive a partial update")
|
||||
require.Equal(t, original.Kind, updated.Kind,
|
||||
"Kind should survive a partial update")
|
||||
require.Equal(t, original.StartedAt.UTC(), updated.StartedAt.UTC(),
|
||||
"StartedAt should survive a partial update")
|
||||
}
|
||||
|
||||
// TestChatDebugStepCOALESCEPreservation verifies that the COALESCE
|
||||
// pattern in UpdateChatDebugStep preserves every field that was not
|
||||
// explicitly supplied in the update. If COALESCE were removed from
|
||||
// any column, the corresponding field would silently null out.
|
||||
func TestChatDebugStepCOALESCEPreservation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store, _ := dbtestutil.NewDB(t)
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
dbgen.Organization(t, store, database.Organization{})
|
||||
user := dbgen.User(t, store, database.User{})
|
||||
|
||||
providerName := "openai"
|
||||
modelName := "debug-step-coalesce-" + uuid.NewString()
|
||||
|
||||
_, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
||||
Provider: providerName,
|
||||
DisplayName: "Debug Provider",
|
||||
APIKey: "test-key",
|
||||
Enabled: true,
|
||||
CentralApiKeyEnabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
modelCfg, err := store.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
||||
Provider: providerName,
|
||||
Model: modelName,
|
||||
DisplayName: "Debug Model",
|
||||
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
||||
UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
||||
Enabled: true,
|
||||
IsDefault: true,
|
||||
ContextLimit: 128000,
|
||||
CompressionThreshold: 80,
|
||||
Options: json.RawMessage(`{}`),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
chat, err := store.InsertChat(ctx, database.InsertChatParams{
|
||||
Status: database.ChatStatusWaiting,
|
||||
OwnerID: user.ID,
|
||||
LastModelConfigID: modelCfg.ID,
|
||||
Title: "chat-step-coalesce-" + uuid.NewString(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
run, err := store.InsertChatDebugRun(ctx, database.InsertChatDebugRunParams{
|
||||
ChatID: chat.ID,
|
||||
Kind: "chat_turn",
|
||||
Status: "in_progress",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Insert a fully-populated step so every nullable field has a value.
|
||||
original, err := store.InsertChatDebugStep(ctx, database.InsertChatDebugStepParams{
|
||||
RunID: run.ID,
|
||||
ChatID: chat.ID,
|
||||
StepNumber: 1,
|
||||
Operation: "llm_call",
|
||||
Status: "in_progress",
|
||||
HistoryTipMessageID: sql.NullInt64{Int64: 10, Valid: true},
|
||||
AssistantMessageID: sql.NullInt64{Int64: 11, Valid: true},
|
||||
NormalizedRequest: pqtype.NullRawMessage{RawMessage: json.RawMessage(`{"prompt":"hello"}`), Valid: true},
|
||||
NormalizedResponse: pqtype.NullRawMessage{RawMessage: json.RawMessage(`{"text":"world"}`), Valid: true},
|
||||
Usage: pqtype.NullRawMessage{RawMessage: json.RawMessage(`{"tokens":42}`), Valid: true},
|
||||
Attempts: pqtype.NullRawMessage{RawMessage: json.RawMessage(`[{"n":1}]`), Valid: true},
|
||||
Error: pqtype.NullRawMessage{RawMessage: json.RawMessage(`{"code":"transient"}`), Valid: true},
|
||||
Metadata: pqtype.NullRawMessage{RawMessage: json.RawMessage(`{"trace_id":"abc"}`), Valid: true},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Update only Status and FinishedAt. Every other nullable param
|
||||
// is left as its Go zero value (Valid: false -> SQL NULL), which
|
||||
// the COALESCE pattern should interpret as "keep existing."
|
||||
now := time.Now()
|
||||
updated, err := store.UpdateChatDebugStep(ctx, database.UpdateChatDebugStepParams{
|
||||
ID: original.ID,
|
||||
ChatID: chat.ID,
|
||||
Status: sql.NullString{String: "completed", Valid: true},
|
||||
FinishedAt: sql.NullTime{
|
||||
Time: now,
|
||||
Valid: true,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Status and FinishedAt should be updated.
|
||||
require.Equal(t, "completed", updated.Status)
|
||||
require.True(t, updated.FinishedAt.Valid)
|
||||
|
||||
// UpdatedAt should advance (set to NOW() unconditionally).
|
||||
require.True(t, updated.UpdatedAt.After(original.UpdatedAt) ||
|
||||
updated.UpdatedAt.Equal(original.UpdatedAt))
|
||||
|
||||
// Every field not in the update call must be preserved exactly.
|
||||
require.Equal(t, original.HistoryTipMessageID, updated.HistoryTipMessageID,
|
||||
"HistoryTipMessageID should survive a partial update")
|
||||
require.Equal(t, original.AssistantMessageID, updated.AssistantMessageID,
|
||||
"AssistantMessageID should survive a partial update")
|
||||
require.JSONEq(t, string(original.NormalizedRequest), string(updated.NormalizedRequest),
|
||||
"NormalizedRequest should survive a partial update")
|
||||
require.JSONEq(t, string(original.NormalizedResponse.RawMessage), string(updated.NormalizedResponse.RawMessage),
|
||||
"NormalizedResponse should survive a partial update")
|
||||
require.JSONEq(t, string(original.Usage.RawMessage), string(updated.Usage.RawMessage),
|
||||
"Usage should survive a partial update")
|
||||
require.JSONEq(t, string(original.Attempts), string(updated.Attempts),
|
||||
"Attempts should survive a partial update")
|
||||
require.JSONEq(t, string(original.Error.RawMessage), string(updated.Error.RawMessage),
|
||||
"Error should survive a partial update")
|
||||
require.JSONEq(t, string(original.Metadata), string(updated.Metadata),
|
||||
"Metadata should survive a partial update")
|
||||
require.Equal(t, original.Operation, updated.Operation,
|
||||
"Operation should survive a partial update")
|
||||
require.Equal(t, original.StepNumber, updated.StepNumber,
|
||||
"StepNumber should survive a partial update")
|
||||
require.Equal(t, original.StartedAt.UTC(), updated.StartedAt.UTC(),
|
||||
"StartedAt should survive a partial update")
|
||||
}
|
||||
|
||||
// TestDeleteChatDebugDataAfterMessageIDNullMessagesSurvive verifies
|
||||
// that runs whose message ID columns are all NULL are never matched
|
||||
// by DeleteChatDebugDataAfterMessageID. SQL's three-valued logic
|
||||
// means NULL > N evaluates to NULL (not TRUE), so these rows must
|
||||
// survive. Without this test a future change could break the
|
||||
// invariant with no test failure.
|
||||
func TestDeleteChatDebugDataAfterMessageIDNullMessagesSurvive(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store, _ := dbtestutil.NewDB(t)
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
dbgen.Organization(t, store, database.Organization{})
|
||||
user := dbgen.User(t, store, database.User{})
|
||||
|
||||
providerName := "openai"
|
||||
modelName := "debug-model-null-msg-" + uuid.NewString()
|
||||
|
||||
_, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
||||
Provider: providerName,
|
||||
DisplayName: "Debug Provider",
|
||||
APIKey: "test-key",
|
||||
Enabled: true,
|
||||
CentralApiKeyEnabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
modelCfg, err := store.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
||||
Provider: providerName,
|
||||
Model: modelName,
|
||||
DisplayName: "Debug Model",
|
||||
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
||||
UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
||||
Enabled: true,
|
||||
IsDefault: true,
|
||||
ContextLimit: 128000,
|
||||
CompressionThreshold: 80,
|
||||
Options: json.RawMessage(`{}`),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
chat, err := store.InsertChat(ctx, database.InsertChatParams{
|
||||
Status: database.ChatStatusWaiting,
|
||||
OwnerID: user.ID,
|
||||
LastModelConfigID: modelCfg.ID,
|
||||
Title: "chat-debug-null-msg-" + uuid.NewString(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Insert a run with all message ID columns left as NULL (Valid: false).
|
||||
nullMsgRun, err := store.InsertChatDebugRun(ctx, database.InsertChatDebugRunParams{
|
||||
ChatID: chat.ID,
|
||||
ModelConfigID: uuid.NullUUID{UUID: modelCfg.ID, Valid: true},
|
||||
Kind: "chat_turn",
|
||||
Status: "in_progress",
|
||||
Provider: sql.NullString{String: providerName, Valid: true},
|
||||
Model: sql.NullString{String: modelName, Valid: true},
|
||||
// TriggerMessageID and HistoryTipMessageID intentionally
|
||||
// omitted (zero-value → SQL NULL).
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Attach a step with NULL message IDs too.
|
||||
nullMsgStep, err := store.InsertChatDebugStep(ctx, database.InsertChatDebugStepParams{
|
||||
RunID: nullMsgRun.ID,
|
||||
ChatID: chat.ID,
|
||||
StepNumber: 1,
|
||||
Operation: "stream",
|
||||
Status: "in_progress",
|
||||
// HistoryTipMessageID and AssistantMessageID intentionally
|
||||
// omitted (zero-value → SQL NULL).
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Delete with an arbitrary cutoff. The run and its step should
|
||||
// survive because NULL > cutoff evaluates to NULL, not TRUE.
|
||||
deletedRows, err := store.DeleteChatDebugDataAfterMessageID(ctx, database.DeleteChatDebugDataAfterMessageIDParams{
|
||||
ChatID: chat.ID,
|
||||
MessageID: 1,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 0, deletedRows, "rows with NULL message IDs must not be deleted")
|
||||
|
||||
// Verify run still exists.
|
||||
remaining, err := store.GetChatDebugRunByID(ctx, nullMsgRun.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, nullMsgRun.ID, remaining.ID)
|
||||
|
||||
// Verify step still exists.
|
||||
remainingSteps, err := store.GetChatDebugStepsByRunID(ctx, nullMsgRun.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, remainingSteps, 1)
|
||||
require.Equal(t, nullMsgStep.ID, remainingSteps[0].ID)
|
||||
}
|
||||
|
||||
func TestChatHasUnread(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -2900,583 +2900,6 @@ func (q *sqlQuerier) UpsertBoundaryUsageStats(ctx context.Context, arg UpsertBou
|
||||
return new_period, err
|
||||
}
|
||||
|
||||
const deleteChatDebugDataAfterMessageID = `-- name: DeleteChatDebugDataAfterMessageID :execrows
|
||||
WITH affected_runs AS (
|
||||
SELECT DISTINCT run.id
|
||||
FROM chat_debug_runs run
|
||||
WHERE run.chat_id = $1::uuid
|
||||
AND (
|
||||
run.history_tip_message_id > $2::bigint
|
||||
OR run.trigger_message_id > $2::bigint
|
||||
)
|
||||
|
||||
UNION
|
||||
|
||||
SELECT DISTINCT step.run_id AS id
|
||||
FROM chat_debug_steps step
|
||||
WHERE step.chat_id = $1::uuid
|
||||
AND (
|
||||
step.assistant_message_id > $2::bigint
|
||||
OR step.history_tip_message_id > $2::bigint
|
||||
)
|
||||
)
|
||||
DELETE FROM chat_debug_runs
|
||||
WHERE chat_id = $1::uuid
|
||||
AND id IN (SELECT id FROM affected_runs)
|
||||
`
|
||||
|
||||
type DeleteChatDebugDataAfterMessageIDParams struct {
|
||||
ChatID uuid.UUID `db:"chat_id" json:"chat_id"`
|
||||
MessageID int64 `db:"message_id" json:"message_id"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) DeleteChatDebugDataAfterMessageID(ctx context.Context, arg DeleteChatDebugDataAfterMessageIDParams) (int64, error) {
|
||||
result, err := q.db.ExecContext(ctx, deleteChatDebugDataAfterMessageID, arg.ChatID, arg.MessageID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return result.RowsAffected()
|
||||
}
|
||||
|
||||
const deleteChatDebugDataByChatID = `-- name: DeleteChatDebugDataByChatID :execrows
|
||||
DELETE FROM chat_debug_runs
|
||||
WHERE chat_id = $1::uuid
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) DeleteChatDebugDataByChatID(ctx context.Context, chatID uuid.UUID) (int64, error) {
|
||||
result, err := q.db.ExecContext(ctx, deleteChatDebugDataByChatID, chatID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return result.RowsAffected()
|
||||
}
|
||||
|
||||
const finalizeStaleChatDebugRows = `-- name: FinalizeStaleChatDebugRows :one
|
||||
WITH finalized_runs AS (
|
||||
UPDATE chat_debug_runs
|
||||
SET
|
||||
status = 'interrupted',
|
||||
updated_at = NOW(),
|
||||
finished_at = NOW()
|
||||
WHERE updated_at < $1::timestamptz
|
||||
AND finished_at IS NULL
|
||||
AND status NOT IN ('completed', 'error', 'interrupted')
|
||||
RETURNING id
|
||||
), finalized_steps AS (
|
||||
UPDATE chat_debug_steps
|
||||
SET
|
||||
status = 'interrupted',
|
||||
updated_at = NOW(),
|
||||
finished_at = NOW()
|
||||
WHERE (
|
||||
updated_at < $1::timestamptz
|
||||
OR run_id IN (SELECT id FROM finalized_runs)
|
||||
)
|
||||
AND finished_at IS NULL
|
||||
AND status NOT IN ('completed', 'error', 'interrupted')
|
||||
RETURNING 1
|
||||
)
|
||||
SELECT
|
||||
(SELECT COUNT(*) FROM finalized_runs)::bigint AS runs_finalized,
|
||||
(SELECT COUNT(*) FROM finalized_steps)::bigint AS steps_finalized
|
||||
`
|
||||
|
||||
type FinalizeStaleChatDebugRowsRow struct {
|
||||
RunsFinalized int64 `db:"runs_finalized" json:"runs_finalized"`
|
||||
StepsFinalized int64 `db:"steps_finalized" json:"steps_finalized"`
|
||||
}
|
||||
|
||||
// Marks orphaned in-progress rows as interrupted so they do not stay
|
||||
// in a non-terminal state forever. The NOT IN list must match the
|
||||
// terminal statuses defined by ChatDebugStatus in codersdk/chats.go.
|
||||
//
|
||||
// The steps CTE also catches steps whose parent run was just finalized
|
||||
// (via run_id IN), because PostgreSQL data-modifying CTEs share the
|
||||
// same snapshot and cannot see each other's row updates. Without this,
|
||||
// a step with a recent updated_at would survive its run's finalization
|
||||
// and remain in 'in_progress' state permanently.
|
||||
func (q *sqlQuerier) FinalizeStaleChatDebugRows(ctx context.Context, updatedBefore time.Time) (FinalizeStaleChatDebugRowsRow, error) {
|
||||
row := q.db.QueryRowContext(ctx, finalizeStaleChatDebugRows, updatedBefore)
|
||||
var i FinalizeStaleChatDebugRowsRow
|
||||
err := row.Scan(&i.RunsFinalized, &i.StepsFinalized)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getChatDebugRunByID = `-- name: GetChatDebugRunByID :one
|
||||
SELECT id, chat_id, root_chat_id, parent_chat_id, model_config_id, trigger_message_id, history_tip_message_id, kind, status, provider, model, summary, started_at, updated_at, finished_at
|
||||
FROM chat_debug_runs
|
||||
WHERE id = $1::uuid
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) GetChatDebugRunByID(ctx context.Context, id uuid.UUID) (ChatDebugRun, error) {
|
||||
row := q.db.QueryRowContext(ctx, getChatDebugRunByID, id)
|
||||
var i ChatDebugRun
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.ChatID,
|
||||
&i.RootChatID,
|
||||
&i.ParentChatID,
|
||||
&i.ModelConfigID,
|
||||
&i.TriggerMessageID,
|
||||
&i.HistoryTipMessageID,
|
||||
&i.Kind,
|
||||
&i.Status,
|
||||
&i.Provider,
|
||||
&i.Model,
|
||||
&i.Summary,
|
||||
&i.StartedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.FinishedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getChatDebugRunsByChatID = `-- name: GetChatDebugRunsByChatID :many
|
||||
SELECT id, chat_id, root_chat_id, parent_chat_id, model_config_id, trigger_message_id, history_tip_message_id, kind, status, provider, model, summary, started_at, updated_at, finished_at
|
||||
FROM chat_debug_runs
|
||||
WHERE chat_id = $1::uuid
|
||||
ORDER BY started_at DESC, id DESC
|
||||
LIMIT $2::int
|
||||
`
|
||||
|
||||
type GetChatDebugRunsByChatIDParams struct {
|
||||
ChatID uuid.UUID `db:"chat_id" json:"chat_id"`
|
||||
LimitVal int32 `db:"limit_val" json:"limit_val"`
|
||||
}
|
||||
|
||||
// Returns the most recent debug runs for a chat, ordered newest-first.
|
||||
// Callers must supply an explicit limit to avoid unbounded result sets.
|
||||
func (q *sqlQuerier) GetChatDebugRunsByChatID(ctx context.Context, arg GetChatDebugRunsByChatIDParams) ([]ChatDebugRun, error) {
|
||||
rows, err := q.db.QueryContext(ctx, getChatDebugRunsByChatID, arg.ChatID, arg.LimitVal)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []ChatDebugRun
|
||||
for rows.Next() {
|
||||
var i ChatDebugRun
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.ChatID,
|
||||
&i.RootChatID,
|
||||
&i.ParentChatID,
|
||||
&i.ModelConfigID,
|
||||
&i.TriggerMessageID,
|
||||
&i.HistoryTipMessageID,
|
||||
&i.Kind,
|
||||
&i.Status,
|
||||
&i.Provider,
|
||||
&i.Model,
|
||||
&i.Summary,
|
||||
&i.StartedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.FinishedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const getChatDebugStepsByRunID = `-- name: GetChatDebugStepsByRunID :many
|
||||
SELECT id, run_id, chat_id, step_number, operation, status, history_tip_message_id, assistant_message_id, normalized_request, normalized_response, usage, attempts, error, metadata, started_at, updated_at, finished_at
|
||||
FROM chat_debug_steps
|
||||
WHERE run_id = $1::uuid
|
||||
ORDER BY step_number ASC, started_at ASC
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) GetChatDebugStepsByRunID(ctx context.Context, runID uuid.UUID) ([]ChatDebugStep, error) {
|
||||
rows, err := q.db.QueryContext(ctx, getChatDebugStepsByRunID, runID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []ChatDebugStep
|
||||
for rows.Next() {
|
||||
var i ChatDebugStep
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.RunID,
|
||||
&i.ChatID,
|
||||
&i.StepNumber,
|
||||
&i.Operation,
|
||||
&i.Status,
|
||||
&i.HistoryTipMessageID,
|
||||
&i.AssistantMessageID,
|
||||
&i.NormalizedRequest,
|
||||
&i.NormalizedResponse,
|
||||
&i.Usage,
|
||||
&i.Attempts,
|
||||
&i.Error,
|
||||
&i.Metadata,
|
||||
&i.StartedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.FinishedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const insertChatDebugRun = `-- name: InsertChatDebugRun :one
|
||||
INSERT INTO chat_debug_runs (
|
||||
chat_id,
|
||||
root_chat_id,
|
||||
parent_chat_id,
|
||||
model_config_id,
|
||||
trigger_message_id,
|
||||
history_tip_message_id,
|
||||
kind,
|
||||
status,
|
||||
provider,
|
||||
model,
|
||||
summary,
|
||||
started_at,
|
||||
updated_at,
|
||||
finished_at
|
||||
)
|
||||
VALUES (
|
||||
$1::uuid,
|
||||
$2::uuid,
|
||||
$3::uuid,
|
||||
$4::uuid,
|
||||
$5::bigint,
|
||||
$6::bigint,
|
||||
$7::text,
|
||||
$8::text,
|
||||
$9::text,
|
||||
$10::text,
|
||||
COALESCE($11::jsonb, '{}'::jsonb),
|
||||
COALESCE($12::timestamptz, NOW()),
|
||||
COALESCE($13::timestamptz, NOW()),
|
||||
$14::timestamptz
|
||||
)
|
||||
RETURNING id, chat_id, root_chat_id, parent_chat_id, model_config_id, trigger_message_id, history_tip_message_id, kind, status, provider, model, summary, started_at, updated_at, finished_at
|
||||
`
|
||||
|
||||
type InsertChatDebugRunParams struct {
|
||||
ChatID uuid.UUID `db:"chat_id" json:"chat_id"`
|
||||
RootChatID uuid.NullUUID `db:"root_chat_id" json:"root_chat_id"`
|
||||
ParentChatID uuid.NullUUID `db:"parent_chat_id" json:"parent_chat_id"`
|
||||
ModelConfigID uuid.NullUUID `db:"model_config_id" json:"model_config_id"`
|
||||
TriggerMessageID sql.NullInt64 `db:"trigger_message_id" json:"trigger_message_id"`
|
||||
HistoryTipMessageID sql.NullInt64 `db:"history_tip_message_id" json:"history_tip_message_id"`
|
||||
Kind string `db:"kind" json:"kind"`
|
||||
Status string `db:"status" json:"status"`
|
||||
Provider sql.NullString `db:"provider" json:"provider"`
|
||||
Model sql.NullString `db:"model" json:"model"`
|
||||
Summary pqtype.NullRawMessage `db:"summary" json:"summary"`
|
||||
StartedAt sql.NullTime `db:"started_at" json:"started_at"`
|
||||
UpdatedAt sql.NullTime `db:"updated_at" json:"updated_at"`
|
||||
FinishedAt sql.NullTime `db:"finished_at" json:"finished_at"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) InsertChatDebugRun(ctx context.Context, arg InsertChatDebugRunParams) (ChatDebugRun, error) {
|
||||
row := q.db.QueryRowContext(ctx, insertChatDebugRun,
|
||||
arg.ChatID,
|
||||
arg.RootChatID,
|
||||
arg.ParentChatID,
|
||||
arg.ModelConfigID,
|
||||
arg.TriggerMessageID,
|
||||
arg.HistoryTipMessageID,
|
||||
arg.Kind,
|
||||
arg.Status,
|
||||
arg.Provider,
|
||||
arg.Model,
|
||||
arg.Summary,
|
||||
arg.StartedAt,
|
||||
arg.UpdatedAt,
|
||||
arg.FinishedAt,
|
||||
)
|
||||
var i ChatDebugRun
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.ChatID,
|
||||
&i.RootChatID,
|
||||
&i.ParentChatID,
|
||||
&i.ModelConfigID,
|
||||
&i.TriggerMessageID,
|
||||
&i.HistoryTipMessageID,
|
||||
&i.Kind,
|
||||
&i.Status,
|
||||
&i.Provider,
|
||||
&i.Model,
|
||||
&i.Summary,
|
||||
&i.StartedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.FinishedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const insertChatDebugStep = `-- name: InsertChatDebugStep :one
|
||||
INSERT INTO chat_debug_steps (
|
||||
run_id,
|
||||
chat_id,
|
||||
step_number,
|
||||
operation,
|
||||
status,
|
||||
history_tip_message_id,
|
||||
assistant_message_id,
|
||||
normalized_request,
|
||||
normalized_response,
|
||||
usage,
|
||||
attempts,
|
||||
error,
|
||||
metadata,
|
||||
started_at,
|
||||
updated_at,
|
||||
finished_at
|
||||
)
|
||||
SELECT
|
||||
$1::uuid,
|
||||
run.chat_id,
|
||||
$2::int,
|
||||
$3::text,
|
||||
$4::text,
|
||||
$5::bigint,
|
||||
$6::bigint,
|
||||
COALESCE($7::jsonb, '{}'::jsonb),
|
||||
$8::jsonb,
|
||||
$9::jsonb,
|
||||
COALESCE($10::jsonb, '[]'::jsonb),
|
||||
$11::jsonb,
|
||||
COALESCE($12::jsonb, '{}'::jsonb),
|
||||
COALESCE($13::timestamptz, NOW()),
|
||||
COALESCE($14::timestamptz, NOW()),
|
||||
$15::timestamptz
|
||||
FROM chat_debug_runs run
|
||||
WHERE run.id = $1::uuid
|
||||
AND run.chat_id = $16::uuid
|
||||
RETURNING id, run_id, chat_id, step_number, operation, status, history_tip_message_id, assistant_message_id, normalized_request, normalized_response, usage, attempts, error, metadata, started_at, updated_at, finished_at
|
||||
`
|
||||
|
||||
type InsertChatDebugStepParams struct {
|
||||
RunID uuid.UUID `db:"run_id" json:"run_id"`
|
||||
StepNumber int32 `db:"step_number" json:"step_number"`
|
||||
Operation string `db:"operation" json:"operation"`
|
||||
Status string `db:"status" json:"status"`
|
||||
HistoryTipMessageID sql.NullInt64 `db:"history_tip_message_id" json:"history_tip_message_id"`
|
||||
AssistantMessageID sql.NullInt64 `db:"assistant_message_id" json:"assistant_message_id"`
|
||||
NormalizedRequest pqtype.NullRawMessage `db:"normalized_request" json:"normalized_request"`
|
||||
NormalizedResponse pqtype.NullRawMessage `db:"normalized_response" json:"normalized_response"`
|
||||
Usage pqtype.NullRawMessage `db:"usage" json:"usage"`
|
||||
Attempts pqtype.NullRawMessage `db:"attempts" json:"attempts"`
|
||||
Error pqtype.NullRawMessage `db:"error" json:"error"`
|
||||
Metadata pqtype.NullRawMessage `db:"metadata" json:"metadata"`
|
||||
StartedAt sql.NullTime `db:"started_at" json:"started_at"`
|
||||
UpdatedAt sql.NullTime `db:"updated_at" json:"updated_at"`
|
||||
FinishedAt sql.NullTime `db:"finished_at" json:"finished_at"`
|
||||
ChatID uuid.UUID `db:"chat_id" json:"chat_id"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) InsertChatDebugStep(ctx context.Context, arg InsertChatDebugStepParams) (ChatDebugStep, error) {
|
||||
row := q.db.QueryRowContext(ctx, insertChatDebugStep,
|
||||
arg.RunID,
|
||||
arg.StepNumber,
|
||||
arg.Operation,
|
||||
arg.Status,
|
||||
arg.HistoryTipMessageID,
|
||||
arg.AssistantMessageID,
|
||||
arg.NormalizedRequest,
|
||||
arg.NormalizedResponse,
|
||||
arg.Usage,
|
||||
arg.Attempts,
|
||||
arg.Error,
|
||||
arg.Metadata,
|
||||
arg.StartedAt,
|
||||
arg.UpdatedAt,
|
||||
arg.FinishedAt,
|
||||
arg.ChatID,
|
||||
)
|
||||
var i ChatDebugStep
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.RunID,
|
||||
&i.ChatID,
|
||||
&i.StepNumber,
|
||||
&i.Operation,
|
||||
&i.Status,
|
||||
&i.HistoryTipMessageID,
|
||||
&i.AssistantMessageID,
|
||||
&i.NormalizedRequest,
|
||||
&i.NormalizedResponse,
|
||||
&i.Usage,
|
||||
&i.Attempts,
|
||||
&i.Error,
|
||||
&i.Metadata,
|
||||
&i.StartedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.FinishedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const updateChatDebugRun = `-- name: UpdateChatDebugRun :one
|
||||
UPDATE chat_debug_runs
|
||||
SET
|
||||
root_chat_id = COALESCE($1::uuid, root_chat_id),
|
||||
parent_chat_id = COALESCE($2::uuid, parent_chat_id),
|
||||
model_config_id = COALESCE($3::uuid, model_config_id),
|
||||
trigger_message_id = COALESCE($4::bigint, trigger_message_id),
|
||||
history_tip_message_id = COALESCE($5::bigint, history_tip_message_id),
|
||||
status = COALESCE($6::text, status),
|
||||
provider = COALESCE($7::text, provider),
|
||||
model = COALESCE($8::text, model),
|
||||
summary = COALESCE($9::jsonb, summary),
|
||||
finished_at = COALESCE($10::timestamptz, finished_at),
|
||||
updated_at = NOW()
|
||||
WHERE id = $11::uuid
|
||||
AND chat_id = $12::uuid
|
||||
RETURNING id, chat_id, root_chat_id, parent_chat_id, model_config_id, trigger_message_id, history_tip_message_id, kind, status, provider, model, summary, started_at, updated_at, finished_at
|
||||
`
|
||||
|
||||
type UpdateChatDebugRunParams struct {
|
||||
RootChatID uuid.NullUUID `db:"root_chat_id" json:"root_chat_id"`
|
||||
ParentChatID uuid.NullUUID `db:"parent_chat_id" json:"parent_chat_id"`
|
||||
ModelConfigID uuid.NullUUID `db:"model_config_id" json:"model_config_id"`
|
||||
TriggerMessageID sql.NullInt64 `db:"trigger_message_id" json:"trigger_message_id"`
|
||||
HistoryTipMessageID sql.NullInt64 `db:"history_tip_message_id" json:"history_tip_message_id"`
|
||||
Status sql.NullString `db:"status" json:"status"`
|
||||
Provider sql.NullString `db:"provider" json:"provider"`
|
||||
Model sql.NullString `db:"model" json:"model"`
|
||||
Summary pqtype.NullRawMessage `db:"summary" json:"summary"`
|
||||
FinishedAt sql.NullTime `db:"finished_at" json:"finished_at"`
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
ChatID uuid.UUID `db:"chat_id" json:"chat_id"`
|
||||
}
|
||||
|
||||
// Uses COALESCE so that passing NULL from Go means "keep the
|
||||
// existing value." This is intentional: debug rows follow a
|
||||
// write-once-finalize pattern where fields are set at creation
|
||||
// or finalization and never cleared back to NULL.
|
||||
func (q *sqlQuerier) UpdateChatDebugRun(ctx context.Context, arg UpdateChatDebugRunParams) (ChatDebugRun, error) {
|
||||
row := q.db.QueryRowContext(ctx, updateChatDebugRun,
|
||||
arg.RootChatID,
|
||||
arg.ParentChatID,
|
||||
arg.ModelConfigID,
|
||||
arg.TriggerMessageID,
|
||||
arg.HistoryTipMessageID,
|
||||
arg.Status,
|
||||
arg.Provider,
|
||||
arg.Model,
|
||||
arg.Summary,
|
||||
arg.FinishedAt,
|
||||
arg.ID,
|
||||
arg.ChatID,
|
||||
)
|
||||
var i ChatDebugRun
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.ChatID,
|
||||
&i.RootChatID,
|
||||
&i.ParentChatID,
|
||||
&i.ModelConfigID,
|
||||
&i.TriggerMessageID,
|
||||
&i.HistoryTipMessageID,
|
||||
&i.Kind,
|
||||
&i.Status,
|
||||
&i.Provider,
|
||||
&i.Model,
|
||||
&i.Summary,
|
||||
&i.StartedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.FinishedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const updateChatDebugStep = `-- name: UpdateChatDebugStep :one
|
||||
UPDATE chat_debug_steps
|
||||
SET
|
||||
status = COALESCE($1::text, status),
|
||||
history_tip_message_id = COALESCE($2::bigint, history_tip_message_id),
|
||||
assistant_message_id = COALESCE($3::bigint, assistant_message_id),
|
||||
normalized_request = COALESCE($4::jsonb, normalized_request),
|
||||
normalized_response = COALESCE($5::jsonb, normalized_response),
|
||||
usage = COALESCE($6::jsonb, usage),
|
||||
attempts = COALESCE($7::jsonb, attempts),
|
||||
error = COALESCE($8::jsonb, error),
|
||||
metadata = COALESCE($9::jsonb, metadata),
|
||||
finished_at = COALESCE($10::timestamptz, finished_at),
|
||||
updated_at = NOW()
|
||||
WHERE id = $11::uuid
|
||||
AND chat_id = $12::uuid
|
||||
RETURNING id, run_id, chat_id, step_number, operation, status, history_tip_message_id, assistant_message_id, normalized_request, normalized_response, usage, attempts, error, metadata, started_at, updated_at, finished_at
|
||||
`
|
||||
|
||||
type UpdateChatDebugStepParams struct {
|
||||
Status sql.NullString `db:"status" json:"status"`
|
||||
HistoryTipMessageID sql.NullInt64 `db:"history_tip_message_id" json:"history_tip_message_id"`
|
||||
AssistantMessageID sql.NullInt64 `db:"assistant_message_id" json:"assistant_message_id"`
|
||||
NormalizedRequest pqtype.NullRawMessage `db:"normalized_request" json:"normalized_request"`
|
||||
NormalizedResponse pqtype.NullRawMessage `db:"normalized_response" json:"normalized_response"`
|
||||
Usage pqtype.NullRawMessage `db:"usage" json:"usage"`
|
||||
Attempts pqtype.NullRawMessage `db:"attempts" json:"attempts"`
|
||||
Error pqtype.NullRawMessage `db:"error" json:"error"`
|
||||
Metadata pqtype.NullRawMessage `db:"metadata" json:"metadata"`
|
||||
FinishedAt sql.NullTime `db:"finished_at" json:"finished_at"`
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
ChatID uuid.UUID `db:"chat_id" json:"chat_id"`
|
||||
}
|
||||
|
||||
// Uses COALESCE so that passing NULL from Go means "keep the
|
||||
// existing value." This is intentional: debug rows follow a
|
||||
// write-once-finalize pattern where fields are set at creation
|
||||
// or finalization and never cleared back to NULL.
|
||||
func (q *sqlQuerier) UpdateChatDebugStep(ctx context.Context, arg UpdateChatDebugStepParams) (ChatDebugStep, error) {
|
||||
row := q.db.QueryRowContext(ctx, updateChatDebugStep,
|
||||
arg.Status,
|
||||
arg.HistoryTipMessageID,
|
||||
arg.AssistantMessageID,
|
||||
arg.NormalizedRequest,
|
||||
arg.NormalizedResponse,
|
||||
arg.Usage,
|
||||
arg.Attempts,
|
||||
arg.Error,
|
||||
arg.Metadata,
|
||||
arg.FinishedAt,
|
||||
arg.ID,
|
||||
arg.ChatID,
|
||||
)
|
||||
var i ChatDebugStep
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.RunID,
|
||||
&i.ChatID,
|
||||
&i.StepNumber,
|
||||
&i.Operation,
|
||||
&i.Status,
|
||||
&i.HistoryTipMessageID,
|
||||
&i.AssistantMessageID,
|
||||
&i.NormalizedRequest,
|
||||
&i.NormalizedResponse,
|
||||
&i.Usage,
|
||||
&i.Attempts,
|
||||
&i.Error,
|
||||
&i.Metadata,
|
||||
&i.StartedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.FinishedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const deleteOldChatFiles = `-- name: DeleteOldChatFiles :execrows
|
||||
WITH kept_file_ids AS (
|
||||
-- NOTE: This uses updated_at as a proxy for archive time
|
||||
@@ -19722,21 +19145,6 @@ func (q *sqlQuerier) GetApplicationName(ctx context.Context) (string, error) {
|
||||
return value, err
|
||||
}
|
||||
|
||||
const getChatDebugLoggingAllowUsers = `-- name: GetChatDebugLoggingAllowUsers :one
|
||||
SELECT
|
||||
COALESCE((SELECT value = 'true' FROM site_configs WHERE key = 'agents_chat_debug_logging_allow_users'), false) :: boolean AS allow_users
|
||||
`
|
||||
|
||||
// GetChatDebugLoggingAllowUsers returns the runtime admin setting that
|
||||
// allows users to opt into chat debug logging when the deployment does
|
||||
// not already force debug logging on globally.
|
||||
func (q *sqlQuerier) GetChatDebugLoggingAllowUsers(ctx context.Context) (bool, error) {
|
||||
row := q.db.QueryRowContext(ctx, getChatDebugLoggingAllowUsers)
|
||||
var allow_users bool
|
||||
err := row.Scan(&allow_users)
|
||||
return allow_users, err
|
||||
}
|
||||
|
||||
const getChatDesktopEnabled = `-- name: GetChatDesktopEnabled :one
|
||||
SELECT
|
||||
COALESCE((SELECT value = 'true' FROM site_configs WHERE key = 'agents_desktop_enabled'), false) :: boolean AS enable_desktop
|
||||
@@ -20048,30 +19456,6 @@ func (q *sqlQuerier) UpsertApplicationName(ctx context.Context, value string) er
|
||||
return err
|
||||
}
|
||||
|
||||
const upsertChatDebugLoggingAllowUsers = `-- name: UpsertChatDebugLoggingAllowUsers :exec
|
||||
INSERT INTO site_configs (key, value)
|
||||
VALUES (
|
||||
'agents_chat_debug_logging_allow_users',
|
||||
CASE
|
||||
WHEN $1::bool THEN 'true'
|
||||
ELSE 'false'
|
||||
END
|
||||
)
|
||||
ON CONFLICT (key) DO UPDATE
|
||||
SET value = CASE
|
||||
WHEN $1::bool THEN 'true'
|
||||
ELSE 'false'
|
||||
END
|
||||
WHERE site_configs.key = 'agents_chat_debug_logging_allow_users'
|
||||
`
|
||||
|
||||
// UpsertChatDebugLoggingAllowUsers updates the runtime admin setting that
|
||||
// allows users to opt into chat debug logging.
|
||||
func (q *sqlQuerier) UpsertChatDebugLoggingAllowUsers(ctx context.Context, allowUsers bool) error {
|
||||
_, err := q.db.ExecContext(ctx, upsertChatDebugLoggingAllowUsers, allowUsers)
|
||||
return err
|
||||
}
|
||||
|
||||
const upsertChatDesktopEnabled = `-- name: UpsertChatDesktopEnabled :exec
|
||||
INSERT INTO site_configs (key, value)
|
||||
VALUES (
|
||||
@@ -24326,21 +23710,6 @@ func (q *sqlQuerier) GetUserChatCustomPrompt(ctx context.Context, userID uuid.UU
|
||||
return chat_custom_prompt, err
|
||||
}
|
||||
|
||||
const getUserChatDebugLoggingEnabled = `-- name: GetUserChatDebugLoggingEnabled :one
|
||||
SELECT
|
||||
value = 'true' AS debug_logging_enabled
|
||||
FROM user_configs
|
||||
WHERE user_id = $1
|
||||
AND key = 'chat_debug_logging_enabled'
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) GetUserChatDebugLoggingEnabled(ctx context.Context, userID uuid.UUID) (bool, error) {
|
||||
row := q.db.QueryRowContext(ctx, getUserChatDebugLoggingEnabled, userID)
|
||||
var debug_logging_enabled bool
|
||||
err := row.Scan(&debug_logging_enabled)
|
||||
return debug_logging_enabled, err
|
||||
}
|
||||
|
||||
const getUserCount = `-- name: GetUserCount :one
|
||||
SELECT
|
||||
COUNT(*)
|
||||
@@ -25335,35 +24704,6 @@ func (q *sqlQuerier) UpdateUserThemePreference(ctx context.Context, arg UpdateUs
|
||||
return i, err
|
||||
}
|
||||
|
||||
const upsertUserChatDebugLoggingEnabled = `-- name: UpsertUserChatDebugLoggingEnabled :exec
|
||||
INSERT INTO user_configs (user_id, key, value)
|
||||
VALUES (
|
||||
$1,
|
||||
'chat_debug_logging_enabled',
|
||||
CASE
|
||||
WHEN $2::bool THEN 'true'
|
||||
ELSE 'false'
|
||||
END
|
||||
)
|
||||
ON CONFLICT ON CONSTRAINT user_configs_pkey
|
||||
DO UPDATE SET value = CASE
|
||||
WHEN $2::bool THEN 'true'
|
||||
ELSE 'false'
|
||||
END
|
||||
WHERE user_configs.user_id = $1
|
||||
AND user_configs.key = 'chat_debug_logging_enabled'
|
||||
`
|
||||
|
||||
type UpsertUserChatDebugLoggingEnabledParams struct {
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
DebugLoggingEnabled bool `db:"debug_logging_enabled" json:"debug_logging_enabled"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) UpsertUserChatDebugLoggingEnabled(ctx context.Context, arg UpsertUserChatDebugLoggingEnabledParams) error {
|
||||
_, err := q.db.ExecContext(ctx, upsertUserChatDebugLoggingEnabled, arg.UserID, arg.DebugLoggingEnabled)
|
||||
return err
|
||||
}
|
||||
|
||||
const validateUserIDs = `-- name: ValidateUserIDs :one
|
||||
WITH input AS (
|
||||
SELECT
|
||||
|
||||
@@ -1,205 +0,0 @@
|
||||
-- name: InsertChatDebugRun :one
|
||||
INSERT INTO chat_debug_runs (
|
||||
chat_id,
|
||||
root_chat_id,
|
||||
parent_chat_id,
|
||||
model_config_id,
|
||||
trigger_message_id,
|
||||
history_tip_message_id,
|
||||
kind,
|
||||
status,
|
||||
provider,
|
||||
model,
|
||||
summary,
|
||||
started_at,
|
||||
updated_at,
|
||||
finished_at
|
||||
)
|
||||
VALUES (
|
||||
@chat_id::uuid,
|
||||
sqlc.narg('root_chat_id')::uuid,
|
||||
sqlc.narg('parent_chat_id')::uuid,
|
||||
sqlc.narg('model_config_id')::uuid,
|
||||
sqlc.narg('trigger_message_id')::bigint,
|
||||
sqlc.narg('history_tip_message_id')::bigint,
|
||||
@kind::text,
|
||||
@status::text,
|
||||
sqlc.narg('provider')::text,
|
||||
sqlc.narg('model')::text,
|
||||
COALESCE(sqlc.narg('summary')::jsonb, '{}'::jsonb),
|
||||
COALESCE(sqlc.narg('started_at')::timestamptz, NOW()),
|
||||
COALESCE(sqlc.narg('updated_at')::timestamptz, NOW()),
|
||||
sqlc.narg('finished_at')::timestamptz
|
||||
)
|
||||
RETURNING *;
|
||||
|
||||
-- name: UpdateChatDebugRun :one
|
||||
-- Uses COALESCE so that passing NULL from Go means "keep the
|
||||
-- existing value." This is intentional: debug rows follow a
|
||||
-- write-once-finalize pattern where fields are set at creation
|
||||
-- or finalization and never cleared back to NULL.
|
||||
UPDATE chat_debug_runs
|
||||
SET
|
||||
root_chat_id = COALESCE(sqlc.narg('root_chat_id')::uuid, root_chat_id),
|
||||
parent_chat_id = COALESCE(sqlc.narg('parent_chat_id')::uuid, parent_chat_id),
|
||||
model_config_id = COALESCE(sqlc.narg('model_config_id')::uuid, model_config_id),
|
||||
trigger_message_id = COALESCE(sqlc.narg('trigger_message_id')::bigint, trigger_message_id),
|
||||
history_tip_message_id = COALESCE(sqlc.narg('history_tip_message_id')::bigint, history_tip_message_id),
|
||||
status = COALESCE(sqlc.narg('status')::text, status),
|
||||
provider = COALESCE(sqlc.narg('provider')::text, provider),
|
||||
model = COALESCE(sqlc.narg('model')::text, model),
|
||||
summary = COALESCE(sqlc.narg('summary')::jsonb, summary),
|
||||
finished_at = COALESCE(sqlc.narg('finished_at')::timestamptz, finished_at),
|
||||
updated_at = NOW()
|
||||
WHERE id = @id::uuid
|
||||
AND chat_id = @chat_id::uuid
|
||||
RETURNING *;
|
||||
|
||||
-- name: InsertChatDebugStep :one
|
||||
INSERT INTO chat_debug_steps (
|
||||
run_id,
|
||||
chat_id,
|
||||
step_number,
|
||||
operation,
|
||||
status,
|
||||
history_tip_message_id,
|
||||
assistant_message_id,
|
||||
normalized_request,
|
||||
normalized_response,
|
||||
usage,
|
||||
attempts,
|
||||
error,
|
||||
metadata,
|
||||
started_at,
|
||||
updated_at,
|
||||
finished_at
|
||||
)
|
||||
SELECT
|
||||
@run_id::uuid,
|
||||
run.chat_id,
|
||||
@step_number::int,
|
||||
@operation::text,
|
||||
@status::text,
|
||||
sqlc.narg('history_tip_message_id')::bigint,
|
||||
sqlc.narg('assistant_message_id')::bigint,
|
||||
COALESCE(sqlc.narg('normalized_request')::jsonb, '{}'::jsonb),
|
||||
sqlc.narg('normalized_response')::jsonb,
|
||||
sqlc.narg('usage')::jsonb,
|
||||
COALESCE(sqlc.narg('attempts')::jsonb, '[]'::jsonb),
|
||||
sqlc.narg('error')::jsonb,
|
||||
COALESCE(sqlc.narg('metadata')::jsonb, '{}'::jsonb),
|
||||
COALESCE(sqlc.narg('started_at')::timestamptz, NOW()),
|
||||
COALESCE(sqlc.narg('updated_at')::timestamptz, NOW()),
|
||||
sqlc.narg('finished_at')::timestamptz
|
||||
FROM chat_debug_runs run
|
||||
WHERE run.id = @run_id::uuid
|
||||
AND run.chat_id = @chat_id::uuid
|
||||
RETURNING *;
|
||||
|
||||
-- name: UpdateChatDebugStep :one
|
||||
-- Uses COALESCE so that passing NULL from Go means "keep the
|
||||
-- existing value." This is intentional: debug rows follow a
|
||||
-- write-once-finalize pattern where fields are set at creation
|
||||
-- or finalization and never cleared back to NULL.
|
||||
UPDATE chat_debug_steps
|
||||
SET
|
||||
status = COALESCE(sqlc.narg('status')::text, status),
|
||||
history_tip_message_id = COALESCE(sqlc.narg('history_tip_message_id')::bigint, history_tip_message_id),
|
||||
assistant_message_id = COALESCE(sqlc.narg('assistant_message_id')::bigint, assistant_message_id),
|
||||
normalized_request = COALESCE(sqlc.narg('normalized_request')::jsonb, normalized_request),
|
||||
normalized_response = COALESCE(sqlc.narg('normalized_response')::jsonb, normalized_response),
|
||||
usage = COALESCE(sqlc.narg('usage')::jsonb, usage),
|
||||
attempts = COALESCE(sqlc.narg('attempts')::jsonb, attempts),
|
||||
error = COALESCE(sqlc.narg('error')::jsonb, error),
|
||||
metadata = COALESCE(sqlc.narg('metadata')::jsonb, metadata),
|
||||
finished_at = COALESCE(sqlc.narg('finished_at')::timestamptz, finished_at),
|
||||
updated_at = NOW()
|
||||
WHERE id = @id::uuid
|
||||
AND chat_id = @chat_id::uuid
|
||||
RETURNING *;
|
||||
|
||||
-- name: GetChatDebugRunsByChatID :many
|
||||
-- Returns the most recent debug runs for a chat, ordered newest-first.
|
||||
-- Callers must supply an explicit limit to avoid unbounded result sets.
|
||||
SELECT *
|
||||
FROM chat_debug_runs
|
||||
WHERE chat_id = @chat_id::uuid
|
||||
ORDER BY started_at DESC, id DESC
|
||||
LIMIT @limit_val::int;
|
||||
|
||||
-- name: GetChatDebugRunByID :one
|
||||
SELECT *
|
||||
FROM chat_debug_runs
|
||||
WHERE id = @id::uuid;
|
||||
|
||||
-- name: GetChatDebugStepsByRunID :many
|
||||
SELECT *
|
||||
FROM chat_debug_steps
|
||||
WHERE run_id = @run_id::uuid
|
||||
ORDER BY step_number ASC, started_at ASC;
|
||||
|
||||
-- name: DeleteChatDebugDataByChatID :execrows
|
||||
DELETE FROM chat_debug_runs
|
||||
WHERE chat_id = @chat_id::uuid;
|
||||
|
||||
-- name: DeleteChatDebugDataAfterMessageID :execrows
|
||||
WITH affected_runs AS (
|
||||
SELECT DISTINCT run.id
|
||||
FROM chat_debug_runs run
|
||||
WHERE run.chat_id = @chat_id::uuid
|
||||
AND (
|
||||
run.history_tip_message_id > @message_id::bigint
|
||||
OR run.trigger_message_id > @message_id::bigint
|
||||
)
|
||||
|
||||
UNION
|
||||
|
||||
SELECT DISTINCT step.run_id AS id
|
||||
FROM chat_debug_steps step
|
||||
WHERE step.chat_id = @chat_id::uuid
|
||||
AND (
|
||||
step.assistant_message_id > @message_id::bigint
|
||||
OR step.history_tip_message_id > @message_id::bigint
|
||||
)
|
||||
)
|
||||
DELETE FROM chat_debug_runs
|
||||
WHERE chat_id = @chat_id::uuid
|
||||
AND id IN (SELECT id FROM affected_runs);
|
||||
|
||||
-- name: FinalizeStaleChatDebugRows :one
|
||||
-- Marks orphaned in-progress rows as interrupted so they do not stay
|
||||
-- in a non-terminal state forever. The NOT IN list must match the
|
||||
-- terminal statuses defined by ChatDebugStatus in codersdk/chats.go.
|
||||
--
|
||||
-- The steps CTE also catches steps whose parent run was just finalized
|
||||
-- (via run_id IN), because PostgreSQL data-modifying CTEs share the
|
||||
-- same snapshot and cannot see each other's row updates. Without this,
|
||||
-- a step with a recent updated_at would survive its run's finalization
|
||||
-- and remain in 'in_progress' state permanently.
|
||||
WITH finalized_runs AS (
|
||||
UPDATE chat_debug_runs
|
||||
SET
|
||||
status = 'interrupted',
|
||||
updated_at = NOW(),
|
||||
finished_at = NOW()
|
||||
WHERE updated_at < @updated_before::timestamptz
|
||||
AND finished_at IS NULL
|
||||
AND status NOT IN ('completed', 'error', 'interrupted')
|
||||
RETURNING id
|
||||
), finalized_steps AS (
|
||||
UPDATE chat_debug_steps
|
||||
SET
|
||||
status = 'interrupted',
|
||||
updated_at = NOW(),
|
||||
finished_at = NOW()
|
||||
WHERE (
|
||||
updated_at < @updated_before::timestamptz
|
||||
OR run_id IN (SELECT id FROM finalized_runs)
|
||||
)
|
||||
AND finished_at IS NULL
|
||||
AND status NOT IN ('completed', 'error', 'interrupted')
|
||||
RETURNING 1
|
||||
)
|
||||
SELECT
|
||||
(SELECT COUNT(*) FROM finalized_runs)::bigint AS runs_finalized,
|
||||
(SELECT COUNT(*) FROM finalized_steps)::bigint AS steps_finalized;
|
||||
@@ -179,31 +179,6 @@ SET value = CASE
|
||||
END
|
||||
WHERE site_configs.key = 'agents_desktop_enabled';
|
||||
|
||||
-- GetChatDebugLoggingAllowUsers returns the runtime admin setting that
|
||||
-- allows users to opt into chat debug logging when the deployment does
|
||||
-- not already force debug logging on globally.
|
||||
-- name: GetChatDebugLoggingAllowUsers :one
|
||||
SELECT
|
||||
COALESCE((SELECT value = 'true' FROM site_configs WHERE key = 'agents_chat_debug_logging_allow_users'), false) :: boolean AS allow_users;
|
||||
|
||||
-- UpsertChatDebugLoggingAllowUsers updates the runtime admin setting that
|
||||
-- allows users to opt into chat debug logging.
|
||||
-- name: UpsertChatDebugLoggingAllowUsers :exec
|
||||
INSERT INTO site_configs (key, value)
|
||||
VALUES (
|
||||
'agents_chat_debug_logging_allow_users',
|
||||
CASE
|
||||
WHEN sqlc.arg(allow_users)::bool THEN 'true'
|
||||
ELSE 'false'
|
||||
END
|
||||
)
|
||||
ON CONFLICT (key) DO UPDATE
|
||||
SET value = CASE
|
||||
WHEN sqlc.arg(allow_users)::bool THEN 'true'
|
||||
ELSE 'false'
|
||||
END
|
||||
WHERE site_configs.key = 'agents_chat_debug_logging_allow_users';
|
||||
|
||||
-- GetChatTemplateAllowlist returns the JSON-encoded template allowlist.
|
||||
-- Returns an empty string when no allowlist has been configured (all templates allowed).
|
||||
-- name: GetChatTemplateAllowlist :one
|
||||
|
||||
@@ -213,31 +213,6 @@ RETURNING *;
|
||||
-- name: DeleteUserChatCompactionThreshold :exec
|
||||
DELETE FROM user_configs WHERE user_id = @user_id AND key = @key;
|
||||
|
||||
-- name: GetUserChatDebugLoggingEnabled :one
|
||||
SELECT
|
||||
value = 'true' AS debug_logging_enabled
|
||||
FROM user_configs
|
||||
WHERE user_id = @user_id
|
||||
AND key = 'chat_debug_logging_enabled';
|
||||
|
||||
-- name: UpsertUserChatDebugLoggingEnabled :exec
|
||||
INSERT INTO user_configs (user_id, key, value)
|
||||
VALUES (
|
||||
@user_id,
|
||||
'chat_debug_logging_enabled',
|
||||
CASE
|
||||
WHEN sqlc.arg(debug_logging_enabled)::bool THEN 'true'
|
||||
ELSE 'false'
|
||||
END
|
||||
)
|
||||
ON CONFLICT ON CONSTRAINT user_configs_pkey
|
||||
DO UPDATE SET value = CASE
|
||||
WHEN sqlc.arg(debug_logging_enabled)::bool THEN 'true'
|
||||
ELSE 'false'
|
||||
END
|
||||
WHERE user_configs.user_id = @user_id
|
||||
AND user_configs.key = 'chat_debug_logging_enabled';
|
||||
|
||||
-- name: GetUserTaskNotificationAlertDismissed :one
|
||||
SELECT
|
||||
value::boolean as task_notification_alert_dismissed
|
||||
|
||||
@@ -15,8 +15,6 @@ const (
|
||||
UniqueAPIKeysPkey UniqueConstraint = "api_keys_pkey" // ALTER TABLE ONLY api_keys ADD CONSTRAINT api_keys_pkey PRIMARY KEY (id);
|
||||
UniqueAuditLogsPkey UniqueConstraint = "audit_logs_pkey" // ALTER TABLE ONLY audit_logs ADD CONSTRAINT audit_logs_pkey PRIMARY KEY (id);
|
||||
UniqueBoundaryUsageStatsPkey UniqueConstraint = "boundary_usage_stats_pkey" // ALTER TABLE ONLY boundary_usage_stats ADD CONSTRAINT boundary_usage_stats_pkey PRIMARY KEY (replica_id);
|
||||
UniqueChatDebugRunsPkey UniqueConstraint = "chat_debug_runs_pkey" // ALTER TABLE ONLY chat_debug_runs ADD CONSTRAINT chat_debug_runs_pkey PRIMARY KEY (id);
|
||||
UniqueChatDebugStepsPkey UniqueConstraint = "chat_debug_steps_pkey" // ALTER TABLE ONLY chat_debug_steps ADD CONSTRAINT chat_debug_steps_pkey PRIMARY KEY (id);
|
||||
UniqueChatDiffStatusesPkey UniqueConstraint = "chat_diff_statuses_pkey" // ALTER TABLE ONLY chat_diff_statuses ADD CONSTRAINT chat_diff_statuses_pkey PRIMARY KEY (chat_id);
|
||||
UniqueChatFileLinksChatIDFileIDKey UniqueConstraint = "chat_file_links_chat_id_file_id_key" // ALTER TABLE ONLY chat_file_links ADD CONSTRAINT chat_file_links_chat_id_file_id_key UNIQUE (chat_id, file_id);
|
||||
UniqueChatFilesPkey UniqueConstraint = "chat_files_pkey" // ALTER TABLE ONLY chat_files ADD CONSTRAINT chat_files_pkey PRIMARY KEY (id);
|
||||
@@ -130,8 +128,6 @@ const (
|
||||
UniqueWorkspaceResourcesPkey UniqueConstraint = "workspace_resources_pkey" // ALTER TABLE ONLY workspace_resources ADD CONSTRAINT workspace_resources_pkey PRIMARY KEY (id);
|
||||
UniqueWorkspacesPkey UniqueConstraint = "workspaces_pkey" // ALTER TABLE ONLY workspaces ADD CONSTRAINT workspaces_pkey PRIMARY KEY (id);
|
||||
UniqueIndexAPIKeyName UniqueConstraint = "idx_api_key_name" // CREATE UNIQUE INDEX idx_api_key_name ON api_keys USING btree (user_id, token_name) WHERE (login_type = 'token'::login_type);
|
||||
UniqueIndexChatDebugRunsIDChat UniqueConstraint = "idx_chat_debug_runs_id_chat" // CREATE UNIQUE INDEX idx_chat_debug_runs_id_chat ON chat_debug_runs USING btree (id, chat_id);
|
||||
UniqueIndexChatDebugStepsRunStep UniqueConstraint = "idx_chat_debug_steps_run_step" // CREATE UNIQUE INDEX idx_chat_debug_steps_run_step ON chat_debug_steps USING btree (run_id, step_number);
|
||||
UniqueIndexChatModelConfigsSingleDefault UniqueConstraint = "idx_chat_model_configs_single_default" // CREATE UNIQUE INDEX idx_chat_model_configs_single_default ON chat_model_configs USING btree ((1)) WHERE ((is_default = true) AND (deleted = false));
|
||||
UniqueIndexConnectionLogsConnectionIDWorkspaceIDAgentName UniqueConstraint = "idx_connection_logs_connection_id_workspace_id_agent_name" // CREATE UNIQUE INDEX idx_connection_logs_connection_id_workspace_id_agent_name ON connection_logs USING btree (connection_id, workspace_id, agent_name);
|
||||
UniqueIndexCustomRolesNameLowerOrganizationID UniqueConstraint = "idx_custom_roles_name_lower_organization_id" // CREATE UNIQUE INDEX idx_custom_roles_name_lower_organization_id ON custom_roles USING btree (lower(name), COALESCE(organization_id, '00000000-0000-0000-0000-000000000000'::uuid));
|
||||
|
||||
@@ -3141,140 +3141,6 @@ func (api *API) putChatDesktopEnabled(rw http.ResponseWriter, r *http.Request) {
|
||||
rw.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func (api *API) deploymentChatDebugLoggingEnabled() bool {
|
||||
return api.DeploymentValues != nil && api.DeploymentValues.AI.Chat.DebugLoggingEnabled.Value()
|
||||
}
|
||||
|
||||
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
||||
//
|
||||
//nolint:revive // get-return: revive assumes get* must be a getter, but this is an HTTP handler.
|
||||
func (api *API) getChatDebugLogging(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) {
|
||||
httpapi.ResourceNotFound(rw)
|
||||
return
|
||||
}
|
||||
|
||||
enabled, err := api.Database.GetChatDebugLoggingEnabled(ctx)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching chat debug logging setting.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusOK, codersdk.ChatDebugLoggingAdminSettings{
|
||||
DebugLoggingEnabled: err == nil && enabled,
|
||||
ForcedByDeployment: api.deploymentChatDebugLoggingEnabled(),
|
||||
})
|
||||
}
|
||||
|
||||
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
||||
func (api *API) putChatDebugLogging(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) {
|
||||
httpapi.Forbidden(rw)
|
||||
return
|
||||
}
|
||||
|
||||
var req codersdk.UpdateChatDebugLoggingAllowUsersRequest
|
||||
if !httpapi.Read(ctx, rw, r, &req) {
|
||||
return
|
||||
}
|
||||
if err := api.Database.UpsertChatDebugLoggingEnabled(ctx, req.AllowUsers); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error updating chat debug logging setting.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
rw.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
||||
//
|
||||
//nolint:revive // get-return: revive assumes get* must be a getter, but this is an HTTP handler.
|
||||
func (api *API) getUserChatDebugLogging(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
apiKey := httpmw.APIKey(r)
|
||||
|
||||
forcedByDeployment := api.deploymentChatDebugLoggingEnabled()
|
||||
allowUsers := false
|
||||
if !forcedByDeployment {
|
||||
enabled, err := api.Database.GetChatDebugLoggingEnabled(ctx)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching chat debug logging setting.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
allowUsers = err == nil && enabled
|
||||
}
|
||||
|
||||
debugEnabled := forcedByDeployment
|
||||
if allowUsers {
|
||||
enabled, err := api.Database.GetUserChatDebugLoggingEnabled(ctx, apiKey.UserID)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching user chat debug logging setting.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
debugEnabled = err == nil && enabled
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, codersdk.UserChatDebugLoggingSettings{
|
||||
DebugLoggingEnabled: debugEnabled,
|
||||
UserToggleAllowed: !forcedByDeployment && allowUsers,
|
||||
ForcedByDeployment: forcedByDeployment,
|
||||
})
|
||||
}
|
||||
|
||||
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
||||
func (api *API) putUserChatDebugLogging(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
apiKey := httpmw.APIKey(r)
|
||||
if api.deploymentChatDebugLoggingEnabled() {
|
||||
httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{
|
||||
Message: "Chat debug logging is already forced on by deployment configuration.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
allowUsers, err := api.Database.GetChatDebugLoggingEnabled(ctx)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching chat debug logging setting.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
if err != nil || !allowUsers {
|
||||
httpapi.Write(ctx, rw, http.StatusForbidden, codersdk.Response{
|
||||
Message: "An administrator has not enabled user-controlled chat debug logging.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var req codersdk.UpdateUserChatDebugLoggingRequest
|
||||
if !httpapi.Read(ctx, rw, r, &req) {
|
||||
return
|
||||
}
|
||||
if err := api.Database.UpsertUserChatDebugLoggingEnabled(ctx, database.UpsertUserChatDebugLoggingEnabledParams{
|
||||
UserID: apiKey.UserID,
|
||||
DebugLoggingEnabled: req.DebugLoggingEnabled,
|
||||
}); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error updating user chat debug logging setting.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
rw.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
||||
//
|
||||
//nolint:revive // get-return: revive assumes get* must be a getter, but this is an HTTP handler.
|
||||
@@ -6001,95 +5867,3 @@ func (api *API) postChatToolResults(rw http.ResponseWriter, r *http.Request) {
|
||||
|
||||
rw.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// getChatDebugRuns returns a list of debug run summaries for a chat.
|
||||
// EXPERIMENTAL
|
||||
//
|
||||
//nolint:revive // get-return: revive assumes get* must be a getter, but this is an HTTP handler.
|
||||
func (api *API) getChatDebugRuns(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
chat := httpmw.ChatParam(r)
|
||||
|
||||
const maxDebugRuns = 100
|
||||
runs, err := api.Database.GetChatDebugRunsByChatID(ctx, database.GetChatDebugRunsByChatIDParams{
|
||||
ChatID: chat.ID,
|
||||
LimitVal: maxDebugRuns,
|
||||
})
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching debug runs.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
summaries := make([]codersdk.ChatDebugRunSummary, 0, len(runs))
|
||||
for _, r := range runs {
|
||||
summaries = append(summaries, db2sdk.ChatDebugRunSummary(r))
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusOK, summaries)
|
||||
}
|
||||
|
||||
// getChatDebugRun returns a single debug run with its steps.
|
||||
// EXPERIMENTAL
|
||||
//
|
||||
//nolint:revive // get-return: revive assumes get* must be a getter, but this is an HTTP handler.
|
||||
func (api *API) getChatDebugRun(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
chat := httpmw.ChatParam(r)
|
||||
|
||||
runIDStr := chi.URLParam(r, "debugRun")
|
||||
runID, err := uuid.Parse(runIDStr)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid debug run ID.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
run, err := api.Database.GetChatDebugRunByID(ctx, runID)
|
||||
if err != nil {
|
||||
// Treat both not-found and authorization failures as 404 to
|
||||
// avoid leaking the existence of runs the caller cannot access.
|
||||
if errors.Is(err, sql.ErrNoRows) || dbauthz.IsNotAuthorizedError(err) {
|
||||
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
|
||||
Message: "Debug run not found.",
|
||||
})
|
||||
return
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching debug run.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Verify the run belongs to this chat.
|
||||
if run.ChatID != chat.ID {
|
||||
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
|
||||
Message: "Debug run not found.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
steps, err := api.Database.GetChatDebugStepsByRunID(ctx, run.ID)
|
||||
if err != nil {
|
||||
// The run may have been deleted or access may have changed
|
||||
// between the two queries. Treat not-found/authz errors as
|
||||
// 404 for consistency with the run lookup above.
|
||||
if errors.Is(err, sql.ErrNoRows) || dbauthz.IsNotAuthorizedError(err) {
|
||||
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
|
||||
Message: "Debug run not found.",
|
||||
})
|
||||
return
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching debug steps.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, db2sdk.ChatDebugRunDetail(run, steps))
|
||||
}
|
||||
|
||||
+10
-152
@@ -7747,148 +7747,6 @@ func TestChatDesktopEnabled(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestChatDebugLoggingSettings(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("DefaultDisabled", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
adminClient := newChatClient(t)
|
||||
firstUser := coderdtest.CreateFirstUser(t, adminClient.Client)
|
||||
memberClientRaw, _ := coderdtest.CreateAnotherUser(t, adminClient.Client, firstUser.OrganizationID)
|
||||
memberClient := codersdk.NewExperimentalClient(memberClientRaw)
|
||||
|
||||
adminResp, err := adminClient.GetChatDebugLogging(ctx)
|
||||
require.NoError(t, err)
|
||||
require.False(t, adminResp.AllowUsers)
|
||||
require.False(t, adminResp.ForcedByDeployment)
|
||||
|
||||
userResp, err := memberClient.GetUserChatDebugLogging(ctx)
|
||||
require.NoError(t, err)
|
||||
require.False(t, userResp.DebugLoggingEnabled)
|
||||
require.False(t, userResp.UserToggleAllowed)
|
||||
require.False(t, userResp.ForcedByDeployment)
|
||||
})
|
||||
|
||||
t.Run("AdminAllowsUsersToOptIn", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
adminClient := newChatClient(t)
|
||||
firstUser := coderdtest.CreateFirstUser(t, adminClient.Client)
|
||||
memberClientRaw, _ := coderdtest.CreateAnotherUser(t, adminClient.Client, firstUser.OrganizationID)
|
||||
memberClient := codersdk.NewExperimentalClient(memberClientRaw)
|
||||
|
||||
err := adminClient.UpdateChatDebugLogging(ctx, codersdk.UpdateChatDebugLoggingAllowUsersRequest{
|
||||
AllowUsers: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
userResp, err := memberClient.GetUserChatDebugLogging(ctx)
|
||||
require.NoError(t, err)
|
||||
require.False(t, userResp.DebugLoggingEnabled)
|
||||
require.True(t, userResp.UserToggleAllowed)
|
||||
require.False(t, userResp.ForcedByDeployment)
|
||||
|
||||
err = memberClient.UpdateUserChatDebugLogging(ctx, codersdk.UpdateUserChatDebugLoggingRequest{
|
||||
DebugLoggingEnabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
userResp, err = memberClient.GetUserChatDebugLogging(ctx)
|
||||
require.NoError(t, err)
|
||||
require.True(t, userResp.DebugLoggingEnabled)
|
||||
require.True(t, userResp.UserToggleAllowed)
|
||||
require.False(t, userResp.ForcedByDeployment)
|
||||
|
||||
err = adminClient.UpdateChatDebugLogging(ctx, codersdk.UpdateChatDebugLoggingAllowUsersRequest{
|
||||
AllowUsers: false,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
userResp, err = memberClient.GetUserChatDebugLogging(ctx)
|
||||
require.NoError(t, err)
|
||||
require.False(t, userResp.DebugLoggingEnabled)
|
||||
require.False(t, userResp.UserToggleAllowed)
|
||||
})
|
||||
|
||||
t.Run("UserWriteFailsWhenAdminDisabled", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
adminClient := newChatClient(t)
|
||||
firstUser := coderdtest.CreateFirstUser(t, adminClient.Client)
|
||||
memberClientRaw, _ := coderdtest.CreateAnotherUser(t, adminClient.Client, firstUser.OrganizationID)
|
||||
memberClient := codersdk.NewExperimentalClient(memberClientRaw)
|
||||
|
||||
err := memberClient.UpdateUserChatDebugLogging(ctx, codersdk.UpdateUserChatDebugLoggingRequest{
|
||||
DebugLoggingEnabled: true,
|
||||
})
|
||||
requireSDKError(t, err, http.StatusForbidden)
|
||||
})
|
||||
|
||||
t.Run("NonAdminCannotManageAdminSetting", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
adminClient := newChatClient(t)
|
||||
firstUser := coderdtest.CreateFirstUser(t, adminClient.Client)
|
||||
memberClientRaw, _ := coderdtest.CreateAnotherUser(t, adminClient.Client, firstUser.OrganizationID)
|
||||
memberClient := codersdk.NewExperimentalClient(memberClientRaw)
|
||||
|
||||
_, err := memberClient.GetChatDebugLogging(ctx)
|
||||
requireSDKError(t, err, http.StatusNotFound)
|
||||
|
||||
err = memberClient.UpdateChatDebugLogging(ctx, codersdk.UpdateChatDebugLoggingAllowUsersRequest{
|
||||
AllowUsers: true,
|
||||
})
|
||||
requireSDKError(t, err, http.StatusForbidden)
|
||||
})
|
||||
|
||||
t.Run("DeploymentForceEnablesDebugLogging", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
values := chatDeploymentValues(t)
|
||||
values.AI.Chat.DebugLoggingEnabled = serpent.Bool(true)
|
||||
adminClient := newChatClientWithDeploymentValues(t, values)
|
||||
firstUser := coderdtest.CreateFirstUser(t, adminClient.Client)
|
||||
memberClientRaw, _ := coderdtest.CreateAnotherUser(t, adminClient.Client, firstUser.OrganizationID)
|
||||
memberClient := codersdk.NewExperimentalClient(memberClientRaw)
|
||||
|
||||
adminResp, err := adminClient.GetChatDebugLogging(ctx)
|
||||
require.NoError(t, err)
|
||||
require.False(t, adminResp.AllowUsers)
|
||||
require.True(t, adminResp.ForcedByDeployment)
|
||||
|
||||
userResp, err := memberClient.GetUserChatDebugLogging(ctx)
|
||||
require.NoError(t, err)
|
||||
require.True(t, userResp.DebugLoggingEnabled)
|
||||
require.False(t, userResp.UserToggleAllowed)
|
||||
require.True(t, userResp.ForcedByDeployment)
|
||||
|
||||
err = memberClient.UpdateUserChatDebugLogging(ctx, codersdk.UpdateUserChatDebugLoggingRequest{
|
||||
DebugLoggingEnabled: false,
|
||||
})
|
||||
requireSDKError(t, err, http.StatusConflict)
|
||||
})
|
||||
|
||||
t.Run("UnauthenticatedUserReadFails", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
adminClient := newChatClient(t)
|
||||
coderdtest.CreateFirstUser(t, adminClient.Client)
|
||||
|
||||
anonClient := codersdk.NewExperimentalClient(codersdk.New(adminClient.URL))
|
||||
_, err := anonClient.GetUserChatDebugLogging(ctx)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
require.Equal(t, http.StatusUnauthorized, sdkErr.StatusCode())
|
||||
})
|
||||
}
|
||||
|
||||
func TestChatWorkspaceTTL(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
@@ -8034,7 +7892,7 @@ func TestChatRetentionDays(t *testing.T) {
|
||||
requireSDKError(t, err, http.StatusBadRequest)
|
||||
}
|
||||
|
||||
//nolint:tparallel // subtests share state via client, firstUser, modelConfig
|
||||
//nolint:tparallel,paralleltest // Subtests share a single coderdtest instance.
|
||||
func TestUserChatCompactionThresholds(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -8042,7 +7900,7 @@ func TestUserChatCompactionThresholds(t *testing.T) {
|
||||
firstUser := coderdtest.CreateFirstUser(t, client.Client)
|
||||
modelConfig := createChatModelConfig(t, client)
|
||||
|
||||
t.Run("EmptyByDefault", func(t *testing.T) { //nolint:paralleltest // subtests share parent state
|
||||
t.Run("EmptyByDefault", func(t *testing.T) {
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
thresholds, err := client.GetUserChatCompactionThresholds(ctx)
|
||||
@@ -8050,7 +7908,7 @@ func TestUserChatCompactionThresholds(t *testing.T) {
|
||||
require.Empty(t, thresholds.Thresholds)
|
||||
})
|
||||
|
||||
t.Run("PutAndGet", func(t *testing.T) { //nolint:paralleltest // subtests share parent state
|
||||
t.Run("PutAndGet", func(t *testing.T) {
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
override, err := client.UpdateUserChatCompactionThreshold(ctx, modelConfig.ID, codersdk.UpdateUserChatCompactionThresholdRequest{
|
||||
@@ -8067,7 +7925,7 @@ func TestUserChatCompactionThresholds(t *testing.T) {
|
||||
require.EqualValues(t, 75, thresholds.Thresholds[0].ThresholdPercent)
|
||||
})
|
||||
|
||||
t.Run("UpsertChangesValue", func(t *testing.T) { //nolint:paralleltest // subtests share parent state
|
||||
t.Run("UpsertChangesValue", func(t *testing.T) {
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
_, err := client.UpdateUserChatCompactionThreshold(ctx, modelConfig.ID, codersdk.UpdateUserChatCompactionThresholdRequest{
|
||||
@@ -8087,7 +7945,7 @@ func TestUserChatCompactionThresholds(t *testing.T) {
|
||||
require.EqualValues(t, 75, thresholds.Thresholds[0].ThresholdPercent)
|
||||
})
|
||||
|
||||
t.Run("BoundaryValues", func(t *testing.T) { //nolint:paralleltest // subtests share parent state
|
||||
t.Run("BoundaryValues", func(t *testing.T) {
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
override, err := client.UpdateUserChatCompactionThreshold(ctx, modelConfig.ID, codersdk.UpdateUserChatCompactionThresholdRequest{
|
||||
@@ -8113,7 +7971,7 @@ func TestUserChatCompactionThresholds(t *testing.T) {
|
||||
require.EqualValues(t, 100, thresholds.Thresholds[0].ThresholdPercent)
|
||||
})
|
||||
|
||||
t.Run("ValidationRejectsInvalid", func(t *testing.T) { //nolint:paralleltest // subtests share parent state
|
||||
t.Run("ValidationRejectsInvalid", func(t *testing.T) {
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
_, err := client.UpdateUserChatCompactionThreshold(ctx, modelConfig.ID, codersdk.UpdateUserChatCompactionThresholdRequest{
|
||||
@@ -8127,7 +7985,7 @@ func TestUserChatCompactionThresholds(t *testing.T) {
|
||||
requireSDKError(t, err, http.StatusBadRequest)
|
||||
})
|
||||
|
||||
t.Run("Delete", func(t *testing.T) { //nolint:paralleltest // subtests share parent state
|
||||
t.Run("Delete", func(t *testing.T) {
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
err := client.DeleteUserChatCompactionThreshold(ctx, modelConfig.ID)
|
||||
@@ -8138,14 +7996,14 @@ func TestUserChatCompactionThresholds(t *testing.T) {
|
||||
require.Empty(t, thresholds.Thresholds)
|
||||
})
|
||||
|
||||
t.Run("DeleteIdempotent", func(t *testing.T) { //nolint:paralleltest // subtests share parent state
|
||||
t.Run("DeleteIdempotent", func(t *testing.T) {
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
err := client.DeleteUserChatCompactionThreshold(ctx, modelConfig.ID)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("NonExistentModelConfig", func(t *testing.T) { //nolint:paralleltest // subtests share parent state
|
||||
t.Run("NonExistentModelConfig", func(t *testing.T) {
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
fakeID := uuid.New()
|
||||
@@ -8155,7 +8013,7 @@ func TestUserChatCompactionThresholds(t *testing.T) {
|
||||
requireSDKError(t, err, http.StatusNotFound)
|
||||
})
|
||||
|
||||
t.Run("IsolatedPerUser", func(t *testing.T) { //nolint:paralleltest // subtests share parent state
|
||||
t.Run("IsolatedPerUser", func(t *testing.T) {
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
memberClientRaw, _ := coderdtest.CreateAnotherUser(t, client.Client, firstUser.OrganizationID)
|
||||
|
||||
+55
-515
@@ -34,7 +34,6 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/webpush"
|
||||
"github.com/coder/coder/v2/coderd/workspacestats"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatcost"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatdebug"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chaterror"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatloop"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatprompt"
|
||||
@@ -130,7 +129,6 @@ type Server struct {
|
||||
pubsub pubsub.Pubsub
|
||||
webpushDispatcher webpush.Dispatcher
|
||||
providerAPIKeys chatprovider.ProviderAPIKeys
|
||||
debugSvc *chatdebug.Service
|
||||
configCache *chatConfigCache
|
||||
configCacheUnsubscribe func()
|
||||
|
||||
@@ -1212,10 +1210,7 @@ func (p *Server) EditMessage(
|
||||
return EditMessageResult{}, xerrors.Errorf("marshal message content: %w", err)
|
||||
}
|
||||
|
||||
var (
|
||||
result EditMessageResult
|
||||
editedMsg database.ChatMessage
|
||||
)
|
||||
var result EditMessageResult
|
||||
txErr := p.db.InTx(func(tx database.Store) error {
|
||||
lockedChat, err := tx.GetChatByIDForUpdate(ctx, opts.ChatID)
|
||||
if err != nil {
|
||||
@@ -1226,17 +1221,17 @@ func (p *Server) EditMessage(
|
||||
return limitErr
|
||||
}
|
||||
|
||||
editedMsg, err = tx.GetChatMessageByID(ctx, opts.EditedMessageID)
|
||||
existing, err := tx.GetChatMessageByID(ctx, opts.EditedMessageID)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return ErrEditedMessageNotFound
|
||||
}
|
||||
return xerrors.Errorf("get edited message: %w", err)
|
||||
}
|
||||
if editedMsg.ChatID != opts.ChatID {
|
||||
if existing.ChatID != opts.ChatID {
|
||||
return ErrEditedMessageNotFound
|
||||
}
|
||||
if editedMsg.Role != database.ChatMessageRoleUser {
|
||||
if existing.Role != database.ChatMessageRoleUser {
|
||||
return ErrEditedMessageNotUser
|
||||
}
|
||||
|
||||
@@ -1263,8 +1258,8 @@ func (p *Server) EditMessage(
|
||||
appendChatMessage(&msgParams, newChatMessage(
|
||||
database.ChatMessageRoleUser,
|
||||
content,
|
||||
editedMsg.Visibility,
|
||||
editedMsg.ModelConfigID.UUID,
|
||||
existing.Visibility,
|
||||
existing.ModelConfigID.UUID,
|
||||
chatprompt.CurrentContentVersion,
|
||||
).withCreatedBy(opts.CreatedBy))
|
||||
newMessages, err := insertChatMessageWithStore(ctx, tx, msgParams)
|
||||
@@ -1307,26 +1302,6 @@ func (p *Server) EditMessage(
|
||||
})
|
||||
p.publishStatus(opts.ChatID, result.Chat.Status, result.Chat.WorkerID)
|
||||
p.publishChatPubsubEvent(result.Chat, codersdk.ChatWatchEventKindStatusChange, nil)
|
||||
|
||||
// Best-effort debug row cleanup. We do not wait for the active
|
||||
// worker to stop because activeChats is process-local and would
|
||||
// not cover multi-replica deployments. Any rows that survive
|
||||
// this pass are caught by the periodic stale-finalization sweep.
|
||||
if p.debugSvc != nil {
|
||||
cleanupCtx, cleanupCancel := context.WithTimeout(context.WithoutCancel(ctx), 5*time.Second)
|
||||
defer cleanupCancel()
|
||||
if _, err := p.debugSvc.DeleteAfterMessageID(
|
||||
cleanupCtx,
|
||||
opts.ChatID,
|
||||
editedMsg.ID-1,
|
||||
); err != nil {
|
||||
p.logger.Warn(ctx, "failed to delete chat debug rows after edit",
|
||||
slog.F("chat_id", opts.ChatID),
|
||||
slog.F("edited_message_id", editedMsg.ID),
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
}
|
||||
p.signalWake()
|
||||
|
||||
return result, nil
|
||||
@@ -1341,67 +1316,46 @@ func (p *Server) ArchiveChat(ctx context.Context, chat database.Chat) error {
|
||||
return xerrors.New("chat_id is required")
|
||||
}
|
||||
|
||||
var (
|
||||
archivedChats []database.Chat
|
||||
interruptedChats []database.Chat
|
||||
)
|
||||
statusChat := chat
|
||||
interrupted := false
|
||||
var archivedChats []database.Chat
|
||||
if err := p.db.InTx(func(tx database.Store) error {
|
||||
if _, err := tx.GetChatByIDForUpdate(ctx, chat.ID); err != nil {
|
||||
lockedChat, err := tx.GetChatByIDForUpdate(ctx, chat.ID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("lock chat for archive: %w", err)
|
||||
}
|
||||
statusChat = lockedChat
|
||||
|
||||
var err error
|
||||
archivedChats, err = tx.ArchiveChatByID(ctx, chat.ID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("archive chat: %w", err)
|
||||
}
|
||||
|
||||
for i, archivedChat := range archivedChats {
|
||||
if archivedChat.Status != database.ChatStatusPending &&
|
||||
archivedChat.Status != database.ChatStatusRunning {
|
||||
continue
|
||||
}
|
||||
|
||||
updatedChat, updateErr := tx.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
||||
ID: archivedChat.ID,
|
||||
// We do not call setChatWaiting here because it intentionally preserves
|
||||
// pending chats so queued-message promotion can win. Archiving is a
|
||||
// harder stop: both pending and running chats must transition to waiting.
|
||||
if lockedChat.Status == database.ChatStatusPending || lockedChat.Status == database.ChatStatusRunning {
|
||||
statusChat, err = tx.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
||||
ID: chat.ID,
|
||||
Status: database.ChatStatusWaiting,
|
||||
WorkerID: uuid.NullUUID{},
|
||||
StartedAt: sql.NullTime{},
|
||||
HeartbeatAt: sql.NullTime{},
|
||||
LastError: sql.NullString{},
|
||||
})
|
||||
if updateErr != nil {
|
||||
return xerrors.Errorf("set archived chat waiting before cleanup: %w", updateErr)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("set chat waiting before archive: %w", err)
|
||||
}
|
||||
archivedChats[i] = updatedChat
|
||||
interruptedChats = append(interruptedChats, updatedChat)
|
||||
interrupted = true
|
||||
}
|
||||
|
||||
archivedChats, err = tx.ArchiveChatByID(ctx, chat.ID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("archive chat: %w", err)
|
||||
}
|
||||
return nil
|
||||
}, nil); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, interruptedChat := range interruptedChats {
|
||||
p.publishStatus(interruptedChat.ID, interruptedChat.Status, interruptedChat.WorkerID)
|
||||
p.publishChatPubsubEvent(interruptedChat, codersdk.ChatWatchEventKindStatusChange, nil)
|
||||
}
|
||||
|
||||
// Best-effort debug row cleanup — no process-local wait so this
|
||||
// works correctly across replicas. If an active goroutine writes
|
||||
// new debug rows after the delete, FinalizeStale will mark them
|
||||
// as interrupted. Those orphaned rows are harmless because the
|
||||
// chat itself is archived and no longer served through the API.
|
||||
if p.debugSvc != nil {
|
||||
for _, archivedChat := range archivedChats {
|
||||
cleanupCtx, cleanupCancel := context.WithTimeout(context.WithoutCancel(ctx), 5*time.Second)
|
||||
if _, err := p.debugSvc.DeleteByChatID(cleanupCtx, archivedChat.ID); err != nil {
|
||||
p.logger.Warn(ctx, "failed to delete chat debug rows after archive",
|
||||
slog.F("chat_id", archivedChat.ID),
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
cleanupCancel()
|
||||
}
|
||||
if interrupted {
|
||||
p.publishStatus(chat.ID, statusChat.Status, statusChat.WorkerID)
|
||||
p.publishChatPubsubEvent(statusChat, codersdk.ChatWatchEventKindStatusChange, nil)
|
||||
}
|
||||
|
||||
p.publishChatPubsubEvents(archivedChats, codersdk.ChatWatchEventKindDeleted)
|
||||
@@ -1864,8 +1818,6 @@ func (p *Server) InterruptChat(
|
||||
}
|
||||
}
|
||||
|
||||
// Debug runs are finalized in the execution path when the owning
|
||||
// goroutine observes cancellation, so we do not mutate debug state here.
|
||||
updatedChat, err := p.setChatWaiting(ctx, chat.ID)
|
||||
if err != nil {
|
||||
p.logger.Error(ctx, "failed to mark chat as waiting",
|
||||
@@ -2106,23 +2058,7 @@ func (p *Server) regenerateChatTitleWithStore(
|
||||
return database.Chat{}, err
|
||||
}
|
||||
|
||||
debugEnabled := p.debugSvc != nil && p.debugSvc.IsEnabled(ctx, chat.ID, chat.OwnerID)
|
||||
titleCtx := ctx
|
||||
titleModel := model
|
||||
finishDebugRun := func(error) {}
|
||||
if debugEnabled {
|
||||
titleCtx, titleModel, finishDebugRun = p.prepareManualTitleDebugRun(
|
||||
ctx,
|
||||
chat,
|
||||
modelConfig,
|
||||
keys,
|
||||
messages,
|
||||
model,
|
||||
)
|
||||
}
|
||||
|
||||
title, usage, err := generateManualTitle(titleCtx, messages, titleModel)
|
||||
finishDebugRun(err)
|
||||
title, usage, err := generateManualTitle(ctx, messages, model)
|
||||
if err != nil {
|
||||
wrappedErr := xerrors.Errorf("generate manual title: %w", err)
|
||||
if usage == (fantasy.Usage{}) {
|
||||
@@ -2160,177 +2096,6 @@ func (p *Server) regenerateChatTitleWithStore(
|
||||
return updatedChat, nil
|
||||
}
|
||||
|
||||
func (p *Server) prepareManualTitleDebugRun(
|
||||
ctx context.Context,
|
||||
chat database.Chat,
|
||||
modelConfig database.ChatModelConfig,
|
||||
keys chatprovider.ProviderAPIKeys,
|
||||
messages []database.ChatMessage,
|
||||
fallbackModel fantasy.LanguageModel,
|
||||
) (context.Context, fantasy.LanguageModel, func(error)) {
|
||||
titleCtx := ctx
|
||||
titleModel := fallbackModel
|
||||
finishDebugRun := func(error) {}
|
||||
|
||||
httpClient := &http.Client{Transport: &chatdebug.RecordingTransport{}}
|
||||
debugModel, debugModelErr := chatprovider.ModelFromConfig(
|
||||
modelConfig.Provider,
|
||||
modelConfig.Model,
|
||||
keys,
|
||||
chatprovider.UserAgent(),
|
||||
chatprovider.CoderHeaders(chat),
|
||||
httpClient,
|
||||
)
|
||||
switch {
|
||||
case debugModelErr != nil:
|
||||
p.logger.Warn(ctx, "failed to create debug-aware manual title model",
|
||||
slog.F("chat_id", chat.ID),
|
||||
slog.F("provider", modelConfig.Provider),
|
||||
slog.F("model", modelConfig.Model),
|
||||
slog.Error(debugModelErr),
|
||||
)
|
||||
case debugModel == nil:
|
||||
p.logger.Warn(ctx, "manual title debug model creation returned nil",
|
||||
slog.F("chat_id", chat.ID),
|
||||
slog.F("provider", modelConfig.Provider),
|
||||
slog.F("model", modelConfig.Model),
|
||||
)
|
||||
default:
|
||||
titleModel = chatdebug.WrapModel(debugModel, p.debugSvc, chatdebug.RecorderOptions{
|
||||
ChatID: chat.ID,
|
||||
OwnerID: chat.OwnerID,
|
||||
Provider: modelConfig.Provider,
|
||||
Model: modelConfig.Model,
|
||||
})
|
||||
}
|
||||
|
||||
var historyTipMessageID int64
|
||||
if len(messages) > 0 {
|
||||
historyTipMessageID = messages[len(messages)-1].ID
|
||||
}
|
||||
|
||||
// Derive a first_message label from the first user message.
|
||||
var firstUserLabel string
|
||||
for _, msg := range messages {
|
||||
if msg.Role == database.ChatMessageRoleUser {
|
||||
if parts, parseErr := chatprompt.ParseContent(msg); parseErr == nil {
|
||||
firstUserLabel = contentBlocksToText(parts)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
if firstUserLabel == "" {
|
||||
firstUserLabel = "Title generation"
|
||||
}
|
||||
seedSummary := chatdebug.SeedSummary(
|
||||
chatdebug.TruncateLabel(firstUserLabel, chatdebug.MaxLabelLength),
|
||||
)
|
||||
|
||||
createRunCtx, createRunCancel := context.WithTimeout(context.WithoutCancel(ctx), 5*time.Second)
|
||||
debugRun, createRunErr := p.debugSvc.CreateRun(createRunCtx, chatdebug.CreateRunParams{
|
||||
ChatID: chat.ID,
|
||||
ModelConfigID: modelConfig.ID,
|
||||
Provider: modelConfig.Provider,
|
||||
Model: modelConfig.Model,
|
||||
Kind: chatdebug.KindTitleGeneration,
|
||||
Status: chatdebug.StatusInProgress,
|
||||
HistoryTipMessageID: historyTipMessageID,
|
||||
TriggerMessageID: 0,
|
||||
Summary: seedSummary,
|
||||
})
|
||||
createRunCancel()
|
||||
if createRunErr != nil {
|
||||
p.logger.Warn(ctx, "failed to create manual title debug run",
|
||||
slog.F("chat_id", chat.ID),
|
||||
slog.F("provider", modelConfig.Provider),
|
||||
slog.F("model", modelConfig.Model),
|
||||
slog.Error(createRunErr),
|
||||
)
|
||||
return titleCtx, titleModel, finishDebugRun
|
||||
}
|
||||
|
||||
runContext := chatdebugRunContext(debugRun)
|
||||
titleCtx = chatdebug.ContextWithRun(titleCtx, &runContext)
|
||||
finishDebugRun = func(generateErr error) {
|
||||
status := chatdebug.StatusCompleted
|
||||
switch {
|
||||
case generateErr == nil:
|
||||
// keep completed
|
||||
case errors.Is(generateErr, context.Canceled):
|
||||
status = chatdebug.StatusInterrupted
|
||||
default:
|
||||
status = chatdebug.StatusError
|
||||
}
|
||||
|
||||
finalSummary := seedSummary
|
||||
aggCtx, aggCancel := context.WithTimeout(context.WithoutCancel(ctx), 5*time.Second)
|
||||
defer aggCancel()
|
||||
if aggregated, aggErr := p.debugSvc.AggregateRunSummary(
|
||||
aggCtx,
|
||||
debugRun.ID,
|
||||
seedSummary,
|
||||
); aggErr != nil {
|
||||
p.logger.Warn(ctx, "failed to aggregate debug run summary",
|
||||
slog.F("chat_id", chat.ID),
|
||||
slog.F("run_id", debugRun.ID),
|
||||
slog.Error(aggErr),
|
||||
)
|
||||
} else {
|
||||
finalSummary = aggregated
|
||||
}
|
||||
|
||||
updateRunCtx, updateRunCancel := context.WithTimeout(context.WithoutCancel(ctx), 5*time.Second)
|
||||
defer updateRunCancel()
|
||||
_, updateRunErr := p.debugSvc.UpdateRun(updateRunCtx, chatdebug.UpdateRunParams{
|
||||
ID: debugRun.ID,
|
||||
ChatID: debugRun.ChatID,
|
||||
Status: status,
|
||||
Summary: finalSummary,
|
||||
FinishedAt: time.Now(),
|
||||
})
|
||||
if updateRunErr != nil {
|
||||
p.logger.Warn(ctx, "failed to finalize manual title debug run",
|
||||
slog.F("chat_id", chat.ID),
|
||||
slog.F("run_id", debugRun.ID),
|
||||
slog.Error(updateRunErr),
|
||||
)
|
||||
}
|
||||
chatdebug.CleanupStepCounter(debugRun.ID)
|
||||
}
|
||||
|
||||
return titleCtx, titleModel, finishDebugRun
|
||||
}
|
||||
|
||||
func chatdebugRunContext(run database.ChatDebugRun) chatdebug.RunContext {
|
||||
runContext := chatdebug.RunContext{
|
||||
RunID: run.ID,
|
||||
ChatID: run.ChatID,
|
||||
Kind: chatdebug.RunKind(run.Kind),
|
||||
}
|
||||
if run.RootChatID.Valid {
|
||||
runContext.RootChatID = run.RootChatID.UUID
|
||||
}
|
||||
if run.ParentChatID.Valid {
|
||||
runContext.ParentChatID = run.ParentChatID.UUID
|
||||
}
|
||||
if run.ModelConfigID.Valid {
|
||||
runContext.ModelConfigID = run.ModelConfigID.UUID
|
||||
}
|
||||
if run.TriggerMessageID.Valid {
|
||||
runContext.TriggerMessageID = run.TriggerMessageID.Int64
|
||||
}
|
||||
if run.HistoryTipMessageID.Valid {
|
||||
runContext.HistoryTipMessageID = run.HistoryTipMessageID.Int64
|
||||
}
|
||||
if run.Provider.Valid {
|
||||
runContext.Provider = run.Provider.String
|
||||
}
|
||||
if run.Model.Valid {
|
||||
runContext.Model = run.Model.String
|
||||
}
|
||||
return runContext
|
||||
}
|
||||
|
||||
func (p *Server) resolveManualTitleModel(
|
||||
ctx context.Context,
|
||||
store database.Store,
|
||||
@@ -2357,7 +2122,6 @@ func (p *Server) resolveManualTitleModel(
|
||||
keys,
|
||||
chatprovider.UserAgent(),
|
||||
chatprovider.CoderHeaders(chat),
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
p.logger.Debug(ctx, "manual title preferred model unavailable",
|
||||
@@ -2390,7 +2154,6 @@ func (p *Server) resolveFallbackManualTitleModel(
|
||||
keys,
|
||||
chatprovider.UserAgent(),
|
||||
chatprovider.CoderHeaders(chat),
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, database.ChatModelConfig{}, xerrors.Errorf(
|
||||
@@ -2925,7 +2688,6 @@ type Config struct {
|
||||
StartWorkspace chattool.StartWorkspaceFn
|
||||
Pubsub pubsub.Pubsub
|
||||
ProviderAPIKeys chatprovider.ProviderAPIKeys
|
||||
AlwaysEnableDebugLogs bool
|
||||
WebpushDispatcher webpush.Dispatcher
|
||||
UsageTracker *workspacestats.UsageTracker
|
||||
Clock quartz.Clock
|
||||
@@ -2972,14 +2734,6 @@ func New(cfg Config) *Server {
|
||||
workerID = uuid.New()
|
||||
}
|
||||
|
||||
debugSvc := chatdebug.NewService(
|
||||
cfg.Database,
|
||||
cfg.Logger.Named("chatdebug"),
|
||||
cfg.Pubsub,
|
||||
chatdebug.WithAlwaysEnable(cfg.AlwaysEnableDebugLogs),
|
||||
)
|
||||
debugSvc.SetStaleAfter(inFlightChatStaleAfter)
|
||||
|
||||
p := &Server{
|
||||
cancel: cancel,
|
||||
closed: make(chan struct{}),
|
||||
@@ -2995,7 +2749,6 @@ func New(cfg Config) *Server {
|
||||
pubsub: cfg.Pubsub,
|
||||
webpushDispatcher: cfg.WebpushDispatcher,
|
||||
providerAPIKeys: cfg.ProviderAPIKeys,
|
||||
debugSvc: debugSvc,
|
||||
pendingChatAcquireInterval: pendingChatAcquireInterval,
|
||||
maxChatsPerAcquire: maxChatsPerAcquire,
|
||||
inFlightChatStaleAfter: inFlightChatStaleAfter,
|
||||
@@ -3044,12 +2797,6 @@ func (p *Server) start(ctx context.Context) {
|
||||
// Recover stale chats on startup and periodically thereafter
|
||||
// to handle chats orphaned by crashed or redeployed workers.
|
||||
p.recoverStaleChats(ctx)
|
||||
if p.debugSvc != nil {
|
||||
_, err := p.debugSvc.FinalizeStale(ctx)
|
||||
if err != nil {
|
||||
p.logger.Warn(ctx, "failed to finalize stale chat debug rows", slog.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
// Single heartbeat loop for all chats on this replica.
|
||||
go p.heartbeatLoop(ctx)
|
||||
@@ -3079,11 +2826,6 @@ func (p *Server) start(ctx context.Context) {
|
||||
p.processOnce(ctx)
|
||||
case <-staleTicker.C:
|
||||
p.recoverStaleChats(ctx)
|
||||
if p.debugSvc != nil {
|
||||
if _, err := p.debugSvc.FinalizeStale(ctx); err != nil {
|
||||
p.logger.Warn(ctx, "failed to finalize stale chat debug rows", slog.Error(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -4595,10 +4337,6 @@ type runChatResult struct {
|
||||
PushSummaryModel fantasy.LanguageModel
|
||||
ProviderKeys chatprovider.ProviderAPIKeys
|
||||
PendingDynamicToolCalls []chatloop.PendingToolCall
|
||||
FallbackProvider string
|
||||
FallbackModel string
|
||||
TriggerMessageID int64
|
||||
HistoryTipMessageID int64
|
||||
}
|
||||
|
||||
func (p *Server) runChat(
|
||||
@@ -4609,14 +4347,11 @@ func (p *Server) runChat(
|
||||
) (runChatResult, error) {
|
||||
result := runChatResult{}
|
||||
var (
|
||||
model fantasy.LanguageModel
|
||||
modelConfig database.ChatModelConfig
|
||||
providerKeys chatprovider.ProviderAPIKeys
|
||||
callConfig codersdk.ChatModelCallConfig
|
||||
messages []database.ChatMessage
|
||||
debugEnabled bool
|
||||
debugProvider string
|
||||
debugModel string
|
||||
model fantasy.LanguageModel
|
||||
modelConfig database.ChatModelConfig
|
||||
providerKeys chatprovider.ProviderAPIKeys
|
||||
callConfig codersdk.ChatModelCallConfig
|
||||
messages []database.ChatMessage
|
||||
)
|
||||
|
||||
// Load MCP server configs and user tokens in parallel with
|
||||
@@ -4629,7 +4364,7 @@ func (p *Server) runChat(
|
||||
var g errgroup.Group
|
||||
g.Go(func() error {
|
||||
var err error
|
||||
model, modelConfig, providerKeys, debugEnabled, debugProvider, debugModel, err = p.resolveChatModel(ctx, chat)
|
||||
model, modelConfig, providerKeys, err = p.resolveChatModel(ctx, chat)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -4687,31 +4422,23 @@ func (p *Server) runChat(
|
||||
chainInfo := resolveChainMode(messages)
|
||||
result.PushSummaryModel = model
|
||||
result.ProviderKeys = providerKeys
|
||||
result.FallbackProvider = modelConfig.Provider
|
||||
result.FallbackModel = modelConfig.Model
|
||||
// Fire title generation asynchronously so it doesn't block the
|
||||
// chat response. It uses a detached context so it can finish
|
||||
// even after the chat processing context is canceled.
|
||||
// Snapshot the original chat model so the goroutine doesn't
|
||||
// race with the model = cuModel reassignment below.
|
||||
// Snapshot ctx before the goroutine to avoid a data race with
|
||||
// the ctx = runCtx reassignment later in the main goroutine.
|
||||
titleModel := result.PushSummaryModel
|
||||
titleCtx := context.WithoutCancel(ctx)
|
||||
p.inflight.Add(1)
|
||||
go func() {
|
||||
defer p.inflight.Done()
|
||||
p.maybeGenerateChatTitle(
|
||||
titleCtx,
|
||||
context.WithoutCancel(ctx),
|
||||
chat,
|
||||
messages,
|
||||
modelConfig.Provider,
|
||||
modelConfig.Model,
|
||||
titleModel,
|
||||
providerKeys,
|
||||
generatedTitle,
|
||||
logger,
|
||||
p.debugSvc,
|
||||
)
|
||||
}()
|
||||
|
||||
@@ -4950,13 +4677,6 @@ func (p *Server) runChat(
|
||||
var finalAssistantText string
|
||||
var pendingDynamicCalls []chatloop.PendingToolCall
|
||||
|
||||
compactionHistoryTipMessageID := int64(0)
|
||||
if len(messages) > 0 {
|
||||
compactionHistoryTipMessageID = messages[len(messages)-1].ID
|
||||
}
|
||||
|
||||
var compactionOptions *chatloop.CompactionOptions
|
||||
|
||||
persistStep := func(persistCtx context.Context, step chatloop.PersistedStep) error {
|
||||
// If the chat context has been canceled, bail out before
|
||||
// inserting any messages. We distinguish the cause so that
|
||||
@@ -5169,12 +4889,6 @@ func (p *Server) runChat(
|
||||
for _, msg := range insertedMessages {
|
||||
p.publishMessage(chat.ID, msg)
|
||||
}
|
||||
if len(insertedMessages) > 0 {
|
||||
compactionHistoryTipMessageID = insertedMessages[len(insertedMessages)-1].ID
|
||||
if compactionOptions != nil {
|
||||
compactionOptions.HistoryTipMessageID = compactionHistoryTipMessageID
|
||||
}
|
||||
}
|
||||
|
||||
// Do NOT clear the stream buffer here. Cross-replica
|
||||
// relay subscribers may still need to snapshot buffered
|
||||
@@ -5204,10 +4918,9 @@ func (p *Server) runChat(
|
||||
effectiveThreshold = override
|
||||
thresholdSource = "user_override"
|
||||
}
|
||||
compactionOptions = &chatloop.CompactionOptions{
|
||||
ThresholdPercent: effectiveThreshold,
|
||||
ContextLimit: modelConfig.ContextLimit,
|
||||
HistoryTipMessageID: compactionHistoryTipMessageID,
|
||||
compactionOptions := &chatloop.CompactionOptions{
|
||||
ThresholdPercent: effectiveThreshold,
|
||||
ContextLimit: modelConfig.ContextLimit,
|
||||
Persist: func(
|
||||
persistCtx context.Context,
|
||||
result chatloop.CompactionResult,
|
||||
@@ -5243,16 +4956,7 @@ func (p *Server) runChat(
|
||||
|
||||
if isComputerUse {
|
||||
// Override model for computer use subagent.
|
||||
resolvedProvider, resolvedModel, resolveErr := chatprovider.ResolveModelWithProviderHint(
|
||||
chattool.ComputerUseModelName,
|
||||
chattool.ComputerUseModelProvider,
|
||||
)
|
||||
if resolveErr != nil {
|
||||
return result, xerrors.Errorf("resolve computer use model metadata: %w", resolveErr)
|
||||
}
|
||||
cuModel, cuDebugEnabled, cuErr := p.newDebugAwareModelFromConfig(
|
||||
ctx,
|
||||
chat,
|
||||
cuModel, cuErr := chatprovider.ModelFromConfig(
|
||||
chattool.ComputerUseModelProvider,
|
||||
chattool.ComputerUseModelName,
|
||||
providerKeys,
|
||||
@@ -5263,13 +4967,6 @@ func (p *Server) runChat(
|
||||
return result, xerrors.Errorf("resolve computer use model: %w", cuErr)
|
||||
}
|
||||
model = cuModel
|
||||
debugEnabled = cuDebugEnabled
|
||||
debugProvider = resolvedProvider
|
||||
debugModel = resolvedModel
|
||||
}
|
||||
if debugEnabled {
|
||||
compactionOptions.DebugSvc = p.debugSvc
|
||||
compactionOptions.ChatID = chat.ID
|
||||
}
|
||||
|
||||
tools := []fantasy.AgentTool{
|
||||
@@ -5486,132 +5183,7 @@ func (p *Server) runChat(
|
||||
)
|
||||
prompt = filterPromptForChainMode(prompt, chainInfo)
|
||||
}
|
||||
|
||||
var loopErr error
|
||||
triggerMessageID := int64(0)
|
||||
var triggerLabel string
|
||||
for i := len(messages) - 1; i >= 0; i-- {
|
||||
if messages[i].Role == database.ChatMessageRoleUser {
|
||||
triggerMessageID = messages[i].ID
|
||||
if parts, parseErr := chatprompt.ParseContent(messages[i]); parseErr == nil {
|
||||
triggerLabel = contentBlocksToText(parts)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
historyTipMessageID := int64(0)
|
||||
if len(messages) > 0 {
|
||||
historyTipMessageID = messages[len(messages)-1].ID
|
||||
}
|
||||
result.TriggerMessageID = triggerMessageID
|
||||
result.HistoryTipMessageID = historyTipMessageID
|
||||
if debugEnabled {
|
||||
seedSummary := chatdebug.SeedSummary(
|
||||
chatdebug.TruncateLabel(triggerLabel, chatdebug.MaxLabelLength),
|
||||
)
|
||||
rootChatID := uuid.Nil
|
||||
if chat.RootChatID.Valid {
|
||||
rootChatID = chat.RootChatID.UUID
|
||||
}
|
||||
parentChatID := uuid.Nil
|
||||
if chat.ParentChatID.Valid {
|
||||
parentChatID = chat.ParentChatID.UUID
|
||||
}
|
||||
run, createRunErr := p.debugSvc.CreateRun(ctx, chatdebug.CreateRunParams{
|
||||
ChatID: chat.ID,
|
||||
RootChatID: rootChatID,
|
||||
ParentChatID: parentChatID,
|
||||
ModelConfigID: modelConfig.ID,
|
||||
TriggerMessageID: triggerMessageID,
|
||||
HistoryTipMessageID: historyTipMessageID,
|
||||
Kind: chatdebug.KindChatTurn,
|
||||
Status: chatdebug.StatusInProgress,
|
||||
Provider: debugProvider,
|
||||
Model: debugModel,
|
||||
Summary: seedSummary,
|
||||
})
|
||||
if createRunErr != nil {
|
||||
logger.Warn(ctx, "failed to create chat debug run",
|
||||
slog.F("chat_id", chat.ID),
|
||||
slog.Error(createRunErr),
|
||||
)
|
||||
} else {
|
||||
runCtx := chatdebug.ContextWithRun(ctx, &chatdebug.RunContext{
|
||||
RunID: run.ID,
|
||||
ChatID: chat.ID,
|
||||
RootChatID: rootChatID,
|
||||
ParentChatID: parentChatID,
|
||||
ModelConfigID: modelConfig.ID,
|
||||
TriggerMessageID: triggerMessageID,
|
||||
HistoryTipMessageID: historyTipMessageID,
|
||||
Kind: chatdebug.KindChatTurn,
|
||||
Provider: debugProvider,
|
||||
Model: debugModel,
|
||||
})
|
||||
defer func() {
|
||||
panicValue := recover()
|
||||
var status chatdebug.Status
|
||||
switch {
|
||||
case panicValue != nil:
|
||||
status = chatdebug.StatusError
|
||||
case loopErr == nil:
|
||||
status = chatdebug.StatusCompleted
|
||||
case errors.Is(loopErr, chatloop.ErrInterrupted),
|
||||
errors.Is(loopErr, context.Canceled):
|
||||
status = chatdebug.StatusInterrupted
|
||||
case errors.Is(loopErr, chatloop.ErrDynamicToolCall):
|
||||
// Dynamic tool calls are a successful pause;
|
||||
// the run completed its model round-trip.
|
||||
status = chatdebug.StatusCompleted
|
||||
default:
|
||||
status = chatdebug.StatusError
|
||||
}
|
||||
|
||||
finalSummary := seedSummary
|
||||
aggCtx, aggCancel := context.WithTimeout(context.WithoutCancel(runCtx), 5*time.Second)
|
||||
defer aggCancel()
|
||||
if aggregated, aggErr := p.debugSvc.AggregateRunSummary(
|
||||
aggCtx,
|
||||
run.ID,
|
||||
seedSummary,
|
||||
); aggErr != nil {
|
||||
logger.Warn(ctx, "failed to aggregate debug run summary",
|
||||
slog.F("chat_id", chat.ID),
|
||||
slog.F("run_id", run.ID),
|
||||
slog.Error(aggErr),
|
||||
)
|
||||
} else {
|
||||
finalSummary = aggregated
|
||||
}
|
||||
|
||||
updateRunCtx, updateRunCancel := context.WithTimeout(context.WithoutCancel(runCtx), 5*time.Second)
|
||||
defer updateRunCancel()
|
||||
if _, updateRunErr := p.debugSvc.UpdateRun(
|
||||
updateRunCtx,
|
||||
chatdebug.UpdateRunParams{
|
||||
ID: run.ID,
|
||||
ChatID: chat.ID,
|
||||
Status: status,
|
||||
Summary: finalSummary,
|
||||
FinishedAt: time.Now(),
|
||||
},
|
||||
); updateRunErr != nil {
|
||||
logger.Warn(ctx, "failed to finalize chat debug run",
|
||||
slog.F("chat_id", chat.ID),
|
||||
slog.F("run_id", run.ID),
|
||||
slog.Error(updateRunErr),
|
||||
)
|
||||
}
|
||||
chatdebug.CleanupStepCounter(run.ID)
|
||||
if panicValue != nil {
|
||||
panic(panicValue)
|
||||
}
|
||||
}()
|
||||
ctx = runCtx
|
||||
}
|
||||
}
|
||||
|
||||
loopErr = chatloop.Run(ctx, chatloop.RunOptions{
|
||||
err = chatloop.Run(ctx, chatloop.RunOptions{
|
||||
Model: model,
|
||||
Messages: prompt,
|
||||
Tools: tools, MaxSteps: maxChatSteps,
|
||||
@@ -5643,13 +5215,6 @@ func (p *Server) runChat(
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("reload chat messages: %w", err)
|
||||
}
|
||||
compactionHistoryTipMessageID = 0
|
||||
if len(reloadedMsgs) > 0 {
|
||||
compactionHistoryTipMessageID = reloadedMsgs[len(reloadedMsgs)-1].ID
|
||||
}
|
||||
if compactionOptions != nil {
|
||||
compactionOptions.HistoryTipMessageID = compactionHistoryTipMessageID
|
||||
}
|
||||
reloadedPrompt, err := chatprompt.ConvertMessagesWithFiles(reloadCtx, reloadedMsgs, p.chatFileResolver(), logger)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("convert reloaded messages: %w", err)
|
||||
@@ -5706,7 +5271,7 @@ func (p *Server) runChat(
|
||||
p.logger.Warn(ctx, "failed to persist interrupted chat step", slog.Error(err))
|
||||
},
|
||||
})
|
||||
if errors.Is(loopErr, chatloop.ErrDynamicToolCall) {
|
||||
if errors.Is(err, chatloop.ErrDynamicToolCall) {
|
||||
// The stream event is published in processChat's
|
||||
// defer after the DB status transitions to
|
||||
// requires_action, preventing a race where a fast
|
||||
@@ -5715,9 +5280,9 @@ func (p *Server) runChat(
|
||||
result.PendingDynamicToolCalls = pendingDynamicCalls
|
||||
return result, nil
|
||||
}
|
||||
if loopErr != nil {
|
||||
classified := chaterror.Classify(loopErr).WithProvider(model.Provider())
|
||||
return result, chaterror.WithClassification(loopErr, classified)
|
||||
if err != nil {
|
||||
classified := chaterror.Classify(err).WithProvider(model.Provider())
|
||||
return result, chaterror.WithClassification(err, classified)
|
||||
}
|
||||
result.FinalAssistantText = finalAssistantText
|
||||
return result, nil
|
||||
@@ -5881,15 +5446,10 @@ func (p *Server) persistChatContextSummary(
|
||||
func (p *Server) resolveChatModel(
|
||||
ctx context.Context,
|
||||
chat database.Chat,
|
||||
) (
|
||||
model fantasy.LanguageModel,
|
||||
dbConfig database.ChatModelConfig,
|
||||
keys chatprovider.ProviderAPIKeys,
|
||||
debugEnabled bool,
|
||||
resolvedProvider string,
|
||||
resolvedModel string,
|
||||
err error,
|
||||
) {
|
||||
) (fantasy.LanguageModel, database.ChatModelConfig, chatprovider.ProviderAPIKeys, error) {
|
||||
var dbConfig database.ChatModelConfig
|
||||
var keys chatprovider.ProviderAPIKeys
|
||||
|
||||
var g errgroup.Group
|
||||
g.Go(func() error {
|
||||
var err error
|
||||
@@ -5908,34 +5468,19 @@ func (p *Server) resolveChatModel(
|
||||
return nil
|
||||
})
|
||||
if err := g.Wait(); err != nil {
|
||||
return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, false, "", "", err
|
||||
return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, err
|
||||
}
|
||||
|
||||
resolvedProvider, resolvedModel, err = chatprovider.ResolveModelWithProviderHint(
|
||||
dbConfig.Model,
|
||||
dbConfig.Provider,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, false, "", "", xerrors.Errorf(
|
||||
"resolve model metadata: %w", err,
|
||||
)
|
||||
}
|
||||
|
||||
model, debugEnabled, err = p.newDebugAwareModelFromConfig(
|
||||
ctx,
|
||||
chat,
|
||||
dbConfig.Provider,
|
||||
dbConfig.Model,
|
||||
keys,
|
||||
chatprovider.UserAgent(),
|
||||
model, err := chatprovider.ModelFromConfig(
|
||||
dbConfig.Provider, dbConfig.Model, keys, chatprovider.UserAgent(),
|
||||
chatprovider.CoderHeaders(chat),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, false, "", "", xerrors.Errorf(
|
||||
return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, xerrors.Errorf(
|
||||
"create model: %w", err,
|
||||
)
|
||||
}
|
||||
return model, dbConfig, keys, debugEnabled, resolvedProvider, resolvedModel, nil
|
||||
return model, dbConfig, keys, nil
|
||||
}
|
||||
|
||||
func (p *Server) resolveUserProviderAPIKeys(
|
||||
@@ -6617,14 +6162,9 @@ func (p *Server) maybeSendPushNotification(
|
||||
pushCtx,
|
||||
chat,
|
||||
assistantText,
|
||||
runResult.FallbackProvider,
|
||||
runResult.FallbackModel,
|
||||
runResult.PushSummaryModel,
|
||||
runResult.ProviderKeys,
|
||||
logger,
|
||||
p.debugSvc,
|
||||
runResult.TriggerMessageID,
|
||||
runResult.HistoryTipMessageID,
|
||||
); summary != "" {
|
||||
pushBody = summary
|
||||
}
|
||||
|
||||
@@ -1,64 +0,0 @@
|
||||
package chatd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatdebug"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatprovider"
|
||||
)
|
||||
|
||||
func (p *Server) newDebugAwareModelFromConfig(
|
||||
ctx context.Context,
|
||||
chat database.Chat,
|
||||
providerHint string,
|
||||
modelName string,
|
||||
providerKeys chatprovider.ProviderAPIKeys,
|
||||
userAgent string,
|
||||
extraHeaders map[string]string,
|
||||
) (fantasy.LanguageModel, bool, error) {
|
||||
provider, resolvedModel, err := chatprovider.ResolveModelWithProviderHint(modelName, providerHint)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
debugEnabled := p.debugSvc != nil && p.debugSvc.IsEnabled(ctx, chat.ID, chat.OwnerID)
|
||||
|
||||
var httpClient *http.Client
|
||||
if debugEnabled {
|
||||
httpClient = &http.Client{Transport: &chatdebug.RecordingTransport{}}
|
||||
}
|
||||
|
||||
model, err := chatprovider.ModelFromConfig(
|
||||
provider,
|
||||
resolvedModel,
|
||||
providerKeys,
|
||||
userAgent,
|
||||
extraHeaders,
|
||||
httpClient,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, debugEnabled, err
|
||||
}
|
||||
if model == nil {
|
||||
return nil, debugEnabled, xerrors.Errorf(
|
||||
"create model for %s/%s returned nil",
|
||||
provider,
|
||||
resolvedModel,
|
||||
)
|
||||
}
|
||||
if !debugEnabled {
|
||||
return model, false, nil
|
||||
}
|
||||
|
||||
return chatdebug.WrapModel(model, p.debugSvc, chatdebug.RecorderOptions{
|
||||
ChatID: chat.ID,
|
||||
OwnerID: chat.OwnerID,
|
||||
Provider: provider,
|
||||
Model: resolvedModel,
|
||||
}), true, nil
|
||||
}
|
||||
@@ -33,15 +33,6 @@ import (
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
// TestWaitForActiveChatStop and TestWaitForActiveChatStop_WaitsForReplacementRun
|
||||
// were removed along with the process-local activeChats mechanism.
|
||||
// Debug cleanup is now best-effort; stale finalization handles orphaned rows.
|
||||
|
||||
// TestArchiveChatWaitsForActiveChatStop and
|
||||
// TestArchiveChatWaitsForEveryInterruptedChat were removed along with
|
||||
// the process-local activeChats mechanism. Archive cleanup is now
|
||||
// best-effort; stale finalization handles any orphaned rows.
|
||||
|
||||
func TestRegenerateChatTitle_PersistsAndBroadcasts(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -1,84 +0,0 @@
|
||||
package chatdebug
|
||||
|
||||
import (
|
||||
"context"
|
||||
"runtime"
|
||||
"sync"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type (
|
||||
runContextKey struct{}
|
||||
stepContextKey struct{}
|
||||
reuseStepKey struct{}
|
||||
reuseHolder struct {
|
||||
mu sync.Mutex
|
||||
handle *stepHandle
|
||||
}
|
||||
)
|
||||
|
||||
// ContextWithRun stores rc in ctx.
|
||||
//
|
||||
// Step counter cleanup is reference-counted per RunID: each live
|
||||
// RunContext increments a counter and runtime.AddCleanup decrements
|
||||
// it when the struct is garbage collected. Shared state (step
|
||||
// counters) is only deleted when the last RunContext for a given
|
||||
// RunID becomes unreachable, preventing premature cleanup when
|
||||
// multiple RunContext instances share the same RunID.
|
||||
func ContextWithRun(ctx context.Context, rc *RunContext) context.Context {
|
||||
if rc == nil {
|
||||
panic("chatdebug: nil RunContext")
|
||||
}
|
||||
|
||||
enriched := context.WithValue(ctx, runContextKey{}, rc)
|
||||
if rc.RunID != uuid.Nil {
|
||||
trackRunRef(rc.RunID)
|
||||
runtime.AddCleanup(rc, func(id uuid.UUID) {
|
||||
releaseRunRef(id)
|
||||
}, rc.RunID)
|
||||
}
|
||||
return enriched
|
||||
}
|
||||
|
||||
// RunFromContext returns the debug run context stored in ctx.
|
||||
func RunFromContext(ctx context.Context) (*RunContext, bool) {
|
||||
rc, ok := ctx.Value(runContextKey{}).(*RunContext)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
return rc, true
|
||||
}
|
||||
|
||||
// ContextWithStep stores sc in ctx.
|
||||
func ContextWithStep(ctx context.Context, sc *StepContext) context.Context {
|
||||
if sc == nil {
|
||||
panic("chatdebug: nil StepContext")
|
||||
}
|
||||
return context.WithValue(ctx, stepContextKey{}, sc)
|
||||
}
|
||||
|
||||
// StepFromContext returns the debug step context stored in ctx.
|
||||
func StepFromContext(ctx context.Context) (*StepContext, bool) {
|
||||
sc, ok := ctx.Value(stepContextKey{}).(*StepContext)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
return sc, true
|
||||
}
|
||||
|
||||
// ReuseStep marks ctx so wrapped model calls under it share one debug step.
|
||||
func ReuseStep(ctx context.Context) context.Context {
|
||||
if holder, ok := reuseHolderFromContext(ctx); ok {
|
||||
return context.WithValue(ctx, reuseStepKey{}, holder)
|
||||
}
|
||||
return context.WithValue(ctx, reuseStepKey{}, &reuseHolder{})
|
||||
}
|
||||
|
||||
func reuseHolderFromContext(ctx context.Context) (*reuseHolder, bool) {
|
||||
holder, ok := ctx.Value(reuseStepKey{}).(*reuseHolder)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
return holder, true
|
||||
}
|
||||
@@ -1,118 +0,0 @@
|
||||
package chatdebug
|
||||
|
||||
import (
|
||||
"context"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestReuseStep_PreservesExistingHolder(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := ReuseStep(context.Background())
|
||||
first, ok := reuseHolderFromContext(ctx)
|
||||
require.True(t, ok)
|
||||
|
||||
reused := ReuseStep(ctx)
|
||||
second, ok := reuseHolderFromContext(reused)
|
||||
require.True(t, ok)
|
||||
require.Same(t, first, second)
|
||||
}
|
||||
|
||||
func TestContextWithRun_CleansUpStepCounterAfterGC(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
runID := uuid.New()
|
||||
chatID := uuid.New()
|
||||
t.Cleanup(func() { CleanupStepCounter(runID) })
|
||||
|
||||
func() {
|
||||
_ = ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})
|
||||
require.Equal(t, int32(1), nextStepNumber(runID))
|
||||
_, ok := stepCounters.Load(runID)
|
||||
require.True(t, ok)
|
||||
}()
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
runtime.GC()
|
||||
runtime.Gosched()
|
||||
_, ok := stepCounters.Load(runID)
|
||||
return !ok
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
}
|
||||
|
||||
func TestContextWithRun_MultipleInstancesSameRunID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
runID := uuid.New()
|
||||
chatID := uuid.New()
|
||||
t.Cleanup(func() { CleanupStepCounter(runID) })
|
||||
|
||||
// rc2 is the surviving instance that should keep the step counter alive.
|
||||
rc2 := &RunContext{RunID: runID, ChatID: chatID}
|
||||
ctx2 := ContextWithRun(context.Background(), rc2)
|
||||
|
||||
// Create a second RunContext with the same RunID and let it become
|
||||
// unreachable. Its GC cleanup must NOT delete the step counter
|
||||
// because rc2 is still alive.
|
||||
func() {
|
||||
rc1 := &RunContext{RunID: runID, ChatID: chatID}
|
||||
ctx1 := ContextWithRun(context.Background(), rc1)
|
||||
h, _ := beginStep(ctx1, &Service{}, RecorderOptions{ChatID: chatID}, OperationGenerate, nil)
|
||||
require.NotNil(t, h)
|
||||
require.Equal(t, int32(1), h.stepCtx.StepNumber)
|
||||
}()
|
||||
|
||||
// Force GC to collect rc1.
|
||||
for range 5 {
|
||||
runtime.GC()
|
||||
runtime.Gosched()
|
||||
}
|
||||
|
||||
// The step counter must still be present because rc2 is alive.
|
||||
_, ok := stepCounters.Load(runID)
|
||||
require.True(t, ok, "step counter was prematurely cleaned up while another RunContext is still alive")
|
||||
|
||||
// Subsequent steps on the surviving context must continue numbering.
|
||||
h2, _ := beginStep(ctx2, &Service{}, RecorderOptions{ChatID: chatID}, OperationGenerate, nil)
|
||||
require.NotNil(t, h2)
|
||||
require.Equal(t, int32(2), h2.stepCtx.StepNumber)
|
||||
}
|
||||
|
||||
func TestContextWithRun_CleansUpStepCounterOnGCAfterCancel(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
runID := uuid.New()
|
||||
chatID := uuid.New()
|
||||
t.Cleanup(func() { CleanupStepCounter(runID) })
|
||||
|
||||
// Run in a closure so the RunContext becomes unreachable after
|
||||
// context cancellation, allowing GC to trigger the cleanup.
|
||||
func() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
ContextWithRun(ctx, &RunContext{RunID: runID, ChatID: chatID})
|
||||
|
||||
require.Equal(t, int32(1), nextStepNumber(runID))
|
||||
|
||||
_, ok := stepCounters.Load(runID)
|
||||
require.True(t, ok)
|
||||
|
||||
cancel()
|
||||
}()
|
||||
|
||||
// After the closure, the RunContext is unreachable.
|
||||
// runtime.AddCleanup fires during GC.
|
||||
require.Eventually(t, func() bool {
|
||||
runtime.GC()
|
||||
runtime.Gosched()
|
||||
_, ok := stepCounters.Load(runID)
|
||||
return !ok
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
|
||||
require.Equal(t, int32(1), nextStepNumber(runID))
|
||||
}
|
||||
@@ -1,105 +0,0 @@
|
||||
package chatdebug_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatdebug"
|
||||
)
|
||||
|
||||
func TestContextWithRunRoundTrip(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rc := &chatdebug.RunContext{
|
||||
RunID: uuid.New(),
|
||||
ChatID: uuid.New(),
|
||||
RootChatID: uuid.New(),
|
||||
ParentChatID: uuid.New(),
|
||||
ModelConfigID: uuid.New(),
|
||||
TriggerMessageID: 11,
|
||||
HistoryTipMessageID: 22,
|
||||
Kind: chatdebug.KindChatTurn,
|
||||
Provider: "anthropic",
|
||||
Model: "claude-sonnet",
|
||||
}
|
||||
|
||||
ctx := chatdebug.ContextWithRun(context.Background(), rc)
|
||||
got, ok := chatdebug.RunFromContext(ctx)
|
||||
require.True(t, ok)
|
||||
require.Same(t, rc, got)
|
||||
require.Equal(t, *rc, *got)
|
||||
}
|
||||
|
||||
func TestRunFromContextAbsent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got, ok := chatdebug.RunFromContext(context.Background())
|
||||
require.False(t, ok)
|
||||
require.Nil(t, got)
|
||||
}
|
||||
|
||||
func TestContextWithStepRoundTrip(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
sc := &chatdebug.StepContext{
|
||||
StepID: uuid.New(),
|
||||
RunID: uuid.New(),
|
||||
ChatID: uuid.New(),
|
||||
StepNumber: 7,
|
||||
Operation: chatdebug.OperationStream,
|
||||
HistoryTipMessageID: 33,
|
||||
}
|
||||
|
||||
ctx := chatdebug.ContextWithStep(context.Background(), sc)
|
||||
got, ok := chatdebug.StepFromContext(ctx)
|
||||
require.True(t, ok)
|
||||
require.Same(t, sc, got)
|
||||
require.Equal(t, *sc, *got)
|
||||
}
|
||||
|
||||
func TestStepFromContextAbsent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got, ok := chatdebug.StepFromContext(context.Background())
|
||||
require.False(t, ok)
|
||||
require.Nil(t, got)
|
||||
}
|
||||
|
||||
func TestContextWithRunAndStep(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rc := &chatdebug.RunContext{RunID: uuid.New(), ChatID: uuid.New()}
|
||||
sc := &chatdebug.StepContext{StepID: uuid.New(), RunID: rc.RunID, ChatID: rc.ChatID}
|
||||
|
||||
ctx := chatdebug.ContextWithStep(
|
||||
chatdebug.ContextWithRun(context.Background(), rc),
|
||||
sc,
|
||||
)
|
||||
|
||||
gotRun, ok := chatdebug.RunFromContext(ctx)
|
||||
require.True(t, ok)
|
||||
require.Same(t, rc, gotRun)
|
||||
|
||||
gotStep, ok := chatdebug.StepFromContext(ctx)
|
||||
require.True(t, ok)
|
||||
require.Same(t, sc, gotStep)
|
||||
}
|
||||
|
||||
func TestContextWithRunPanicsOnNil(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
require.Panics(t, func() {
|
||||
_ = chatdebug.ContextWithRun(context.Background(), nil)
|
||||
})
|
||||
}
|
||||
|
||||
func TestContextWithStepPanicsOnNil(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
require.Panics(t, func() {
|
||||
_ = chatdebug.ContextWithStep(context.Background(), nil)
|
||||
})
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,331 +0,0 @@
|
||||
package chatdebug //nolint:testpackage // Checks unexported normalized structs against fantasy source types.
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// fieldDisposition documents whether a fantasy struct field is captured
|
||||
// by the corresponding normalized struct ("normalized") or
|
||||
// intentionally omitted ("skipped: <reason>"). The test fails when a
|
||||
// fantasy type gains a field that is not yet classified, forcing the
|
||||
// developer to decide whether to normalize or skip it.
|
||||
//
|
||||
// This mirrors the audit-table exhaustiveness check in
|
||||
// enterprise/audit/table.go — same idea, different domain.
|
||||
type fieldDisposition = map[string]string
|
||||
|
||||
// TestNormalizationFieldCoverage ensures every exported field on the
|
||||
// fantasy types that model.go normalizes is explicitly accounted for.
|
||||
// When the fantasy library adds a field the test fails, surfacing the
|
||||
// drift at `go test` time rather than silently dropping data.
|
||||
func TestNormalizationFieldCoverage(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
typ reflect.Type
|
||||
fields fieldDisposition
|
||||
}{
|
||||
// ── struct-to-struct mappings ──────────────────────────
|
||||
|
||||
{
|
||||
name: "fantasy.Usage → normalizedUsage",
|
||||
typ: reflect.TypeFor[fantasy.Usage](),
|
||||
fields: fieldDisposition{
|
||||
"InputTokens": "normalized",
|
||||
"OutputTokens": "normalized",
|
||||
"TotalTokens": "normalized",
|
||||
"ReasoningTokens": "normalized",
|
||||
"CacheCreationTokens": "normalized",
|
||||
"CacheReadTokens": "normalized",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "fantasy.Call → normalizedCallPayload",
|
||||
typ: reflect.TypeFor[fantasy.Call](),
|
||||
fields: fieldDisposition{
|
||||
"Prompt": "normalized",
|
||||
"MaxOutputTokens": "normalized",
|
||||
"Temperature": "normalized",
|
||||
"TopP": "normalized",
|
||||
"TopK": "normalized",
|
||||
"PresencePenalty": "normalized",
|
||||
"FrequencyPenalty": "normalized",
|
||||
"Tools": "normalized",
|
||||
"ToolChoice": "normalized",
|
||||
"UserAgent": "skipped: internal transport header, not useful for debug panel",
|
||||
"ProviderOptions": "skipped: opaque provider data, only count preserved",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "fantasy.ObjectCall → normalizedObjectCallPayload",
|
||||
typ: reflect.TypeFor[fantasy.ObjectCall](),
|
||||
fields: fieldDisposition{
|
||||
"Prompt": "normalized",
|
||||
"Schema": "skipped: full schema too large; SchemaName+SchemaDescription captured instead",
|
||||
"SchemaName": "normalized",
|
||||
"SchemaDescription": "normalized",
|
||||
"MaxOutputTokens": "normalized",
|
||||
"Temperature": "normalized",
|
||||
"TopP": "normalized",
|
||||
"TopK": "normalized",
|
||||
"PresencePenalty": "normalized",
|
||||
"FrequencyPenalty": "normalized",
|
||||
"UserAgent": "skipped: internal transport header, not useful for debug panel",
|
||||
"ProviderOptions": "skipped: opaque provider data, only count preserved",
|
||||
"RepairText": "skipped: function value, not serializable",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "fantasy.Response → normalizedResponsePayload",
|
||||
typ: reflect.TypeFor[fantasy.Response](),
|
||||
fields: fieldDisposition{
|
||||
"Content": "normalized",
|
||||
"FinishReason": "normalized",
|
||||
"Usage": "normalized",
|
||||
"Warnings": "normalized",
|
||||
"ProviderMetadata": "skipped: opaque provider-specific metadata",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "fantasy.ObjectResponse → normalizedObjectResponsePayload",
|
||||
typ: reflect.TypeFor[fantasy.ObjectResponse](),
|
||||
fields: fieldDisposition{
|
||||
"Object": "skipped: arbitrary user type, not serializable generically",
|
||||
"RawText": "normalized: as RawTextLength (length only, content unbounded)",
|
||||
"Usage": "normalized",
|
||||
"FinishReason": "normalized",
|
||||
"Warnings": "normalized",
|
||||
"ProviderMetadata": "skipped: opaque provider-specific metadata",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "fantasy.CallWarning → normalizedWarning",
|
||||
typ: reflect.TypeFor[fantasy.CallWarning](),
|
||||
fields: fieldDisposition{
|
||||
"Type": "normalized",
|
||||
"Setting": "normalized",
|
||||
"Tool": "skipped: interface value, warning message+type sufficient for debug panel",
|
||||
"Details": "normalized",
|
||||
"Message": "normalized",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "fantasy.StreamPart → appendNormalizedStreamContent",
|
||||
typ: reflect.TypeFor[fantasy.StreamPart](),
|
||||
fields: fieldDisposition{
|
||||
"Type": "normalized",
|
||||
"ID": "normalized: as ToolCallID in content parts",
|
||||
"ToolCallName": "normalized: as ToolName in content parts",
|
||||
"ToolCallInput": "normalized: as Arguments or Result (bounded)",
|
||||
"Delta": "normalized: accumulated into text/reasoning content parts",
|
||||
"ProviderExecuted": "skipped: provider vs client distinction not needed for debug panel",
|
||||
"Usage": "normalized: captured in stream finalize",
|
||||
"FinishReason": "normalized: captured in stream finalize",
|
||||
"Error": "normalized: captured in stream error handling",
|
||||
"Warnings": "normalized: captured in stream warning accumulation",
|
||||
"SourceType": "normalized",
|
||||
"URL": "normalized",
|
||||
"Title": "normalized",
|
||||
"ProviderMetadata": "skipped: opaque provider-specific metadata",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "fantasy.ObjectStreamPart → wrapObjectStreamSeq",
|
||||
typ: reflect.TypeFor[fantasy.ObjectStreamPart](),
|
||||
fields: fieldDisposition{
|
||||
"Type": "normalized: drives switch in wrapObjectStreamSeq",
|
||||
"Object": "skipped: arbitrary user type, only ObjectPartCount tracked",
|
||||
"Delta": "normalized: accumulated into rawTextLength",
|
||||
"Error": "normalized: captured in stream error handling",
|
||||
"Usage": "normalized: captured in stream finalize",
|
||||
"FinishReason": "normalized: captured in stream finalize",
|
||||
"Warnings": "normalized: captured in stream warning accumulation",
|
||||
"ProviderMetadata": "skipped: opaque provider-specific metadata",
|
||||
},
|
||||
},
|
||||
|
||||
// ── message part types (normalizeMessageParts) ────────
|
||||
|
||||
{
|
||||
name: "fantasy.TextPart → normalizedMessagePart",
|
||||
typ: reflect.TypeFor[fantasy.TextPart](),
|
||||
fields: fieldDisposition{
|
||||
"Text": "normalized: bounded to MaxMessagePartTextLength",
|
||||
"ProviderOptions": "skipped: opaque provider-specific options",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "fantasy.ReasoningPart → normalizedMessagePart",
|
||||
typ: reflect.TypeFor[fantasy.ReasoningPart](),
|
||||
fields: fieldDisposition{
|
||||
"Text": "normalized: bounded to MaxMessagePartTextLength",
|
||||
"ProviderOptions": "skipped: opaque provider-specific options",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "fantasy.FilePart → normalizedMessagePart",
|
||||
typ: reflect.TypeFor[fantasy.FilePart](),
|
||||
fields: fieldDisposition{
|
||||
"Filename": "normalized",
|
||||
"Data": "skipped: binary data never stored in debug records",
|
||||
"MediaType": "normalized",
|
||||
"ProviderOptions": "skipped: opaque provider-specific options",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "fantasy.ToolCallPart → normalizedMessagePart",
|
||||
typ: reflect.TypeFor[fantasy.ToolCallPart](),
|
||||
fields: fieldDisposition{
|
||||
"ToolCallID": "normalized",
|
||||
"ToolName": "normalized",
|
||||
"Input": "normalized: as Arguments (bounded)",
|
||||
"ProviderExecuted": "skipped: provider vs client distinction not needed for debug panel",
|
||||
"ProviderOptions": "skipped: opaque provider-specific options",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "fantasy.ToolResultPart → normalizedMessagePart",
|
||||
typ: reflect.TypeFor[fantasy.ToolResultPart](),
|
||||
fields: fieldDisposition{
|
||||
"ToolCallID": "normalized",
|
||||
"Output": "normalized: text extracted via normalizeToolResultOutput",
|
||||
"ProviderExecuted": "skipped: provider vs client distinction not needed for debug panel",
|
||||
"ProviderOptions": "skipped: opaque provider-specific options",
|
||||
},
|
||||
},
|
||||
|
||||
// ── response content types (normalizeContentParts) ────
|
||||
|
||||
{
|
||||
name: "fantasy.TextContent → normalizedContentPart",
|
||||
typ: reflect.TypeFor[fantasy.TextContent](),
|
||||
fields: fieldDisposition{
|
||||
"Text": "normalized: bounded to MaxMessagePartTextLength",
|
||||
"ProviderMetadata": "skipped: opaque provider-specific metadata",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "fantasy.ReasoningContent → normalizedContentPart",
|
||||
typ: reflect.TypeFor[fantasy.ReasoningContent](),
|
||||
fields: fieldDisposition{
|
||||
"Text": "normalized: bounded to MaxMessagePartTextLength",
|
||||
"ProviderMetadata": "skipped: opaque provider-specific metadata",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "fantasy.FileContent → normalizedContentPart",
|
||||
typ: reflect.TypeFor[fantasy.FileContent](),
|
||||
fields: fieldDisposition{
|
||||
"MediaType": "normalized",
|
||||
"Data": "skipped: binary data never stored in debug records",
|
||||
"ProviderMetadata": "skipped: opaque provider-specific metadata",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "fantasy.SourceContent → normalizedContentPart",
|
||||
typ: reflect.TypeFor[fantasy.SourceContent](),
|
||||
fields: fieldDisposition{
|
||||
"SourceType": "normalized",
|
||||
"ID": "skipped: provider-internal identifier, not actionable in debug panel",
|
||||
"URL": "normalized",
|
||||
"Title": "normalized",
|
||||
"MediaType": "skipped: only relevant for document sources, rarely useful for debugging",
|
||||
"Filename": "skipped: only relevant for document sources, rarely useful for debugging",
|
||||
"ProviderMetadata": "skipped: opaque provider-specific metadata",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "fantasy.ToolCallContent → normalizedContentPart",
|
||||
typ: reflect.TypeFor[fantasy.ToolCallContent](),
|
||||
fields: fieldDisposition{
|
||||
"ToolCallID": "normalized",
|
||||
"ToolName": "normalized",
|
||||
"Input": "normalized: as Arguments (bounded), InputLength tracks original",
|
||||
"ProviderExecuted": "skipped: provider vs client distinction not needed for debug panel",
|
||||
"ProviderMetadata": "skipped: opaque provider-specific metadata",
|
||||
"Invalid": "skipped: validation state not surfaced in debug panel",
|
||||
"ValidationError": "skipped: validation state not surfaced in debug panel",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "fantasy.ToolResultContent → normalizedContentPart",
|
||||
typ: reflect.TypeFor[fantasy.ToolResultContent](),
|
||||
fields: fieldDisposition{
|
||||
"ToolCallID": "normalized",
|
||||
"ToolName": "normalized",
|
||||
"Result": "normalized: text extracted via normalizeToolResultOutput",
|
||||
"ClientMetadata": "skipped: client execution metadata not needed for debug panel",
|
||||
"ProviderExecuted": "skipped: provider vs client distinction not needed for debug panel",
|
||||
"ProviderMetadata": "skipped: opaque provider-specific metadata",
|
||||
},
|
||||
},
|
||||
|
||||
// ── tool types (normalizeTools) ───────────────────────
|
||||
|
||||
{
|
||||
name: "fantasy.FunctionTool → normalizedTool",
|
||||
typ: reflect.TypeFor[fantasy.FunctionTool](),
|
||||
fields: fieldDisposition{
|
||||
"Name": "normalized",
|
||||
"Description": "normalized",
|
||||
"InputSchema": "normalized: preserved as JSON for debug panel rendering",
|
||||
"ProviderOptions": "skipped: opaque provider-specific options",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "fantasy.ProviderDefinedTool → normalizedTool",
|
||||
typ: reflect.TypeFor[fantasy.ProviderDefinedTool](),
|
||||
fields: fieldDisposition{
|
||||
"ID": "normalized",
|
||||
"Name": "normalized",
|
||||
"Args": "skipped: provider-specific configuration not needed for debug panel",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Every exported field on the fantasy type must be
|
||||
// registered as "normalized" or "skipped: <reason>".
|
||||
for i := range tt.typ.NumField() {
|
||||
field := tt.typ.Field(i)
|
||||
if !field.IsExported() {
|
||||
continue
|
||||
}
|
||||
disposition, ok := tt.fields[field.Name]
|
||||
if !ok {
|
||||
require.Failf(t, "unregistered field",
|
||||
"%s.%s is not in the coverage map — "+
|
||||
"add it as \"normalized\" or \"skipped: <reason>\"",
|
||||
tt.typ.Name(), field.Name)
|
||||
}
|
||||
require.NotEmptyf(t, disposition,
|
||||
"%s.%s has an empty disposition — "+
|
||||
"use \"normalized\" or \"skipped: <reason>\"",
|
||||
tt.typ.Name(), field.Name)
|
||||
}
|
||||
|
||||
// Catch stale entries that reference removed fields.
|
||||
for name := range tt.fields {
|
||||
found := false
|
||||
for i := range tt.typ.NumField() {
|
||||
if tt.typ.Field(i).Name == name {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
require.Truef(t, found,
|
||||
"stale coverage entry %s.%s — "+
|
||||
"field no longer exists in fantasy, remove it",
|
||||
tt.typ.Name(), name)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,987 +0,0 @@
|
||||
package chatdebug
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbmock"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chattest"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
type testError struct{ message string }
|
||||
|
||||
func (e *testError) Error() string { return e.message }
|
||||
|
||||
func expectDebugLoggingEnabled(
|
||||
t *testing.T,
|
||||
db *dbmock.MockStore,
|
||||
ownerID uuid.UUID,
|
||||
) {
|
||||
t.Helper()
|
||||
|
||||
db.EXPECT().GetChatDebugLoggingEnabled(gomock.Any()).Return(true, nil)
|
||||
db.EXPECT().GetUserChatDebugLoggingEnabled(gomock.Any(), ownerID).Return(true, nil)
|
||||
}
|
||||
|
||||
func expectCreateStepNumberWithRequestValidity(
|
||||
t *testing.T,
|
||||
db *dbmock.MockStore,
|
||||
runID uuid.UUID,
|
||||
chatID uuid.UUID,
|
||||
stepNumber int32,
|
||||
op Operation,
|
||||
normalizedRequestValid bool,
|
||||
) uuid.UUID {
|
||||
t.Helper()
|
||||
|
||||
stepID := uuid.New()
|
||||
db.EXPECT().
|
||||
InsertChatDebugStep(gomock.Any(), gomock.AssignableToTypeOf(database.InsertChatDebugStepParams{})).
|
||||
DoAndReturn(func(_ context.Context, params database.InsertChatDebugStepParams) (database.ChatDebugStep, error) {
|
||||
require.Equal(t, runID, params.RunID)
|
||||
require.Equal(t, chatID, params.ChatID)
|
||||
require.Equal(t, stepNumber, params.StepNumber)
|
||||
require.Equal(t, string(op), params.Operation)
|
||||
require.Equal(t, string(StatusInProgress), params.Status)
|
||||
require.Equal(t, normalizedRequestValid, params.NormalizedRequest.Valid)
|
||||
|
||||
return database.ChatDebugStep{
|
||||
ID: stepID,
|
||||
RunID: runID,
|
||||
ChatID: chatID,
|
||||
StepNumber: params.StepNumber,
|
||||
Operation: params.Operation,
|
||||
Status: params.Status,
|
||||
}, nil
|
||||
})
|
||||
|
||||
// CreateStep now touches the parent run's updated_at to prevent
|
||||
// premature stale finalization.
|
||||
db.EXPECT().
|
||||
UpdateChatDebugRun(gomock.Any(), gomock.AssignableToTypeOf(database.UpdateChatDebugRunParams{})).
|
||||
DoAndReturn(func(_ context.Context, params database.UpdateChatDebugRunParams) (database.ChatDebugRun, error) {
|
||||
require.Equal(t, runID, params.ID)
|
||||
require.Equal(t, chatID, params.ChatID)
|
||||
return database.ChatDebugRun{ID: runID, ChatID: chatID}, nil
|
||||
})
|
||||
|
||||
return stepID
|
||||
}
|
||||
|
||||
func expectCreateStepNumber(
|
||||
t *testing.T,
|
||||
db *dbmock.MockStore,
|
||||
runID uuid.UUID,
|
||||
chatID uuid.UUID,
|
||||
stepNumber int32,
|
||||
op Operation,
|
||||
) uuid.UUID {
|
||||
t.Helper()
|
||||
|
||||
return expectCreateStepNumberWithRequestValidity(
|
||||
t,
|
||||
db,
|
||||
runID,
|
||||
chatID,
|
||||
stepNumber,
|
||||
op,
|
||||
true,
|
||||
)
|
||||
}
|
||||
|
||||
func expectCreateStep(
|
||||
t *testing.T,
|
||||
db *dbmock.MockStore,
|
||||
runID uuid.UUID,
|
||||
chatID uuid.UUID,
|
||||
op Operation,
|
||||
) uuid.UUID {
|
||||
t.Helper()
|
||||
|
||||
return expectCreateStepNumber(t, db, runID, chatID, 1, op)
|
||||
}
|
||||
|
||||
func expectUpdateStep(
|
||||
t *testing.T,
|
||||
db *dbmock.MockStore,
|
||||
stepID uuid.UUID,
|
||||
chatID uuid.UUID,
|
||||
status Status,
|
||||
assertFn func(database.UpdateChatDebugStepParams),
|
||||
) {
|
||||
t.Helper()
|
||||
|
||||
db.EXPECT().
|
||||
UpdateChatDebugStep(gomock.Any(), gomock.AssignableToTypeOf(database.UpdateChatDebugStepParams{})).
|
||||
DoAndReturn(func(_ context.Context, params database.UpdateChatDebugStepParams) (database.ChatDebugStep, error) {
|
||||
require.Equal(t, stepID, params.ID)
|
||||
require.Equal(t, chatID, params.ChatID)
|
||||
require.True(t, params.Status.Valid)
|
||||
require.Equal(t, string(status), params.Status.String)
|
||||
require.True(t, params.FinishedAt.Valid)
|
||||
|
||||
if assertFn != nil {
|
||||
assertFn(params)
|
||||
}
|
||||
|
||||
return database.ChatDebugStep{
|
||||
ID: stepID,
|
||||
ChatID: chatID,
|
||||
Status: params.Status.String,
|
||||
}, nil
|
||||
})
|
||||
}
|
||||
|
||||
func TestDebugModel_Provider(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
inner := &chattest.FakeModel{ProviderName: "provider-a", ModelName: "model-a"}
|
||||
model := &debugModel{inner: inner}
|
||||
|
||||
require.Equal(t, inner.Provider(), model.Provider())
|
||||
}
|
||||
|
||||
func TestDebugModel_Model(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
inner := &chattest.FakeModel{ProviderName: "provider-a", ModelName: "model-a"}
|
||||
model := &debugModel{inner: inner}
|
||||
|
||||
require.Equal(t, inner.Model(), model.Model())
|
||||
}
|
||||
|
||||
func TestDebugModel_Disabled(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
chatID := uuid.New()
|
||||
ownerID := uuid.New()
|
||||
|
||||
svc := NewService(db, testutil.Logger(t), nil)
|
||||
respWant := &fantasy.Response{FinishReason: fantasy.FinishReasonStop}
|
||||
inner := &chattest.FakeModel{
|
||||
GenerateFn: func(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) {
|
||||
_, ok := StepFromContext(ctx)
|
||||
require.False(t, ok)
|
||||
require.Nil(t, attemptSinkFromContext(ctx))
|
||||
return respWant, nil
|
||||
},
|
||||
}
|
||||
|
||||
model := &debugModel{
|
||||
inner: inner,
|
||||
svc: svc,
|
||||
opts: RecorderOptions{
|
||||
ChatID: chatID,
|
||||
OwnerID: ownerID,
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := model.Generate(context.Background(), fantasy.Call{})
|
||||
require.NoError(t, err)
|
||||
require.Same(t, respWant, resp)
|
||||
}
|
||||
|
||||
func TestDebugModel_Generate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
chatID := uuid.New()
|
||||
ownerID := uuid.New()
|
||||
runID := uuid.New()
|
||||
call := fantasy.Call{
|
||||
Prompt: fantasy.Prompt{fantasy.NewUserMessage("hello")},
|
||||
MaxOutputTokens: int64Ptr(128),
|
||||
Temperature: float64Ptr(0.25),
|
||||
}
|
||||
respWant := &fantasy.Response{
|
||||
Content: fantasy.ResponseContent{
|
||||
fantasy.TextContent{Text: "hello"},
|
||||
fantasy.ToolCallContent{ToolCallID: "tool-1", ToolName: "tool", Input: `{}`},
|
||||
fantasy.SourceContent{ID: "source-1", Title: "docs", URL: "https://example.com"},
|
||||
},
|
||||
FinishReason: fantasy.FinishReasonStop,
|
||||
Usage: fantasy.Usage{InputTokens: 10, OutputTokens: 4, TotalTokens: 14},
|
||||
Warnings: []fantasy.CallWarning{{Message: "warning"}},
|
||||
}
|
||||
|
||||
expectDebugLoggingEnabled(t, db, ownerID)
|
||||
stepID := expectCreateStep(t, db, runID, chatID, OperationGenerate)
|
||||
expectUpdateStep(t, db, stepID, chatID, StatusCompleted, func(params database.UpdateChatDebugStepParams) {
|
||||
require.True(t, params.NormalizedResponse.Valid)
|
||||
require.True(t, params.Usage.Valid)
|
||||
require.True(t, params.Attempts.Valid)
|
||||
// Clean successes (no prior error) leave the error column
|
||||
// as SQL NULL rather than sending jsonClear.
|
||||
require.False(t, params.Error.Valid)
|
||||
require.False(t, params.Metadata.Valid)
|
||||
|
||||
// Verify actual JSON content so a broken tag or field
|
||||
// rename is caught rather than only checking .Valid.
|
||||
var usage fantasy.Usage
|
||||
require.NoError(t, json.Unmarshal(params.Usage.RawMessage, &usage))
|
||||
require.EqualValues(t, 10, usage.InputTokens)
|
||||
require.EqualValues(t, 4, usage.OutputTokens)
|
||||
require.EqualValues(t, 14, usage.TotalTokens)
|
||||
|
||||
var resp map[string]any
|
||||
require.NoError(t, json.Unmarshal(params.NormalizedResponse.RawMessage, &resp))
|
||||
require.Equal(t, "stop", resp["finish_reason"])
|
||||
})
|
||||
|
||||
svc := NewService(db, testutil.Logger(t), nil)
|
||||
inner := &chattest.FakeModel{
|
||||
GenerateFn: func(ctx context.Context, got fantasy.Call) (*fantasy.Response, error) {
|
||||
require.Equal(t, call, got)
|
||||
stepCtx, ok := StepFromContext(ctx)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, runID, stepCtx.RunID)
|
||||
require.Equal(t, chatID, stepCtx.ChatID)
|
||||
require.Equal(t, int32(1), stepCtx.StepNumber)
|
||||
require.Equal(t, OperationGenerate, stepCtx.Operation)
|
||||
require.NotEqual(t, uuid.Nil, stepCtx.StepID)
|
||||
require.NotNil(t, attemptSinkFromContext(ctx))
|
||||
return respWant, nil
|
||||
},
|
||||
}
|
||||
|
||||
model := &debugModel{
|
||||
inner: inner,
|
||||
svc: svc,
|
||||
opts: RecorderOptions{ChatID: chatID, OwnerID: ownerID},
|
||||
}
|
||||
t.Cleanup(func() { CleanupStepCounter(runID) })
|
||||
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})
|
||||
|
||||
resp, err := model.Generate(ctx, call)
|
||||
require.NoError(t, err)
|
||||
require.Same(t, respWant, resp)
|
||||
}
|
||||
|
||||
func TestDebugModel_GeneratePersistsAttemptsWithoutResponseClose(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
chatID := uuid.New()
|
||||
ownerID := uuid.New()
|
||||
runID := uuid.New()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
body, err := io.ReadAll(req.Body)
|
||||
require.NoError(t, err)
|
||||
require.JSONEq(t, `{"message":"hello","api_key":"super-secret"}`,
|
||||
string(body))
|
||||
require.Equal(t, "Bearer top-secret", req.Header.Get("Authorization"))
|
||||
|
||||
rw.Header().Set("Content-Type", "application/json")
|
||||
rw.Header().Set("X-API-Key", "response-secret")
|
||||
rw.WriteHeader(http.StatusCreated)
|
||||
_, _ = rw.Write([]byte(`{"token":"response-secret","safe":"ok"}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
expectDebugLoggingEnabled(t, db, ownerID)
|
||||
stepID := expectCreateStep(t, db, runID, chatID, OperationGenerate)
|
||||
expectUpdateStep(t, db, stepID, chatID, StatusCompleted, func(params database.UpdateChatDebugStepParams) {
|
||||
require.True(t, params.Attempts.Valid)
|
||||
require.True(t, params.NormalizedResponse.Valid)
|
||||
require.True(t, params.Usage.Valid)
|
||||
|
||||
var attempts []Attempt
|
||||
require.NoError(t, json.Unmarshal(params.Attempts.RawMessage, &attempts))
|
||||
require.Len(t, attempts, 1)
|
||||
require.Equal(t, attemptStatusCompleted, attempts[0].Status)
|
||||
require.Equal(t, http.StatusCreated, attempts[0].ResponseStatus)
|
||||
})
|
||||
|
||||
svc := NewService(db, testutil.Logger(t), nil)
|
||||
inner := &chattest.FakeModel{
|
||||
GenerateFn: func(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) {
|
||||
client := &http.Client{Transport: &RecordingTransport{Base: server.Client().Transport}}
|
||||
req, err := http.NewRequestWithContext(
|
||||
ctx,
|
||||
http.MethodPost,
|
||||
server.URL,
|
||||
strings.NewReader(`{"message":"hello","api_key":"super-secret"}`),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
req.Header.Set("Authorization", "Bearer top-secret")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
require.JSONEq(t, `{"token":"response-secret","safe":"ok"}`, string(body))
|
||||
require.NoError(t, resp.Body.Close())
|
||||
return &fantasy.Response{FinishReason: fantasy.FinishReasonStop}, nil
|
||||
},
|
||||
}
|
||||
|
||||
model := &debugModel{
|
||||
inner: inner,
|
||||
svc: svc,
|
||||
opts: RecorderOptions{ChatID: chatID, OwnerID: ownerID},
|
||||
}
|
||||
t.Cleanup(func() { CleanupStepCounter(runID) })
|
||||
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})
|
||||
|
||||
resp, err := model.Generate(ctx, fantasy.Call{})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
}
|
||||
|
||||
func TestDebugModel_GenerateError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
chatID := uuid.New()
|
||||
ownerID := uuid.New()
|
||||
runID := uuid.New()
|
||||
wantErr := &testError{message: "boom"}
|
||||
|
||||
expectDebugLoggingEnabled(t, db, ownerID)
|
||||
stepID := expectCreateStep(t, db, runID, chatID, OperationGenerate)
|
||||
expectUpdateStep(t, db, stepID, chatID, StatusError, func(params database.UpdateChatDebugStepParams) {
|
||||
require.False(t, params.NormalizedResponse.Valid)
|
||||
require.False(t, params.Usage.Valid)
|
||||
require.True(t, params.Attempts.Valid)
|
||||
require.True(t, params.Error.Valid)
|
||||
require.False(t, params.Metadata.Valid)
|
||||
})
|
||||
|
||||
svc := NewService(db, testutil.Logger(t), nil)
|
||||
model := &debugModel{
|
||||
inner: &chattest.FakeModel{
|
||||
GenerateFn: func(context.Context, fantasy.Call) (*fantasy.Response, error) {
|
||||
return nil, wantErr
|
||||
},
|
||||
},
|
||||
svc: svc,
|
||||
opts: RecorderOptions{ChatID: chatID, OwnerID: ownerID},
|
||||
}
|
||||
t.Cleanup(func() { CleanupStepCounter(runID) })
|
||||
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})
|
||||
|
||||
resp, err := model.Generate(ctx, fantasy.Call{})
|
||||
require.Nil(t, resp)
|
||||
require.ErrorIs(t, err, wantErr)
|
||||
}
|
||||
|
||||
func TestStepStatusForError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("Canceled", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
require.Equal(t, StatusInterrupted, stepStatusForError(context.Canceled))
|
||||
})
|
||||
|
||||
t.Run("DeadlineExceeded", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
require.Equal(t, StatusInterrupted, stepStatusForError(context.DeadlineExceeded))
|
||||
})
|
||||
|
||||
t.Run("OtherError", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
require.Equal(t, StatusError, stepStatusForError(xerrors.New("boom")))
|
||||
})
|
||||
}
|
||||
|
||||
func TestDebugModel_Stream(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
chatID := uuid.New()
|
||||
ownerID := uuid.New()
|
||||
runID := uuid.New()
|
||||
errPart := xerrors.New("chunk failed")
|
||||
parts := []fantasy.StreamPart{
|
||||
{Type: fantasy.StreamPartTypeTextDelta, Delta: "hel"},
|
||||
{Type: fantasy.StreamPartTypeToolCall, ID: "tool-call-1", ToolCallName: "tool"},
|
||||
{Type: fantasy.StreamPartTypeSource, ID: "source-1", URL: "https://example.com", Title: "docs"},
|
||||
{Type: fantasy.StreamPartTypeWarnings, Warnings: []fantasy.CallWarning{{Message: "w1"}, {Message: "w2"}}},
|
||||
{Type: fantasy.StreamPartTypeError, Error: errPart},
|
||||
{Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop, Usage: fantasy.Usage{InputTokens: 8, OutputTokens: 3, TotalTokens: 11}},
|
||||
}
|
||||
|
||||
expectDebugLoggingEnabled(t, db, ownerID)
|
||||
stepID := expectCreateStep(t, db, runID, chatID, OperationStream)
|
||||
expectUpdateStep(t, db, stepID, chatID, StatusError, func(params database.UpdateChatDebugStepParams) {
|
||||
require.True(t, params.NormalizedResponse.Valid)
|
||||
require.True(t, params.Usage.Valid)
|
||||
require.True(t, params.Attempts.Valid)
|
||||
require.True(t, params.Error.Valid)
|
||||
require.True(t, params.Metadata.Valid)
|
||||
})
|
||||
|
||||
svc := NewService(db, testutil.Logger(t), nil)
|
||||
model := &debugModel{
|
||||
inner: &chattest.FakeModel{
|
||||
StreamFn: func(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
stepCtx, ok := StepFromContext(ctx)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, runID, stepCtx.RunID)
|
||||
require.Equal(t, chatID, stepCtx.ChatID)
|
||||
require.Equal(t, int32(1), stepCtx.StepNumber)
|
||||
require.Equal(t, OperationStream, stepCtx.Operation)
|
||||
require.NotEqual(t, uuid.Nil, stepCtx.StepID)
|
||||
require.NotNil(t, attemptSinkFromContext(ctx))
|
||||
return partsToSeq(parts), nil
|
||||
},
|
||||
},
|
||||
svc: svc,
|
||||
opts: RecorderOptions{ChatID: chatID, OwnerID: ownerID},
|
||||
}
|
||||
t.Cleanup(func() { CleanupStepCounter(runID) })
|
||||
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})
|
||||
|
||||
seq, err := model.Stream(ctx, fantasy.Call{})
|
||||
require.NoError(t, err)
|
||||
|
||||
got := make([]fantasy.StreamPart, 0, len(parts))
|
||||
for part := range seq {
|
||||
got = append(got, part)
|
||||
}
|
||||
|
||||
require.Equal(t, parts, got)
|
||||
}
|
||||
|
||||
func TestDebugModel_StreamObject(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
chatID := uuid.New()
|
||||
ownerID := uuid.New()
|
||||
runID := uuid.New()
|
||||
parts := []fantasy.ObjectStreamPart{
|
||||
{Type: fantasy.ObjectStreamPartTypeTextDelta, Delta: "ob"},
|
||||
{Type: fantasy.ObjectStreamPartTypeTextDelta, Delta: "ject"},
|
||||
{Type: fantasy.ObjectStreamPartTypeObject, Object: map[string]any{"value": "object"}},
|
||||
{Type: fantasy.ObjectStreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop, Usage: fantasy.Usage{InputTokens: 5, OutputTokens: 2, TotalTokens: 7}},
|
||||
}
|
||||
|
||||
expectDebugLoggingEnabled(t, db, ownerID)
|
||||
stepID := expectCreateStep(t, db, runID, chatID, OperationStream)
|
||||
expectUpdateStep(t, db, stepID, chatID, StatusCompleted, func(params database.UpdateChatDebugStepParams) {
|
||||
require.True(t, params.NormalizedResponse.Valid)
|
||||
require.True(t, params.Usage.Valid)
|
||||
require.True(t, params.Attempts.Valid)
|
||||
// Clean successes (no prior error) leave the error column
|
||||
// as SQL NULL rather than sending jsonClear.
|
||||
require.False(t, params.Error.Valid)
|
||||
require.True(t, params.Metadata.Valid)
|
||||
})
|
||||
|
||||
svc := NewService(db, testutil.Logger(t), nil)
|
||||
model := &debugModel{
|
||||
inner: &chattest.FakeModel{
|
||||
StreamObjectFn: func(ctx context.Context, call fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) {
|
||||
stepCtx, ok := StepFromContext(ctx)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, runID, stepCtx.RunID)
|
||||
require.Equal(t, chatID, stepCtx.ChatID)
|
||||
require.Equal(t, int32(1), stepCtx.StepNumber)
|
||||
require.Equal(t, OperationStream, stepCtx.Operation)
|
||||
require.NotEqual(t, uuid.Nil, stepCtx.StepID)
|
||||
require.NotNil(t, attemptSinkFromContext(ctx))
|
||||
return objectPartsToSeq(parts), nil
|
||||
},
|
||||
},
|
||||
svc: svc,
|
||||
opts: RecorderOptions{ChatID: chatID, OwnerID: ownerID},
|
||||
}
|
||||
t.Cleanup(func() { CleanupStepCounter(runID) })
|
||||
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})
|
||||
|
||||
seq, err := model.StreamObject(ctx, fantasy.ObjectCall{})
|
||||
require.NoError(t, err)
|
||||
|
||||
got := make([]fantasy.ObjectStreamPart, 0, len(parts))
|
||||
for part := range seq {
|
||||
got = append(got, part)
|
||||
}
|
||||
|
||||
require.Equal(t, parts, got)
|
||||
}
|
||||
|
||||
// TestDebugModel_StreamCompletedAfterFinish verifies that when a consumer
|
||||
// stops iteration after receiving a finish part, the step is marked as
|
||||
// completed rather than interrupted.
|
||||
func TestDebugModel_StreamCompletedAfterFinish(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
chatID := uuid.New()
|
||||
ownerID := uuid.New()
|
||||
runID := uuid.New()
|
||||
parts := []fantasy.StreamPart{
|
||||
{Type: fantasy.StreamPartTypeTextDelta, Delta: "hello"},
|
||||
{Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop, Usage: fantasy.Usage{InputTokens: 5, OutputTokens: 1, TotalTokens: 6}},
|
||||
}
|
||||
|
||||
// The mock expectation for UpdateStep with StatusCompleted is the
|
||||
// assertion: if the wrapper chose StatusInterrupted instead, the
|
||||
// mock would reject the call.
|
||||
expectDebugLoggingEnabled(t, db, ownerID)
|
||||
stepID := expectCreateStep(t, db, runID, chatID, OperationStream)
|
||||
expectUpdateStep(t, db, stepID, chatID, StatusCompleted, nil)
|
||||
|
||||
svc := NewService(db, testutil.Logger(t), nil)
|
||||
model := &debugModel{
|
||||
inner: &chattest.FakeModel{
|
||||
StreamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
return partsToSeq(parts), nil
|
||||
},
|
||||
},
|
||||
svc: svc,
|
||||
opts: RecorderOptions{ChatID: chatID, OwnerID: ownerID},
|
||||
}
|
||||
t.Cleanup(func() { CleanupStepCounter(runID) })
|
||||
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})
|
||||
|
||||
seq, err := model.Stream(ctx, fantasy.Call{})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Consumer reads the finish part then breaks — this should still
|
||||
// be considered a completed stream, not interrupted.
|
||||
for part := range seq {
|
||||
if part.Type == fantasy.StreamPartTypeFinish {
|
||||
break
|
||||
}
|
||||
}
|
||||
// gomock verifies UpdateStep was called with StatusCompleted.
|
||||
}
|
||||
|
||||
// TestDebugModel_StreamInterruptedBeforeFinish verifies that when a consumer
|
||||
// stops iteration before receiving a finish part, the step is marked as
|
||||
// interrupted.
|
||||
func TestDebugModel_StreamInterruptedBeforeFinish(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
chatID := uuid.New()
|
||||
ownerID := uuid.New()
|
||||
runID := uuid.New()
|
||||
parts := []fantasy.StreamPart{
|
||||
{Type: fantasy.StreamPartTypeTextDelta, Delta: "hello"},
|
||||
{Type: fantasy.StreamPartTypeTextDelta, Delta: " world"},
|
||||
{Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop},
|
||||
}
|
||||
|
||||
// The mock expectation for UpdateStep with StatusInterrupted is the
|
||||
// assertion: breaking before the finish part means interrupted.
|
||||
expectDebugLoggingEnabled(t, db, ownerID)
|
||||
stepID := expectCreateStep(t, db, runID, chatID, OperationStream)
|
||||
expectUpdateStep(t, db, stepID, chatID, StatusInterrupted, nil)
|
||||
|
||||
svc := NewService(db, testutil.Logger(t), nil)
|
||||
model := &debugModel{
|
||||
inner: &chattest.FakeModel{
|
||||
StreamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
return partsToSeq(parts), nil
|
||||
},
|
||||
},
|
||||
svc: svc,
|
||||
opts: RecorderOptions{ChatID: chatID, OwnerID: ownerID},
|
||||
}
|
||||
t.Cleanup(func() { CleanupStepCounter(runID) })
|
||||
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})
|
||||
|
||||
seq, err := model.Stream(ctx, fantasy.Call{})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Consumer reads the first delta then breaks before finish.
|
||||
count := 0
|
||||
for range seq {
|
||||
count++
|
||||
if count == 1 {
|
||||
break
|
||||
}
|
||||
}
|
||||
require.Equal(t, 1, count)
|
||||
// gomock verifies UpdateStep was called with StatusInterrupted.
|
||||
}
|
||||
|
||||
func TestDebugModel_StreamRejectsNilSequence(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
chatID := uuid.New()
|
||||
ownerID := uuid.New()
|
||||
runID := uuid.New()
|
||||
|
||||
expectDebugLoggingEnabled(t, db, ownerID)
|
||||
stepID := expectCreateStep(t, db, runID, chatID, OperationStream)
|
||||
expectUpdateStep(t, db, stepID, chatID, StatusError, func(params database.UpdateChatDebugStepParams) {
|
||||
require.False(t, params.NormalizedResponse.Valid)
|
||||
require.False(t, params.Usage.Valid)
|
||||
require.True(t, params.Attempts.Valid)
|
||||
require.True(t, params.Error.Valid)
|
||||
require.False(t, params.Metadata.Valid)
|
||||
})
|
||||
|
||||
svc := NewService(db, testutil.Logger(t), nil)
|
||||
model := &debugModel{
|
||||
inner: &chattest.FakeModel{
|
||||
StreamFn: func(context.Context, fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
var nilStream fantasy.StreamResponse
|
||||
return nilStream, nil
|
||||
},
|
||||
},
|
||||
svc: svc,
|
||||
opts: RecorderOptions{ChatID: chatID, OwnerID: ownerID},
|
||||
}
|
||||
t.Cleanup(func() { CleanupStepCounter(runID) })
|
||||
|
||||
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})
|
||||
|
||||
seq, err := model.Stream(ctx, fantasy.Call{})
|
||||
require.Nil(t, seq)
|
||||
require.ErrorIs(t, err, ErrNilModelResult)
|
||||
}
|
||||
|
||||
func TestDebugModel_StreamObjectRejectsNilSequence(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
chatID := uuid.New()
|
||||
ownerID := uuid.New()
|
||||
runID := uuid.New()
|
||||
|
||||
expectDebugLoggingEnabled(t, db, ownerID)
|
||||
stepID := expectCreateStep(t, db, runID, chatID, OperationStream)
|
||||
expectUpdateStep(t, db, stepID, chatID, StatusError, func(params database.UpdateChatDebugStepParams) {
|
||||
require.False(t, params.NormalizedResponse.Valid)
|
||||
require.False(t, params.Usage.Valid)
|
||||
require.True(t, params.Attempts.Valid)
|
||||
require.True(t, params.Error.Valid)
|
||||
require.True(t, params.Metadata.Valid)
|
||||
})
|
||||
|
||||
svc := NewService(db, testutil.Logger(t), nil)
|
||||
model := &debugModel{
|
||||
inner: &chattest.FakeModel{
|
||||
StreamObjectFn: func(context.Context, fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) {
|
||||
var nilStream fantasy.ObjectStreamResponse
|
||||
return nilStream, nil
|
||||
},
|
||||
},
|
||||
svc: svc,
|
||||
opts: RecorderOptions{ChatID: chatID, OwnerID: ownerID},
|
||||
}
|
||||
t.Cleanup(func() { CleanupStepCounter(runID) })
|
||||
|
||||
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})
|
||||
|
||||
seq, err := model.StreamObject(ctx, fantasy.ObjectCall{})
|
||||
require.Nil(t, seq)
|
||||
require.ErrorIs(t, err, ErrNilModelResult)
|
||||
}
|
||||
|
||||
func TestDebugModel_StreamEarlyStop(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
chatID := uuid.New()
|
||||
ownerID := uuid.New()
|
||||
runID := uuid.New()
|
||||
parts := []fantasy.StreamPart{
|
||||
{Type: fantasy.StreamPartTypeTextDelta, Delta: "first"},
|
||||
{Type: fantasy.StreamPartTypeTextDelta, Delta: "second"},
|
||||
}
|
||||
|
||||
expectDebugLoggingEnabled(t, db, ownerID)
|
||||
stepID := expectCreateStep(t, db, runID, chatID, OperationStream)
|
||||
expectUpdateStep(t, db, stepID, chatID, StatusInterrupted, func(params database.UpdateChatDebugStepParams) {
|
||||
require.True(t, params.NormalizedResponse.Valid)
|
||||
require.False(t, params.Usage.Valid)
|
||||
require.True(t, params.Attempts.Valid)
|
||||
require.False(t, params.Error.Valid)
|
||||
require.True(t, params.Metadata.Valid)
|
||||
})
|
||||
|
||||
svc := NewService(db, testutil.Logger(t), nil)
|
||||
model := &debugModel{
|
||||
inner: &chattest.FakeModel{
|
||||
StreamFn: func(context.Context, fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
return partsToSeq(parts), nil
|
||||
},
|
||||
},
|
||||
svc: svc,
|
||||
opts: RecorderOptions{ChatID: chatID, OwnerID: ownerID},
|
||||
}
|
||||
t.Cleanup(func() { CleanupStepCounter(runID) })
|
||||
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})
|
||||
|
||||
seq, err := model.Stream(ctx, fantasy.Call{})
|
||||
require.NoError(t, err)
|
||||
|
||||
count := 0
|
||||
for part := range seq {
|
||||
require.Equal(t, parts[0], part)
|
||||
count++
|
||||
break
|
||||
}
|
||||
require.Equal(t, 1, count)
|
||||
}
|
||||
|
||||
func TestStreamErrorStatus(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("CancellationBecomesInterrupted", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
require.Equal(t, StatusInterrupted, streamErrorStatus(StatusCompleted, context.Canceled))
|
||||
})
|
||||
|
||||
t.Run("DeadlineExceededBecomesInterrupted", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
require.Equal(t, StatusInterrupted, streamErrorStatus(StatusCompleted, context.DeadlineExceeded))
|
||||
})
|
||||
|
||||
t.Run("NilErrorBecomesError", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
require.Equal(t, StatusError, streamErrorStatus(StatusCompleted, nil))
|
||||
})
|
||||
|
||||
t.Run("ExistingErrorWins", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
require.Equal(t, StatusError, streamErrorStatus(StatusError, context.Canceled))
|
||||
})
|
||||
}
|
||||
|
||||
func objectPartsToSeq(parts []fantasy.ObjectStreamPart) fantasy.ObjectStreamResponse {
|
||||
return func(yield func(fantasy.ObjectStreamPart) bool) {
|
||||
for _, part := range parts {
|
||||
if !yield(part) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func partsToSeq(parts []fantasy.StreamPart) fantasy.StreamResponse {
|
||||
return func(yield func(fantasy.StreamPart) bool) {
|
||||
for _, part := range parts {
|
||||
if !yield(part) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDebugModel_GenerateObject(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
chatID := uuid.New()
|
||||
ownerID := uuid.New()
|
||||
runID := uuid.New()
|
||||
call := fantasy.ObjectCall{
|
||||
Prompt: fantasy.Prompt{fantasy.NewUserMessage("summarize")},
|
||||
SchemaName: "Summary",
|
||||
MaxOutputTokens: int64Ptr(256),
|
||||
}
|
||||
respWant := &fantasy.ObjectResponse{
|
||||
RawText: `{"title":"test"}`,
|
||||
FinishReason: fantasy.FinishReasonStop,
|
||||
Usage: fantasy.Usage{InputTokens: 5, OutputTokens: 3, TotalTokens: 8},
|
||||
}
|
||||
|
||||
svc := NewService(db, testutil.Logger(t), nil)
|
||||
inner := &chattest.FakeModel{
|
||||
GenerateObjectFn: func(ctx context.Context, got fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
|
||||
require.Equal(t, call, got)
|
||||
stepCtx, ok := StepFromContext(ctx)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, runID, stepCtx.RunID)
|
||||
require.Equal(t, chatID, stepCtx.ChatID)
|
||||
require.Equal(t, OperationGenerate, stepCtx.Operation)
|
||||
require.NotEqual(t, uuid.Nil, stepCtx.StepID)
|
||||
require.NotNil(t, attemptSinkFromContext(ctx))
|
||||
return respWant, nil
|
||||
},
|
||||
}
|
||||
|
||||
model := &debugModel{
|
||||
inner: inner,
|
||||
svc: svc,
|
||||
opts: RecorderOptions{ChatID: chatID, OwnerID: ownerID},
|
||||
}
|
||||
t.Cleanup(func() { CleanupStepCounter(runID) })
|
||||
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})
|
||||
|
||||
resp, err := model.GenerateObject(ctx, call)
|
||||
require.NoError(t, err)
|
||||
require.Same(t, respWant, resp)
|
||||
}
|
||||
|
||||
func TestDebugModel_GenerateObjectError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
chatID := uuid.New()
|
||||
runID := uuid.New()
|
||||
wantErr := &testError{message: "object boom"}
|
||||
|
||||
svc := NewService(db, testutil.Logger(t), nil)
|
||||
model := &debugModel{
|
||||
inner: &chattest.FakeModel{
|
||||
GenerateObjectFn: func(context.Context, fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
|
||||
return nil, wantErr
|
||||
},
|
||||
},
|
||||
svc: svc,
|
||||
opts: RecorderOptions{ChatID: chatID, OwnerID: uuid.New()},
|
||||
}
|
||||
t.Cleanup(func() { CleanupStepCounter(runID) })
|
||||
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})
|
||||
|
||||
resp, err := model.GenerateObject(ctx, fantasy.ObjectCall{})
|
||||
require.Nil(t, resp)
|
||||
require.ErrorIs(t, err, wantErr)
|
||||
}
|
||||
|
||||
func TestDebugModel_GenerateObjectRejectsNilResponse(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
chatID := uuid.New()
|
||||
runID := uuid.New()
|
||||
|
||||
svc := NewService(db, testutil.Logger(t), nil)
|
||||
model := &debugModel{
|
||||
inner: &chattest.FakeModel{
|
||||
GenerateObjectFn: func(context.Context, fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
|
||||
return nil, nil //nolint:nilnil // Intentionally testing nil response handling.
|
||||
},
|
||||
},
|
||||
svc: svc,
|
||||
opts: RecorderOptions{ChatID: chatID, OwnerID: uuid.New()},
|
||||
}
|
||||
t.Cleanup(func() { CleanupStepCounter(runID) })
|
||||
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})
|
||||
|
||||
resp, err := model.GenerateObject(ctx, fantasy.ObjectCall{})
|
||||
require.Nil(t, resp)
|
||||
require.ErrorIs(t, err, ErrNilModelResult)
|
||||
}
|
||||
|
||||
func TestWrapStreamSeq_CompletedNotDowngradedByCtxCancel(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handle := &stepHandle{
|
||||
stepCtx: &StepContext{StepID: uuid.New(), RunID: uuid.New(), ChatID: uuid.New()},
|
||||
sink: &attemptSink{},
|
||||
}
|
||||
|
||||
// Create a context that we cancel after the stream finishes.
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
parts := []fantasy.StreamPart{
|
||||
{Type: fantasy.StreamPartTypeTextDelta, Delta: "hello"},
|
||||
{Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop, Usage: fantasy.Usage{InputTokens: 5, OutputTokens: 1, TotalTokens: 6}},
|
||||
}
|
||||
seq := wrapStreamSeq(ctx, handle, partsToSeq(parts))
|
||||
|
||||
//nolint:revive // Intentionally consuming iterator to trigger side-effects.
|
||||
for range seq {
|
||||
}
|
||||
|
||||
// Cancel the context after the stream has been fully consumed
|
||||
// and finalized. The status should remain completed.
|
||||
cancel()
|
||||
|
||||
handle.mu.Lock()
|
||||
status := handle.status
|
||||
handle.mu.Unlock()
|
||||
require.Equal(t, StatusCompleted, status)
|
||||
}
|
||||
|
||||
func TestWrapObjectStreamSeq_CompletedNotDowngradedByCtxCancel(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handle := &stepHandle{
|
||||
stepCtx: &StepContext{StepID: uuid.New(), RunID: uuid.New(), ChatID: uuid.New()},
|
||||
sink: &attemptSink{},
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
parts := []fantasy.ObjectStreamPart{
|
||||
{Type: fantasy.ObjectStreamPartTypeTextDelta, Delta: "obj"},
|
||||
{Type: fantasy.ObjectStreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop, Usage: fantasy.Usage{InputTokens: 3, OutputTokens: 1, TotalTokens: 4}},
|
||||
}
|
||||
seq := wrapObjectStreamSeq(ctx, handle, objectPartsToSeq(parts))
|
||||
|
||||
//nolint:revive // Intentionally consuming iterator to trigger side-effects.
|
||||
for range seq {
|
||||
}
|
||||
|
||||
cancel()
|
||||
|
||||
handle.mu.Lock()
|
||||
status := handle.status
|
||||
handle.mu.Unlock()
|
||||
require.Equal(t, StatusCompleted, status)
|
||||
}
|
||||
|
||||
func TestWrapStreamSeq_DroppedStreamFinalizedOnCtxCancel(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handle := &stepHandle{
|
||||
stepCtx: &StepContext{StepID: uuid.New(), RunID: uuid.New(), ChatID: uuid.New()},
|
||||
sink: &attemptSink{},
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
parts := []fantasy.StreamPart{
|
||||
{Type: fantasy.StreamPartTypeTextDelta, Delta: "hello"},
|
||||
{Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop},
|
||||
}
|
||||
|
||||
// Create the wrapped stream but never iterate it.
|
||||
_ = wrapStreamSeq(ctx, handle, partsToSeq(parts))
|
||||
|
||||
// Cancel the context — the AfterFunc safety net should finalize
|
||||
// the step as interrupted.
|
||||
cancel()
|
||||
|
||||
// AfterFunc fires asynchronously; give it a moment.
|
||||
require.Eventually(t, func() bool {
|
||||
handle.mu.Lock()
|
||||
defer handle.mu.Unlock()
|
||||
return handle.status == StatusInterrupted
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
}
|
||||
|
||||
func int64Ptr(v int64) *int64 { return &v }
|
||||
|
||||
func float64Ptr(v float64) *float64 { return &v }
|
||||
@@ -1,379 +0,0 @@
|
||||
package chatdebug //nolint:testpackage // Uses unexported normalization helpers.
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
func TestNormalizeCall_PreservesToolSchemasAndMessageToolPayloads(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
payload := normalizeCall(fantasy.Call{
|
||||
Prompt: fantasy.Prompt{
|
||||
{
|
||||
Role: fantasy.MessageRoleAssistant,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.ToolCallPart{
|
||||
ToolCallID: "call-search",
|
||||
ToolName: "search_docs",
|
||||
Input: `{"query":"debug panel"}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: fantasy.MessageRoleTool,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.ToolResultPart{
|
||||
ToolCallID: "call-search",
|
||||
Output: fantasy.ToolResultOutputContentText{
|
||||
Text: `{"matches":["model.go","DebugStepCard.tsx"]}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Tools: []fantasy.Tool{
|
||||
fantasy.FunctionTool{
|
||||
Name: "search_docs",
|
||||
Description: "Searches documentation.",
|
||||
InputSchema: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"query": map[string]any{"type": "string"},
|
||||
},
|
||||
"required": []string{"query"},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
require.Len(t, payload.Tools, 1)
|
||||
require.True(t, payload.Tools[0].HasInputSchema)
|
||||
require.JSONEq(t, `{"type":"object","properties":{"query":{"type":"string"}},"required":["query"]}`,
|
||||
string(payload.Tools[0].InputSchema))
|
||||
|
||||
require.Len(t, payload.Messages, 2)
|
||||
require.Equal(t, "tool-call", payload.Messages[0].Parts[0].Type)
|
||||
require.Equal(t, `{"query":"debug panel"}`, payload.Messages[0].Parts[0].Arguments)
|
||||
require.Equal(t, "tool-result", payload.Messages[1].Parts[0].Type)
|
||||
require.Equal(t,
|
||||
`{"matches":["model.go","DebugStepCard.tsx"]}`,
|
||||
payload.Messages[1].Parts[0].Result,
|
||||
)
|
||||
}
|
||||
|
||||
func TestNormalizers_SkipTypedNilInterfaceValues(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("MessageParts", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var nilPart *fantasy.TextPart
|
||||
parts := normalizeMessageParts([]fantasy.MessagePart{
|
||||
nilPart,
|
||||
fantasy.TextPart{Text: "hello"},
|
||||
})
|
||||
require.Len(t, parts, 1)
|
||||
require.Equal(t, "text", parts[0].Type)
|
||||
require.Equal(t, "hello", parts[0].Text)
|
||||
})
|
||||
|
||||
t.Run("Tools", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var nilTool *fantasy.FunctionTool
|
||||
tools := normalizeTools([]fantasy.Tool{
|
||||
nilTool,
|
||||
fantasy.FunctionTool{Name: "search_docs"},
|
||||
})
|
||||
require.Len(t, tools, 1)
|
||||
require.Equal(t, "function", tools[0].Type)
|
||||
require.Equal(t, "search_docs", tools[0].Name)
|
||||
})
|
||||
|
||||
t.Run("ContentParts", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var nilContent *fantasy.TextContent
|
||||
content := normalizeContentParts(fantasy.ResponseContent{
|
||||
nilContent,
|
||||
fantasy.TextContent{Text: "hello"},
|
||||
})
|
||||
require.Len(t, content, 1)
|
||||
require.Equal(t, "text", content[0].Type)
|
||||
require.Equal(t, "hello", content[0].Text)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAppendNormalizedStreamContent_PreservesOrderAndCanonicalTypes(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var content []normalizedContentPart
|
||||
streamDebugBytes := 0
|
||||
for _, part := range []fantasy.StreamPart{
|
||||
{Type: fantasy.StreamPartTypeTextDelta, Delta: "before "},
|
||||
{Type: fantasy.StreamPartTypeToolCall, ID: "call-1", ToolCallName: "search_docs", ToolCallInput: `{"query":"debug"}`},
|
||||
{Type: fantasy.StreamPartTypeToolResult, ID: "call-1", ToolCallName: "search_docs", ToolCallInput: `{"matches":1}`},
|
||||
{Type: fantasy.StreamPartTypeTextDelta, Delta: "after"},
|
||||
} {
|
||||
content = appendNormalizedStreamContent(content, part, &streamDebugBytes)
|
||||
}
|
||||
|
||||
require.Equal(t, []normalizedContentPart{
|
||||
{Type: "text", Text: "before "},
|
||||
{Type: "tool-call", ToolCallID: "call-1", ToolName: "search_docs", Arguments: `{"query":"debug"}`, InputLength: len(`{"query":"debug"}`)},
|
||||
{Type: "tool-result", ToolCallID: "call-1", ToolName: "search_docs", Result: `{"matches":1}`},
|
||||
{Type: "text", Text: "after"},
|
||||
}, content)
|
||||
}
|
||||
|
||||
func TestAppendNormalizedStreamContent_GlobalTextCap(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
streamDebugBytes := 0
|
||||
long := strings.Repeat("a", maxStreamDebugTextBytes)
|
||||
var content []normalizedContentPart
|
||||
for _, part := range []fantasy.StreamPart{
|
||||
{Type: fantasy.StreamPartTypeTextDelta, Delta: long},
|
||||
{Type: fantasy.StreamPartTypeToolCall, ID: "call-1", ToolCallName: "search_docs", ToolCallInput: `{}`},
|
||||
{Type: fantasy.StreamPartTypeTextDelta, Delta: "tail"},
|
||||
} {
|
||||
content = appendNormalizedStreamContent(content, part, &streamDebugBytes)
|
||||
}
|
||||
|
||||
require.Len(t, content, 2)
|
||||
require.Equal(t, strings.Repeat("a", maxStreamDebugTextBytes), content[0].Text)
|
||||
require.Equal(t, "tool-call", content[1].Type)
|
||||
require.Equal(t, maxStreamDebugTextBytes, streamDebugBytes)
|
||||
}
|
||||
|
||||
func TestWrapStreamSeq_SourceCountExcludesToolResults(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handle := &stepHandle{
|
||||
stepCtx: &StepContext{StepID: uuid.New(), RunID: uuid.New(), ChatID: uuid.New()},
|
||||
sink: &attemptSink{},
|
||||
}
|
||||
seq := wrapStreamSeq(context.Background(), handle, partsToSeq([]fantasy.StreamPart{
|
||||
{Type: fantasy.StreamPartTypeToolResult, ID: "tool-1", ToolCallName: "search_docs"},
|
||||
{Type: fantasy.StreamPartTypeSource, ID: "source-1", URL: "https://example.com", Title: "docs"},
|
||||
{Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop},
|
||||
}))
|
||||
|
||||
partCount := 0
|
||||
for range seq {
|
||||
partCount++
|
||||
}
|
||||
require.Equal(t, 3, partCount)
|
||||
|
||||
metadata, ok := handle.metadata.(map[string]any)
|
||||
require.True(t, ok)
|
||||
summary, ok := metadata["stream_summary"].(streamSummary)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, 1, summary.SourceCount)
|
||||
}
|
||||
|
||||
func TestWrapObjectStreamSeq_UsesStructuredOutputPayload(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handle := &stepHandle{
|
||||
stepCtx: &StepContext{StepID: uuid.New(), RunID: uuid.New(), ChatID: uuid.New()},
|
||||
sink: &attemptSink{},
|
||||
}
|
||||
usage := fantasy.Usage{InputTokens: 3, OutputTokens: 2, TotalTokens: 5}
|
||||
seq := wrapObjectStreamSeq(context.Background(), handle, objectPartsToSeq([]fantasy.ObjectStreamPart{
|
||||
{Type: fantasy.ObjectStreamPartTypeTextDelta, Delta: "ob"},
|
||||
{Type: fantasy.ObjectStreamPartTypeTextDelta, Delta: "ject"},
|
||||
{Type: fantasy.ObjectStreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop, Usage: usage},
|
||||
}))
|
||||
|
||||
partCount := 0
|
||||
for range seq {
|
||||
partCount++
|
||||
}
|
||||
require.Equal(t, 3, partCount)
|
||||
|
||||
resp, ok := handle.response.(normalizedObjectResponsePayload)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, normalizedObjectResponsePayload{
|
||||
RawTextLength: len("object"),
|
||||
FinishReason: string(fantasy.FinishReasonStop),
|
||||
Usage: normalizeUsage(usage),
|
||||
StructuredOutput: true,
|
||||
}, resp)
|
||||
}
|
||||
|
||||
func TestNormalizeResponse_UsesCanonicalToolTypes(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
payload := normalizeResponse(&fantasy.Response{
|
||||
Content: fantasy.ResponseContent{
|
||||
fantasy.ToolCallContent{
|
||||
ToolCallID: "call-calc",
|
||||
ToolName: "calculator",
|
||||
Input: `{"operation":"add","operands":[2,2]}`,
|
||||
},
|
||||
fantasy.ToolResultContent{
|
||||
ToolCallID: "call-calc",
|
||||
ToolName: "calculator",
|
||||
Result: fantasy.ToolResultOutputContentText{Text: `{"sum":4}`},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
require.Len(t, payload.Content, 2)
|
||||
require.Equal(t, "tool-call", payload.Content[0].Type)
|
||||
require.Equal(t, "tool-result", payload.Content[1].Type)
|
||||
}
|
||||
|
||||
func TestBoundText_RespectsDocumentedRuneLimit(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
runes := make([]rune, MaxMessagePartTextLength+5)
|
||||
for i := range runes {
|
||||
runes[i] = 'a'
|
||||
}
|
||||
input := string(runes)
|
||||
got := boundText(input)
|
||||
require.Equal(t, MaxMessagePartTextLength, len([]rune(got)))
|
||||
require.Equal(t, '…', []rune(got)[len([]rune(got))-1])
|
||||
}
|
||||
|
||||
func TestNormalizeToolResultOutput(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("TextValue", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := normalizeToolResultOutput(fantasy.ToolResultOutputContentText{Text: "hello"})
|
||||
require.Equal(t, "hello", got)
|
||||
})
|
||||
|
||||
t.Run("TextPointer", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := normalizeToolResultOutput(&fantasy.ToolResultOutputContentText{Text: "hello"})
|
||||
require.Equal(t, "hello", got)
|
||||
})
|
||||
|
||||
t.Run("TextPointerNil", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := normalizeToolResultOutput((*fantasy.ToolResultOutputContentText)(nil))
|
||||
require.Equal(t, "", got)
|
||||
})
|
||||
|
||||
t.Run("ErrorValue", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := normalizeToolResultOutput(fantasy.ToolResultOutputContentError{
|
||||
Error: xerrors.New("tool failed"),
|
||||
})
|
||||
require.Equal(t, "tool failed", got)
|
||||
})
|
||||
|
||||
t.Run("ErrorValueNilError", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := normalizeToolResultOutput(fantasy.ToolResultOutputContentError{Error: nil})
|
||||
require.Equal(t, "", got)
|
||||
})
|
||||
|
||||
t.Run("ErrorPointer", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := normalizeToolResultOutput(&fantasy.ToolResultOutputContentError{
|
||||
Error: xerrors.New("ptr fail"),
|
||||
})
|
||||
require.Equal(t, "ptr fail", got)
|
||||
})
|
||||
|
||||
t.Run("ErrorPointerNil", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := normalizeToolResultOutput((*fantasy.ToolResultOutputContentError)(nil))
|
||||
require.Equal(t, "", got)
|
||||
})
|
||||
|
||||
t.Run("ErrorPointerNilError", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := normalizeToolResultOutput(&fantasy.ToolResultOutputContentError{Error: nil})
|
||||
require.Equal(t, "", got)
|
||||
})
|
||||
|
||||
t.Run("MediaWithText", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := normalizeToolResultOutput(fantasy.ToolResultOutputContentMedia{
|
||||
Text: "caption",
|
||||
MediaType: "image/png",
|
||||
})
|
||||
require.Equal(t, "caption", got)
|
||||
})
|
||||
|
||||
t.Run("MediaWithoutText", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := normalizeToolResultOutput(fantasy.ToolResultOutputContentMedia{
|
||||
MediaType: "image/png",
|
||||
})
|
||||
require.Equal(t, "[media output: image/png]", got)
|
||||
})
|
||||
|
||||
t.Run("MediaWithoutTextOrType", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := normalizeToolResultOutput(fantasy.ToolResultOutputContentMedia{})
|
||||
require.Equal(t, "[media output]", got)
|
||||
})
|
||||
|
||||
t.Run("MediaPointerNil", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := normalizeToolResultOutput((*fantasy.ToolResultOutputContentMedia)(nil))
|
||||
require.Equal(t, "", got)
|
||||
})
|
||||
|
||||
t.Run("MediaPointerWithText", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := normalizeToolResultOutput(&fantasy.ToolResultOutputContentMedia{
|
||||
Text: "ptr caption",
|
||||
MediaType: "image/jpeg",
|
||||
})
|
||||
require.Equal(t, "ptr caption", got)
|
||||
})
|
||||
|
||||
t.Run("NilOutput", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := normalizeToolResultOutput(nil)
|
||||
require.Equal(t, "", got)
|
||||
})
|
||||
|
||||
t.Run("DefaultJSON", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
// An unexpected type falls through to the default JSON
|
||||
// marshal branch.
|
||||
got := normalizeToolResultOutput(fantasy.ToolResultOutputContentText{
|
||||
Text: "fallback",
|
||||
})
|
||||
require.Equal(t, "fallback", got)
|
||||
})
|
||||
}
|
||||
|
||||
func TestNormalizeResponse_PreservesToolCallArguments(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
payload := normalizeResponse(&fantasy.Response{
|
||||
Content: fantasy.ResponseContent{
|
||||
fantasy.ToolCallContent{
|
||||
ToolCallID: "call-calc",
|
||||
ToolName: "calculator",
|
||||
Input: `{"operation":"add","operands":[2,2]}`,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
require.Len(t, payload.Content, 1)
|
||||
require.Equal(t, "call-calc", payload.Content[0].ToolCallID)
|
||||
require.Equal(t, "calculator", payload.Content[0].ToolName)
|
||||
require.JSONEq(t,
|
||||
`{"operation":"add","operands":[2,2]}`,
|
||||
payload.Content[0].Arguments,
|
||||
)
|
||||
require.Equal(t, len(`{"operation":"add","operands":[2,2]}`), payload.Content[0].InputLength)
|
||||
}
|
||||
@@ -1,319 +0,0 @@
|
||||
package chatdebug
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/google/uuid"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
)
|
||||
|
||||
// RecorderOptions identifies the chat/model context for debug recording.
|
||||
type RecorderOptions struct {
|
||||
ChatID uuid.UUID
|
||||
OwnerID uuid.UUID
|
||||
Provider string
|
||||
Model string
|
||||
}
|
||||
|
||||
// WrapModel returns model unchanged when debug recording is disabled, or a
|
||||
// debug wrapper when a service is available.
|
||||
func WrapModel(
|
||||
model fantasy.LanguageModel,
|
||||
svc *Service,
|
||||
opts RecorderOptions,
|
||||
) fantasy.LanguageModel {
|
||||
if model == nil {
|
||||
panic("chatdebug: nil LanguageModel")
|
||||
}
|
||||
if svc == nil {
|
||||
return model
|
||||
}
|
||||
return &debugModel{inner: model, svc: svc, opts: opts}
|
||||
}
|
||||
|
||||
type attemptSink struct {
|
||||
mu sync.Mutex
|
||||
attempts []Attempt
|
||||
attemptCounter atomic.Int32
|
||||
}
|
||||
|
||||
func (s *attemptSink) nextAttemptNumber() int {
|
||||
if s == nil {
|
||||
panic("chatdebug: nil attemptSink")
|
||||
}
|
||||
return int(s.attemptCounter.Add(1))
|
||||
}
|
||||
|
||||
func (s *attemptSink) record(a Attempt) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.attempts = append(s.attempts, a)
|
||||
}
|
||||
|
||||
func (s *attemptSink) snapshot() []Attempt {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
attempts := make([]Attempt, len(s.attempts))
|
||||
copy(attempts, s.attempts)
|
||||
return attempts
|
||||
}
|
||||
|
||||
type attemptSinkKey struct{}
|
||||
|
||||
func withAttemptSink(ctx context.Context, sink *attemptSink) context.Context {
|
||||
if sink == nil {
|
||||
panic("chatdebug: nil attemptSink")
|
||||
}
|
||||
return context.WithValue(ctx, attemptSinkKey{}, sink)
|
||||
}
|
||||
|
||||
func attemptSinkFromContext(ctx context.Context) *attemptSink {
|
||||
sink, _ := ctx.Value(attemptSinkKey{}).(*attemptSink)
|
||||
return sink
|
||||
}
|
||||
|
||||
var stepCounters sync.Map // map[uuid.UUID]*atomic.Int32
|
||||
|
||||
// runRefCounts tracks how many live RunContext instances reference each
|
||||
// RunID. Cleanup of shared state (step counters) is deferred until the
|
||||
// last RunContext for a given RunID is garbage collected.
|
||||
var runRefCounts sync.Map // map[uuid.UUID]*atomic.Int32
|
||||
|
||||
func trackRunRef(runID uuid.UUID) {
|
||||
val, _ := runRefCounts.LoadOrStore(runID, &atomic.Int32{})
|
||||
counter := val.(*atomic.Int32)
|
||||
counter.Add(1)
|
||||
}
|
||||
|
||||
// releaseRunRef decrements the reference count for runID and cleans up
|
||||
// shared state when the last reference is released.
|
||||
func releaseRunRef(runID uuid.UUID) {
|
||||
val, ok := runRefCounts.Load(runID)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
counter := val.(*atomic.Int32)
|
||||
if counter.Add(-1) <= 0 {
|
||||
runRefCounts.Delete(runID)
|
||||
stepCounters.Delete(runID)
|
||||
}
|
||||
}
|
||||
|
||||
func nextStepNumber(runID uuid.UUID) int32 {
|
||||
val, _ := stepCounters.LoadOrStore(runID, &atomic.Int32{})
|
||||
counter, ok := val.(*atomic.Int32)
|
||||
if !ok {
|
||||
panic("chatdebug: invalid step counter type")
|
||||
}
|
||||
return counter.Add(1)
|
||||
}
|
||||
|
||||
// CleanupStepCounter removes per-run step counter and reference count
|
||||
// state. This is used by tests and later stacked branches that have a
|
||||
// real run lifecycle.
|
||||
func CleanupStepCounter(runID uuid.UUID) {
|
||||
stepCounters.Delete(runID)
|
||||
runRefCounts.Delete(runID)
|
||||
}
|
||||
|
||||
const stepFinalizeTimeout = 5 * time.Second
|
||||
|
||||
func stepFinalizeContext(ctx context.Context) (context.Context, context.CancelFunc) {
|
||||
if ctx == nil {
|
||||
panic("chatdebug: nil context")
|
||||
}
|
||||
return context.WithTimeout(context.WithoutCancel(ctx), stepFinalizeTimeout)
|
||||
}
|
||||
|
||||
func syncStepCounter(runID uuid.UUID, stepNumber int32) {
|
||||
val, _ := stepCounters.LoadOrStore(runID, &atomic.Int32{})
|
||||
counter, ok := val.(*atomic.Int32)
|
||||
if !ok {
|
||||
panic("chatdebug: invalid step counter type")
|
||||
}
|
||||
for {
|
||||
current := counter.Load()
|
||||
if current >= stepNumber {
|
||||
return
|
||||
}
|
||||
if counter.CompareAndSwap(current, stepNumber) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type stepHandle struct {
|
||||
stepCtx *StepContext
|
||||
sink *attemptSink
|
||||
svc *Service
|
||||
opts RecorderOptions
|
||||
mu sync.Mutex
|
||||
status Status
|
||||
response any
|
||||
usage any
|
||||
err any
|
||||
metadata any
|
||||
// hadError tracks whether a prior finalization wrote an error
|
||||
// payload. Used to decide whether a successful retry needs to
|
||||
// explicitly clear the error field via jsonClear.
|
||||
hadError bool
|
||||
}
|
||||
|
||||
// beginStep validates preconditions, creates a debug step, and returns a
|
||||
// handle plus an enriched context carrying StepContext and attemptSink.
|
||||
// Returns (nil, original ctx) when debug recording should be skipped.
|
||||
func beginStep(
|
||||
ctx context.Context,
|
||||
svc *Service,
|
||||
opts RecorderOptions,
|
||||
op Operation,
|
||||
normalizedReq any,
|
||||
) (*stepHandle, context.Context) {
|
||||
if svc == nil {
|
||||
return nil, ctx
|
||||
}
|
||||
|
||||
rc, ok := RunFromContext(ctx)
|
||||
if !ok || rc.RunID == uuid.Nil {
|
||||
return nil, ctx
|
||||
}
|
||||
|
||||
chatID := opts.ChatID
|
||||
if chatID == uuid.Nil {
|
||||
chatID = rc.ChatID
|
||||
}
|
||||
if !svc.IsEnabled(ctx, chatID, opts.OwnerID) {
|
||||
return nil, ctx
|
||||
}
|
||||
|
||||
holder, reuseStep := reuseHolderFromContext(ctx)
|
||||
if reuseStep {
|
||||
holder.mu.Lock()
|
||||
defer holder.mu.Unlock()
|
||||
// Only reuse the cached handle if it belongs to the same run.
|
||||
// A different RunContext means a new logical run, so we must
|
||||
// create a fresh step to avoid cross-run attribution.
|
||||
if holder.handle != nil && holder.handle.stepCtx.RunID == rc.RunID {
|
||||
enriched := ContextWithStep(ctx, holder.handle.stepCtx)
|
||||
enriched = withAttemptSink(enriched, holder.handle.sink)
|
||||
return holder.handle, enriched
|
||||
}
|
||||
}
|
||||
|
||||
stepNum := nextStepNumber(rc.RunID)
|
||||
step, err := svc.CreateStep(ctx, CreateStepParams{
|
||||
RunID: rc.RunID,
|
||||
ChatID: chatID,
|
||||
StepNumber: stepNum,
|
||||
Operation: op,
|
||||
Status: StatusInProgress,
|
||||
HistoryTipMessageID: rc.HistoryTipMessageID,
|
||||
NormalizedRequest: normalizedReq,
|
||||
})
|
||||
if err != nil {
|
||||
svc.log.Warn(ctx, "failed to create chat debug step",
|
||||
slog.Error(err),
|
||||
slog.F("chat_id", chatID),
|
||||
slog.F("run_id", rc.RunID),
|
||||
slog.F("operation", op),
|
||||
)
|
||||
return nil, ctx
|
||||
}
|
||||
|
||||
syncStepCounter(rc.RunID, step.StepNumber)
|
||||
actualStepNumber := step.StepNumber
|
||||
if actualStepNumber == 0 {
|
||||
actualStepNumber = stepNum
|
||||
}
|
||||
|
||||
sc := &StepContext{
|
||||
StepID: step.ID,
|
||||
RunID: rc.RunID,
|
||||
ChatID: chatID,
|
||||
StepNumber: actualStepNumber,
|
||||
Operation: op,
|
||||
HistoryTipMessageID: rc.HistoryTipMessageID,
|
||||
}
|
||||
handle := &stepHandle{stepCtx: sc, sink: &attemptSink{}, svc: svc, opts: opts}
|
||||
enriched := ContextWithStep(ctx, handle.stepCtx)
|
||||
enriched = withAttemptSink(enriched, handle.sink)
|
||||
if reuseStep {
|
||||
holder.handle = handle
|
||||
}
|
||||
|
||||
return handle, enriched
|
||||
}
|
||||
|
||||
// finish updates the debug step with final status and data. A mutex
|
||||
// guards the write so concurrent callers (e.g. retried stream wrappers
|
||||
// sharing a reuse handle) don't race. Unlike sync.Once, later retries
|
||||
// are allowed to overwrite earlier failure results so the step reflects
|
||||
// the final outcome.
|
||||
func (h *stepHandle) finish(
|
||||
ctx context.Context,
|
||||
status Status,
|
||||
response any,
|
||||
usage any,
|
||||
errPayload any,
|
||||
metadata any,
|
||||
) {
|
||||
if h == nil || h.stepCtx == nil {
|
||||
return
|
||||
}
|
||||
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
h.status = status
|
||||
h.response = response
|
||||
h.usage = usage
|
||||
h.err = errPayload
|
||||
h.metadata = metadata
|
||||
if errPayload != nil {
|
||||
h.hadError = true
|
||||
}
|
||||
if h.svc == nil {
|
||||
return
|
||||
}
|
||||
|
||||
updateCtx, cancel := stepFinalizeContext(ctx)
|
||||
defer cancel()
|
||||
|
||||
// When the step completes successfully after a prior failed
|
||||
// attempt, the error field must be explicitly cleared. A plain
|
||||
// nil would leave the COALESCE-based SQL untouched, so we send
|
||||
// jsonClear{} which serializes as a valid JSONB null. Only do
|
||||
// this when a prior error was actually recorded — otherwise
|
||||
// clean successes would get a spurious JSONB null that downstream
|
||||
// aggregation could misread as an error.
|
||||
errValue := errPayload
|
||||
if errValue == nil && status == StatusCompleted && h.hadError {
|
||||
errValue = jsonClear{}
|
||||
}
|
||||
|
||||
if _, updateErr := h.svc.UpdateStep(updateCtx, UpdateStepParams{
|
||||
ID: h.stepCtx.StepID,
|
||||
ChatID: h.stepCtx.ChatID,
|
||||
Status: status,
|
||||
NormalizedResponse: response,
|
||||
Usage: usage,
|
||||
Attempts: h.sink.snapshot(),
|
||||
Error: errValue,
|
||||
Metadata: metadata,
|
||||
FinishedAt: time.Now(),
|
||||
}); updateErr != nil {
|
||||
h.svc.log.Warn(updateCtx, "failed to finalize chat debug step",
|
||||
slog.Error(updateErr),
|
||||
slog.F("step_id", h.stepCtx.StepID),
|
||||
slog.F("chat_id", h.stepCtx.ChatID),
|
||||
slog.F("status", status),
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -1,184 +0,0 @@
|
||||
package chatdebug //nolint:testpackage // Uses unexported recorder helpers.
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sort"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database/dbmock"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chattest"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestAttemptSink_ThreadSafe(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const n = 256
|
||||
|
||||
sink := &attemptSink{}
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(n)
|
||||
|
||||
for i := range n {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
sink.record(Attempt{Number: i + 1, ResponseStatus: 200 + i})
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
attempts := sink.snapshot()
|
||||
require.Len(t, attempts, n)
|
||||
|
||||
numbers := make([]int, 0, n)
|
||||
statuses := make([]int, 0, n)
|
||||
for _, attempt := range attempts {
|
||||
numbers = append(numbers, attempt.Number)
|
||||
statuses = append(statuses, attempt.ResponseStatus)
|
||||
}
|
||||
sort.Ints(numbers)
|
||||
sort.Ints(statuses)
|
||||
|
||||
for i := range n {
|
||||
require.Equal(t, i+1, numbers[i])
|
||||
require.Equal(t, 200+i, statuses[i])
|
||||
}
|
||||
}
|
||||
|
||||
func TestAttemptSinkContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
require.Nil(t, attemptSinkFromContext(ctx))
|
||||
|
||||
sink := &attemptSink{}
|
||||
ctx = withAttemptSink(ctx, sink)
|
||||
require.Same(t, sink, attemptSinkFromContext(ctx))
|
||||
}
|
||||
|
||||
func TestWrapModel_NilModel(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
require.Panics(t, func() {
|
||||
WrapModel(nil, &Service{}, RecorderOptions{})
|
||||
})
|
||||
}
|
||||
|
||||
func TestWrapModel_NilService(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
model := &chattest.FakeModel{ProviderName: "provider", ModelName: "model"}
|
||||
wrapped := WrapModel(model, nil, RecorderOptions{})
|
||||
require.Same(t, model, wrapped)
|
||||
}
|
||||
|
||||
func TestNextStepNumber_Concurrent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const n = 256
|
||||
|
||||
runID := uuid.New()
|
||||
results := make([]int, n)
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(n)
|
||||
|
||||
for i := range n {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
results[i] = int(nextStepNumber(runID))
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
sort.Ints(results)
|
||||
for i := range n {
|
||||
require.Equal(t, i+1, results[i])
|
||||
}
|
||||
}
|
||||
|
||||
func TestStepFinalizeContext_StripsCancellation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
baseCtx, cancelBase := context.WithCancel(context.Background())
|
||||
cancelBase()
|
||||
require.ErrorIs(t, baseCtx.Err(), context.Canceled)
|
||||
|
||||
finalizeCtx, cancelFinalize := stepFinalizeContext(baseCtx)
|
||||
defer cancelFinalize()
|
||||
|
||||
require.NoError(t, finalizeCtx.Err())
|
||||
_, hasDeadline := finalizeCtx.Deadline()
|
||||
require.True(t, hasDeadline)
|
||||
}
|
||||
|
||||
func TestSyncStepCounter_AdvancesCounter(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
runID := uuid.New()
|
||||
t.Cleanup(func() { CleanupStepCounter(runID) })
|
||||
|
||||
syncStepCounter(runID, 7)
|
||||
require.Equal(t, int32(8), nextStepNumber(runID))
|
||||
}
|
||||
|
||||
func TestStepHandleFinish_NilHandle(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var handle *stepHandle
|
||||
handle.finish(context.Background(), StatusCompleted, nil, nil, nil, nil)
|
||||
}
|
||||
|
||||
func TestBeginStep_NilService(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
handle, enriched := beginStep(ctx, nil, RecorderOptions{}, OperationGenerate, nil)
|
||||
require.Nil(t, handle)
|
||||
require.Nil(t, attemptSinkFromContext(enriched))
|
||||
_, ok := StepFromContext(enriched)
|
||||
require.False(t, ok)
|
||||
}
|
||||
|
||||
func TestBeginStep_FallsBackToRunChatID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
runID := uuid.New()
|
||||
runChatID := uuid.New()
|
||||
ownerID := uuid.New()
|
||||
expectDebugLoggingEnabled(t, db, ownerID)
|
||||
expectCreateStepNumberWithRequestValidity(t, db, runID, runChatID, 1, OperationGenerate, false)
|
||||
|
||||
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: runChatID})
|
||||
svc := NewService(db, testutil.Logger(t), nil)
|
||||
|
||||
handle, enriched := beginStep(ctx, svc, RecorderOptions{OwnerID: ownerID}, OperationGenerate, nil)
|
||||
require.NotNil(t, handle)
|
||||
require.Equal(t, runChatID, handle.stepCtx.ChatID)
|
||||
|
||||
stepCtx, ok := StepFromContext(enriched)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, runChatID, stepCtx.ChatID)
|
||||
}
|
||||
|
||||
func TestWrapModel_ReturnsDebugModel(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
model := &chattest.FakeModel{ProviderName: "provider", ModelName: "model"}
|
||||
wrapped := WrapModel(model, &Service{}, RecorderOptions{})
|
||||
|
||||
require.NotSame(t, model, wrapped)
|
||||
require.IsType(t, &debugModel{}, wrapped)
|
||||
require.Implements(t, (*fantasy.LanguageModel)(nil), wrapped)
|
||||
require.Equal(t, model.Provider(), wrapped.Provider())
|
||||
require.Equal(t, model.Model(), wrapped.Model())
|
||||
}
|
||||
@@ -1,227 +0,0 @@
|
||||
package chatdebug
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
// RedactedValue replaces sensitive values in debug payloads.
|
||||
const RedactedValue = "[REDACTED]"
|
||||
|
||||
var sensitiveHeaderNames = map[string]struct{}{
|
||||
"authorization": {},
|
||||
"x-api-key": {},
|
||||
"api-key": {},
|
||||
"proxy-authorization": {},
|
||||
"cookie": {},
|
||||
"set-cookie": {},
|
||||
}
|
||||
|
||||
// sensitiveJSONKeyFragments triggers redaction for JSON keys containing
|
||||
// these substrings. Notably, "token" is intentionally absent because it
|
||||
// false-positively redacts LLM token-usage fields (input_tokens,
|
||||
// output_tokens, prompt_tokens, completion_tokens, reasoning_tokens,
|
||||
// cache_creation_input_tokens, cache_read_input_tokens, etc.). Auth-
|
||||
// related token fields are caught by the exact-match set below.
|
||||
var sensitiveJSONKeyFragments = []string{
|
||||
"secret",
|
||||
"password",
|
||||
"authorization",
|
||||
"credential",
|
||||
}
|
||||
|
||||
// sensitiveJSONKeyExact matches auth-related token/key field names
|
||||
// without false-positiving on LLM usage counters. Includes both
|
||||
// snake_case originals and their camelCase-lowered equivalents
|
||||
// (e.g. "accessToken" → "accesstoken") so that providers using
|
||||
// either convention are caught.
|
||||
var sensitiveJSONKeyExact = map[string]struct{}{
|
||||
"token": {},
|
||||
"access_token": {},
|
||||
"accesstoken": {},
|
||||
"refresh_token": {},
|
||||
"refreshtoken": {},
|
||||
"id_token": {},
|
||||
"idtoken": {},
|
||||
"api_token": {},
|
||||
"apitoken": {},
|
||||
"api_key": {},
|
||||
"apikey": {},
|
||||
"api-key": {},
|
||||
"x-api-key": {},
|
||||
"auth_token": {},
|
||||
"authtoken": {},
|
||||
"bearer_token": {},
|
||||
"bearertoken": {},
|
||||
"session_token": {},
|
||||
"sessiontoken": {},
|
||||
"security_token": {},
|
||||
"securitytoken": {},
|
||||
"private_key": {},
|
||||
"privatekey": {},
|
||||
"signing_key": {},
|
||||
"signingkey": {},
|
||||
"secret_key": {},
|
||||
"secretkey": {},
|
||||
}
|
||||
|
||||
// RedactHeaders returns a flattened copy of h with sensitive values redacted.
|
||||
func RedactHeaders(h http.Header) map[string]string {
|
||||
if h == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
redacted := make(map[string]string, len(h))
|
||||
for name, values := range h {
|
||||
if isSensitiveName(name) {
|
||||
redacted[name] = RedactedValue
|
||||
continue
|
||||
}
|
||||
redacted[name] = strings.Join(values, ", ")
|
||||
}
|
||||
return redacted
|
||||
}
|
||||
|
||||
// RedactJSONSecrets redacts sensitive JSON values by key name. When
|
||||
// the input is not valid JSON (truncated body, HTML error page, etc.)
|
||||
// the raw bytes are replaced entirely with a diagnostic placeholder
|
||||
// to avoid leaking credentials from malformed payloads.
|
||||
func RedactJSONSecrets(data []byte) []byte {
|
||||
if len(data) == 0 {
|
||||
return data
|
||||
}
|
||||
|
||||
decoder := json.NewDecoder(bytes.NewReader(data))
|
||||
decoder.UseNumber()
|
||||
|
||||
var value any
|
||||
if err := decoder.Decode(&value); err != nil {
|
||||
// Cannot parse: replace entirely to prevent credential leaks
|
||||
// from non-JSON error responses (HTML pages, partial bodies).
|
||||
return []byte(`{"error":"chatdebug: body is not valid JSON, redacted for safety"}`)
|
||||
}
|
||||
if err := consumeJSONEOF(decoder); err != nil {
|
||||
return []byte(`{"error":"chatdebug: body contains extra JSON values, redacted for safety"}`)
|
||||
}
|
||||
|
||||
redacted, changed := redactJSONValue(value)
|
||||
if !changed {
|
||||
return data
|
||||
}
|
||||
|
||||
encoded, err := json.Marshal(redacted)
|
||||
if err != nil {
|
||||
return data
|
||||
}
|
||||
return encoded
|
||||
}
|
||||
|
||||
func consumeJSONEOF(decoder *json.Decoder) error {
|
||||
var extra any
|
||||
err := decoder.Decode(&extra)
|
||||
if errors.Is(err, io.EOF) {
|
||||
return nil
|
||||
}
|
||||
if err == nil {
|
||||
return xerrors.New("chatdebug: extra JSON values")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
var safeRateLimitHeaderNames = map[string]struct{}{
|
||||
"anthropic-ratelimit-requests-limit": {},
|
||||
"anthropic-ratelimit-requests-remaining": {},
|
||||
"anthropic-ratelimit-requests-reset": {},
|
||||
"anthropic-ratelimit-tokens-limit": {},
|
||||
"anthropic-ratelimit-tokens-remaining": {},
|
||||
"anthropic-ratelimit-tokens-reset": {},
|
||||
"x-ratelimit-limit-requests": {},
|
||||
"x-ratelimit-limit-tokens": {},
|
||||
"x-ratelimit-remaining-requests": {},
|
||||
"x-ratelimit-remaining-tokens": {},
|
||||
"x-ratelimit-reset-requests": {},
|
||||
"x-ratelimit-reset-tokens": {},
|
||||
}
|
||||
|
||||
// isSensitiveName reports whether a name (header or query parameter)
|
||||
// looks like a credential-carrying key. Exact-match headers are
|
||||
// checked first, then the rate-limit allowlist, then substring
|
||||
// patterns for API keys and auth tokens.
|
||||
func isSensitiveName(name string) bool {
|
||||
lowerName := strings.ToLower(name)
|
||||
if _, ok := sensitiveHeaderNames[lowerName]; ok {
|
||||
return true
|
||||
}
|
||||
if _, ok := safeRateLimitHeaderNames[lowerName]; ok {
|
||||
return false
|
||||
}
|
||||
if strings.Contains(lowerName, "api-key") ||
|
||||
strings.Contains(lowerName, "api_key") ||
|
||||
strings.Contains(lowerName, "apikey") {
|
||||
return true
|
||||
}
|
||||
// Catch any header containing "token" (e.g. Token, X-Token,
|
||||
// X-Auth-Token). Safe rate-limit headers like
|
||||
// x-ratelimit-remaining-tokens are already allowlisted above
|
||||
// and will not reach this point.
|
||||
if strings.Contains(lowerName, "token") {
|
||||
return true
|
||||
}
|
||||
return strings.Contains(lowerName, "secret") ||
|
||||
strings.Contains(lowerName, "bearer")
|
||||
}
|
||||
|
||||
func isSensitiveJSONKey(key string) bool {
|
||||
lowerKey := strings.ToLower(key)
|
||||
if _, ok := sensitiveJSONKeyExact[lowerKey]; ok {
|
||||
return true
|
||||
}
|
||||
for _, fragment := range sensitiveJSONKeyFragments {
|
||||
if strings.Contains(lowerKey, fragment) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func redactJSONValue(value any) (any, bool) {
|
||||
switch typed := value.(type) {
|
||||
case map[string]any:
|
||||
changed := false
|
||||
for key, child := range typed {
|
||||
if isSensitiveJSONKey(key) {
|
||||
if current, ok := child.(string); ok && current == RedactedValue {
|
||||
continue
|
||||
}
|
||||
typed[key] = RedactedValue
|
||||
changed = true
|
||||
continue
|
||||
}
|
||||
|
||||
redactedChild, childChanged := redactJSONValue(child)
|
||||
if childChanged {
|
||||
typed[key] = redactedChild
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
return typed, changed
|
||||
case []any:
|
||||
changed := false
|
||||
for i, child := range typed {
|
||||
redactedChild, childChanged := redactJSONValue(child)
|
||||
if childChanged {
|
||||
typed[i] = redactedChild
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
return typed, changed
|
||||
default:
|
||||
return value, false
|
||||
}
|
||||
}
|
||||
@@ -1,277 +0,0 @@
|
||||
package chatdebug_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatdebug"
|
||||
)
|
||||
|
||||
func TestRedactHeaders(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("nil input", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
require.Nil(t, chatdebug.RedactHeaders(nil))
|
||||
})
|
||||
|
||||
t.Run("empty header", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
redacted := chatdebug.RedactHeaders(http.Header{})
|
||||
require.NotNil(t, redacted)
|
||||
require.Empty(t, redacted)
|
||||
})
|
||||
|
||||
t.Run("authorization redacted and others preserved", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
headers := http.Header{
|
||||
"Authorization": {"Bearer secret-token"},
|
||||
"Accept": {"application/json"},
|
||||
}
|
||||
|
||||
redacted := chatdebug.RedactHeaders(headers)
|
||||
require.Equal(t, chatdebug.RedactedValue, redacted["Authorization"])
|
||||
require.Equal(t, "application/json", redacted["Accept"])
|
||||
})
|
||||
|
||||
t.Run("multi-value headers are flattened", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
headers := http.Header{
|
||||
"Accept": {"application/json", "text/plain"},
|
||||
}
|
||||
|
||||
redacted := chatdebug.RedactHeaders(headers)
|
||||
require.Equal(t, "application/json, text/plain", redacted["Accept"])
|
||||
})
|
||||
|
||||
t.Run("header name matching is case insensitive", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
lowerAuthorization := "authorization"
|
||||
upperAuthorization := "AUTHORIZATION"
|
||||
headers := http.Header{
|
||||
lowerAuthorization: {"lower"},
|
||||
upperAuthorization: {"upper"},
|
||||
}
|
||||
|
||||
redacted := chatdebug.RedactHeaders(headers)
|
||||
require.Equal(t, chatdebug.RedactedValue, redacted[lowerAuthorization])
|
||||
require.Equal(t, chatdebug.RedactedValue, redacted[upperAuthorization])
|
||||
})
|
||||
|
||||
t.Run("token and secret substrings are redacted", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
traceHeader := "X-Trace-ID"
|
||||
headers := http.Header{
|
||||
"X-Auth-Token": {"abc"},
|
||||
"X-Custom-Secret": {"def"},
|
||||
"X-Bearer": {"ghi"},
|
||||
traceHeader: {"trace"},
|
||||
}
|
||||
|
||||
redacted := chatdebug.RedactHeaders(headers)
|
||||
require.Equal(t, chatdebug.RedactedValue, redacted["X-Auth-Token"])
|
||||
require.Equal(t, chatdebug.RedactedValue, redacted["X-Custom-Secret"])
|
||||
require.Equal(t, chatdebug.RedactedValue, redacted["X-Bearer"])
|
||||
require.Equal(t, "trace", redacted[traceHeader])
|
||||
})
|
||||
|
||||
t.Run("known safe rate limit headers containing token are not redacted", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
headers := http.Header{
|
||||
"Anthropic-Ratelimit-Tokens-Limit": {"1000000"},
|
||||
"Anthropic-Ratelimit-Tokens-Remaining": {"999000"},
|
||||
"Anthropic-Ratelimit-Tokens-Reset": {"2026-03-31T08:55:26Z"},
|
||||
"X-RateLimit-Limit-Tokens": {"120000"},
|
||||
"X-RateLimit-Remaining-Tokens": {"119500"},
|
||||
"X-RateLimit-Reset-Tokens": {"12ms"},
|
||||
}
|
||||
|
||||
redacted := chatdebug.RedactHeaders(headers)
|
||||
require.Equal(t, "1000000", redacted["Anthropic-Ratelimit-Tokens-Limit"])
|
||||
require.Equal(t, "999000", redacted["Anthropic-Ratelimit-Tokens-Remaining"])
|
||||
require.Equal(t, "2026-03-31T08:55:26Z", redacted["Anthropic-Ratelimit-Tokens-Reset"])
|
||||
require.Equal(t, "120000", redacted["X-RateLimit-Limit-Tokens"])
|
||||
require.Equal(t, "119500", redacted["X-RateLimit-Remaining-Tokens"])
|
||||
require.Equal(t, "12ms", redacted["X-RateLimit-Reset-Tokens"])
|
||||
})
|
||||
|
||||
t.Run("non-standard headers with api-key pattern are redacted", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
headers := http.Header{
|
||||
"X-Custom-Api-Key": {"secret-key"},
|
||||
"X-Custom-Secret": {"secret-val"},
|
||||
"X-Custom-Session-Token": {"session-id"},
|
||||
}
|
||||
|
||||
redacted := chatdebug.RedactHeaders(headers)
|
||||
require.Equal(t, chatdebug.RedactedValue, redacted["X-Custom-Api-Key"])
|
||||
require.Equal(t, chatdebug.RedactedValue, redacted["X-Custom-Secret"])
|
||||
require.Equal(t, chatdebug.RedactedValue, redacted["X-Custom-Session-Token"])
|
||||
})
|
||||
|
||||
t.Run("rate limit headers with token in name are preserved", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Rate-limit headers containing "token" should NOT be redacted
|
||||
// because they carry usage/limit counts, not credentials.
|
||||
headers := http.Header{
|
||||
"X-Ratelimit-Limit-Tokens": {"1000000"},
|
||||
"X-Ratelimit-Remaining-Tokens": {"999000"},
|
||||
}
|
||||
|
||||
redacted := chatdebug.RedactHeaders(headers)
|
||||
require.Equal(t, "1000000", redacted["X-Ratelimit-Limit-Tokens"])
|
||||
require.Equal(t, "999000", redacted["X-Ratelimit-Remaining-Tokens"])
|
||||
})
|
||||
|
||||
t.Run("original header is not modified", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
headers := http.Header{
|
||||
"Authorization": {"Bearer keep-me"},
|
||||
"X-Test": {"value"},
|
||||
}
|
||||
|
||||
redacted := chatdebug.RedactHeaders(headers)
|
||||
redacted["X-Test"] = "changed"
|
||||
|
||||
require.Equal(t, []string{"Bearer keep-me"}, headers["Authorization"])
|
||||
require.Equal(t, []string{"value"}, headers["X-Test"])
|
||||
require.Equal(t, chatdebug.RedactedValue, redacted["Authorization"])
|
||||
})
|
||||
t.Run("api-key header variants are redacted", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
headers := http.Header{
|
||||
"X-Goog-Api-Key": {"secret"},
|
||||
"X-Api_Key": {"other-secret"},
|
||||
"X-Safe": {"ok"},
|
||||
}
|
||||
|
||||
redacted := chatdebug.RedactHeaders(headers)
|
||||
require.Equal(t, chatdebug.RedactedValue, redacted["X-Goog-Api-Key"])
|
||||
require.Equal(t, chatdebug.RedactedValue, redacted["X-Api_Key"])
|
||||
require.Equal(t, "ok", redacted["X-Safe"])
|
||||
})
|
||||
|
||||
t.Run("plain token headers are redacted", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Headers like "Token" or "X-Token" should be redacted
|
||||
// even without auth/session/access qualifiers.
|
||||
headers := http.Header{
|
||||
"Token": {"my-secret-token"},
|
||||
"X-Token": {"another-secret"},
|
||||
"X-Safe": {"ok"},
|
||||
}
|
||||
|
||||
redacted := chatdebug.RedactHeaders(headers)
|
||||
require.Equal(t, chatdebug.RedactedValue, redacted["Token"])
|
||||
require.Equal(t, chatdebug.RedactedValue, redacted["X-Token"])
|
||||
require.Equal(t, "ok", redacted["X-Safe"])
|
||||
})
|
||||
}
|
||||
|
||||
func TestRedactJSONSecrets(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("redacts top level secret fields", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
input := []byte(`{"api_key":"abc","token":"def","password":"ghi","safe":"ok"}`)
|
||||
redacted := chatdebug.RedactJSONSecrets(input)
|
||||
require.JSONEq(t, `{"api_key":"[REDACTED]","token":"[REDACTED]","password":"[REDACTED]","safe":"ok"}`, string(redacted))
|
||||
})
|
||||
|
||||
t.Run("redacts security_token exact key", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
input := []byte(`{"security_token":"s3cret","securityToken":"tok","safe":"ok"}`)
|
||||
redacted := chatdebug.RedactJSONSecrets(input)
|
||||
require.JSONEq(t, `{"security_token":"[REDACTED]","securityToken":"[REDACTED]","safe":"ok"}`, string(redacted))
|
||||
})
|
||||
|
||||
t.Run("preserves LLM token usage fields", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
input := []byte(`{"input_tokens":100,"output_tokens":50,"prompt_tokens":80,"completion_tokens":20,"reasoning_tokens":10,"cache_creation_input_tokens":5,"cache_read_input_tokens":3,"total_tokens":150,"max_tokens":4096,"max_output_tokens":2048}`)
|
||||
redacted := chatdebug.RedactJSONSecrets(input)
|
||||
// All usage/limit fields should be preserved, not redacted.
|
||||
require.Equal(t, input, redacted)
|
||||
})
|
||||
|
||||
t.Run("redacts nested objects", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
input := []byte(`{"outer":{"nested_secret":"abc","safe":1},"keep":true}`)
|
||||
redacted := chatdebug.RedactJSONSecrets(input)
|
||||
require.JSONEq(t, `{"outer":{"nested_secret":"[REDACTED]","safe":1},"keep":true}`, string(redacted))
|
||||
})
|
||||
|
||||
t.Run("redacts arrays of objects", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
input := []byte(`[{"token":"abc"},{"value":1,"credentials":{"access_key":"def"}}]`)
|
||||
redacted := chatdebug.RedactJSONSecrets(input)
|
||||
require.JSONEq(t, `[{"token":"[REDACTED]"},{"value":1,"credentials":"[REDACTED]"}]`, string(redacted))
|
||||
})
|
||||
|
||||
t.Run("concatenated JSON is replaced with diagnostic", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
input := []byte(`{"token":"abc"}{"safe":"ok"}`)
|
||||
result := chatdebug.RedactJSONSecrets(input)
|
||||
require.Contains(t, string(result), "extra JSON values")
|
||||
})
|
||||
|
||||
t.Run("non JSON input is replaced with diagnostic", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
input := []byte("not json")
|
||||
result := chatdebug.RedactJSONSecrets(input)
|
||||
require.Contains(t, string(result), "not valid JSON")
|
||||
})
|
||||
|
||||
t.Run("empty input is unchanged", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
input := []byte{}
|
||||
require.Equal(t, input, chatdebug.RedactJSONSecrets(input))
|
||||
})
|
||||
|
||||
t.Run("JSON without sensitive keys is unchanged", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
input := []byte(`{"safe":"ok","nested":{"value":1}}`)
|
||||
require.Equal(t, input, chatdebug.RedactJSONSecrets(input))
|
||||
})
|
||||
|
||||
t.Run("key matching is case insensitive", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
input := []byte(`{"API_KEY":"abc","Token":"def","PASSWORD":"ghi"}`)
|
||||
redacted := chatdebug.RedactJSONSecrets(input)
|
||||
require.JSONEq(t, `{"API_KEY":"[REDACTED]","Token":"[REDACTED]","PASSWORD":"[REDACTED]"}`, string(redacted))
|
||||
})
|
||||
|
||||
t.Run("camelCase token field names are redacted", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Providers may use camelCase (e.g. accessToken, refreshToken).
|
||||
// These should be redacted even though they don't match the
|
||||
// snake_case originals exactly.
|
||||
input := []byte(`{"accessToken":"abc","refreshToken":"def","authToken":"ghi","input_tokens":100,"output_tokens":50}`)
|
||||
redacted := chatdebug.RedactJSONSecrets(input)
|
||||
require.JSONEq(t, `{"accessToken":"[REDACTED]","refreshToken":"[REDACTED]","authToken":"[REDACTED]","input_tokens":100,"output_tokens":50}`, string(redacted))
|
||||
})
|
||||
}
|
||||
@@ -1,113 +0,0 @@
|
||||
package chatdebug //nolint:testpackage // Uses unexported recorder helpers.
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database/dbmock"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestBeginStepReuseStep(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("reuses handle under ReuseStep", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
chatID := uuid.New()
|
||||
ownerID := uuid.New()
|
||||
runID := uuid.New()
|
||||
t.Cleanup(func() { CleanupStepCounter(runID) })
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
expectDebugLoggingEnabled(t, db, ownerID)
|
||||
expectCreateStepNumberWithRequestValidity(
|
||||
t,
|
||||
db,
|
||||
runID,
|
||||
chatID,
|
||||
1,
|
||||
OperationStream,
|
||||
false,
|
||||
)
|
||||
expectDebugLoggingEnabled(t, db, ownerID)
|
||||
|
||||
svc := NewService(db, testutil.Logger(t), nil)
|
||||
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})
|
||||
ctx = ReuseStep(ctx)
|
||||
opts := RecorderOptions{ChatID: chatID, OwnerID: ownerID}
|
||||
|
||||
firstHandle, firstEnriched := beginStep(ctx, svc, opts, OperationStream, nil)
|
||||
secondHandle, secondEnriched := beginStep(ctx, svc, opts, OperationStream, nil)
|
||||
|
||||
require.NotNil(t, firstHandle)
|
||||
require.Same(t, firstHandle, secondHandle)
|
||||
require.Same(t, firstHandle.stepCtx, secondHandle.stepCtx)
|
||||
require.Same(t, firstHandle.sink, secondHandle.sink)
|
||||
require.Equal(t, runID, firstHandle.stepCtx.RunID)
|
||||
require.Equal(t, chatID, firstHandle.stepCtx.ChatID)
|
||||
require.Equal(t, int32(1), firstHandle.stepCtx.StepNumber)
|
||||
require.Equal(t, OperationStream, firstHandle.stepCtx.Operation)
|
||||
require.NotEqual(t, uuid.Nil, firstHandle.stepCtx.StepID)
|
||||
|
||||
firstStepCtx, ok := StepFromContext(firstEnriched)
|
||||
require.True(t, ok)
|
||||
secondStepCtx, ok := StepFromContext(secondEnriched)
|
||||
require.True(t, ok)
|
||||
require.Same(t, firstStepCtx, secondStepCtx)
|
||||
require.Same(t, firstHandle.stepCtx, firstStepCtx)
|
||||
require.Same(t, attemptSinkFromContext(firstEnriched), attemptSinkFromContext(secondEnriched))
|
||||
})
|
||||
|
||||
t.Run("creates new handles without ReuseStep", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
chatID := uuid.New()
|
||||
ownerID := uuid.New()
|
||||
runID := uuid.New()
|
||||
t.Cleanup(func() { CleanupStepCounter(runID) })
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
expectDebugLoggingEnabled(t, db, ownerID)
|
||||
expectCreateStepNumberWithRequestValidity(
|
||||
t,
|
||||
db,
|
||||
runID,
|
||||
chatID,
|
||||
1,
|
||||
OperationStream,
|
||||
false,
|
||||
)
|
||||
expectDebugLoggingEnabled(t, db, ownerID)
|
||||
expectCreateStepNumberWithRequestValidity(
|
||||
t,
|
||||
db,
|
||||
runID,
|
||||
chatID,
|
||||
2,
|
||||
OperationStream,
|
||||
false,
|
||||
)
|
||||
|
||||
svc := NewService(db, testutil.Logger(t), nil)
|
||||
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})
|
||||
opts := RecorderOptions{ChatID: chatID, OwnerID: ownerID}
|
||||
|
||||
firstHandle, _ := beginStep(ctx, svc, opts, OperationStream, nil)
|
||||
secondHandle, _ := beginStep(ctx, svc, opts, OperationStream, nil)
|
||||
|
||||
require.NotNil(t, firstHandle)
|
||||
require.NotNil(t, secondHandle)
|
||||
require.NotSame(t, firstHandle, secondHandle)
|
||||
require.NotSame(t, firstHandle.sink, secondHandle.sink)
|
||||
require.Equal(t, int32(1), firstHandle.stepCtx.StepNumber)
|
||||
require.Equal(t, int32(2), secondHandle.stepCtx.StepNumber)
|
||||
require.NotEqual(t, firstHandle.stepCtx.StepID, secondHandle.stepCtx.StepID)
|
||||
})
|
||||
}
|
||||
@@ -1,539 +0,0 @@
|
||||
package chatdebug
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/sqlc-dev/pqtype"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/database/pubsub"
|
||||
)
|
||||
|
||||
// DefaultStaleThreshold is the fallback stale timeout for debug rows
|
||||
// when no caller-provided value is supplied.
|
||||
const DefaultStaleThreshold = 5 * time.Minute
|
||||
|
||||
// Service persists chat debug rows and fans out lightweight change events.
|
||||
type Service struct {
|
||||
db database.Store
|
||||
log slog.Logger
|
||||
pubsub pubsub.Pubsub
|
||||
alwaysEnable bool
|
||||
// staleAfterNanos stores the stale threshold as nanoseconds in an
|
||||
// atomic.Int64 so SetStaleAfter and FinalizeStale can be called
|
||||
// from concurrent goroutines without a data race.
|
||||
staleAfterNanos atomic.Int64
|
||||
}
|
||||
|
||||
// ServiceOption configures optional Service behavior.
|
||||
type ServiceOption func(*Service)
|
||||
|
||||
// WithStaleThreshold overrides the default stale-row finalization
|
||||
// threshold. Callers that already have a configurable in-flight chat
|
||||
// timeout (e.g. chatd's InFlightChatStaleAfter) should pass it here
|
||||
// so the two sweeps stay in sync.
|
||||
func WithStaleThreshold(d time.Duration) ServiceOption {
|
||||
return func(s *Service) {
|
||||
if d > 0 {
|
||||
s.staleAfterNanos.Store(d.Nanoseconds())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WithAlwaysEnable forces debug logging on for every chat regardless
|
||||
// of the runtime admin and user opt-in settings. This is used for the
|
||||
// deployment-level serpent flag.
|
||||
func WithAlwaysEnable(always bool) ServiceOption {
|
||||
return func(s *Service) {
|
||||
s.alwaysEnable = always
|
||||
}
|
||||
}
|
||||
|
||||
// CreateRunParams contains friendly inputs for creating a debug run.
|
||||
type CreateRunParams struct {
|
||||
ChatID uuid.UUID
|
||||
RootChatID uuid.UUID
|
||||
ParentChatID uuid.UUID
|
||||
ModelConfigID uuid.UUID
|
||||
TriggerMessageID int64
|
||||
HistoryTipMessageID int64
|
||||
Kind RunKind
|
||||
Status Status
|
||||
Provider string
|
||||
Model string
|
||||
Summary any
|
||||
}
|
||||
|
||||
// UpdateRunParams contains inputs for updating a debug run.
|
||||
// Zero-valued fields are treated as "keep the existing value" by the
|
||||
// COALESCE-based SQL query. Once a field is set it cannot be cleared
|
||||
// back to NULL — this is intentional for the write-once-finalize
|
||||
// lifecycle of debug rows.
|
||||
type UpdateRunParams struct {
|
||||
ID uuid.UUID
|
||||
ChatID uuid.UUID
|
||||
Status Status
|
||||
Summary any
|
||||
FinishedAt time.Time
|
||||
}
|
||||
|
||||
// CreateStepParams contains friendly inputs for creating a debug step.
|
||||
type CreateStepParams struct {
|
||||
RunID uuid.UUID
|
||||
ChatID uuid.UUID
|
||||
StepNumber int32
|
||||
Operation Operation
|
||||
Status Status
|
||||
HistoryTipMessageID int64
|
||||
NormalizedRequest any
|
||||
}
|
||||
|
||||
// UpdateStepParams contains optional inputs for updating a debug step.
|
||||
// Most payload fields are typed as any and serialized through nullJSON
|
||||
// because their shape varies by provider. The Attempts field uses a
|
||||
// concrete slice for compile-time safety where the schema is stable.
|
||||
// Zero-valued fields are treated as "keep the existing value" by the
|
||||
// COALESCE-based SQL query — once set, fields cannot be cleared back
|
||||
// to NULL. This is intentional for the write-once-finalize lifecycle
|
||||
// of debug rows.
|
||||
type UpdateStepParams struct {
|
||||
ID uuid.UUID
|
||||
ChatID uuid.UUID
|
||||
Status Status
|
||||
AssistantMessageID int64
|
||||
NormalizedResponse any
|
||||
Usage any
|
||||
Attempts []Attempt
|
||||
Error any
|
||||
Metadata any
|
||||
FinishedAt time.Time
|
||||
}
|
||||
|
||||
// NewService constructs a chat debug persistence service.
|
||||
func NewService(db database.Store, log slog.Logger, ps pubsub.Pubsub, opts ...ServiceOption) *Service {
|
||||
if db == nil {
|
||||
panic("chatdebug: nil database.Store")
|
||||
}
|
||||
|
||||
s := &Service{
|
||||
db: db,
|
||||
log: log,
|
||||
pubsub: ps,
|
||||
}
|
||||
s.staleAfterNanos.Store(DefaultStaleThreshold.Nanoseconds())
|
||||
for _, opt := range opts {
|
||||
opt(s)
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// SetStaleAfter overrides the in-flight stale threshold used when
|
||||
// finalizing abandoned debug rows. Zero or negative durations keep the
|
||||
// default threshold.
|
||||
func (s *Service) SetStaleAfter(staleAfter time.Duration) {
|
||||
if s == nil || staleAfter <= 0 {
|
||||
return
|
||||
}
|
||||
s.staleAfterNanos.Store(staleAfter.Nanoseconds())
|
||||
}
|
||||
|
||||
func chatdContext(ctx context.Context) context.Context {
|
||||
//nolint:gocritic // AsChatd provides narrowly-scoped daemon access for
|
||||
// chat debug persistence reads and writes.
|
||||
return dbauthz.AsChatd(ctx)
|
||||
}
|
||||
|
||||
// IsEnabled returns whether debug logging is enabled for the given chat.
|
||||
func (s *Service) IsEnabled(
|
||||
ctx context.Context,
|
||||
chatID uuid.UUID,
|
||||
ownerID uuid.UUID,
|
||||
) bool {
|
||||
if s == nil {
|
||||
return false
|
||||
}
|
||||
if s.alwaysEnable {
|
||||
return true
|
||||
}
|
||||
if s.db == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
authCtx := chatdContext(ctx)
|
||||
|
||||
allowUsers, err := s.db.GetChatDebugLoggingEnabled(authCtx)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return false
|
||||
}
|
||||
s.log.Warn(ctx, "failed to load runtime admin chat debug logging setting",
|
||||
slog.Error(err),
|
||||
)
|
||||
return false
|
||||
}
|
||||
if !allowUsers {
|
||||
return false
|
||||
}
|
||||
|
||||
if ownerID == uuid.Nil {
|
||||
s.log.Warn(ctx, "missing chat owner for debug logging enablement check",
|
||||
slog.F("chat_id", chatID),
|
||||
)
|
||||
return false
|
||||
}
|
||||
|
||||
enabled, err := s.db.GetUserChatDebugLoggingEnabled(authCtx, ownerID)
|
||||
if err == nil {
|
||||
return enabled
|
||||
}
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return false
|
||||
}
|
||||
|
||||
s.log.Warn(ctx, "failed to load user chat debug logging setting",
|
||||
slog.Error(err),
|
||||
slog.F("chat_id", chatID),
|
||||
slog.F("owner_id", ownerID),
|
||||
)
|
||||
return false
|
||||
}
|
||||
|
||||
// CreateRun inserts a new debug run and emits a run update event.
|
||||
func (s *Service) CreateRun(
|
||||
ctx context.Context,
|
||||
params CreateRunParams,
|
||||
) (database.ChatDebugRun, error) {
|
||||
run, err := s.db.InsertChatDebugRun(chatdContext(ctx),
|
||||
database.InsertChatDebugRunParams{
|
||||
ChatID: params.ChatID,
|
||||
RootChatID: nullUUID(params.RootChatID),
|
||||
ParentChatID: nullUUID(params.ParentChatID),
|
||||
ModelConfigID: nullUUID(params.ModelConfigID),
|
||||
TriggerMessageID: nullInt64(params.TriggerMessageID),
|
||||
HistoryTipMessageID: nullInt64(params.HistoryTipMessageID),
|
||||
Kind: string(params.Kind),
|
||||
Status: string(params.Status),
|
||||
Provider: nullString(params.Provider),
|
||||
Model: nullString(params.Model),
|
||||
Summary: s.nullJSON(ctx, params.Summary),
|
||||
StartedAt: sql.NullTime{},
|
||||
UpdatedAt: sql.NullTime{},
|
||||
FinishedAt: sql.NullTime{},
|
||||
})
|
||||
if err != nil {
|
||||
return database.ChatDebugRun{}, err
|
||||
}
|
||||
|
||||
s.publishEvent(ctx, run.ChatID, EventKindRunUpdate, run.ID, uuid.Nil)
|
||||
return run, nil
|
||||
}
|
||||
|
||||
// UpdateRun updates an existing debug run and emits a run update event.
|
||||
func (s *Service) UpdateRun(
|
||||
ctx context.Context,
|
||||
params UpdateRunParams,
|
||||
) (database.ChatDebugRun, error) {
|
||||
run, err := s.db.UpdateChatDebugRun(chatdContext(ctx),
|
||||
database.UpdateChatDebugRunParams{
|
||||
RootChatID: uuid.NullUUID{},
|
||||
ParentChatID: uuid.NullUUID{},
|
||||
ModelConfigID: uuid.NullUUID{},
|
||||
TriggerMessageID: sql.NullInt64{},
|
||||
HistoryTipMessageID: sql.NullInt64{},
|
||||
Status: nullString(string(params.Status)),
|
||||
Provider: sql.NullString{},
|
||||
Model: sql.NullString{},
|
||||
Summary: s.nullJSON(ctx, params.Summary),
|
||||
FinishedAt: nullTime(params.FinishedAt),
|
||||
ID: params.ID,
|
||||
ChatID: params.ChatID,
|
||||
})
|
||||
if err != nil {
|
||||
return database.ChatDebugRun{}, err
|
||||
}
|
||||
|
||||
s.publishEvent(ctx, run.ChatID, EventKindRunUpdate, run.ID, uuid.Nil)
|
||||
return run, nil
|
||||
}
|
||||
|
||||
// CreateStep inserts a new debug step and emits a step update event.
|
||||
func (s *Service) CreateStep(
|
||||
ctx context.Context,
|
||||
params CreateStepParams,
|
||||
) (database.ChatDebugStep, error) {
|
||||
insert := database.InsertChatDebugStepParams{
|
||||
RunID: params.RunID,
|
||||
StepNumber: params.StepNumber,
|
||||
Operation: string(params.Operation),
|
||||
Status: string(params.Status),
|
||||
HistoryTipMessageID: nullInt64(params.HistoryTipMessageID),
|
||||
AssistantMessageID: sql.NullInt64{},
|
||||
NormalizedRequest: s.nullJSON(ctx, params.NormalizedRequest),
|
||||
NormalizedResponse: pqtype.NullRawMessage{},
|
||||
Usage: pqtype.NullRawMessage{},
|
||||
Attempts: pqtype.NullRawMessage{},
|
||||
Error: pqtype.NullRawMessage{},
|
||||
Metadata: pqtype.NullRawMessage{},
|
||||
StartedAt: sql.NullTime{},
|
||||
UpdatedAt: sql.NullTime{},
|
||||
FinishedAt: sql.NullTime{},
|
||||
ChatID: params.ChatID,
|
||||
}
|
||||
|
||||
// Cap retry attempts to prevent infinite loops under
|
||||
// pathological concurrency. Each iteration performs two DB
|
||||
// round-trips (insert + list), so 10 retries is generous.
|
||||
const maxCreateStepRetries = 10
|
||||
|
||||
for attempt := 0; attempt < maxCreateStepRetries; attempt++ {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return database.ChatDebugStep{}, err
|
||||
}
|
||||
|
||||
step, err := s.db.InsertChatDebugStep(chatdContext(ctx), insert)
|
||||
if err == nil {
|
||||
// Touch the parent run's updated_at so the stale-
|
||||
// finalization sweep does not prematurely interrupt
|
||||
// long-running runs that are still producing steps.
|
||||
if _, touchErr := s.db.UpdateChatDebugRun(chatdContext(ctx), database.UpdateChatDebugRunParams{
|
||||
RootChatID: uuid.NullUUID{},
|
||||
ParentChatID: uuid.NullUUID{},
|
||||
ModelConfigID: uuid.NullUUID{},
|
||||
TriggerMessageID: sql.NullInt64{},
|
||||
HistoryTipMessageID: sql.NullInt64{},
|
||||
Status: sql.NullString{},
|
||||
Provider: sql.NullString{},
|
||||
Model: sql.NullString{},
|
||||
Summary: pqtype.NullRawMessage{},
|
||||
FinishedAt: sql.NullTime{},
|
||||
ID: params.RunID,
|
||||
ChatID: params.ChatID,
|
||||
}); touchErr != nil {
|
||||
s.log.Warn(ctx, "failed to touch parent run updated_at",
|
||||
slog.F("run_id", params.RunID),
|
||||
slog.Error(touchErr),
|
||||
)
|
||||
}
|
||||
s.publishEvent(ctx, step.ChatID, EventKindStepUpdate, step.RunID, step.ID)
|
||||
return step, nil
|
||||
}
|
||||
if !database.IsUniqueViolation(err, database.UniqueIndexChatDebugStepsRunStep) {
|
||||
return database.ChatDebugStep{}, err
|
||||
}
|
||||
|
||||
steps, listErr := s.db.GetChatDebugStepsByRunID(chatdContext(ctx), params.RunID)
|
||||
if listErr != nil {
|
||||
return database.ChatDebugStep{}, listErr
|
||||
}
|
||||
nextStepNumber := insert.StepNumber + 1
|
||||
for _, existing := range steps {
|
||||
if existing.StepNumber >= nextStepNumber {
|
||||
nextStepNumber = existing.StepNumber + 1
|
||||
}
|
||||
}
|
||||
insert.StepNumber = nextStepNumber
|
||||
}
|
||||
|
||||
return database.ChatDebugStep{}, xerrors.Errorf(
|
||||
"chatdebug: failed to create step after %d retries (run %s)",
|
||||
maxCreateStepRetries, params.RunID,
|
||||
)
|
||||
}
|
||||
|
||||
// UpdateStep updates an existing debug step and emits a step update event.
|
||||
func (s *Service) UpdateStep(
|
||||
ctx context.Context,
|
||||
params UpdateStepParams,
|
||||
) (database.ChatDebugStep, error) {
|
||||
step, err := s.db.UpdateChatDebugStep(chatdContext(ctx),
|
||||
database.UpdateChatDebugStepParams{
|
||||
Status: nullString(string(params.Status)),
|
||||
HistoryTipMessageID: sql.NullInt64{},
|
||||
AssistantMessageID: nullInt64(params.AssistantMessageID),
|
||||
NormalizedRequest: pqtype.NullRawMessage{},
|
||||
NormalizedResponse: s.nullJSON(ctx, params.NormalizedResponse),
|
||||
Usage: s.nullJSON(ctx, params.Usage),
|
||||
Attempts: s.nullJSON(ctx, params.Attempts),
|
||||
Error: s.nullJSON(ctx, params.Error),
|
||||
Metadata: s.nullJSON(ctx, params.Metadata),
|
||||
FinishedAt: nullTime(params.FinishedAt),
|
||||
ID: params.ID,
|
||||
ChatID: params.ChatID,
|
||||
})
|
||||
if err != nil {
|
||||
return database.ChatDebugStep{}, err
|
||||
}
|
||||
|
||||
s.publishEvent(ctx, step.ChatID, EventKindStepUpdate, step.RunID, step.ID)
|
||||
return step, nil
|
||||
}
|
||||
|
||||
// DeleteByChatID deletes all debug data for a chat and emits a delete event.
|
||||
func (s *Service) DeleteByChatID(
|
||||
ctx context.Context,
|
||||
chatID uuid.UUID,
|
||||
) (int64, error) {
|
||||
deleted, err := s.db.DeleteChatDebugDataByChatID(chatdContext(ctx), chatID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
s.publishEvent(ctx, chatID, EventKindDelete, uuid.Nil, uuid.Nil)
|
||||
return deleted, nil
|
||||
}
|
||||
|
||||
// DeleteAfterMessageID deletes debug data newer than the given message.
|
||||
func (s *Service) DeleteAfterMessageID(
|
||||
ctx context.Context,
|
||||
chatID uuid.UUID,
|
||||
messageID int64,
|
||||
) (int64, error) {
|
||||
deleted, err := s.db.DeleteChatDebugDataAfterMessageID(
|
||||
chatdContext(ctx),
|
||||
database.DeleteChatDebugDataAfterMessageIDParams{
|
||||
ChatID: chatID,
|
||||
MessageID: messageID,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
s.publishEvent(ctx, chatID, EventKindDelete, uuid.Nil, uuid.Nil)
|
||||
return deleted, nil
|
||||
}
|
||||
|
||||
// FinalizeStale finalizes stale in-flight debug rows and emits a broadcast.
|
||||
func (s *Service) FinalizeStale(
|
||||
ctx context.Context,
|
||||
) (database.FinalizeStaleChatDebugRowsRow, error) {
|
||||
ns := s.staleAfterNanos.Load()
|
||||
staleAfter := time.Duration(ns)
|
||||
if staleAfter <= 0 {
|
||||
staleAfter = DefaultStaleThreshold
|
||||
}
|
||||
|
||||
result, err := s.db.FinalizeStaleChatDebugRows(
|
||||
chatdContext(ctx),
|
||||
time.Now().Add(-staleAfter),
|
||||
)
|
||||
if err != nil {
|
||||
return database.FinalizeStaleChatDebugRowsRow{}, err
|
||||
}
|
||||
|
||||
if result.RunsFinalized > 0 || result.StepsFinalized > 0 {
|
||||
s.publishEvent(ctx, uuid.Nil, EventKindFinalize, uuid.Nil, uuid.Nil)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func nullUUID(id uuid.UUID) uuid.NullUUID {
|
||||
return uuid.NullUUID{UUID: id, Valid: id != uuid.Nil}
|
||||
}
|
||||
|
||||
func nullInt64(v int64) sql.NullInt64 {
|
||||
return sql.NullInt64{Int64: v, Valid: v != 0}
|
||||
}
|
||||
|
||||
func nullString(value string) sql.NullString {
|
||||
return sql.NullString{String: value, Valid: value != ""}
|
||||
}
|
||||
|
||||
func nullTime(value time.Time) sql.NullTime {
|
||||
return sql.NullTime{Time: value, Valid: !value.IsZero()}
|
||||
}
|
||||
|
||||
// nullJSON marshals value to a NullRawMessage. When value is nil or
|
||||
// marshals to JSON "null", the result is {Valid: false}. Combined with
|
||||
// the COALESCE-based UPDATE queries, this means a caller cannot clear a
|
||||
// previously-set JSON column back to NULL — passing nil preserves the
|
||||
// existing value. This is acceptable for debug logs because fields
|
||||
// accumulate monotonically (request → response → usage → error) and
|
||||
// never need to be cleared during normal operation.
|
||||
// jsonClear is a sentinel value that tells nullJSON to emit a valid
|
||||
// JSON null (JSONB 'null') instead of SQL NULL. COALESCE treats SQL
|
||||
// NULL as "keep existing" but replaces with a non-NULL JSONB value,
|
||||
// so passing jsonClear explicitly overwrites a previously set field.
|
||||
type jsonClear struct{}
|
||||
|
||||
func (s *Service) nullJSON(ctx context.Context, value any) pqtype.NullRawMessage {
|
||||
if value == nil {
|
||||
return pqtype.NullRawMessage{}
|
||||
}
|
||||
// Sentinel: emit a valid JSONB null so COALESCE replaces
|
||||
// any previously stored value.
|
||||
if _, ok := value.(jsonClear); ok {
|
||||
return pqtype.NullRawMessage{
|
||||
RawMessage: json.RawMessage("null"),
|
||||
Valid: true,
|
||||
}
|
||||
}
|
||||
|
||||
data, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
s.log.Warn(ctx, "failed to marshal chat debug JSON",
|
||||
slog.Error(err),
|
||||
slog.F("value_type", fmt.Sprintf("%T", value)),
|
||||
)
|
||||
return pqtype.NullRawMessage{}
|
||||
}
|
||||
if bytes.Equal(data, []byte("null")) {
|
||||
return pqtype.NullRawMessage{}
|
||||
}
|
||||
|
||||
return pqtype.NullRawMessage{RawMessage: data, Valid: true}
|
||||
}
|
||||
|
||||
func (s *Service) publishEvent(
|
||||
ctx context.Context,
|
||||
chatID uuid.UUID,
|
||||
kind EventKind,
|
||||
runID uuid.UUID,
|
||||
stepID uuid.UUID,
|
||||
) {
|
||||
if s.pubsub == nil {
|
||||
s.log.Debug(ctx,
|
||||
"chat debug pubsub unavailable; skipping event",
|
||||
slog.F("kind", kind),
|
||||
slog.F("chat_id", chatID),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
event := DebugEvent{
|
||||
Kind: kind,
|
||||
ChatID: chatID,
|
||||
RunID: runID,
|
||||
StepID: stepID,
|
||||
}
|
||||
data, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
s.log.Warn(ctx, "failed to marshal chat debug event",
|
||||
slog.Error(err),
|
||||
slog.F("kind", kind),
|
||||
slog.F("chat_id", chatID),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
channel := PubsubChannel(chatID)
|
||||
if err := s.pubsub.Publish(channel, data); err != nil {
|
||||
s.log.Warn(ctx, "failed to publish chat debug event",
|
||||
slog.Error(err),
|
||||
slog.F("channel", channel),
|
||||
slog.F("kind", kind),
|
||||
slog.F("chat_id", chatID),
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -1,853 +0,0 @@
|
||||
package chatdebug_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/lib/pq"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbgen"
|
||||
"github.com/coder/coder/v2/coderd/database/dbmock"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
||||
dbpubsub "github.com/coder/coder/v2/coderd/database/pubsub"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatdebug"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatprompt"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
type testFixture struct {
|
||||
ctx context.Context
|
||||
db database.Store
|
||||
svc *chatdebug.Service
|
||||
owner database.User
|
||||
chat database.Chat
|
||||
model database.ChatModelConfig
|
||||
}
|
||||
|
||||
func TestService_IsEnabled(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
db, _, _ := dbtestutil.NewDBWithSQLDB(t)
|
||||
owner, chat, model := seedChat(ctx, t, db)
|
||||
require.NotEqual(t, uuid.Nil, model.ID)
|
||||
|
||||
svc := chatdebug.NewService(db, testutil.Logger(t), nil)
|
||||
|
||||
// Default is off until an admin allows user opt-in.
|
||||
require.False(t, svc.IsEnabled(ctx, chat.ID, owner.ID))
|
||||
|
||||
err := db.UpsertChatDebugLoggingEnabled(ctx, true)
|
||||
require.NoError(t, err)
|
||||
// Allowing user opt-in is not enough on its own; the user must opt in.
|
||||
require.False(t, svc.IsEnabled(ctx, chat.ID, owner.ID))
|
||||
require.False(t, svc.IsEnabled(ctx, chat.ID, uuid.Nil))
|
||||
|
||||
err = db.UpsertUserChatDebugLoggingEnabled(ctx,
|
||||
database.UpsertUserChatDebugLoggingEnabledParams{
|
||||
UserID: owner.ID,
|
||||
DebugLoggingEnabled: true,
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.True(t, svc.IsEnabled(ctx, chat.ID, owner.ID))
|
||||
|
||||
err = db.UpsertUserChatDebugLoggingEnabled(ctx,
|
||||
database.UpsertUserChatDebugLoggingEnabledParams{
|
||||
UserID: owner.ID,
|
||||
DebugLoggingEnabled: false,
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.False(t, svc.IsEnabled(ctx, chat.ID, owner.ID))
|
||||
}
|
||||
|
||||
func TestService_IsEnabled_AlwaysEnable(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
db, _, _ := dbtestutil.NewDBWithSQLDB(t)
|
||||
owner, chat, model := seedChat(ctx, t, db)
|
||||
require.NotEqual(t, uuid.Nil, model.ID)
|
||||
|
||||
svc := chatdebug.NewService(db, testutil.Logger(t), nil, chatdebug.WithAlwaysEnable(true))
|
||||
require.True(t, svc.IsEnabled(ctx, chat.ID, owner.ID))
|
||||
require.True(t, svc.IsEnabled(ctx, chat.ID, uuid.Nil))
|
||||
}
|
||||
|
||||
func TestService_IsEnabled_ZeroValueService(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var svc *chatdebug.Service
|
||||
require.False(t, svc.IsEnabled(context.Background(), uuid.Nil, uuid.Nil))
|
||||
|
||||
require.False(t, (&chatdebug.Service{}).IsEnabled(context.Background(), uuid.Nil, uuid.Nil))
|
||||
}
|
||||
|
||||
func TestService_CreateRun(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
fixture := newFixture(t)
|
||||
rootChat := insertChat(fixture.ctx, t, fixture.db, fixture.owner.ID, fixture.model.ID)
|
||||
parentChat := insertChat(fixture.ctx, t, fixture.db, fixture.owner.ID, fixture.model.ID)
|
||||
triggerMsg := insertMessage(fixture.ctx, t, fixture.db, fixture.chat.ID,
|
||||
fixture.owner.ID, fixture.model.ID, database.ChatMessageRoleUser, "trigger")
|
||||
historyTipMsg := insertMessage(fixture.ctx, t, fixture.db, fixture.chat.ID,
|
||||
fixture.owner.ID, fixture.model.ID, database.ChatMessageRoleAssistant,
|
||||
"history-tip")
|
||||
|
||||
run, err := fixture.svc.CreateRun(fixture.ctx, chatdebug.CreateRunParams{
|
||||
ChatID: fixture.chat.ID,
|
||||
RootChatID: rootChat.ID,
|
||||
ParentChatID: parentChat.ID,
|
||||
ModelConfigID: fixture.model.ID,
|
||||
TriggerMessageID: triggerMsg.ID,
|
||||
HistoryTipMessageID: historyTipMsg.ID,
|
||||
Kind: chatdebug.KindChatTurn,
|
||||
Status: chatdebug.StatusInProgress,
|
||||
Provider: fixture.model.Provider,
|
||||
Model: fixture.model.Model,
|
||||
Summary: map[string]any{
|
||||
"phase": "create",
|
||||
"count": 1,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assertRunMatches(t, run, fixture.chat.ID, rootChat.ID, parentChat.ID,
|
||||
fixture.model.ID, triggerMsg.ID, historyTipMsg.ID,
|
||||
chatdebug.KindChatTurn, chatdebug.StatusInProgress,
|
||||
fixture.model.Provider, fixture.model.Model,
|
||||
`{"count":1,"phase":"create"}`)
|
||||
|
||||
stored, err := fixture.db.GetChatDebugRunByID(fixture.ctx, run.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, run.ID, stored.ID)
|
||||
require.JSONEq(t, string(run.Summary), string(stored.Summary))
|
||||
}
|
||||
|
||||
func TestService_CreateRun_TypedNilSummaryUsesDefaultObject(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
fixture := newFixture(t)
|
||||
var summary map[string]any
|
||||
|
||||
run, err := fixture.svc.CreateRun(fixture.ctx, chatdebug.CreateRunParams{
|
||||
ChatID: fixture.chat.ID,
|
||||
Kind: chatdebug.KindChatTurn,
|
||||
Status: chatdebug.StatusInProgress,
|
||||
Summary: summary,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.JSONEq(t, `{}`, string(run.Summary))
|
||||
}
|
||||
|
||||
func TestService_UpdateRun(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
fixture := newFixture(t)
|
||||
run, err := fixture.svc.CreateRun(fixture.ctx, chatdebug.CreateRunParams{
|
||||
ChatID: fixture.chat.ID,
|
||||
Kind: chatdebug.KindChatTurn,
|
||||
Status: chatdebug.StatusInProgress,
|
||||
Summary: map[string]any{
|
||||
"before": true,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
finishedAt := time.Now().UTC().Round(time.Microsecond)
|
||||
updated, err := fixture.svc.UpdateRun(fixture.ctx, chatdebug.UpdateRunParams{
|
||||
ID: run.ID,
|
||||
ChatID: fixture.chat.ID,
|
||||
Status: chatdebug.StatusCompleted,
|
||||
Summary: map[string]any{"after": "done"},
|
||||
FinishedAt: finishedAt,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, string(chatdebug.StatusCompleted), updated.Status)
|
||||
require.True(t, updated.FinishedAt.Valid)
|
||||
require.WithinDuration(t, finishedAt, updated.FinishedAt.Time, time.Second)
|
||||
require.JSONEq(t, `{"after":"done"}`, string(updated.Summary))
|
||||
|
||||
stored, err := fixture.db.GetChatDebugRunByID(fixture.ctx, run.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, string(chatdebug.StatusCompleted), stored.Status)
|
||||
require.JSONEq(t, `{"after":"done"}`, string(stored.Summary))
|
||||
require.True(t, stored.FinishedAt.Valid)
|
||||
}
|
||||
|
||||
func TestService_CreateStep(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
fixture := newFixture(t)
|
||||
run := createRun(t, fixture)
|
||||
historyTipMsg := insertMessage(fixture.ctx, t, fixture.db, fixture.chat.ID,
|
||||
fixture.owner.ID, fixture.model.ID, database.ChatMessageRoleAssistant,
|
||||
"history-tip")
|
||||
|
||||
step, err := fixture.svc.CreateStep(fixture.ctx, chatdebug.CreateStepParams{
|
||||
RunID: run.ID,
|
||||
ChatID: fixture.chat.ID,
|
||||
StepNumber: 1,
|
||||
Operation: chatdebug.OperationStream,
|
||||
Status: chatdebug.StatusInProgress,
|
||||
HistoryTipMessageID: historyTipMsg.ID,
|
||||
NormalizedRequest: map[string]any{
|
||||
"messages": []string{"hello"},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, fixture.chat.ID, step.ChatID)
|
||||
require.Equal(t, run.ID, step.RunID)
|
||||
require.EqualValues(t, 1, step.StepNumber)
|
||||
require.Equal(t, string(chatdebug.OperationStream), step.Operation)
|
||||
require.Equal(t, string(chatdebug.StatusInProgress), step.Status)
|
||||
require.True(t, step.HistoryTipMessageID.Valid)
|
||||
require.Equal(t, historyTipMsg.ID, step.HistoryTipMessageID.Int64)
|
||||
require.JSONEq(t, `{"messages":["hello"]}`, string(step.NormalizedRequest))
|
||||
|
||||
steps, err := fixture.db.GetChatDebugStepsByRunID(fixture.ctx, run.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, steps, 1)
|
||||
require.Equal(t, step.ID, steps[0].ID)
|
||||
}
|
||||
|
||||
func TestService_CreateStep_RetriesDuplicateStepNumbers(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
fixture := newFixture(t)
|
||||
run := createRun(t, fixture)
|
||||
|
||||
first, err := fixture.svc.CreateStep(fixture.ctx, chatdebug.CreateStepParams{
|
||||
RunID: run.ID,
|
||||
ChatID: fixture.chat.ID,
|
||||
StepNumber: 1,
|
||||
Operation: chatdebug.OperationStream,
|
||||
Status: chatdebug.StatusInProgress,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
second, err := fixture.svc.CreateStep(fixture.ctx, chatdebug.CreateStepParams{
|
||||
RunID: run.ID,
|
||||
ChatID: fixture.chat.ID,
|
||||
StepNumber: 1,
|
||||
Operation: chatdebug.OperationGenerate,
|
||||
Status: chatdebug.StatusInProgress,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 1, first.StepNumber)
|
||||
require.EqualValues(t, 2, second.StepNumber)
|
||||
}
|
||||
|
||||
func TestService_CreateStep_ListRetryErrorWins(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
svc := chatdebug.NewService(db, testutil.Logger(t), nil)
|
||||
runID := uuid.New()
|
||||
chatID := uuid.New()
|
||||
listErr := xerrors.New("list chat debug steps")
|
||||
|
||||
db.EXPECT().InsertChatDebugStep(
|
||||
gomock.Any(),
|
||||
gomock.AssignableToTypeOf(database.InsertChatDebugStepParams{}),
|
||||
).Return(database.ChatDebugStep{}, &pq.Error{
|
||||
Code: pq.ErrorCode("23505"),
|
||||
Constraint: string(database.UniqueIndexChatDebugStepsRunStep),
|
||||
})
|
||||
db.EXPECT().GetChatDebugStepsByRunID(gomock.Any(), runID).Return(nil, listErr)
|
||||
|
||||
_, err := svc.CreateStep(context.Background(), chatdebug.CreateStepParams{
|
||||
RunID: runID,
|
||||
ChatID: chatID,
|
||||
StepNumber: 1,
|
||||
Operation: chatdebug.OperationStream,
|
||||
Status: chatdebug.StatusInProgress,
|
||||
})
|
||||
require.ErrorIs(t, err, listErr)
|
||||
}
|
||||
|
||||
func TestService_UpdateStep(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
fixture := newFixture(t)
|
||||
run := createRun(t, fixture)
|
||||
step, err := fixture.svc.CreateStep(fixture.ctx, chatdebug.CreateStepParams{
|
||||
RunID: run.ID,
|
||||
ChatID: fixture.chat.ID,
|
||||
StepNumber: 1,
|
||||
Operation: chatdebug.OperationStream,
|
||||
Status: chatdebug.StatusInProgress,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
assistantMsg := insertMessage(fixture.ctx, t, fixture.db, fixture.chat.ID,
|
||||
fixture.owner.ID, fixture.model.ID, database.ChatMessageRoleAssistant,
|
||||
"assistant")
|
||||
finishedAt := time.Now().UTC().Round(time.Microsecond)
|
||||
updated, err := fixture.svc.UpdateStep(fixture.ctx, chatdebug.UpdateStepParams{
|
||||
ID: step.ID,
|
||||
ChatID: fixture.chat.ID,
|
||||
Status: chatdebug.StatusCompleted,
|
||||
AssistantMessageID: assistantMsg.ID,
|
||||
NormalizedResponse: map[string]any{"text": "done"},
|
||||
Usage: map[string]any{"input_tokens": 10, "output_tokens": 5},
|
||||
Attempts: []chatdebug.Attempt{{
|
||||
Number: 1,
|
||||
ResponseStatus: 200,
|
||||
DurationMs: 25,
|
||||
}},
|
||||
Metadata: map[string]any{"provider": fixture.model.Provider},
|
||||
FinishedAt: finishedAt,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, string(chatdebug.StatusCompleted), updated.Status)
|
||||
require.True(t, updated.AssistantMessageID.Valid)
|
||||
require.Equal(t, assistantMsg.ID, updated.AssistantMessageID.Int64)
|
||||
require.True(t, updated.NormalizedResponse.Valid)
|
||||
require.JSONEq(t, `{"text":"done"}`,
|
||||
string(updated.NormalizedResponse.RawMessage))
|
||||
require.True(t, updated.Usage.Valid)
|
||||
require.JSONEq(t, `{"input_tokens":10,"output_tokens":5}`,
|
||||
string(updated.Usage.RawMessage))
|
||||
require.JSONEq(t,
|
||||
`[{"number":1,"response_status":200,"duration_ms":25}]`,
|
||||
string(updated.Attempts),
|
||||
)
|
||||
require.JSONEq(t, `{"provider":"`+fixture.model.Provider+`"}`,
|
||||
string(updated.Metadata))
|
||||
require.True(t, updated.FinishedAt.Valid)
|
||||
storedSteps, err := fixture.db.GetChatDebugStepsByRunID(fixture.ctx, run.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, storedSteps, 1)
|
||||
require.Equal(t, updated.ID, storedSteps[0].ID)
|
||||
}
|
||||
|
||||
func TestService_UpdateStep_TypedNilAttemptsPreserveExistingValue(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
fixture := newFixture(t)
|
||||
run := createRun(t, fixture)
|
||||
step, err := fixture.svc.CreateStep(fixture.ctx, chatdebug.CreateStepParams{
|
||||
RunID: run.ID,
|
||||
ChatID: fixture.chat.ID,
|
||||
StepNumber: 1,
|
||||
Operation: chatdebug.OperationStream,
|
||||
Status: chatdebug.StatusInProgress,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = fixture.svc.UpdateStep(fixture.ctx, chatdebug.UpdateStepParams{
|
||||
ID: step.ID,
|
||||
ChatID: fixture.chat.ID,
|
||||
Status: chatdebug.StatusCompleted,
|
||||
Attempts: []chatdebug.Attempt{{
|
||||
Number: 1,
|
||||
}},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
var typedNilAttempts []chatdebug.Attempt
|
||||
updated, err := fixture.svc.UpdateStep(fixture.ctx, chatdebug.UpdateStepParams{
|
||||
ID: step.ID,
|
||||
ChatID: fixture.chat.ID,
|
||||
Attempts: typedNilAttempts,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
var attempts []map[string]any
|
||||
require.NoError(t, json.Unmarshal(updated.Attempts, &attempts))
|
||||
require.Len(t, attempts, 1)
|
||||
require.EqualValues(t, 1, attempts[0]["number"])
|
||||
}
|
||||
|
||||
func TestService_DeleteByChatID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
fixture := newFixture(t)
|
||||
run := createRun(t, fixture)
|
||||
_, err := fixture.svc.CreateStep(fixture.ctx, chatdebug.CreateStepParams{
|
||||
RunID: run.ID,
|
||||
ChatID: fixture.chat.ID,
|
||||
StepNumber: 1,
|
||||
Operation: chatdebug.OperationGenerate,
|
||||
Status: chatdebug.StatusInProgress,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
deleted, err := fixture.svc.DeleteByChatID(fixture.ctx, fixture.chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 1, deleted)
|
||||
|
||||
runs, err := fixture.db.GetChatDebugRunsByChatID(fixture.ctx, database.GetChatDebugRunsByChatIDParams{
|
||||
ChatID: fixture.chat.ID,
|
||||
LimitVal: 100,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, runs)
|
||||
}
|
||||
|
||||
func TestService_DeleteAfterMessageID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
fixture := newFixture(t)
|
||||
low := insertMessage(fixture.ctx, t, fixture.db, fixture.chat.ID, fixture.owner.ID,
|
||||
fixture.model.ID, database.ChatMessageRoleAssistant, "low")
|
||||
threshold := insertMessage(fixture.ctx, t, fixture.db, fixture.chat.ID,
|
||||
fixture.owner.ID, fixture.model.ID, database.ChatMessageRoleAssistant,
|
||||
"threshold")
|
||||
high := insertMessage(fixture.ctx, t, fixture.db, fixture.chat.ID, fixture.owner.ID,
|
||||
fixture.model.ID, database.ChatMessageRoleAssistant, "high")
|
||||
require.Less(t, low.ID, threshold.ID)
|
||||
require.Less(t, threshold.ID, high.ID)
|
||||
|
||||
runKeep := createRun(t, fixture)
|
||||
stepKeep, err := fixture.svc.CreateStep(fixture.ctx, chatdebug.CreateStepParams{
|
||||
RunID: runKeep.ID,
|
||||
ChatID: fixture.chat.ID,
|
||||
StepNumber: 1,
|
||||
Operation: chatdebug.OperationGenerate,
|
||||
Status: chatdebug.StatusInProgress,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = fixture.svc.UpdateStep(fixture.ctx, chatdebug.UpdateStepParams{
|
||||
ID: stepKeep.ID,
|
||||
ChatID: fixture.chat.ID,
|
||||
AssistantMessageID: low.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
runDelete := createRun(t, fixture)
|
||||
stepDelete, err := fixture.svc.CreateStep(fixture.ctx, chatdebug.CreateStepParams{
|
||||
RunID: runDelete.ID,
|
||||
ChatID: fixture.chat.ID,
|
||||
StepNumber: 1,
|
||||
Operation: chatdebug.OperationGenerate,
|
||||
Status: chatdebug.StatusInProgress,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = fixture.svc.UpdateStep(fixture.ctx, chatdebug.UpdateStepParams{
|
||||
ID: stepDelete.ID,
|
||||
ChatID: fixture.chat.ID,
|
||||
AssistantMessageID: high.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
deleted, err := fixture.svc.DeleteAfterMessageID(fixture.ctx, fixture.chat.ID,
|
||||
threshold.ID)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 1, deleted)
|
||||
|
||||
runs, err := fixture.db.GetChatDebugRunsByChatID(fixture.ctx, database.GetChatDebugRunsByChatIDParams{
|
||||
ChatID: fixture.chat.ID,
|
||||
LimitVal: 100,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, runs, 1)
|
||||
require.Equal(t, runKeep.ID, runs[0].ID)
|
||||
|
||||
steps, err := fixture.db.GetChatDebugStepsByRunID(fixture.ctx, runKeep.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, steps, 1)
|
||||
require.Equal(t, stepKeep.ID, steps[0].ID)
|
||||
}
|
||||
|
||||
func TestService_FinalizeStale_UsesConfiguredThreshold(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
svc := chatdebug.NewService(db, testutil.Logger(t), nil)
|
||||
svc.SetStaleAfter(42 * time.Second)
|
||||
|
||||
db.EXPECT().FinalizeStaleChatDebugRows(gomock.Any(), gomock.Any()).DoAndReturn(
|
||||
func(_ context.Context, staleBefore time.Time) (database.FinalizeStaleChatDebugRowsRow, error) {
|
||||
require.WithinDuration(t, time.Now().Add(-42*time.Second), staleBefore, 2*time.Second)
|
||||
return database.FinalizeStaleChatDebugRowsRow{}, nil
|
||||
},
|
||||
)
|
||||
|
||||
result, err := svc.FinalizeStale(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Zero(t, result.RunsFinalized)
|
||||
require.Zero(t, result.StepsFinalized)
|
||||
}
|
||||
|
||||
func TestService_FinalizeStale(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
owner, chat, model := seedChat(ctx, t, db)
|
||||
require.NotEqual(t, uuid.Nil, owner.ID)
|
||||
|
||||
staleTime := time.Now().Add(-10 * time.Minute).UTC().Round(time.Microsecond)
|
||||
run, err := db.InsertChatDebugRun(ctx, database.InsertChatDebugRunParams{
|
||||
ChatID: chat.ID,
|
||||
ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true},
|
||||
Kind: string(chatdebug.KindChatTurn),
|
||||
Status: string(chatdebug.StatusInProgress),
|
||||
StartedAt: sql.NullTime{Time: staleTime, Valid: true},
|
||||
UpdatedAt: sql.NullTime{Time: staleTime, Valid: true},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
step, err := db.InsertChatDebugStep(ctx, database.InsertChatDebugStepParams{
|
||||
RunID: run.ID,
|
||||
StepNumber: 1,
|
||||
Operation: string(chatdebug.OperationStream),
|
||||
Status: string(chatdebug.StatusInProgress),
|
||||
StartedAt: sql.NullTime{Time: staleTime, Valid: true},
|
||||
UpdatedAt: sql.NullTime{Time: staleTime, Valid: true},
|
||||
ChatID: chat.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
svc := chatdebug.NewService(db, testutil.Logger(t), nil)
|
||||
result, err := svc.FinalizeStale(ctx)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 1, result.RunsFinalized)
|
||||
require.EqualValues(t, 1, result.StepsFinalized)
|
||||
|
||||
storedRun, err := db.GetChatDebugRunByID(ctx, run.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, string(chatdebug.StatusInterrupted), storedRun.Status)
|
||||
require.True(t, storedRun.FinishedAt.Valid)
|
||||
|
||||
storedSteps, err := db.GetChatDebugStepsByRunID(ctx, run.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, storedSteps, 1)
|
||||
require.Equal(t, step.ID, storedSteps[0].ID)
|
||||
require.Equal(t, string(chatdebug.StatusInterrupted), storedSteps[0].Status)
|
||||
require.True(t, storedSteps[0].FinishedAt.Valid)
|
||||
}
|
||||
|
||||
func TestService_FinalizeStale_BroadcastsFinalizeEvent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
owner, chat, model := seedChat(ctx, t, db)
|
||||
require.NotEqual(t, uuid.Nil, owner.ID)
|
||||
|
||||
staleTime := time.Now().Add(-10 * time.Minute).UTC().Round(time.Microsecond)
|
||||
run, err := db.InsertChatDebugRun(ctx, database.InsertChatDebugRunParams{
|
||||
ChatID: chat.ID,
|
||||
ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true},
|
||||
Kind: string(chatdebug.KindChatTurn),
|
||||
Status: string(chatdebug.StatusInProgress),
|
||||
StartedAt: sql.NullTime{Time: staleTime, Valid: true},
|
||||
UpdatedAt: sql.NullTime{Time: staleTime, Valid: true},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = db.InsertChatDebugStep(ctx, database.InsertChatDebugStepParams{
|
||||
RunID: run.ID,
|
||||
StepNumber: 1,
|
||||
Operation: string(chatdebug.OperationStream),
|
||||
Status: string(chatdebug.StatusInProgress),
|
||||
StartedAt: sql.NullTime{Time: staleTime, Valid: true},
|
||||
UpdatedAt: sql.NullTime{Time: staleTime, Valid: true},
|
||||
ChatID: chat.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
memoryPubsub := dbpubsub.NewInMemory()
|
||||
svc := chatdebug.NewService(db, testutil.Logger(t), memoryPubsub)
|
||||
type eventResult struct {
|
||||
event chatdebug.DebugEvent
|
||||
err error
|
||||
}
|
||||
events := make(chan eventResult, 1)
|
||||
cancel, err := memoryPubsub.Subscribe(chatdebug.PubsubChannel(uuid.Nil),
|
||||
func(_ context.Context, message []byte) {
|
||||
var event chatdebug.DebugEvent
|
||||
unmarshalErr := json.Unmarshal(message, &event)
|
||||
events <- eventResult{event: event, err: unmarshalErr}
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
defer cancel()
|
||||
|
||||
result, err := svc.FinalizeStale(ctx)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 1, result.RunsFinalized)
|
||||
require.EqualValues(t, 1, result.StepsFinalized)
|
||||
|
||||
select {
|
||||
case received := <-events:
|
||||
require.NoError(t, received.err)
|
||||
require.Equal(t, chatdebug.EventKindFinalize, received.event.Kind)
|
||||
require.Equal(t, uuid.Nil, received.event.ChatID)
|
||||
require.Equal(t, uuid.Nil, received.event.RunID)
|
||||
require.Equal(t, uuid.Nil, received.event.StepID)
|
||||
case <-time.After(testutil.WaitShort):
|
||||
t.Fatal("timed out waiting for finalize event")
|
||||
}
|
||||
}
|
||||
|
||||
func TestService_FinalizeStale_NoChangesDoesNotBroadcast(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
owner, chat, _ := seedChat(ctx, t, db)
|
||||
require.NotEqual(t, uuid.Nil, owner.ID)
|
||||
|
||||
memoryPubsub := dbpubsub.NewInMemory()
|
||||
svc := chatdebug.NewService(db, testutil.Logger(t), memoryPubsub)
|
||||
events := make(chan chatdebug.DebugEvent, 1)
|
||||
cancel, err := memoryPubsub.Subscribe(chatdebug.PubsubChannel(uuid.Nil),
|
||||
func(_ context.Context, message []byte) {
|
||||
var event chatdebug.DebugEvent
|
||||
if err := json.Unmarshal(message, &event); err == nil {
|
||||
events <- event
|
||||
}
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
defer cancel()
|
||||
|
||||
result, err := svc.FinalizeStale(ctx)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 0, result.RunsFinalized)
|
||||
require.EqualValues(t, 0, result.StepsFinalized)
|
||||
|
||||
select {
|
||||
case event := <-events:
|
||||
t.Fatalf("unexpected finalize event: %+v", event)
|
||||
default:
|
||||
}
|
||||
|
||||
_ = chat // keep seeded chat usage explicit for test readability.
|
||||
}
|
||||
|
||||
func TestService_PublishesEvents(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
owner, chat, model := seedChat(ctx, t, db)
|
||||
require.NotEqual(t, uuid.Nil, owner.ID)
|
||||
|
||||
memoryPubsub := dbpubsub.NewInMemory()
|
||||
svc := chatdebug.NewService(db, testutil.Logger(t), memoryPubsub)
|
||||
type eventResult struct {
|
||||
event chatdebug.DebugEvent
|
||||
err error
|
||||
}
|
||||
events := make(chan eventResult, 1)
|
||||
cancel, err := memoryPubsub.Subscribe(chatdebug.PubsubChannel(chat.ID),
|
||||
func(_ context.Context, message []byte) {
|
||||
var event chatdebug.DebugEvent
|
||||
unmarshalErr := json.Unmarshal(message, &event)
|
||||
events <- eventResult{event: event, err: unmarshalErr}
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
defer cancel()
|
||||
|
||||
run, err := svc.CreateRun(ctx, chatdebug.CreateRunParams{
|
||||
ChatID: chat.ID,
|
||||
ModelConfigID: model.ID,
|
||||
Kind: chatdebug.KindChatTurn,
|
||||
Status: chatdebug.StatusInProgress,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case received := <-events:
|
||||
require.NoError(t, received.err)
|
||||
require.Equal(t, chatdebug.EventKindRunUpdate, received.event.Kind)
|
||||
require.Equal(t, chat.ID, received.event.ChatID)
|
||||
require.Equal(t, run.ID, received.event.RunID)
|
||||
require.Equal(t, uuid.Nil, received.event.StepID)
|
||||
case <-time.After(testutil.WaitShort):
|
||||
t.Fatal("timed out waiting for debug event")
|
||||
}
|
||||
|
||||
select {
|
||||
case received := <-events:
|
||||
t.Fatalf("unexpected extra event: %+v", received.event)
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func newFixture(t *testing.T) testFixture {
|
||||
t.Helper()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
owner, chat, model := seedChat(ctx, t, db)
|
||||
return testFixture{
|
||||
ctx: ctx,
|
||||
db: db,
|
||||
svc: chatdebug.NewService(db, testutil.Logger(t), nil),
|
||||
owner: owner,
|
||||
chat: chat,
|
||||
model: model,
|
||||
}
|
||||
}
|
||||
|
||||
func seedChat(
|
||||
ctx context.Context,
|
||||
t *testing.T,
|
||||
db database.Store,
|
||||
) (database.User, database.Chat, database.ChatModelConfig) {
|
||||
t.Helper()
|
||||
|
||||
owner := dbgen.User(t, db, database.User{})
|
||||
providerName := "openai"
|
||||
_, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
||||
Provider: providerName,
|
||||
DisplayName: "OpenAI",
|
||||
APIKey: "test-key",
|
||||
CreatedBy: uuid.NullUUID{UUID: owner.ID, Valid: true},
|
||||
Enabled: true,
|
||||
CentralApiKeyEnabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
model, err := db.InsertChatModelConfig(ctx,
|
||||
database.InsertChatModelConfigParams{
|
||||
Provider: providerName,
|
||||
Model: "model-" + uuid.NewString(),
|
||||
DisplayName: "Test Model",
|
||||
CreatedBy: uuid.NullUUID{UUID: owner.ID, Valid: true},
|
||||
UpdatedBy: uuid.NullUUID{UUID: owner.ID, Valid: true},
|
||||
Enabled: true,
|
||||
IsDefault: true,
|
||||
ContextLimit: 128000,
|
||||
CompressionThreshold: 70,
|
||||
Options: json.RawMessage(`{}`),
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
chat := insertChat(ctx, t, db, owner.ID, model.ID)
|
||||
return owner, chat, model
|
||||
}
|
||||
|
||||
func insertChat(
|
||||
ctx context.Context,
|
||||
t *testing.T,
|
||||
db database.Store,
|
||||
ownerID uuid.UUID,
|
||||
modelID uuid.UUID,
|
||||
) database.Chat {
|
||||
t.Helper()
|
||||
|
||||
chat, err := db.InsertChat(ctx, database.InsertChatParams{
|
||||
Status: database.ChatStatusWaiting,
|
||||
OwnerID: ownerID,
|
||||
LastModelConfigID: modelID,
|
||||
Title: "chat-" + uuid.NewString(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
return chat
|
||||
}
|
||||
|
||||
func insertMessage(
|
||||
ctx context.Context,
|
||||
t *testing.T,
|
||||
db database.Store,
|
||||
chatID uuid.UUID,
|
||||
createdBy uuid.UUID,
|
||||
modelID uuid.UUID,
|
||||
role database.ChatMessageRole,
|
||||
text string,
|
||||
) database.ChatMessage {
|
||||
t.Helper()
|
||||
|
||||
parts, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{
|
||||
codersdk.ChatMessageText(text),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
messages, err := db.InsertChatMessages(ctx, database.InsertChatMessagesParams{
|
||||
ChatID: chatID,
|
||||
CreatedBy: []uuid.UUID{createdBy},
|
||||
ModelConfigID: []uuid.UUID{modelID},
|
||||
Role: []database.ChatMessageRole{role},
|
||||
Content: []string{string(parts.RawMessage)},
|
||||
ContentVersion: []int16{chatprompt.CurrentContentVersion},
|
||||
Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth},
|
||||
InputTokens: []int64{0},
|
||||
OutputTokens: []int64{0},
|
||||
TotalTokens: []int64{0},
|
||||
ReasoningTokens: []int64{0},
|
||||
CacheCreationTokens: []int64{0},
|
||||
CacheReadTokens: []int64{0},
|
||||
ContextLimit: []int64{0},
|
||||
Compressed: []bool{false},
|
||||
TotalCostMicros: []int64{0},
|
||||
RuntimeMs: []int64{0},
|
||||
ProviderResponseID: []string{""},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, messages, 1)
|
||||
return messages[0]
|
||||
}
|
||||
|
||||
func createRun(t *testing.T, fixture testFixture) database.ChatDebugRun {
|
||||
t.Helper()
|
||||
|
||||
run, err := fixture.svc.CreateRun(fixture.ctx, chatdebug.CreateRunParams{
|
||||
ChatID: fixture.chat.ID,
|
||||
ModelConfigID: fixture.model.ID,
|
||||
Kind: chatdebug.KindChatTurn,
|
||||
Status: chatdebug.StatusInProgress,
|
||||
Provider: fixture.model.Provider,
|
||||
Model: fixture.model.Model,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
return run
|
||||
}
|
||||
|
||||
func assertRunMatches(
|
||||
t *testing.T,
|
||||
run database.ChatDebugRun,
|
||||
chatID uuid.UUID,
|
||||
rootChatID uuid.UUID,
|
||||
parentChatID uuid.UUID,
|
||||
modelID uuid.UUID,
|
||||
triggerMessageID int64,
|
||||
historyTipMessageID int64,
|
||||
kind chatdebug.RunKind,
|
||||
status chatdebug.Status,
|
||||
provider string,
|
||||
model string,
|
||||
summary string,
|
||||
) {
|
||||
t.Helper()
|
||||
|
||||
require.Equal(t, chatID, run.ChatID)
|
||||
require.True(t, run.RootChatID.Valid)
|
||||
require.Equal(t, rootChatID, run.RootChatID.UUID)
|
||||
require.True(t, run.ParentChatID.Valid)
|
||||
require.Equal(t, parentChatID, run.ParentChatID.UUID)
|
||||
require.True(t, run.ModelConfigID.Valid)
|
||||
require.Equal(t, modelID, run.ModelConfigID.UUID)
|
||||
require.True(t, run.TriggerMessageID.Valid)
|
||||
require.Equal(t, triggerMessageID, run.TriggerMessageID.Int64)
|
||||
require.True(t, run.HistoryTipMessageID.Valid)
|
||||
require.Equal(t, historyTipMessageID, run.HistoryTipMessageID.Int64)
|
||||
require.Equal(t, string(kind), run.Kind)
|
||||
require.Equal(t, string(status), run.Status)
|
||||
require.True(t, run.Provider.Valid)
|
||||
require.Equal(t, provider, run.Provider.String)
|
||||
require.True(t, run.Model.Valid)
|
||||
require.Equal(t, model, run.Model.String)
|
||||
require.JSONEq(t, summary, string(run.Summary))
|
||||
require.False(t, run.StartedAt.IsZero())
|
||||
require.False(t, run.UpdatedAt.IsZero())
|
||||
require.False(t, run.FinishedAt.Valid)
|
||||
}
|
||||
@@ -1,57 +0,0 @@
|
||||
package chatdebug
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBeginStep_SkipsNilRunID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := ContextWithRun(context.Background(), &RunContext{ChatID: uuid.New()})
|
||||
handle, enriched := beginStep(ctx, &Service{}, RecorderOptions{ChatID: uuid.New()}, OperationGenerate, nil)
|
||||
require.Nil(t, handle)
|
||||
require.Equal(t, ctx, enriched)
|
||||
}
|
||||
|
||||
func TestTruncateLabel(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
maxLen int
|
||||
want string
|
||||
}{
|
||||
{name: "Empty", input: "", maxLen: 10, want: ""},
|
||||
{name: "WhitespaceOnly", input: " \t\n ", maxLen: 10, want: ""},
|
||||
{name: "ShortText", input: "hello world", maxLen: 20, want: "hello world"},
|
||||
{name: "ExactLength", input: "abcde", maxLen: 5, want: "abcde"},
|
||||
{name: "LongTextTruncated", input: "abcdefghij", maxLen: 5, want: "abcd…"},
|
||||
{name: "NegativeMaxLen", input: "hello", maxLen: -1, want: ""},
|
||||
{name: "ZeroMaxLen", input: "hello", maxLen: 0, want: ""},
|
||||
{name: "SingleRuneLimit", input: "hello", maxLen: 1, want: "…"},
|
||||
{name: "MultipleWhitespaceRuns", input: " hello world \t again ", maxLen: 100, want: "hello world again"},
|
||||
{name: "UnicodeRunes", input: "こんにちは世界", maxLen: 3, want: "こん…"},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := TruncateLabel(tc.input, tc.maxLen)
|
||||
require.Equal(t, tc.want, got)
|
||||
require.LessOrEqual(t, utf8.RuneCountInString(got), maxInt(tc.maxLen, 0))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func maxInt(a, b int) int {
|
||||
if a > b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
@@ -1,218 +0,0 @@
|
||||
package chatdebug
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"regexp"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/google/uuid"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
)
|
||||
|
||||
// MaxLabelLength is the default rune limit for truncated labels.
|
||||
const MaxLabelLength = 100
|
||||
|
||||
// whitespaceRun matches one or more consecutive whitespace characters.
|
||||
var whitespaceRun = regexp.MustCompile(`\s+`)
|
||||
|
||||
// TruncateLabel whitespace-normalizes and truncates text to maxLen runes.
|
||||
// Returns "" if input is empty or whitespace-only.
|
||||
func TruncateLabel(text string, maxLen int) string {
|
||||
if maxLen < 0 {
|
||||
maxLen = 0
|
||||
}
|
||||
|
||||
normalized := strings.TrimSpace(whitespaceRun.ReplaceAllString(text, " "))
|
||||
if normalized == "" || maxLen == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
if utf8.RuneCountInString(normalized) <= maxLen {
|
||||
return normalized
|
||||
}
|
||||
if maxLen == 1 {
|
||||
return "…"
|
||||
}
|
||||
|
||||
// Truncate to leave room for the trailing ellipsis within maxLen.
|
||||
runes := []rune(normalized)
|
||||
return string(runes[:maxLen-1]) + "…"
|
||||
}
|
||||
|
||||
// SeedSummary builds a base summary map with a first_message label.
|
||||
// Returns nil if label is empty.
|
||||
func SeedSummary(label string) map[string]any {
|
||||
if label == "" {
|
||||
return nil
|
||||
}
|
||||
return map[string]any{"first_message": label}
|
||||
}
|
||||
|
||||
// ExtractFirstUserText extracts the plain text content from a
|
||||
// fantasy.Prompt for the first user message. Used to derive
|
||||
// first_message labels at run creation time.
|
||||
func ExtractFirstUserText(prompt fantasy.Prompt) string {
|
||||
for _, msg := range prompt {
|
||||
if msg.Role != fantasy.MessageRoleUser {
|
||||
continue
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
for _, part := range msg.Content {
|
||||
tp, ok := fantasy.AsMessagePart[fantasy.TextPart](part)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
_, _ = sb.WriteString(tp.Text)
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// AggregateRunSummary reads all steps for the given run, computes token
|
||||
// totals, and merges them with the run's existing summary (preserving any
|
||||
// seeded first_message label). The baseSummary parameter should be the
|
||||
// current run summary (may be nil).
|
||||
func (s *Service) AggregateRunSummary(
|
||||
ctx context.Context,
|
||||
runID uuid.UUID,
|
||||
baseSummary map[string]any,
|
||||
) (map[string]any, error) {
|
||||
if runID == uuid.Nil {
|
||||
return baseSummary, nil
|
||||
}
|
||||
|
||||
steps, err := s.db.GetChatDebugStepsByRunID(chatdContext(ctx), runID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Start from a shallow copy of baseSummary to avoid mutating the
|
||||
// caller's map.
|
||||
// Capacity hint: baseSummary entries plus 8 derived keys
|
||||
// (step_count, total_input_tokens, total_output_tokens,
|
||||
// total_reasoning_tokens, total_cache_creation_tokens,
|
||||
// total_cache_read_tokens, has_error, endpoint_label).
|
||||
result := make(map[string]any, len(baseSummary)+8)
|
||||
for k, v := range baseSummary {
|
||||
result[k] = v
|
||||
}
|
||||
|
||||
// Clear derived fields before recomputing them so stale values from a
|
||||
// previous aggregation do not survive when the new totals are zero or
|
||||
// the endpoint label is unavailable.
|
||||
for _, key := range []string{
|
||||
"step_count",
|
||||
"total_input_tokens",
|
||||
"total_output_tokens",
|
||||
"total_reasoning_tokens",
|
||||
"total_cache_creation_tokens",
|
||||
"total_cache_read_tokens",
|
||||
"endpoint_label",
|
||||
"has_error",
|
||||
} {
|
||||
delete(result, key)
|
||||
}
|
||||
var (
|
||||
totalInput int64
|
||||
totalOutput int64
|
||||
totalReasoning int64
|
||||
totalCacheCreation int64
|
||||
totalCacheRead int64
|
||||
hasError bool
|
||||
)
|
||||
|
||||
for _, step := range steps {
|
||||
// Flag runs that hit a real error. Interrupted steps represent
|
||||
// user-initiated cancellation (e.g. clicking Stop) and should
|
||||
// not trigger the error indicator in the debug panel.
|
||||
// A JSONB null (used by jsonClear to erase a prior error) is
|
||||
// Valid but carries no meaningful content, so exclude it.
|
||||
errorIsReal := step.Error.Valid &&
|
||||
len(step.Error.RawMessage) > 0 &&
|
||||
!bytes.Equal(step.Error.RawMessage, []byte("null"))
|
||||
if step.Status == string(StatusError) ||
|
||||
(errorIsReal && step.Status != string(StatusInterrupted)) {
|
||||
hasError = true
|
||||
}
|
||||
if !step.Usage.Valid || len(step.Usage.RawMessage) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
var usage fantasy.Usage
|
||||
if err := json.Unmarshal(step.Usage.RawMessage, &usage); err != nil {
|
||||
s.log.Warn(ctx, "skipping malformed step usage JSON",
|
||||
slog.Error(err),
|
||||
slog.F("run_id", runID),
|
||||
slog.F("step_id", step.ID),
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
totalInput += usage.InputTokens
|
||||
totalOutput += usage.OutputTokens
|
||||
totalReasoning += usage.ReasoningTokens
|
||||
totalCacheCreation += usage.CacheCreationTokens
|
||||
totalCacheRead += usage.CacheReadTokens
|
||||
}
|
||||
|
||||
result["step_count"] = len(steps)
|
||||
result["total_input_tokens"] = totalInput
|
||||
result["total_output_tokens"] = totalOutput
|
||||
|
||||
// Only include reasoning/cache fields when non-zero to keep the
|
||||
// summary compact for the common case.
|
||||
if totalReasoning > 0 {
|
||||
result["total_reasoning_tokens"] = totalReasoning
|
||||
}
|
||||
if totalCacheCreation > 0 {
|
||||
result["total_cache_creation_tokens"] = totalCacheCreation
|
||||
}
|
||||
if totalCacheRead > 0 {
|
||||
result["total_cache_read_tokens"] = totalCacheRead
|
||||
}
|
||||
|
||||
if hasError {
|
||||
result["has_error"] = true
|
||||
}
|
||||
|
||||
// Derive endpoint_label from the first completed attempt's path
|
||||
// across all steps. This gives the debug panel a meaningful
|
||||
// identifier like "POST /v1/messages" for the run row.
|
||||
if label := extractEndpointLabel(steps); label != "" {
|
||||
result["endpoint_label"] = label
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// extractEndpointLabel scans steps for the first completed attempt with a
|
||||
// non-empty path and returns "METHOD /path" (or just "/path").
|
||||
func extractEndpointLabel(steps []database.ChatDebugStep) string {
|
||||
for _, step := range steps {
|
||||
if len(step.Attempts) == 0 {
|
||||
continue
|
||||
}
|
||||
var attempts []Attempt
|
||||
if err := json.Unmarshal(step.Attempts, &attempts); err != nil {
|
||||
continue
|
||||
}
|
||||
for _, a := range attempts {
|
||||
if a.Status != attemptStatusCompleted || a.Path == "" {
|
||||
continue
|
||||
}
|
||||
if a.Method != "" {
|
||||
return a.Method + " " + a.Path
|
||||
}
|
||||
return a.Path
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@@ -1,416 +0,0 @@
|
||||
package chatdebug_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"unicode/utf8"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatdebug"
|
||||
)
|
||||
|
||||
func TestTruncateLabel(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
maxLen int
|
||||
want string
|
||||
}{
|
||||
{name: "Empty", input: "", maxLen: 10, want: ""},
|
||||
{name: "WhitespaceOnly", input: " \t\n ", maxLen: 10, want: ""},
|
||||
{name: "ShortText", input: "hello world", maxLen: 20, want: "hello world"},
|
||||
{name: "ExactLength", input: "abcde", maxLen: 5, want: "abcde"},
|
||||
{name: "LongTextTruncated", input: "abcdefghij", maxLen: 5, want: "abcd…"},
|
||||
{name: "NegativeMaxLen", input: "hello", maxLen: -1, want: ""},
|
||||
{name: "ZeroMaxLen", input: "hello", maxLen: 0, want: ""},
|
||||
{name: "SingleRuneLimit", input: "hello", maxLen: 1, want: "…"},
|
||||
{name: "MultipleWhitespaceRuns", input: " hello world \t again ", maxLen: 100, want: "hello world again"},
|
||||
{name: "UnicodeRunes", input: "こんにちは世界", maxLen: 3, want: "こん…"},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := chatdebug.TruncateLabel(tc.input, tc.maxLen)
|
||||
require.Equal(t, tc.want, got)
|
||||
require.LessOrEqual(t, utf8.RuneCountInString(got), maxInt(tc.maxLen, 0))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func maxInt(a, b int) int {
|
||||
if a > b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func TestSeedSummary(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("NonEmptyLabel", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := chatdebug.SeedSummary("hello world")
|
||||
require.Equal(t, map[string]any{"first_message": "hello world"}, got)
|
||||
})
|
||||
|
||||
t.Run("EmptyLabel", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := chatdebug.SeedSummary("")
|
||||
require.Nil(t, got)
|
||||
})
|
||||
}
|
||||
|
||||
func TestExtractFirstUserText(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("EmptyPrompt", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := chatdebug.ExtractFirstUserText(fantasy.Prompt{})
|
||||
require.Equal(t, "", got)
|
||||
})
|
||||
|
||||
t.Run("NoUserMessages", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
prompt := fantasy.Prompt{
|
||||
{
|
||||
Role: fantasy.MessageRoleSystem,
|
||||
Content: []fantasy.MessagePart{fantasy.TextPart{Text: "system"}},
|
||||
},
|
||||
{
|
||||
Role: fantasy.MessageRoleAssistant,
|
||||
Content: []fantasy.MessagePart{fantasy.TextPart{Text: "assistant"}},
|
||||
},
|
||||
}
|
||||
got := chatdebug.ExtractFirstUserText(prompt)
|
||||
require.Equal(t, "", got)
|
||||
})
|
||||
|
||||
t.Run("FirstUserMessageMixedParts", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
prompt := fantasy.Prompt{
|
||||
{
|
||||
Role: fantasy.MessageRoleUser,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.TextPart{Text: "hello "},
|
||||
fantasy.FilePart{Filename: "test.png"},
|
||||
fantasy.TextPart{Text: "world"},
|
||||
},
|
||||
},
|
||||
}
|
||||
got := chatdebug.ExtractFirstUserText(prompt)
|
||||
require.Equal(t, "hello world", got)
|
||||
})
|
||||
|
||||
t.Run("MultipleUserMessagesReturnsFirst", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
prompt := fantasy.Prompt{
|
||||
{
|
||||
Role: fantasy.MessageRoleSystem,
|
||||
Content: []fantasy.MessagePart{fantasy.TextPart{Text: "system"}},
|
||||
},
|
||||
{
|
||||
Role: fantasy.MessageRoleUser,
|
||||
Content: []fantasy.MessagePart{fantasy.TextPart{Text: "first"}},
|
||||
},
|
||||
{
|
||||
Role: fantasy.MessageRoleUser,
|
||||
Content: []fantasy.MessagePart{fantasy.TextPart{Text: "second"}},
|
||||
},
|
||||
}
|
||||
got := chatdebug.ExtractFirstUserText(prompt)
|
||||
require.Equal(t, "first", got)
|
||||
})
|
||||
}
|
||||
|
||||
func TestService_AggregateRunSummary(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("NilRunID", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
fixture := newFixture(t)
|
||||
got, err := fixture.svc.AggregateRunSummary(fixture.ctx, uuid.Nil, nil)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, got)
|
||||
})
|
||||
|
||||
t.Run("NilBaseSummary", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
fixture := newFixture(t)
|
||||
run := createRun(t, fixture)
|
||||
|
||||
// Create a step with usage.
|
||||
step := createTestStep(t, fixture, run.ID)
|
||||
updateTestStepWithUsage(t, fixture, step.ID, 10, 5, 0, 0)
|
||||
|
||||
got, err := fixture.svc.AggregateRunSummary(fixture.ctx, run.ID, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, got)
|
||||
require.EqualValues(t, 1, got["step_count"])
|
||||
require.EqualValues(t, int64(10), got["total_input_tokens"])
|
||||
require.EqualValues(t, int64(5), got["total_output_tokens"])
|
||||
})
|
||||
|
||||
t.Run("PreservesFirstMessage", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
fixture := newFixture(t)
|
||||
run := createRun(t, fixture)
|
||||
|
||||
step := createTestStep(t, fixture, run.ID)
|
||||
updateTestStepWithUsage(t, fixture, step.ID, 20, 10, 0, 0)
|
||||
|
||||
base := map[string]any{"first_message": "hello world"}
|
||||
got, err := fixture.svc.AggregateRunSummary(fixture.ctx, run.ID, base)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "hello world", got["first_message"])
|
||||
require.EqualValues(t, 1, got["step_count"])
|
||||
require.EqualValues(t, int64(20), got["total_input_tokens"])
|
||||
require.EqualValues(t, int64(10), got["total_output_tokens"])
|
||||
})
|
||||
|
||||
t.Run("ClearsStaleDerivedFields", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
fixture := newFixture(t)
|
||||
run := createRun(t, fixture)
|
||||
|
||||
step := createTestStep(t, fixture, run.ID)
|
||||
updateTestStepWithUsage(t, fixture, step.ID, 10, 5, 0, 0)
|
||||
|
||||
base := map[string]any{
|
||||
"first_message": "hello world",
|
||||
"step_count": 9,
|
||||
"total_input_tokens": 999,
|
||||
"total_output_tokens": 888,
|
||||
"total_reasoning_tokens": 777,
|
||||
"total_cache_creation_tokens": 100,
|
||||
"total_cache_read_tokens": 200,
|
||||
"has_error": true,
|
||||
"endpoint_label": "POST /stale",
|
||||
}
|
||||
|
||||
got, err := fixture.svc.AggregateRunSummary(fixture.ctx, run.ID, base)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "hello world", got["first_message"])
|
||||
require.EqualValues(t, 1, got["step_count"])
|
||||
require.EqualValues(t, int64(10), got["total_input_tokens"])
|
||||
require.EqualValues(t, int64(5), got["total_output_tokens"])
|
||||
// Stale reasoning tokens must be cleared because the step
|
||||
// has zero reasoning tokens.
|
||||
require.NotContains(t, got, "total_reasoning_tokens")
|
||||
require.NotContains(t, got, "total_cache_creation_tokens")
|
||||
require.NotContains(t, got, "total_cache_read_tokens")
|
||||
// has_error must be cleared because the step is not in error
|
||||
// status and has no error payload.
|
||||
require.NotContains(t, got, "has_error")
|
||||
require.NotContains(t, got, "endpoint_label")
|
||||
})
|
||||
|
||||
t.Run("RecomputesHasErrorAndCompletedEndpointLabel", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
fixture := newFixture(t)
|
||||
run := createRun(t, fixture)
|
||||
|
||||
step1 := createTestStep(t, fixture, run.ID)
|
||||
_, err := fixture.svc.UpdateStep(fixture.ctx, chatdebug.UpdateStepParams{
|
||||
ID: step1.ID,
|
||||
ChatID: fixture.chat.ID,
|
||||
Status: chatdebug.StatusError,
|
||||
Attempts: []chatdebug.Attempt{{
|
||||
Number: 1,
|
||||
Status: "failed",
|
||||
Method: "POST",
|
||||
Path: "/failed",
|
||||
}},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
step2 := createTestStepN(t, fixture, run.ID, 2)
|
||||
_, err = fixture.svc.UpdateStep(fixture.ctx, chatdebug.UpdateStepParams{
|
||||
ID: step2.ID,
|
||||
ChatID: fixture.chat.ID,
|
||||
Status: chatdebug.StatusCompleted,
|
||||
Attempts: []chatdebug.Attempt{{
|
||||
Number: 1,
|
||||
Status: "completed",
|
||||
Method: "POST",
|
||||
Path: "/v1/messages",
|
||||
}},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
got, err := fixture.svc.AggregateRunSummary(fixture.ctx, run.ID, nil)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, true, got["has_error"])
|
||||
require.Equal(t, "POST /v1/messages", got["endpoint_label"])
|
||||
})
|
||||
|
||||
t.Run("MultipleStepsSumTokens", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
fixture := newFixture(t)
|
||||
run := createRun(t, fixture)
|
||||
|
||||
step1 := createTestStep(t, fixture, run.ID)
|
||||
updateTestStepWithUsage(t, fixture, step1.ID, 10, 5, 2, 3)
|
||||
|
||||
step2 := createTestStepN(t, fixture, run.ID, 2)
|
||||
updateTestStepWithUsage(t, fixture, step2.ID, 15, 7, 1, 4)
|
||||
|
||||
got, err := fixture.svc.AggregateRunSummary(fixture.ctx, run.ID, nil)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 2, got["step_count"])
|
||||
require.EqualValues(t, int64(25), got["total_input_tokens"])
|
||||
require.EqualValues(t, int64(12), got["total_output_tokens"])
|
||||
require.EqualValues(t, int64(3), got["total_cache_creation_tokens"])
|
||||
require.EqualValues(t, int64(7), got["total_cache_read_tokens"])
|
||||
})
|
||||
|
||||
t.Run("StepWithNilUsageContributesZeroTokens", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
fixture := newFixture(t)
|
||||
run := createRun(t, fixture)
|
||||
|
||||
// Step with usage.
|
||||
step1 := createTestStep(t, fixture, run.ID)
|
||||
updateTestStepWithUsage(t, fixture, step1.ID, 10, 5, 0, 0)
|
||||
|
||||
// Step without usage (just complete it, no usage).
|
||||
step2 := createTestStepN(t, fixture, run.ID, 2)
|
||||
_, err := fixture.svc.UpdateStep(fixture.ctx, chatdebug.UpdateStepParams{
|
||||
ID: step2.ID,
|
||||
ChatID: fixture.chat.ID,
|
||||
Status: chatdebug.StatusCompleted,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
got, err := fixture.svc.AggregateRunSummary(fixture.ctx, run.ID, nil)
|
||||
require.NoError(t, err)
|
||||
// Both steps are counted even though one has no usage.
|
||||
require.EqualValues(t, 2, got["step_count"])
|
||||
require.EqualValues(t, int64(10), got["total_input_tokens"])
|
||||
require.EqualValues(t, int64(5), got["total_output_tokens"])
|
||||
})
|
||||
|
||||
t.Run("ZeroCacheTotalsOmitCacheFields", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
fixture := newFixture(t)
|
||||
run := createRun(t, fixture)
|
||||
|
||||
step := createTestStep(t, fixture, run.ID)
|
||||
updateTestStepWithUsage(t, fixture, step.ID, 10, 5, 0, 0)
|
||||
|
||||
got, err := fixture.svc.AggregateRunSummary(fixture.ctx, run.ID, nil)
|
||||
require.NoError(t, err)
|
||||
_, hasCacheCreation := got["total_cache_creation_tokens"]
|
||||
_, hasCacheRead := got["total_cache_read_tokens"]
|
||||
require.False(t, hasCacheCreation,
|
||||
"cache creation tokens should be omitted when zero")
|
||||
require.False(t, hasCacheRead,
|
||||
"cache read tokens should be omitted when zero")
|
||||
})
|
||||
|
||||
t.Run("ReasoningTokensSummedAcrossSteps", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
fixture := newFixture(t)
|
||||
run := createRun(t, fixture)
|
||||
|
||||
step1 := createTestStep(t, fixture, run.ID)
|
||||
updateTestStepWithFullUsage(t, fixture, step1.ID, 10, 5, 20, 0, 0)
|
||||
|
||||
step2 := createTestStepN(t, fixture, run.ID, 2)
|
||||
updateTestStepWithFullUsage(t, fixture, step2.ID, 15, 7, 30, 0, 0)
|
||||
|
||||
got, err := fixture.svc.AggregateRunSummary(fixture.ctx, run.ID, nil)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 2, got["step_count"])
|
||||
require.EqualValues(t, int64(25), got["total_input_tokens"])
|
||||
require.EqualValues(t, int64(12), got["total_output_tokens"])
|
||||
require.EqualValues(t, int64(50), got["total_reasoning_tokens"],
|
||||
"reasoning tokens should be summed across steps")
|
||||
})
|
||||
|
||||
t.Run("ZeroReasoningTokensOmitsField", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
fixture := newFixture(t)
|
||||
run := createRun(t, fixture)
|
||||
|
||||
step := createTestStep(t, fixture, run.ID)
|
||||
updateTestStepWithFullUsage(t, fixture, step.ID, 10, 5, 0, 0, 0)
|
||||
|
||||
got, err := fixture.svc.AggregateRunSummary(fixture.ctx, run.ID, nil)
|
||||
require.NoError(t, err)
|
||||
_, hasReasoning := got["total_reasoning_tokens"]
|
||||
require.False(t, hasReasoning,
|
||||
"reasoning tokens should be omitted when zero")
|
||||
})
|
||||
}
|
||||
|
||||
// createTestStep is a thin helper that creates a debug step with
|
||||
// step number 1 for the given run.
|
||||
func createTestStep(
|
||||
t *testing.T,
|
||||
fixture testFixture,
|
||||
runID uuid.UUID,
|
||||
) database.ChatDebugStep {
|
||||
t.Helper()
|
||||
return createTestStepN(t, fixture, runID, 1)
|
||||
}
|
||||
|
||||
// createTestStepN creates a debug step with the given step number.
|
||||
func createTestStepN(
|
||||
t *testing.T,
|
||||
fixture testFixture,
|
||||
runID uuid.UUID,
|
||||
stepNumber int32,
|
||||
) database.ChatDebugStep {
|
||||
t.Helper()
|
||||
step, err := fixture.svc.CreateStep(fixture.ctx, chatdebug.CreateStepParams{
|
||||
RunID: runID,
|
||||
ChatID: fixture.chat.ID,
|
||||
StepNumber: stepNumber,
|
||||
Operation: chatdebug.OperationGenerate,
|
||||
Status: chatdebug.StatusInProgress,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
return step
|
||||
}
|
||||
|
||||
// updateTestStepWithUsage completes a step and sets token usage fields.
|
||||
func updateTestStepWithUsage(
|
||||
t *testing.T,
|
||||
fixture testFixture,
|
||||
stepID uuid.UUID,
|
||||
input, output, cacheCreation, cacheRead int64,
|
||||
) {
|
||||
t.Helper()
|
||||
updateTestStepWithFullUsage(t, fixture, stepID, input, output, 0, cacheCreation, cacheRead)
|
||||
}
|
||||
|
||||
// updateTestStepWithFullUsage completes a step with all token usage
|
||||
// fields, including reasoning tokens.
|
||||
func updateTestStepWithFullUsage(
|
||||
t *testing.T,
|
||||
fixture testFixture,
|
||||
stepID uuid.UUID,
|
||||
input, output, reasoning, cacheCreation, cacheRead int64,
|
||||
) {
|
||||
t.Helper()
|
||||
_, err := fixture.svc.UpdateStep(fixture.ctx, chatdebug.UpdateStepParams{
|
||||
ID: stepID,
|
||||
ChatID: fixture.chat.ID,
|
||||
Status: chatdebug.StatusCompleted,
|
||||
Usage: map[string]any{
|
||||
"input_tokens": input,
|
||||
"output_tokens": output,
|
||||
"reasoning_tokens": reasoning,
|
||||
"cache_creation_tokens": cacheCreation,
|
||||
"cache_read_tokens": cacheRead,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
@@ -1,382 +0,0 @@
|
||||
package chatdebug
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"mime"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// attemptStatusCompleted is the status recorded when a response body
|
||||
// is fully read without transport-level errors.
|
||||
const attemptStatusCompleted = "completed"
|
||||
|
||||
// attemptStatusFailed is the status recorded when a transport error
|
||||
// or body read error occurs.
|
||||
const attemptStatusFailed = "failed"
|
||||
|
||||
// maxRecordedRequestBodyBytes caps in-memory request capture when GetBody
|
||||
// is available.
|
||||
const maxRecordedRequestBodyBytes = 50_000
|
||||
|
||||
// maxRecordedResponseBodyBytes caps in-memory response capture.
|
||||
const maxRecordedResponseBodyBytes = 50_000
|
||||
|
||||
// RecordingTransport captures HTTP request/response data for debug steps.
|
||||
// When the request context carries an attemptSink, it records each round
|
||||
// trip. Otherwise it delegates directly.
|
||||
type RecordingTransport struct {
|
||||
// Base is the underlying transport. nil defaults to http.DefaultTransport.
|
||||
Base http.RoundTripper
|
||||
}
|
||||
|
||||
var _ http.RoundTripper = (*RecordingTransport)(nil)
|
||||
|
||||
func (t *RecordingTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
if req == nil {
|
||||
panic("chatdebug: nil request")
|
||||
}
|
||||
|
||||
base := t.Base
|
||||
if base == nil {
|
||||
base = http.DefaultTransport
|
||||
}
|
||||
|
||||
sink := attemptSinkFromContext(req.Context())
|
||||
if sink == nil {
|
||||
return base.RoundTrip(req)
|
||||
}
|
||||
|
||||
requestHeaders := RedactHeaders(req.Header)
|
||||
|
||||
// Capture method and URL/path from the request.
|
||||
method := req.Method
|
||||
reqURL := ""
|
||||
reqPath := ""
|
||||
if req.URL != nil {
|
||||
reqURL = redactURL(req.URL)
|
||||
reqPath = req.URL.Path
|
||||
}
|
||||
|
||||
requestBody, err := captureRequestBody(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
attemptNumber := sink.nextAttemptNumber()
|
||||
|
||||
startedAt := time.Now()
|
||||
resp, err := base.RoundTrip(req)
|
||||
finishedAt := time.Now()
|
||||
durationMs := finishedAt.Sub(startedAt).Milliseconds()
|
||||
if err != nil {
|
||||
sink.record(Attempt{
|
||||
Number: attemptNumber,
|
||||
Status: attemptStatusFailed,
|
||||
Method: method,
|
||||
URL: reqURL,
|
||||
Path: reqPath,
|
||||
StartedAt: startedAt.UTC().Format(time.RFC3339Nano),
|
||||
FinishedAt: finishedAt.UTC().Format(time.RFC3339Nano),
|
||||
RequestHeaders: requestHeaders,
|
||||
RequestBody: requestBody,
|
||||
Error: err.Error(),
|
||||
DurationMs: durationMs,
|
||||
})
|
||||
return nil, err
|
||||
}
|
||||
|
||||
respHeaders := RedactHeaders(resp.Header)
|
||||
resp.Body = &recordingBody{
|
||||
inner: resp.Body,
|
||||
sink: sink,
|
||||
startedAt: startedAt,
|
||||
contentLength: resp.ContentLength,
|
||||
base: Attempt{
|
||||
Number: attemptNumber,
|
||||
Method: method,
|
||||
URL: reqURL,
|
||||
Path: reqPath,
|
||||
RequestHeaders: requestHeaders,
|
||||
RequestBody: requestBody,
|
||||
ResponseStatus: resp.StatusCode,
|
||||
ResponseHeaders: respHeaders,
|
||||
DurationMs: durationMs,
|
||||
},
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func redactURL(u *url.URL) string {
|
||||
if u == nil {
|
||||
return ""
|
||||
}
|
||||
clone := *u
|
||||
clone.User = nil
|
||||
q := clone.Query()
|
||||
for key, values := range q {
|
||||
if isSensitiveName(key) || isSensitiveJSONKey(key) {
|
||||
for i := range values {
|
||||
values[i] = RedactedValue
|
||||
}
|
||||
q[key] = values
|
||||
}
|
||||
}
|
||||
clone.RawQuery = q.Encode()
|
||||
return clone.String()
|
||||
}
|
||||
|
||||
func captureRequestBody(req *http.Request) ([]byte, error) {
|
||||
if req == nil || req.Body == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if req.GetBody != nil {
|
||||
clone, err := req.GetBody()
|
||||
if err == nil {
|
||||
defer clone.Close()
|
||||
limited, err := io.ReadAll(io.LimitReader(clone, maxRecordedRequestBodyBytes+1))
|
||||
if err == nil {
|
||||
if len(limited) > maxRecordedRequestBodyBytes {
|
||||
return []byte("[TRUNCATED]"), nil
|
||||
}
|
||||
return RedactJSONSecrets(limited), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Without GetBody we cannot safely capture the request body without
|
||||
// fully consuming a potentially large or streaming body before the
|
||||
// request is sent. Skip capture in that case to keep debug logging
|
||||
// lightweight and non-invasive.
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
type recordingBody struct {
|
||||
inner io.ReadCloser
|
||||
contentLength int64
|
||||
sink *attemptSink
|
||||
base Attempt
|
||||
startedAt time.Time
|
||||
|
||||
mu sync.Mutex
|
||||
buf bytes.Buffer
|
||||
truncated bool
|
||||
sawEOF bool
|
||||
bytesRead int64
|
||||
|
||||
recordOnce sync.Once
|
||||
closeOnce sync.Once
|
||||
}
|
||||
|
||||
func (r *recordingBody) Read(p []byte) (int, error) {
|
||||
n, err := r.inner.Read(p)
|
||||
|
||||
r.mu.Lock()
|
||||
r.bytesRead += int64(n)
|
||||
if n > 0 && !r.truncated {
|
||||
remaining := maxRecordedResponseBodyBytes - r.buf.Len()
|
||||
if remaining > 0 {
|
||||
toWrite := n
|
||||
if toWrite > remaining {
|
||||
toWrite = remaining
|
||||
r.truncated = true
|
||||
}
|
||||
_, _ = r.buf.Write(p[:toWrite])
|
||||
} else {
|
||||
r.truncated = true
|
||||
}
|
||||
}
|
||||
if errors.Is(err, io.EOF) {
|
||||
r.sawEOF = true
|
||||
}
|
||||
r.mu.Unlock()
|
||||
|
||||
if err != nil {
|
||||
r.record(err)
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (r *recordingBody) Close() error {
|
||||
r.mu.Lock()
|
||||
sawEOF := r.sawEOF
|
||||
bytesRead := r.bytesRead
|
||||
contentLength := r.contentLength
|
||||
truncated := r.truncated
|
||||
responseBody := append([]byte(nil), r.buf.Bytes()...)
|
||||
r.mu.Unlock()
|
||||
|
||||
contentType := r.base.ResponseHeaders["Content-Type"]
|
||||
shouldDrainUnknownLengthJSON := contentLength < 0 &&
|
||||
!sawEOF &&
|
||||
bytesRead > 0 &&
|
||||
!truncated &&
|
||||
isCompleteUnknownLengthJSONBody(contentType, responseBody)
|
||||
|
||||
// Always close the inner reader first so that stalled chunked
|
||||
// bodies cannot block drainToEOF indefinitely. Once inner is
|
||||
// closed, reads return immediately with an error or EOF.
|
||||
var closeErr error
|
||||
r.closeOnce.Do(func() {
|
||||
closeErr = r.inner.Close()
|
||||
})
|
||||
if closeErr != nil {
|
||||
r.record(closeErr)
|
||||
return closeErr
|
||||
}
|
||||
|
||||
// Drain remaining bytes that may already be buffered inside the
|
||||
// HTTP transport after close. Because inner is closed, this
|
||||
// finishes immediately rather than blocking on the network.
|
||||
if shouldDrainUnknownLengthJSON {
|
||||
// Best-effort drain; ignore errors since inner is closed.
|
||||
_ = r.drainToEOF()
|
||||
}
|
||||
|
||||
r.mu.Lock()
|
||||
sawEOF = r.sawEOF
|
||||
bytesRead = r.bytesRead
|
||||
contentLength = r.contentLength
|
||||
truncated = r.truncated
|
||||
responseBody = append([]byte(nil), r.buf.Bytes()...)
|
||||
r.mu.Unlock()
|
||||
|
||||
switch {
|
||||
// Only check JSON completeness when the recording buffer is
|
||||
// not truncated. A truncated buffer is an incomplete prefix
|
||||
// of the body, so the completeness check would false-positive.
|
||||
case sawEOF && !truncated && contentLength < 0 && isJSONLikeContentType(contentType) && !isCompleteUnknownLengthJSONBody(contentType, responseBody):
|
||||
r.record(io.ErrUnexpectedEOF)
|
||||
case sawEOF:
|
||||
r.record(io.EOF)
|
||||
case responseHasNoBody(r.base.Method, r.base.ResponseStatus):
|
||||
r.record(nil)
|
||||
case contentLength >= 0 && bytesRead >= contentLength:
|
||||
r.record(nil)
|
||||
case contentLength < 0 && !truncated && isCompleteUnknownLengthJSONBody(contentType, responseBody):
|
||||
r.record(nil)
|
||||
default:
|
||||
r.record(io.ErrUnexpectedEOF)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func responseHasNoBody(method string, statusCode int) bool {
|
||||
if method == http.MethodHead {
|
||||
return true
|
||||
}
|
||||
return statusCode == http.StatusNoContent ||
|
||||
statusCode == http.StatusNotModified ||
|
||||
(statusCode >= 100 && statusCode < 200)
|
||||
}
|
||||
|
||||
func isJSONLikeContentType(contentType string) bool {
|
||||
mediaType, _, err := mime.ParseMediaType(contentType)
|
||||
if err != nil {
|
||||
mediaType = strings.TrimSpace(strings.Split(contentType, ";")[0])
|
||||
}
|
||||
return mediaType == "application/json" || strings.HasSuffix(mediaType, "+json")
|
||||
}
|
||||
|
||||
// maxDrainBytes caps how many trailing bytes drainToEOF will consume.
|
||||
// This prevents Close() from blocking indefinitely on a misbehaving
|
||||
// or extremely large chunked body.
|
||||
const maxDrainBytes = 64 * 1024 // 64 KB
|
||||
|
||||
func (r *recordingBody) drainToEOF() error {
|
||||
buf := make([]byte, 4*1024)
|
||||
var drained int64
|
||||
for {
|
||||
n, err := r.inner.Read(buf)
|
||||
|
||||
r.mu.Lock()
|
||||
r.bytesRead += int64(n)
|
||||
drained += int64(n)
|
||||
if n > 0 && !r.truncated {
|
||||
remaining := maxRecordedResponseBodyBytes - r.buf.Len()
|
||||
if remaining > 0 {
|
||||
toWrite := n
|
||||
if toWrite > remaining {
|
||||
toWrite = remaining
|
||||
r.truncated = true
|
||||
}
|
||||
_, _ = r.buf.Write(buf[:toWrite])
|
||||
} else {
|
||||
r.truncated = true
|
||||
}
|
||||
}
|
||||
if errors.Is(err, io.EOF) {
|
||||
r.sawEOF = true
|
||||
}
|
||||
r.mu.Unlock()
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Safety valve: stop draining after maxDrainBytes to prevent
|
||||
// Close() from blocking indefinitely on a chunked body.
|
||||
if drained >= maxDrainBytes {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func isCompleteUnknownLengthJSONBody(contentType string, body []byte) bool {
|
||||
if !isJSONLikeContentType(contentType) {
|
||||
return false
|
||||
}
|
||||
|
||||
trimmed := bytes.TrimSpace(body)
|
||||
if len(trimmed) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
decoder := json.NewDecoder(bytes.NewReader(trimmed))
|
||||
var value any
|
||||
if err := decoder.Decode(&value); err != nil {
|
||||
return false
|
||||
}
|
||||
var extra any
|
||||
return errors.Is(decoder.Decode(&extra), io.EOF)
|
||||
}
|
||||
|
||||
func (r *recordingBody) record(err error) {
|
||||
r.recordOnce.Do(func() {
|
||||
finishedAt := time.Now()
|
||||
|
||||
r.mu.Lock()
|
||||
truncated := r.truncated
|
||||
responseBody := append([]byte(nil), r.buf.Bytes()...)
|
||||
base := r.base
|
||||
startedAt := r.startedAt
|
||||
r.mu.Unlock()
|
||||
|
||||
if truncated {
|
||||
base.ResponseBody = []byte("[TRUNCATED]")
|
||||
} else {
|
||||
base.ResponseBody = RedactJSONSecrets(responseBody)
|
||||
}
|
||||
base.StartedAt = startedAt.UTC().Format(time.RFC3339Nano)
|
||||
base.FinishedAt = finishedAt.UTC().Format(time.RFC3339Nano)
|
||||
// Recompute duration to include body read time.
|
||||
base.DurationMs = finishedAt.Sub(startedAt).Milliseconds()
|
||||
if err != nil && !errors.Is(err, io.EOF) {
|
||||
base.Error = err.Error()
|
||||
base.Status = attemptStatusFailed
|
||||
} else {
|
||||
base.Status = attemptStatusCompleted
|
||||
}
|
||||
r.sink.record(base)
|
||||
})
|
||||
}
|
||||
@@ -1,737 +0,0 @@
|
||||
package chatdebug //nolint:testpackage // Uses unexported recorder helpers.
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
func newTestSinkContext(t *testing.T) (context.Context, *attemptSink) {
|
||||
t.Helper()
|
||||
|
||||
sink := &attemptSink{}
|
||||
return withAttemptSink(context.Background(), sink), sink
|
||||
}
|
||||
|
||||
type roundTripFunc func(*http.Request) (*http.Response, error)
|
||||
|
||||
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return f(req)
|
||||
}
|
||||
|
||||
type scriptedReadCloser struct {
|
||||
chunks [][]byte
|
||||
index int
|
||||
offset int // byte offset within current chunk
|
||||
}
|
||||
|
||||
func (r *scriptedReadCloser) Read(p []byte) (int, error) {
|
||||
if r.index >= len(r.chunks) {
|
||||
return 0, io.EOF
|
||||
}
|
||||
chunk := r.chunks[r.index]
|
||||
remaining := chunk[r.offset:]
|
||||
n := copy(p, remaining)
|
||||
r.offset += n
|
||||
if r.offset >= len(chunk) {
|
||||
r.index++
|
||||
r.offset = 0
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (*scriptedReadCloser) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestRecordingTransport_NoSink(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
gotMethod := make(chan string, 1)
|
||||
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
gotMethod <- req.Method
|
||||
_, _ = rw.Write([]byte("ok"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := &http.Client{
|
||||
Transport: &RecordingTransport{Base: server.Client().Transport},
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, server.URL, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
require.Equal(t, "ok", string(body))
|
||||
require.Equal(t, http.MethodGet, <-gotMethod)
|
||||
}
|
||||
|
||||
func TestRecordingTransport_CaptureRequest(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const requestBody = `{"message":"hello","api_key":"super-secret"}`
|
||||
|
||||
type receivedRequest struct {
|
||||
authorization string
|
||||
body []byte
|
||||
}
|
||||
gotRequest := make(chan receivedRequest, 1)
|
||||
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
body, err := io.ReadAll(req.Body)
|
||||
require.NoError(t, err)
|
||||
gotRequest <- receivedRequest{
|
||||
authorization: req.Header.Get("Authorization"),
|
||||
body: body,
|
||||
}
|
||||
_, _ = rw.Write([]byte(`{"ok":true}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
ctx, sink := newTestSinkContext(t)
|
||||
client := &http.Client{
|
||||
Transport: &RecordingTransport{Base: server.Client().Transport},
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(
|
||||
ctx,
|
||||
http.MethodPost,
|
||||
server.URL,
|
||||
strings.NewReader(requestBody),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
req.Header.Set("Authorization", "Bearer top-secret")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
_, err = io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, resp.Body.Close())
|
||||
|
||||
attempts := sink.snapshot()
|
||||
require.Len(t, attempts, 1)
|
||||
require.Equal(t, 1, attempts[0].Number)
|
||||
require.Equal(t, RedactedValue, attempts[0].RequestHeaders["Authorization"])
|
||||
require.Equal(t, "application/json", attempts[0].RequestHeaders["Content-Type"])
|
||||
require.JSONEq(t, `{"message":"hello","api_key":"[REDACTED]"}`, string(attempts[0].RequestBody))
|
||||
|
||||
received := <-gotRequest
|
||||
require.JSONEq(t, requestBody, string(received.body))
|
||||
require.Equal(t, "Bearer top-secret", received.authorization)
|
||||
}
|
||||
|
||||
func TestRecordingTransport_RedactsSensitiveQueryParameters(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
_, _ = rw.Write([]byte(`ok`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
ctx, sink := newTestSinkContext(t)
|
||||
client := &http.Client{Transport: &RecordingTransport{Base: server.Client().Transport}}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, server.URL+`?api_key=secret&safe=ok`, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
_, err = io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, resp.Body.Close())
|
||||
|
||||
attempts := sink.snapshot()
|
||||
require.Len(t, attempts, 1)
|
||||
require.Contains(t, attempts[0].URL, "api_key=%5BREDACTED%5D")
|
||||
require.Contains(t, attempts[0].URL, "safe=ok")
|
||||
}
|
||||
|
||||
func TestRecordingTransport_TruncatesLargeRequestBodies(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
_, _ = io.Copy(io.Discard, req.Body)
|
||||
_, _ = rw.Write([]byte(`ok`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
ctx, sink := newTestSinkContext(t)
|
||||
client := &http.Client{Transport: &RecordingTransport{Base: server.Client().Transport}}
|
||||
|
||||
large := strings.Repeat("x", maxRecordedRequestBodyBytes+1024)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, server.URL, strings.NewReader(large))
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
_, err = io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, resp.Body.Close())
|
||||
|
||||
attempts := sink.snapshot()
|
||||
require.Len(t, attempts, 1)
|
||||
require.Equal(t, []byte("[TRUNCATED]"), attempts[0].RequestBody)
|
||||
}
|
||||
|
||||
func TestRecordingTransport_StripsURLUserinfo(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
_, _ = rw.Write([]byte(`ok`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
ctx, sink := newTestSinkContext(t)
|
||||
client := &http.Client{Transport: &RecordingTransport{Base: server.Client().Transport}}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, strings.Replace(server.URL, "http://", "http://user:secret@", 1)+`?api_key=secret`, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
_, err = io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, resp.Body.Close())
|
||||
|
||||
attempts := sink.snapshot()
|
||||
require.Len(t, attempts, 1)
|
||||
require.NotContains(t, attempts[0].URL, "user:secret")
|
||||
require.Contains(t, attempts[0].URL, "api_key=%5BREDACTED%5D")
|
||||
}
|
||||
|
||||
func TestRecordingTransport_SkipsNonReplayableRequestBodyCapture(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const requestBody = `{"message":"hello"}`
|
||||
gotRequest := make(chan []byte, 1)
|
||||
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
body, err := io.ReadAll(req.Body)
|
||||
require.NoError(t, err)
|
||||
gotRequest <- body
|
||||
_, _ = rw.Write([]byte(`ok`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
ctx, sink := newTestSinkContext(t)
|
||||
client := &http.Client{Transport: &RecordingTransport{Base: server.Client().Transport}}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, server.URL, io.NopCloser(strings.NewReader(requestBody)))
|
||||
require.NoError(t, err)
|
||||
req.GetBody = nil
|
||||
|
||||
resp, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
_, err = io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, resp.Body.Close())
|
||||
|
||||
require.JSONEq(t, requestBody, string(<-gotRequest))
|
||||
attempts := sink.snapshot()
|
||||
require.Len(t, attempts, 1)
|
||||
require.Nil(t, attempts[0].RequestBody)
|
||||
}
|
||||
|
||||
func TestRecordingTransport_CaptureResponse(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
rw.Header().Set("X-API-Key", "response-secret")
|
||||
rw.Header().Set("X-Trace-ID", "trace-123")
|
||||
rw.WriteHeader(http.StatusCreated)
|
||||
_, _ = rw.Write([]byte(`{"token":"response-secret","safe":"ok"}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
ctx, sink := newTestSinkContext(t)
|
||||
client := &http.Client{
|
||||
Transport: &RecordingTransport{Base: server.Client().Transport},
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, server.URL, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, resp.Body.Close())
|
||||
require.JSONEq(t, `{"token":"response-secret","safe":"ok"}`, string(body))
|
||||
|
||||
attempts := sink.snapshot()
|
||||
require.Len(t, attempts, 1)
|
||||
require.Equal(t, http.StatusCreated, attempts[0].ResponseStatus)
|
||||
require.Equal(t, RedactedValue, attempts[0].ResponseHeaders["X-Api-Key"])
|
||||
require.Equal(t, "trace-123", attempts[0].ResponseHeaders["X-Trace-Id"])
|
||||
require.JSONEq(t, `{"token":"[REDACTED]","safe":"ok"}`, string(attempts[0].ResponseBody))
|
||||
}
|
||||
|
||||
func TestRecordingTransport_CaptureResponseOnEOFWithoutClose(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
rw.Header().Set("Content-Type", "application/json")
|
||||
rw.Header().Set("X-API-Key", "response-secret")
|
||||
rw.WriteHeader(http.StatusAccepted)
|
||||
_, _ = rw.Write([]byte(`{"token":"response-secret","safe":"ok"}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
ctx, sink := newTestSinkContext(t)
|
||||
client := &http.Client{
|
||||
Transport: &RecordingTransport{Base: server.Client().Transport},
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, server.URL, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
require.JSONEq(t, `{"token":"response-secret","safe":"ok"}`, string(body))
|
||||
|
||||
attempts := sink.snapshot()
|
||||
require.Len(t, attempts, 1)
|
||||
require.Equal(t, http.StatusAccepted, attempts[0].ResponseStatus)
|
||||
require.Equal(t, "application/json", attempts[0].ResponseHeaders["Content-Type"])
|
||||
require.Equal(t, RedactedValue, attempts[0].ResponseHeaders["X-Api-Key"])
|
||||
require.JSONEq(t, `{"token":"[REDACTED]","safe":"ok"}`, string(attempts[0].ResponseBody))
|
||||
require.NoError(t, resp.Body.Close())
|
||||
}
|
||||
|
||||
func TestRecordingTransport_StreamingBody(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
flusher, ok := rw.(http.Flusher)
|
||||
require.True(t, ok)
|
||||
|
||||
rw.Header().Set("Content-Type", "application/json")
|
||||
_, _ = rw.Write([]byte(`{"safe":"stream",`))
|
||||
flusher.Flush()
|
||||
_, _ = rw.Write([]byte(`"token":"chunk-secret"}`))
|
||||
flusher.Flush()
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
ctx, sink := newTestSinkContext(t)
|
||||
client := &http.Client{
|
||||
Transport: &RecordingTransport{Base: server.Client().Transport},
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, server.URL, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
buf := make([]byte, 5)
|
||||
var body strings.Builder
|
||||
for {
|
||||
n, readErr := resp.Body.Read(buf)
|
||||
if n > 0 {
|
||||
_, writeErr := body.Write(buf[:n])
|
||||
require.NoError(t, writeErr)
|
||||
}
|
||||
if errors.Is(readErr, io.EOF) {
|
||||
break
|
||||
}
|
||||
require.NoError(t, readErr)
|
||||
}
|
||||
require.NoError(t, resp.Body.Close())
|
||||
require.JSONEq(t, `{"safe":"stream","token":"chunk-secret"}`, body.String())
|
||||
|
||||
attempts := sink.snapshot()
|
||||
require.Len(t, attempts, 1)
|
||||
require.JSONEq(t, `{"safe":"stream","token":"[REDACTED]"}`, string(attempts[0].ResponseBody))
|
||||
}
|
||||
|
||||
func TestRecordingTransport_CloseAfterDecoderConsumesContentLengthSucceeds(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
rw.Header().Set("Content-Type", "application/json")
|
||||
_, _ = rw.Write([]byte(`{"token":"response-secret","safe":"ok"}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
ctx, sink := newTestSinkContext(t)
|
||||
client := &http.Client{Transport: &RecordingTransport{Base: server.Client().Transport}}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, server.URL, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
var decoded map[string]string
|
||||
require.NoError(t, json.NewDecoder(resp.Body).Decode(&decoded))
|
||||
require.Equal(t, "ok", decoded["safe"])
|
||||
require.NoError(t, resp.Body.Close())
|
||||
|
||||
attempts := sink.snapshot()
|
||||
require.Len(t, attempts, 1)
|
||||
require.Equal(t, attemptStatusCompleted, attempts[0].Status)
|
||||
require.Empty(t, attempts[0].Error)
|
||||
}
|
||||
|
||||
func TestRecordingTransport_CloseAfterDecoderConsumesUnknownLengthJSONSucceeds(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, sink := newTestSinkContext(t)
|
||||
client := &http.Client{
|
||||
Transport: &RecordingTransport{
|
||||
Base: roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||
return &http.Response{ //nolint:exhaustruct // Test response exercises unknown-length close semantics.
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: &scriptedReadCloser{chunks: [][]byte{[]byte(`{"token":"response-secret","safe":"ok"}`)}},
|
||||
ContentLength: -1,
|
||||
}, nil
|
||||
}),
|
||||
},
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://example.invalid", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
var decoded map[string]string
|
||||
require.NoError(t, json.NewDecoder(resp.Body).Decode(&decoded))
|
||||
require.Equal(t, "ok", decoded["safe"])
|
||||
require.NoError(t, resp.Body.Close())
|
||||
|
||||
attempts := sink.snapshot()
|
||||
require.Len(t, attempts, 1)
|
||||
require.Equal(t, attemptStatusCompleted, attempts[0].Status)
|
||||
require.Empty(t, attempts[0].Error)
|
||||
}
|
||||
|
||||
func TestRecordingTransport_CloseAfterDecoderConsumesUnknownLengthJSONWithTrailingDocumentMarksFailed(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, sink := newTestSinkContext(t)
|
||||
client := &http.Client{
|
||||
Transport: &RecordingTransport{
|
||||
Base: roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||
return &http.Response{ //nolint:exhaustruct // Test response exercises unknown-length close semantics.
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: &scriptedReadCloser{chunks: [][]byte{[]byte("{\"token\":\"response-secret\",\"safe\":\"ok\"}{\"token\":\"second\"}")}},
|
||||
ContentLength: -1,
|
||||
}, nil
|
||||
}),
|
||||
},
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://example.invalid", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
var decoded map[string]string
|
||||
require.NoError(t, json.NewDecoder(resp.Body).Decode(&decoded))
|
||||
require.Equal(t, "ok", decoded["safe"])
|
||||
require.NoError(t, resp.Body.Close())
|
||||
|
||||
attempts := sink.snapshot()
|
||||
require.Len(t, attempts, 1)
|
||||
require.Equal(t, attemptStatusFailed, attempts[0].Status)
|
||||
require.Equal(t, io.ErrUnexpectedEOF.Error(), attempts[0].Error)
|
||||
}
|
||||
|
||||
func TestRecordingTransport_CloseAfterDecoderConsumesUnknownLengthNDJSONMarksFailed(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, sink := newTestSinkContext(t)
|
||||
client := &http.Client{
|
||||
Transport: &RecordingTransport{
|
||||
Base: roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||
return &http.Response{ //nolint:exhaustruct // Test response exercises unknown-length close semantics.
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"application/x-ndjson"}},
|
||||
Body: &scriptedReadCloser{chunks: [][]byte{[]byte("{\"token\":\"response-secret\",\"safe\":\"ok\"}\n{\"token\":\"second\"}\n")}},
|
||||
ContentLength: -1,
|
||||
}, nil
|
||||
}),
|
||||
},
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://example.invalid", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
var decoded map[string]string
|
||||
require.NoError(t, json.NewDecoder(resp.Body).Decode(&decoded))
|
||||
require.Equal(t, "ok", decoded["safe"])
|
||||
require.NoError(t, resp.Body.Close())
|
||||
|
||||
attempts := sink.snapshot()
|
||||
require.Len(t, attempts, 1)
|
||||
require.Equal(t, attemptStatusFailed, attempts[0].Status)
|
||||
require.Equal(t, io.ErrUnexpectedEOF.Error(), attempts[0].Error)
|
||||
}
|
||||
|
||||
func TestRecordingTransport_CloseAfterDecoderDrainsUnknownLengthSucceeds(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, sink := newTestSinkContext(t)
|
||||
client := &http.Client{
|
||||
Transport: &RecordingTransport{
|
||||
Base: roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||
return &http.Response{ //nolint:exhaustruct // Test response exercises unknown-length close semantics.
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: &scriptedReadCloser{chunks: [][]byte{[]byte(`{"token":"response-secret","safe":"ok"}`)}},
|
||||
ContentLength: -1,
|
||||
}, nil
|
||||
}),
|
||||
},
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://example.invalid", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
var decoded map[string]string
|
||||
require.NoError(t, json.NewDecoder(resp.Body).Decode(&decoded))
|
||||
require.Equal(t, "ok", decoded["safe"])
|
||||
_, err = io.Copy(io.Discard, resp.Body)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, resp.Body.Close())
|
||||
|
||||
attempts := sink.snapshot()
|
||||
require.Len(t, attempts, 1)
|
||||
require.Equal(t, attemptStatusCompleted, attempts[0].Status)
|
||||
require.Empty(t, attempts[0].Error)
|
||||
}
|
||||
|
||||
func TestRecordingTransport_CloseWithoutReadingHeadResponseSucceeds(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, sink := newTestSinkContext(t)
|
||||
client := &http.Client{
|
||||
Transport: &RecordingTransport{
|
||||
Base: roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||
return &http.Response{ //nolint:exhaustruct // Test response exercises no-body close semantics.
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: &scriptedReadCloser{chunks: [][]byte{[]byte(`{"ignored":true}`)}},
|
||||
ContentLength: 13,
|
||||
Request: req,
|
||||
}, nil
|
||||
}),
|
||||
},
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodHead, "http://example.invalid", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, resp.Body.Close())
|
||||
|
||||
attempts := sink.snapshot()
|
||||
require.Len(t, attempts, 1)
|
||||
require.Equal(t, attemptStatusCompleted, attempts[0].Status)
|
||||
require.Empty(t, attempts[0].Error)
|
||||
}
|
||||
|
||||
func TestRecordingTransport_CloseWithoutReadingUnknownLengthMarksFailed(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, sink := newTestSinkContext(t)
|
||||
client := &http.Client{
|
||||
Transport: &RecordingTransport{
|
||||
Base: roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||
return &http.Response{ //nolint:exhaustruct // Test response exercises unknown-length close semantics.
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: &scriptedReadCloser{chunks: [][]byte{[]byte(`{"token":"response-secret","safe":"ok"}`)}},
|
||||
ContentLength: -1,
|
||||
}, nil
|
||||
}),
|
||||
},
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://example.invalid", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, resp.Body.Close())
|
||||
|
||||
attempts := sink.snapshot()
|
||||
require.Len(t, attempts, 1)
|
||||
require.Equal(t, attemptStatusFailed, attempts[0].Status)
|
||||
require.Equal(t, io.ErrUnexpectedEOF.Error(), attempts[0].Error)
|
||||
}
|
||||
|
||||
func TestRecordingTransport_PrematureCloseUnknownLengthMarksFailed(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, sink := newTestSinkContext(t)
|
||||
client := &http.Client{
|
||||
Transport: &RecordingTransport{
|
||||
Base: roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||
return &http.Response{ //nolint:exhaustruct // Test response exercises unknown-length close semantics.
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: &scriptedReadCloser{chunks: [][]byte{[]byte(`{"token":"response-secret","safe":"ok"}`)}},
|
||||
ContentLength: -1,
|
||||
}, nil
|
||||
}),
|
||||
},
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://example.invalid", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
buf := make([]byte, 5)
|
||||
_, err = resp.Body.Read(buf)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, resp.Body.Close())
|
||||
|
||||
attempts := sink.snapshot()
|
||||
require.Len(t, attempts, 1)
|
||||
require.Equal(t, attemptStatusFailed, attempts[0].Status)
|
||||
require.Equal(t, io.ErrUnexpectedEOF.Error(), attempts[0].Error)
|
||||
}
|
||||
|
||||
func TestRecordingTransport_PrematureCloseMarksFailed(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
_, _ = rw.Write([]byte(`{"token":"response-secret","safe":"ok"}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
ctx, sink := newTestSinkContext(t)
|
||||
client := &http.Client{Transport: &RecordingTransport{Base: server.Client().Transport}}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, server.URL, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
buf := make([]byte, 5)
|
||||
_, err = resp.Body.Read(buf)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, resp.Body.Close())
|
||||
|
||||
attempts := sink.snapshot()
|
||||
require.Len(t, attempts, 1)
|
||||
require.Equal(t, attemptStatusFailed, attempts[0].Status)
|
||||
}
|
||||
|
||||
func TestRecordingTransport_TruncatesLargeResponses(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
_, _ = rw.Write([]byte(strings.Repeat("x", maxRecordedResponseBodyBytes+1024)))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
ctx, sink := newTestSinkContext(t)
|
||||
client := &http.Client{Transport: &RecordingTransport{Base: server.Client().Transport}}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, server.URL, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
_, err = io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, resp.Body.Close())
|
||||
|
||||
attempts := sink.snapshot()
|
||||
require.Len(t, attempts, 1)
|
||||
require.Equal(t, []byte("[TRUNCATED]"), attempts[0].ResponseBody)
|
||||
}
|
||||
|
||||
func TestRecordingTransport_TransportError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, sink := newTestSinkContext(t)
|
||||
client := &http.Client{
|
||||
Transport: &RecordingTransport{
|
||||
Base: roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||
return nil, xerrors.New("transport exploded")
|
||||
}),
|
||||
},
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(
|
||||
ctx,
|
||||
http.MethodPost,
|
||||
"http://example.invalid",
|
||||
strings.NewReader(`{"password":"secret","safe":"ok"}`),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
req.Header.Set("Authorization", "Bearer top-secret")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if resp != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
require.Nil(t, resp)
|
||||
require.EqualError(t, err, "Post \"http://example.invalid\": transport exploded")
|
||||
|
||||
attempts := sink.snapshot()
|
||||
require.Len(t, attempts, 1)
|
||||
require.Equal(t, 1, attempts[0].Number)
|
||||
require.Equal(t, RedactedValue, attempts[0].RequestHeaders["Authorization"])
|
||||
require.JSONEq(t, `{"password":"[REDACTED]","safe":"ok"}`, string(attempts[0].RequestBody))
|
||||
require.Zero(t, attempts[0].ResponseStatus)
|
||||
require.Equal(t, "transport exploded", attempts[0].Error)
|
||||
require.GreaterOrEqual(t, attempts[0].DurationMs, int64(0))
|
||||
}
|
||||
|
||||
func TestRecordingTransport_NilBase(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
_, _ = rw.Write([]byte("ok"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := &http.Client{Transport: &RecordingTransport{}}
|
||||
|
||||
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, server.URL, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "ok", string(body))
|
||||
}
|
||||
@@ -1,146 +0,0 @@
|
||||
package chatdebug
|
||||
|
||||
import "github.com/google/uuid"
|
||||
|
||||
// RunKind identifies the kind of debug run being recorded.
|
||||
type RunKind string
|
||||
|
||||
const (
|
||||
// KindChatTurn records a standard chat turn.
|
||||
KindChatTurn RunKind = "chat_turn"
|
||||
// KindTitleGeneration records title generation for a chat.
|
||||
KindTitleGeneration RunKind = "title_generation"
|
||||
// KindQuickgen records quick-generation workflows.
|
||||
KindQuickgen RunKind = "quickgen"
|
||||
// KindCompaction records history compaction workflows.
|
||||
KindCompaction RunKind = "compaction"
|
||||
)
|
||||
|
||||
// AllRunKinds contains every RunKind value. Update this when
|
||||
// adding new constants above.
|
||||
var AllRunKinds = []RunKind{
|
||||
KindChatTurn,
|
||||
KindTitleGeneration,
|
||||
KindQuickgen,
|
||||
KindCompaction,
|
||||
}
|
||||
|
||||
// Status identifies lifecycle state shared by runs and steps.
|
||||
type Status string
|
||||
|
||||
const (
|
||||
// StatusInProgress indicates work is still running.
|
||||
StatusInProgress Status = "in_progress"
|
||||
// StatusCompleted indicates work finished successfully.
|
||||
StatusCompleted Status = "completed"
|
||||
// StatusError indicates work finished with an error.
|
||||
StatusError Status = "error"
|
||||
// StatusInterrupted indicates work was canceled or interrupted.
|
||||
StatusInterrupted Status = "interrupted"
|
||||
)
|
||||
|
||||
// AllStatuses contains every Status value. Update this when
|
||||
// adding new constants above.
|
||||
var AllStatuses = []Status{
|
||||
StatusInProgress,
|
||||
StatusCompleted,
|
||||
StatusError,
|
||||
StatusInterrupted,
|
||||
}
|
||||
|
||||
// Operation identifies the model operation a step performed.
|
||||
type Operation string
|
||||
|
||||
const (
|
||||
// OperationStream records a streaming model operation.
|
||||
OperationStream Operation = "stream"
|
||||
// OperationGenerate records a non-streaming generation operation.
|
||||
OperationGenerate Operation = "generate"
|
||||
)
|
||||
|
||||
// AllOperations contains every Operation value. Update this when
|
||||
// adding new constants above.
|
||||
var AllOperations = []Operation{
|
||||
OperationStream,
|
||||
OperationGenerate,
|
||||
}
|
||||
|
||||
// RunContext carries identity and metadata for a debug run.
|
||||
type RunContext struct {
|
||||
RunID uuid.UUID
|
||||
ChatID uuid.UUID
|
||||
RootChatID uuid.UUID // Zero means not set.
|
||||
ParentChatID uuid.UUID // Zero means not set.
|
||||
ModelConfigID uuid.UUID // Zero means not set.
|
||||
TriggerMessageID int64 // Zero means not set.
|
||||
HistoryTipMessageID int64 // Zero means not set.
|
||||
Kind RunKind
|
||||
Provider string
|
||||
Model string
|
||||
}
|
||||
|
||||
// StepContext carries identity and metadata for a debug step.
|
||||
type StepContext struct {
|
||||
StepID uuid.UUID
|
||||
RunID uuid.UUID
|
||||
ChatID uuid.UUID
|
||||
StepNumber int32
|
||||
Operation Operation
|
||||
HistoryTipMessageID int64 // Zero means not set.
|
||||
}
|
||||
|
||||
// Attempt captures a single HTTP round trip made during a step.
|
||||
type Attempt struct {
|
||||
Number int `json:"number"`
|
||||
Status string `json:"status,omitempty"`
|
||||
Method string `json:"method,omitempty"`
|
||||
URL string `json:"url,omitempty"`
|
||||
Path string `json:"path,omitempty"`
|
||||
StartedAt string `json:"started_at,omitempty"`
|
||||
FinishedAt string `json:"finished_at,omitempty"`
|
||||
RequestHeaders map[string]string `json:"request_headers,omitempty"`
|
||||
RequestBody []byte `json:"request_body,omitempty"`
|
||||
ResponseStatus int `json:"response_status,omitempty"`
|
||||
ResponseHeaders map[string]string `json:"response_headers,omitempty"`
|
||||
ResponseBody []byte `json:"response_body,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
DurationMs int64 `json:"duration_ms"`
|
||||
RetryClassification string `json:"retry_classification,omitempty"`
|
||||
RetryDelayMs int64 `json:"retry_delay_ms,omitempty"`
|
||||
}
|
||||
|
||||
// EventKind identifies the type of pubsub debug event.
|
||||
type EventKind string
|
||||
|
||||
const (
|
||||
// EventKindRunUpdate publishes a run mutation.
|
||||
EventKindRunUpdate EventKind = "run_update"
|
||||
// EventKindStepUpdate publishes a step mutation.
|
||||
EventKindStepUpdate EventKind = "step_update"
|
||||
// EventKindFinalize publishes a finalization signal.
|
||||
EventKindFinalize EventKind = "finalize"
|
||||
// EventKindDelete publishes a deletion signal.
|
||||
EventKindDelete EventKind = "delete"
|
||||
)
|
||||
|
||||
// DebugEvent is the lightweight pubsub envelope for chat debug updates.
|
||||
type DebugEvent struct {
|
||||
Kind EventKind `json:"kind"`
|
||||
ChatID uuid.UUID `json:"chat_id"`
|
||||
RunID uuid.UUID `json:"run_id"`
|
||||
StepID uuid.UUID `json:"step_id"`
|
||||
}
|
||||
|
||||
// BroadcastPubsubChannel is the shared pubsub channel for chat-debug events
|
||||
// that are not scoped to a single chat, such as stale finalization sweeps.
|
||||
const BroadcastPubsubChannel = "chat_debug:broadcast"
|
||||
|
||||
// PubsubChannel returns the chat-scoped pubsub channel for debug events.
|
||||
// Nil chat IDs use the shared broadcast channel so publishers and subscribers
|
||||
// can coordinate through one discoverable helper.
|
||||
func PubsubChannel(chatID uuid.UUID) string {
|
||||
if chatID == uuid.Nil {
|
||||
return BroadcastPubsubChannel
|
||||
}
|
||||
return "chat_debug:" + chatID.String()
|
||||
}
|
||||
@@ -1,54 +0,0 @@
|
||||
package chatdebug_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatdebug"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
// toStrings converts a typed string slice to []string for comparison.
|
||||
func toStrings[T ~string](values []T) []string {
|
||||
out := make([]string, len(values))
|
||||
for i, v := range values {
|
||||
out[i] = string(v)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// TestTypesMatchSDK verifies that every chatdebug constant has a
|
||||
// corresponding codersdk constant with the same string value.
|
||||
// If this test fails you probably added a constant to one package
|
||||
// but forgot to update the other.
|
||||
func TestTypesMatchSDK(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("RunKind", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
require.ElementsMatch(t,
|
||||
toStrings(chatdebug.AllRunKinds),
|
||||
toStrings(codersdk.AllChatDebugRunKinds),
|
||||
"chatdebug.AllRunKinds and codersdk.AllChatDebugRunKinds have diverged",
|
||||
)
|
||||
})
|
||||
|
||||
t.Run("Status", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
require.ElementsMatch(t,
|
||||
toStrings(chatdebug.AllStatuses),
|
||||
toStrings(codersdk.AllChatDebugStatuses),
|
||||
"chatdebug.AllStatuses and codersdk.AllChatDebugStatuses have diverged",
|
||||
)
|
||||
})
|
||||
|
||||
t.Run("Operation", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
require.ElementsMatch(t,
|
||||
toStrings(chatdebug.AllOperations),
|
||||
toStrings(codersdk.AllChatDebugStepOperations),
|
||||
"chatdebug.AllOperations and codersdk.AllChatDebugStepOperations have diverged",
|
||||
)
|
||||
})
|
||||
}
|
||||
@@ -20,7 +20,6 @@ import (
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database/dbtime"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatdebug"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chaterror"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatprompt"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatretry"
|
||||
@@ -369,8 +368,7 @@ func Run(ctx context.Context, opts RunOptions) error {
|
||||
}
|
||||
|
||||
var result stepResult
|
||||
stepCtx := chatdebug.ReuseStep(ctx)
|
||||
err := chatretry.Retry(stepCtx, func(retryCtx context.Context) error {
|
||||
err := chatretry.Retry(ctx, func(retryCtx context.Context) error {
|
||||
attempt, streamErr := guardedStream(
|
||||
retryCtx,
|
||||
opts.Model.Provider(),
|
||||
|
||||
@@ -18,7 +18,6 @@ import (
|
||||
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chaterror"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatretry"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chattest"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
"github.com/coder/quartz"
|
||||
@@ -42,9 +41,9 @@ func TestRun_ActiveToolsPrepareBehavior(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var capturedCall fantasy.Call
|
||||
model := &chattest.FakeModel{
|
||||
ProviderName: fantasyanthropic.Name,
|
||||
StreamFn: func(_ context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
model := &loopTestModel{
|
||||
provider: fantasyanthropic.Name,
|
||||
streamFn: func(_ context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
capturedCall = call
|
||||
return streamFromParts([]fantasy.StreamPart{
|
||||
{Type: fantasy.StreamPartTypeTextStart, ID: "text-1"},
|
||||
@@ -104,9 +103,9 @@ func TestRun_ActiveToolsPrepareBehavior(t *testing.T) {
|
||||
func TestProcessStepStream_AnthropicUsageMatchesFinalDelta(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
model := &chattest.FakeModel{
|
||||
ProviderName: fantasyanthropic.Name,
|
||||
StreamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
model := &loopTestModel{
|
||||
provider: fantasyanthropic.Name,
|
||||
streamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
return streamFromParts([]fantasy.StreamPart{
|
||||
{Type: fantasy.StreamPartTypeTextStart, ID: "text-1"},
|
||||
{Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "cached response"},
|
||||
@@ -161,9 +160,9 @@ func TestRun_OnRetryEnrichesProvider(t *testing.T) {
|
||||
|
||||
var records []retryRecord
|
||||
calls := 0
|
||||
model := &chattest.FakeModel{
|
||||
ProviderName: "openai",
|
||||
StreamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
model := &loopTestModel{
|
||||
provider: "openai",
|
||||
streamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
calls++
|
||||
if calls == 1 {
|
||||
return nil, xerrors.New("received status 429 from upstream")
|
||||
@@ -287,9 +286,9 @@ func TestRun_RetriesStartupTimeoutWhileOpeningStream(t *testing.T) {
|
||||
attempts := 0
|
||||
attemptCause := make(chan error, 1)
|
||||
var retries []chatretry.ClassifiedError
|
||||
model := &chattest.FakeModel{
|
||||
ProviderName: "openai",
|
||||
StreamFn: func(ctx context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
model := &loopTestModel{
|
||||
provider: "openai",
|
||||
streamFn: func(ctx context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
attempts++
|
||||
if attempts == 1 {
|
||||
<-ctx.Done()
|
||||
@@ -365,9 +364,9 @@ func TestRun_RetriesStartupTimeoutBeforeFirstPart(t *testing.T) {
|
||||
attempts := 0
|
||||
attemptCause := make(chan error, 1)
|
||||
var retries []chatretry.ClassifiedError
|
||||
model := &chattest.FakeModel{
|
||||
ProviderName: "openai",
|
||||
StreamFn: func(ctx context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
model := &loopTestModel{
|
||||
provider: "openai",
|
||||
streamFn: func(ctx context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
attempts++
|
||||
if attempts == 1 {
|
||||
return iter.Seq[fantasy.StreamPart](func(yield func(fantasy.StreamPart) bool) {
|
||||
@@ -448,9 +447,9 @@ func TestRun_FirstPartDisarmsStartupTimeout(t *testing.T) {
|
||||
retried := false
|
||||
firstPartYielded := make(chan struct{}, 1)
|
||||
continueStream := make(chan struct{})
|
||||
model := &chattest.FakeModel{
|
||||
ProviderName: "openai",
|
||||
StreamFn: func(ctx context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
model := &loopTestModel{
|
||||
provider: "openai",
|
||||
streamFn: func(ctx context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
attempts++
|
||||
return iter.Seq[fantasy.StreamPart](func(yield func(fantasy.StreamPart) bool) {
|
||||
if !yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeTextStart, ID: "text-1"}) {
|
||||
@@ -527,9 +526,9 @@ func TestRun_PanicInPublishMessagePartReleasesAttempt(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
attemptReleased := make(chan struct{})
|
||||
model := &chattest.FakeModel{
|
||||
ProviderName: "openai",
|
||||
StreamFn: func(ctx context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
model := &loopTestModel{
|
||||
provider: "openai",
|
||||
streamFn: func(ctx context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
close(attemptReleased)
|
||||
@@ -584,9 +583,9 @@ func TestRun_RetriesStartupTimeoutWhenStreamClosesSilently(t *testing.T) {
|
||||
attempts := 0
|
||||
attemptCause := make(chan error, 1)
|
||||
var retries []chatretry.ClassifiedError
|
||||
model := &chattest.FakeModel{
|
||||
ProviderName: "openai",
|
||||
StreamFn: func(ctx context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
model := &loopTestModel{
|
||||
provider: "openai",
|
||||
streamFn: func(ctx context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
attempts++
|
||||
if attempts == 1 {
|
||||
return iter.Seq[fantasy.StreamPart](func(yield func(fantasy.StreamPart) bool) {
|
||||
@@ -649,9 +648,9 @@ func TestRun_InterruptedStepPersistsSyntheticToolResult(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
started := make(chan struct{})
|
||||
model := &chattest.FakeModel{
|
||||
ProviderName: "fake",
|
||||
StreamFn: func(ctx context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
model := &loopTestModel{
|
||||
provider: "fake",
|
||||
streamFn: func(ctx context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
return iter.Seq[fantasy.StreamPart](func(yield func(fantasy.StreamPart) bool) {
|
||||
parts := []fantasy.StreamPart{
|
||||
{
|
||||
@@ -763,6 +762,52 @@ func TestRun_InterruptedStepPersistsSyntheticToolResult(t *testing.T) {
|
||||
"interrupted tool should have no call timestamp (never reached StreamPartTypeToolCall)")
|
||||
}
|
||||
|
||||
type loopTestModel struct {
|
||||
provider string
|
||||
model string
|
||||
generateFn func(context.Context, fantasy.Call) (*fantasy.Response, error)
|
||||
streamFn func(context.Context, fantasy.Call) (fantasy.StreamResponse, error)
|
||||
}
|
||||
|
||||
func (m *loopTestModel) Provider() string {
|
||||
if m.provider != "" {
|
||||
return m.provider
|
||||
}
|
||||
return "fake"
|
||||
}
|
||||
|
||||
func (m *loopTestModel) Model() string {
|
||||
if m.model != "" {
|
||||
return m.model
|
||||
}
|
||||
return "fake"
|
||||
}
|
||||
|
||||
func (m *loopTestModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) {
|
||||
if m.generateFn != nil {
|
||||
return m.generateFn(ctx, call)
|
||||
}
|
||||
return &fantasy.Response{}, nil
|
||||
}
|
||||
|
||||
func (m *loopTestModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
if m.streamFn != nil {
|
||||
return m.streamFn(ctx, call)
|
||||
}
|
||||
return streamFromParts([]fantasy.StreamPart{{
|
||||
Type: fantasy.StreamPartTypeFinish,
|
||||
FinishReason: fantasy.FinishReasonStop,
|
||||
}}), nil
|
||||
}
|
||||
|
||||
func (*loopTestModel) GenerateObject(context.Context, fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
|
||||
return nil, xerrors.New("not implemented")
|
||||
}
|
||||
|
||||
func (*loopTestModel) StreamObject(context.Context, fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) {
|
||||
return nil, xerrors.New("not implemented")
|
||||
}
|
||||
|
||||
func streamFromParts(parts []fantasy.StreamPart) fantasy.StreamResponse {
|
||||
return iter.Seq[fantasy.StreamPart](func(yield func(fantasy.StreamPart) bool) {
|
||||
for _, part := range parts {
|
||||
@@ -815,9 +860,9 @@ func TestRun_MultiStepToolExecution(t *testing.T) {
|
||||
var streamCalls int
|
||||
var secondCallPrompt []fantasy.Message
|
||||
|
||||
model := &chattest.FakeModel{
|
||||
ProviderName: "fake",
|
||||
StreamFn: func(_ context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
model := &loopTestModel{
|
||||
provider: "fake",
|
||||
streamFn: func(_ context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
mu.Lock()
|
||||
step := streamCalls
|
||||
streamCalls++
|
||||
@@ -927,9 +972,9 @@ func TestRun_ParallelToolExecutionTimestamps(t *testing.T) {
|
||||
var mu sync.Mutex
|
||||
var streamCalls int
|
||||
|
||||
model := &chattest.FakeModel{
|
||||
ProviderName: "fake",
|
||||
StreamFn: func(_ context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
model := &loopTestModel{
|
||||
provider: "fake",
|
||||
streamFn: func(_ context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
mu.Lock()
|
||||
step := streamCalls
|
||||
streamCalls++
|
||||
@@ -1019,9 +1064,9 @@ func TestRun_ParallelToolExecutionTimestamps(t *testing.T) {
|
||||
func TestRun_PersistStepErrorPropagates(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
model := &chattest.FakeModel{
|
||||
ProviderName: "fake",
|
||||
StreamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
model := &loopTestModel{
|
||||
provider: "fake",
|
||||
streamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
return streamFromParts([]fantasy.StreamPart{
|
||||
{Type: fantasy.StreamPartTypeTextStart, ID: "text-1"},
|
||||
{Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "hello"},
|
||||
@@ -1058,9 +1103,9 @@ func TestRun_ShutdownDuringToolExecutionReturnsContextCanceled(t *testing.T) {
|
||||
toolStarted := make(chan struct{})
|
||||
|
||||
// Model returns a single tool call, then finishes.
|
||||
model := &chattest.FakeModel{
|
||||
ProviderName: "fake",
|
||||
StreamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
model := &loopTestModel{
|
||||
provider: "fake",
|
||||
streamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
return streamFromParts([]fantasy.StreamPart{
|
||||
{Type: fantasy.StreamPartTypeToolInputStart, ID: "tc-block", ToolCallName: "blocking_tool"},
|
||||
{Type: fantasy.StreamPartTypeToolInputDelta, ID: "tc-block", Delta: `{}`},
|
||||
@@ -1316,9 +1361,9 @@ func TestRun_InterruptedDuringToolExecutionPersistsStep(t *testing.T) {
|
||||
toolStarted := make(chan struct{})
|
||||
|
||||
// Model returns a completed tool call in the stream.
|
||||
model := &chattest.FakeModel{
|
||||
ProviderName: "fake",
|
||||
StreamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
model := &loopTestModel{
|
||||
provider: "fake",
|
||||
streamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
return streamFromParts([]fantasy.StreamPart{
|
||||
{Type: fantasy.StreamPartTypeTextStart, ID: "text-1"},
|
||||
{Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "calling tool"},
|
||||
@@ -1426,9 +1471,9 @@ func TestRun_InterruptedDuringToolExecutionPersistsStep(t *testing.T) {
|
||||
func TestRun_ProviderExecutedToolResultTimestamps(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
model := &chattest.FakeModel{
|
||||
ProviderName: "fake",
|
||||
StreamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
model := &loopTestModel{
|
||||
provider: "fake",
|
||||
streamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
// Simulate a provider-executed tool call and result
|
||||
// (e.g. Anthropic web search) followed by a text
|
||||
// response — all in a single stream.
|
||||
@@ -1496,9 +1541,9 @@ func TestRun_ProviderExecutedToolResultTimestamps(t *testing.T) {
|
||||
func TestRun_PersistStepInterruptedFallback(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
model := &chattest.FakeModel{
|
||||
ProviderName: "fake",
|
||||
StreamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
model := &loopTestModel{
|
||||
provider: "fake",
|
||||
streamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
return streamFromParts([]fantasy.StreamPart{
|
||||
{Type: fantasy.StreamPartTypeTextStart, ID: "text-1"},
|
||||
{Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "hello world"},
|
||||
|
||||
@@ -7,10 +7,8 @@ import (
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatdebug"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
@@ -48,9 +46,6 @@ type CompactionOptions struct {
|
||||
SystemSummaryPrefix string
|
||||
Timeout time.Duration
|
||||
Persist func(context.Context, CompactionResult) error
|
||||
DebugSvc *chatdebug.Service
|
||||
ChatID uuid.UUID
|
||||
HistoryTipMessageID int64
|
||||
|
||||
// ToolCallID and ToolName identify the synthetic tool call
|
||||
// used to represent compaction in the message stream.
|
||||
@@ -274,92 +269,6 @@ func shouldCompact(contextTokens, contextLimit int64, thresholdPercent int32) (f
|
||||
return usagePercent, usagePercent >= float64(thresholdPercent)
|
||||
}
|
||||
|
||||
func startCompactionDebugRun(
|
||||
ctx context.Context,
|
||||
options CompactionOptions,
|
||||
) (context.Context, func(error)) {
|
||||
if options.DebugSvc == nil || options.ChatID == uuid.Nil {
|
||||
return ctx, func(error) {}
|
||||
}
|
||||
|
||||
parentRun, ok := chatdebug.RunFromContext(ctx)
|
||||
if !ok {
|
||||
return ctx, func(error) {}
|
||||
}
|
||||
|
||||
historyTipMessageID := options.HistoryTipMessageID
|
||||
if historyTipMessageID == 0 {
|
||||
historyTipMessageID = parentRun.HistoryTipMessageID
|
||||
}
|
||||
|
||||
run, err := options.DebugSvc.CreateRun(ctx, chatdebug.CreateRunParams{
|
||||
ChatID: options.ChatID,
|
||||
RootChatID: parentRun.RootChatID,
|
||||
ParentChatID: parentRun.ParentChatID,
|
||||
ModelConfigID: parentRun.ModelConfigID,
|
||||
TriggerMessageID: parentRun.TriggerMessageID,
|
||||
HistoryTipMessageID: historyTipMessageID,
|
||||
Kind: chatdebug.KindCompaction,
|
||||
Status: chatdebug.StatusInProgress,
|
||||
Provider: parentRun.Provider,
|
||||
Model: parentRun.Model,
|
||||
})
|
||||
if err != nil {
|
||||
// Debug instrumentation must not surface as a compaction failure.
|
||||
return ctx, func(error) {}
|
||||
}
|
||||
|
||||
compactionCtx := chatdebug.ContextWithRun(ctx, &chatdebug.RunContext{
|
||||
RunID: run.ID,
|
||||
ChatID: options.ChatID,
|
||||
RootChatID: parentRun.RootChatID,
|
||||
ParentChatID: parentRun.ParentChatID,
|
||||
ModelConfigID: parentRun.ModelConfigID,
|
||||
TriggerMessageID: parentRun.TriggerMessageID,
|
||||
HistoryTipMessageID: historyTipMessageID,
|
||||
Kind: chatdebug.KindCompaction,
|
||||
Provider: parentRun.Provider,
|
||||
Model: parentRun.Model,
|
||||
})
|
||||
|
||||
return compactionCtx, func(runErr error) {
|
||||
status := chatdebug.StatusCompleted
|
||||
if runErr != nil {
|
||||
status = chatdebug.StatusError
|
||||
if xerrors.Is(runErr, ErrInterrupted) || xerrors.Is(runErr, context.Canceled) {
|
||||
status = chatdebug.StatusInterrupted
|
||||
}
|
||||
}
|
||||
finalizeCtx, finalizeCancel := context.WithTimeout(
|
||||
context.WithoutCancel(compactionCtx),
|
||||
5*time.Second,
|
||||
)
|
||||
defer finalizeCancel()
|
||||
|
||||
finalSummary := map[string]any(nil)
|
||||
if aggregated, aggErr := options.DebugSvc.AggregateRunSummary(
|
||||
finalizeCtx,
|
||||
run.ID,
|
||||
nil,
|
||||
); aggErr == nil {
|
||||
finalSummary = aggregated
|
||||
}
|
||||
|
||||
// Debug instrumentation must not surface as a compaction failure.
|
||||
_, _ = options.DebugSvc.UpdateRun(
|
||||
finalizeCtx,
|
||||
chatdebug.UpdateRunParams{
|
||||
ID: run.ID,
|
||||
ChatID: options.ChatID,
|
||||
Status: status,
|
||||
Summary: finalSummary,
|
||||
FinishedAt: time.Now(),
|
||||
},
|
||||
)
|
||||
chatdebug.CleanupStepCounter(run.ID)
|
||||
}
|
||||
}
|
||||
|
||||
// generateCompactionSummary asks the model to summarize the
|
||||
// conversation so far. The provided messages should contain the
|
||||
// complete history (system prompt, user/assistant turns, tool
|
||||
@@ -370,7 +279,7 @@ func generateCompactionSummary(
|
||||
model fantasy.LanguageModel,
|
||||
messages []fantasy.Message,
|
||||
options CompactionOptions,
|
||||
) (summary string, err error) {
|
||||
) (string, error) {
|
||||
summaryPrompt := make([]fantasy.Message, 0, len(messages)+1)
|
||||
summaryPrompt = append(summaryPrompt, messages...)
|
||||
summaryPrompt = append(summaryPrompt, fantasy.Message{
|
||||
@@ -384,11 +293,6 @@ func generateCompactionSummary(
|
||||
summaryCtx, cancel := context.WithTimeout(ctx, options.Timeout)
|
||||
defer cancel()
|
||||
|
||||
summaryCtx, finishDebugRun := startCompactionDebugRun(summaryCtx, options)
|
||||
defer func() {
|
||||
finishDebugRun(err)
|
||||
}()
|
||||
|
||||
response, err := model.Generate(summaryCtx, fantasy.Call{
|
||||
Prompt: summaryPrompt,
|
||||
ToolChoice: &toolChoice,
|
||||
|
||||
@@ -2,168 +2,16 @@ package chatloop //nolint:testpackage // Uses internal symbols.
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/google/uuid"
|
||||
"github.com/sqlc-dev/pqtype"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbmock"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatdebug"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chattest"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestStartCompactionDebugRun_DoesNotReportDebugErrors(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
newParentContext := func(chatID uuid.UUID) context.Context {
|
||||
return chatdebug.ContextWithRun(context.Background(), &chatdebug.RunContext{
|
||||
RunID: uuid.New(),
|
||||
ChatID: chatID,
|
||||
RootChatID: uuid.New(),
|
||||
ParentChatID: uuid.New(),
|
||||
ModelConfigID: uuid.New(),
|
||||
TriggerMessageID: 41,
|
||||
HistoryTipMessageID: 42,
|
||||
Kind: chatdebug.KindChatTurn,
|
||||
Provider: "fake-provider",
|
||||
Model: "fake-model",
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("CreateRun", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
svc := chatdebug.NewService(db, testutil.Logger(t), nil)
|
||||
chatID := uuid.New()
|
||||
reportedErr := make(chan error, 1)
|
||||
|
||||
db.EXPECT().InsertChatDebugRun(
|
||||
gomock.Any(),
|
||||
gomock.AssignableToTypeOf(database.InsertChatDebugRunParams{}),
|
||||
).Return(database.ChatDebugRun{}, xerrors.New("insert compaction debug run"))
|
||||
|
||||
ctx := newParentContext(chatID)
|
||||
compactionCtx, finish := startCompactionDebugRun(ctx, CompactionOptions{
|
||||
DebugSvc: svc,
|
||||
ChatID: chatID,
|
||||
OnError: func(err error) {
|
||||
reportedErr <- err
|
||||
},
|
||||
})
|
||||
require.Same(t, ctx, compactionCtx)
|
||||
finish(nil)
|
||||
select {
|
||||
case err := <-reportedErr:
|
||||
t.Fatalf("unexpected OnError callback: %v", err)
|
||||
default:
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("FinalizeRunAggregatesSummary", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
svc := chatdebug.NewService(db, testutil.Logger(t), nil)
|
||||
chatID := uuid.New()
|
||||
runID := uuid.New()
|
||||
usageJSON, err := json.Marshal(fantasy.Usage{InputTokens: 7, OutputTokens: 3})
|
||||
require.NoError(t, err)
|
||||
attemptsJSON, err := json.Marshal([]chatdebug.Attempt{{
|
||||
Status: "completed",
|
||||
Method: "POST",
|
||||
Path: "/v1/messages",
|
||||
}})
|
||||
require.NoError(t, err)
|
||||
|
||||
db.EXPECT().InsertChatDebugRun(
|
||||
gomock.Any(),
|
||||
gomock.AssignableToTypeOf(database.InsertChatDebugRunParams{}),
|
||||
).Return(database.ChatDebugRun{ //nolint:exhaustruct // Test only needs IDs.
|
||||
ID: runID,
|
||||
ChatID: chatID,
|
||||
}, nil)
|
||||
db.EXPECT().GetChatDebugStepsByRunID(gomock.Any(), runID).Return([]database.ChatDebugStep{{
|
||||
ID: uuid.New(),
|
||||
RunID: runID,
|
||||
ChatID: chatID,
|
||||
Status: string(chatdebug.StatusCompleted),
|
||||
Usage: pqtype.NullRawMessage{RawMessage: usageJSON, Valid: true},
|
||||
Attempts: attemptsJSON,
|
||||
}}, nil)
|
||||
db.EXPECT().UpdateChatDebugRun(
|
||||
gomock.Any(),
|
||||
gomock.AssignableToTypeOf(database.UpdateChatDebugRunParams{}),
|
||||
).DoAndReturn(func(_ context.Context, params database.UpdateChatDebugRunParams) (database.ChatDebugRun, error) {
|
||||
require.Equal(t, chatID, params.ChatID)
|
||||
require.Equal(t, runID, params.ID)
|
||||
require.True(t, params.Summary.Valid)
|
||||
require.JSONEq(t, `{"endpoint_label":"POST /v1/messages","step_count":1,"total_input_tokens":7,"total_output_tokens":3}`,
|
||||
string(params.Summary.RawMessage))
|
||||
return database.ChatDebugRun{ID: runID, ChatID: chatID}, nil
|
||||
})
|
||||
|
||||
ctx := newParentContext(chatID)
|
||||
compactionCtx, finish := startCompactionDebugRun(ctx, CompactionOptions{
|
||||
DebugSvc: svc,
|
||||
ChatID: chatID,
|
||||
})
|
||||
require.NotSame(t, ctx, compactionCtx)
|
||||
finish(nil)
|
||||
})
|
||||
|
||||
t.Run("FinalizeRun", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
svc := chatdebug.NewService(db, testutil.Logger(t), nil)
|
||||
chatID := uuid.New()
|
||||
reportedErr := make(chan error, 1)
|
||||
runID := uuid.New()
|
||||
|
||||
db.EXPECT().InsertChatDebugRun(
|
||||
gomock.Any(),
|
||||
gomock.AssignableToTypeOf(database.InsertChatDebugRunParams{}),
|
||||
).Return(database.ChatDebugRun{ //nolint:exhaustruct // Test only needs IDs.
|
||||
ID: runID,
|
||||
ChatID: chatID,
|
||||
}, nil)
|
||||
db.EXPECT().GetChatDebugStepsByRunID(gomock.Any(), runID).Return(nil, xerrors.New("aggregate compaction debug run"))
|
||||
db.EXPECT().UpdateChatDebugRun(
|
||||
gomock.Any(),
|
||||
gomock.AssignableToTypeOf(database.UpdateChatDebugRunParams{}),
|
||||
).Return(database.ChatDebugRun{}, xerrors.New("finalize compaction debug run"))
|
||||
|
||||
ctx := newParentContext(chatID)
|
||||
compactionCtx, finish := startCompactionDebugRun(ctx, CompactionOptions{
|
||||
DebugSvc: svc,
|
||||
ChatID: chatID,
|
||||
OnError: func(err error) {
|
||||
reportedErr <- err
|
||||
},
|
||||
})
|
||||
require.NotSame(t, ctx, compactionCtx)
|
||||
finish(nil)
|
||||
select {
|
||||
case err := <-reportedErr:
|
||||
t.Fatalf("unexpected OnError callback: %v", err)
|
||||
default:
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRun_Compaction(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -174,9 +22,9 @@ func TestRun_Compaction(t *testing.T) {
|
||||
var persistedCompaction CompactionResult
|
||||
const summaryText = "summary text for compaction"
|
||||
|
||||
model := &chattest.FakeModel{
|
||||
ProviderName: "fake",
|
||||
StreamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
model := &loopTestModel{
|
||||
provider: "fake",
|
||||
streamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
return streamFromParts([]fantasy.StreamPart{
|
||||
{Type: fantasy.StreamPartTypeTextStart, ID: "text-1"},
|
||||
{Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "done"},
|
||||
@@ -191,7 +39,7 @@ func TestRun_Compaction(t *testing.T) {
|
||||
},
|
||||
}), nil
|
||||
},
|
||||
GenerateFn: func(_ context.Context, call fantasy.Call) (*fantasy.Response, error) {
|
||||
generateFn: func(_ context.Context, call fantasy.Call) (*fantasy.Response, error) {
|
||||
require.NotEmpty(t, call.Prompt)
|
||||
lastPrompt := call.Prompt[len(call.Prompt)-1]
|
||||
require.Equal(t, fantasy.MessageRoleUser, lastPrompt.Role)
|
||||
@@ -259,9 +107,9 @@ func TestRun_Compaction(t *testing.T) {
|
||||
// and the tool-result part publishes after Persist.
|
||||
var callOrder []string
|
||||
|
||||
model := &chattest.FakeModel{
|
||||
ProviderName: "fake",
|
||||
StreamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
model := &loopTestModel{
|
||||
provider: "fake",
|
||||
streamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
return streamFromParts([]fantasy.StreamPart{
|
||||
{Type: fantasy.StreamPartTypeTextStart, ID: "text-1"},
|
||||
{Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "done"},
|
||||
@@ -276,7 +124,7 @@ func TestRun_Compaction(t *testing.T) {
|
||||
},
|
||||
}), nil
|
||||
},
|
||||
GenerateFn: func(_ context.Context, _ fantasy.Call) (*fantasy.Response, error) {
|
||||
generateFn: func(_ context.Context, _ fantasy.Call) (*fantasy.Response, error) {
|
||||
callOrder = append(callOrder, "generate")
|
||||
return &fantasy.Response{
|
||||
Content: []fantasy.Content{
|
||||
@@ -341,9 +189,9 @@ func TestRun_Compaction(t *testing.T) {
|
||||
|
||||
publishCalled := false
|
||||
|
||||
model := &chattest.FakeModel{
|
||||
ProviderName: "fake",
|
||||
StreamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
model := &loopTestModel{
|
||||
provider: "fake",
|
||||
streamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
return streamFromParts([]fantasy.StreamPart{
|
||||
{
|
||||
Type: fantasy.StreamPartTypeFinish,
|
||||
@@ -392,9 +240,9 @@ func TestRun_Compaction(t *testing.T) {
|
||||
|
||||
const summaryText = "compacted summary"
|
||||
|
||||
model := &chattest.FakeModel{
|
||||
ProviderName: "fake",
|
||||
StreamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
model := &loopTestModel{
|
||||
provider: "fake",
|
||||
streamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
mu.Lock()
|
||||
step := streamCallCount
|
||||
streamCallCount++
|
||||
@@ -439,7 +287,7 @@ func TestRun_Compaction(t *testing.T) {
|
||||
}), nil
|
||||
}
|
||||
},
|
||||
GenerateFn: func(_ context.Context, _ fantasy.Call) (*fantasy.Response, error) {
|
||||
generateFn: func(_ context.Context, _ fantasy.Call) (*fantasy.Response, error) {
|
||||
return &fantasy.Response{
|
||||
Content: []fantasy.Content{
|
||||
fantasy.TextContent{Text: summaryText},
|
||||
@@ -498,9 +346,9 @@ func TestRun_Compaction(t *testing.T) {
|
||||
|
||||
const summaryText = "compacted summary for skip test"
|
||||
|
||||
model := &chattest.FakeModel{
|
||||
ProviderName: "fake",
|
||||
StreamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
model := &loopTestModel{
|
||||
provider: "fake",
|
||||
streamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
mu.Lock()
|
||||
step := streamCallCount
|
||||
streamCallCount++
|
||||
@@ -545,7 +393,7 @@ func TestRun_Compaction(t *testing.T) {
|
||||
}), nil
|
||||
}
|
||||
},
|
||||
GenerateFn: func(_ context.Context, _ fantasy.Call) (*fantasy.Response, error) {
|
||||
generateFn: func(_ context.Context, _ fantasy.Call) (*fantasy.Response, error) {
|
||||
return &fantasy.Response{
|
||||
Content: []fantasy.Content{
|
||||
fantasy.TextContent{Text: summaryText},
|
||||
@@ -594,9 +442,9 @@ func TestRun_Compaction(t *testing.T) {
|
||||
t.Run("ErrorsAreReported", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
model := &chattest.FakeModel{
|
||||
ProviderName: "fake",
|
||||
StreamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
model := &loopTestModel{
|
||||
provider: "fake",
|
||||
streamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
return streamFromParts([]fantasy.StreamPart{
|
||||
{
|
||||
Type: fantasy.StreamPartTypeFinish,
|
||||
@@ -607,7 +455,7 @@ func TestRun_Compaction(t *testing.T) {
|
||||
},
|
||||
}), nil
|
||||
},
|
||||
GenerateFn: func(_ context.Context, _ fantasy.Call) (*fantasy.Response, error) {
|
||||
generateFn: func(_ context.Context, _ fantasy.Call) (*fantasy.Response, error) {
|
||||
return nil, xerrors.New("generate failed")
|
||||
},
|
||||
}
|
||||
@@ -663,9 +511,9 @@ func TestRun_Compaction(t *testing.T) {
|
||||
textMessage(fantasy.MessageRoleUser, "compacted user"),
|
||||
}
|
||||
|
||||
model := &chattest.FakeModel{
|
||||
ProviderName: "fake",
|
||||
StreamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
model := &loopTestModel{
|
||||
provider: "fake",
|
||||
streamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
mu.Lock()
|
||||
step := streamCallCount
|
||||
streamCallCount++
|
||||
@@ -708,7 +556,7 @@ func TestRun_Compaction(t *testing.T) {
|
||||
}), nil
|
||||
}
|
||||
},
|
||||
GenerateFn: func(_ context.Context, _ fantasy.Call) (*fantasy.Response, error) {
|
||||
generateFn: func(_ context.Context, _ fantasy.Call) (*fantasy.Response, error) {
|
||||
return &fantasy.Response{
|
||||
Content: []fantasy.Content{
|
||||
fantasy.TextContent{Text: summaryText},
|
||||
@@ -769,9 +617,9 @@ func TestRun_Compaction(t *testing.T) {
|
||||
|
||||
const summaryText = "post-run compacted summary"
|
||||
|
||||
model := &chattest.FakeModel{
|
||||
ProviderName: "fake",
|
||||
StreamFn: func(_ context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
model := &loopTestModel{
|
||||
provider: "fake",
|
||||
streamFn: func(_ context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
mu.Lock()
|
||||
step := streamCallCount
|
||||
streamCallCount++
|
||||
@@ -811,7 +659,7 @@ func TestRun_Compaction(t *testing.T) {
|
||||
}), nil
|
||||
}
|
||||
},
|
||||
GenerateFn: func(_ context.Context, _ fantasy.Call) (*fantasy.Response, error) {
|
||||
generateFn: func(_ context.Context, _ fantasy.Call) (*fantasy.Response, error) {
|
||||
return &fantasy.Response{
|
||||
Content: []fantasy.Content{
|
||||
fantasy.TextContent{Text: summaryText},
|
||||
@@ -875,9 +723,9 @@ func TestRun_Compaction(t *testing.T) {
|
||||
// The LLM calls a dynamic tool. Usage is above the
|
||||
// compaction threshold so compaction should fire even
|
||||
// though the chatloop exits via ErrDynamicToolCall.
|
||||
model := &chattest.FakeModel{
|
||||
ProviderName: "fake",
|
||||
StreamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
model := &loopTestModel{
|
||||
provider: "fake",
|
||||
streamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
return streamFromParts([]fantasy.StreamPart{
|
||||
{Type: fantasy.StreamPartTypeToolInputStart, ID: "tc-1", ToolCallName: "my_dynamic_tool"},
|
||||
{Type: fantasy.StreamPartTypeToolInputDelta, ID: "tc-1", Delta: `{"query": "test"}`},
|
||||
@@ -898,7 +746,7 @@ func TestRun_Compaction(t *testing.T) {
|
||||
},
|
||||
}), nil
|
||||
},
|
||||
GenerateFn: func(_ context.Context, _ fantasy.Call) (*fantasy.Response, error) {
|
||||
generateFn: func(_ context.Context, _ fantasy.Call) (*fantasy.Response, error) {
|
||||
return &fantasy.Response{
|
||||
Content: []fantasy.Content{
|
||||
fantasy.TextContent{Text: summaryText},
|
||||
|
||||
@@ -2,7 +2,6 @@ package chatprovider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
@@ -1115,15 +1114,13 @@ func CoderHeadersFromIDs(
|
||||
// language model client using the provided provider credentials. The
|
||||
// userAgent is sent as the User-Agent header on every outgoing LLM
|
||||
// API request. extraHeaders, when non-nil, are sent as additional
|
||||
// HTTP headers on every request. httpClient, when non-nil, is used for
|
||||
// all provider HTTP requests.
|
||||
// HTTP headers on every request.
|
||||
func ModelFromConfig(
|
||||
providerHint string,
|
||||
modelName string,
|
||||
providerKeys ProviderAPIKeys,
|
||||
userAgent string,
|
||||
extraHeaders map[string]string,
|
||||
httpClient *http.Client,
|
||||
) (fantasy.LanguageModel, error) {
|
||||
provider, modelID, err := ResolveModelWithProviderHint(modelName, providerHint)
|
||||
if err != nil {
|
||||
@@ -1149,9 +1146,6 @@ func ModelFromConfig(
|
||||
if baseURL != "" {
|
||||
options = append(options, fantasyanthropic.WithBaseURL(baseURL))
|
||||
}
|
||||
if httpClient != nil {
|
||||
options = append(options, fantasyanthropic.WithHTTPClient(httpClient))
|
||||
}
|
||||
providerClient, err = fantasyanthropic.New(options...)
|
||||
case fantasyazure.Name:
|
||||
if baseURL == "" {
|
||||
@@ -1166,9 +1160,6 @@ func ModelFromConfig(
|
||||
if len(extraHeaders) > 0 {
|
||||
azureOpts = append(azureOpts, fantasyazure.WithHeaders(extraHeaders))
|
||||
}
|
||||
if httpClient != nil {
|
||||
azureOpts = append(azureOpts, fantasyazure.WithHTTPClient(httpClient))
|
||||
}
|
||||
providerClient, err = fantasyazure.New(azureOpts...)
|
||||
case fantasybedrock.Name:
|
||||
bedrockOpts := []fantasybedrock.Option{
|
||||
@@ -1178,9 +1169,6 @@ func ModelFromConfig(
|
||||
if len(extraHeaders) > 0 {
|
||||
bedrockOpts = append(bedrockOpts, fantasybedrock.WithHeaders(extraHeaders))
|
||||
}
|
||||
if httpClient != nil {
|
||||
bedrockOpts = append(bedrockOpts, fantasybedrock.WithHTTPClient(httpClient))
|
||||
}
|
||||
providerClient, err = fantasybedrock.New(bedrockOpts...)
|
||||
case fantasygoogle.Name:
|
||||
options := []fantasygoogle.Option{
|
||||
@@ -1193,9 +1181,6 @@ func ModelFromConfig(
|
||||
if baseURL != "" {
|
||||
options = append(options, fantasygoogle.WithBaseURL(baseURL))
|
||||
}
|
||||
if httpClient != nil {
|
||||
options = append(options, fantasygoogle.WithHTTPClient(httpClient))
|
||||
}
|
||||
providerClient, err = fantasygoogle.New(options...)
|
||||
case fantasyopenai.Name:
|
||||
options := []fantasyopenai.Option{
|
||||
@@ -1209,9 +1194,6 @@ func ModelFromConfig(
|
||||
if baseURL != "" {
|
||||
options = append(options, fantasyopenai.WithBaseURL(baseURL))
|
||||
}
|
||||
if httpClient != nil {
|
||||
options = append(options, fantasyopenai.WithHTTPClient(httpClient))
|
||||
}
|
||||
providerClient, err = fantasyopenai.New(options...)
|
||||
case fantasyopenaicompat.Name:
|
||||
options := []fantasyopenaicompat.Option{
|
||||
@@ -1224,9 +1206,6 @@ func ModelFromConfig(
|
||||
if baseURL != "" {
|
||||
options = append(options, fantasyopenaicompat.WithBaseURL(baseURL))
|
||||
}
|
||||
if httpClient != nil {
|
||||
options = append(options, fantasyopenaicompat.WithHTTPClient(httpClient))
|
||||
}
|
||||
providerClient, err = fantasyopenaicompat.New(options...)
|
||||
case fantasyopenrouter.Name:
|
||||
routerOpts := []fantasyopenrouter.Option{
|
||||
@@ -1236,9 +1215,6 @@ func ModelFromConfig(
|
||||
if len(extraHeaders) > 0 {
|
||||
routerOpts = append(routerOpts, fantasyopenrouter.WithHeaders(extraHeaders))
|
||||
}
|
||||
if httpClient != nil {
|
||||
routerOpts = append(routerOpts, fantasyopenrouter.WithHTTPClient(httpClient))
|
||||
}
|
||||
providerClient, err = fantasyopenrouter.New(routerOpts...)
|
||||
case fantasyvercel.Name:
|
||||
options := []fantasyvercel.Option{
|
||||
@@ -1251,9 +1227,6 @@ func ModelFromConfig(
|
||||
if baseURL != "" {
|
||||
options = append(options, fantasyvercel.WithBaseURL(baseURL))
|
||||
}
|
||||
if httpClient != nil {
|
||||
options = append(options, fantasyvercel.WithHTTPClient(httpClient))
|
||||
}
|
||||
providerClient, err = fantasyvercel.New(options...)
|
||||
default:
|
||||
return nil, xerrors.Errorf("unsupported model provider %q", provider)
|
||||
|
||||
@@ -181,12 +181,6 @@ func TestResolveUserProviderKeys(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
type roundTripperFunc func(*http.Request) (*http.Response, error)
|
||||
|
||||
func (fn roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return fn(req)
|
||||
}
|
||||
|
||||
func TestReasoningEffortFromChat(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -783,7 +777,7 @@ func TestModelFromConfig_ExtraHeaders(t *testing.T) {
|
||||
BaseURLByProvider: map[string]string{"openai": serverURL},
|
||||
}
|
||||
|
||||
model, err := chatprovider.ModelFromConfig("openai", "gpt-4", keys, chatprovider.UserAgent(), headers, nil)
|
||||
model, err := chatprovider.ModelFromConfig("openai", "gpt-4", keys, chatprovider.UserAgent(), headers)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = model.Generate(ctx, fantasy.Call{
|
||||
@@ -814,7 +808,7 @@ func TestModelFromConfig_ExtraHeaders(t *testing.T) {
|
||||
BaseURLByProvider: map[string]string{"anthropic": serverURL},
|
||||
}
|
||||
|
||||
model, err := chatprovider.ModelFromConfig("anthropic", "claude-sonnet-4-20250514", keys, chatprovider.UserAgent(), headers, nil)
|
||||
model, err := chatprovider.ModelFromConfig("anthropic", "claude-sonnet-4-20250514", keys, chatprovider.UserAgent(), headers)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = model.Generate(ctx, fantasy.Call{
|
||||
@@ -850,7 +844,7 @@ func TestModelFromConfig_NilExtraHeaders(t *testing.T) {
|
||||
BaseURLByProvider: map[string]string{"openai": serverURL},
|
||||
}
|
||||
|
||||
model, err := chatprovider.ModelFromConfig("openai", "gpt-4", keys, chatprovider.UserAgent(), nil, nil)
|
||||
model, err := chatprovider.ModelFromConfig("openai", "gpt-4", keys, chatprovider.UserAgent(), nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = model.Generate(ctx, fantasy.Call{
|
||||
@@ -865,48 +859,6 @@ func TestModelFromConfig_NilExtraHeaders(t *testing.T) {
|
||||
_ = testutil.TryReceive(ctx, t, called)
|
||||
}
|
||||
|
||||
func TestModelFromConfig_HTTPClient(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
called := make(chan struct{})
|
||||
serverURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
||||
assert.Equal(t, "true", req.Header.Get("X-Test-Transport"))
|
||||
close(called)
|
||||
return chattest.OpenAINonStreamingResponse("hello")
|
||||
})
|
||||
|
||||
keys := chatprovider.ProviderAPIKeys{
|
||||
ByProvider: map[string]string{"openai": "test-key"},
|
||||
BaseURLByProvider: map[string]string{"openai": serverURL},
|
||||
}
|
||||
client := &http.Client{Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||
cloned := req.Clone(req.Context())
|
||||
cloned.Header = req.Header.Clone()
|
||||
cloned.Header.Set("X-Test-Transport", "true")
|
||||
return http.DefaultTransport.RoundTrip(cloned)
|
||||
})}
|
||||
|
||||
model, err := chatprovider.ModelFromConfig(
|
||||
"openai",
|
||||
"gpt-4",
|
||||
keys,
|
||||
chatprovider.UserAgent(),
|
||||
nil,
|
||||
client,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = model.Generate(ctx, fantasy.Call{
|
||||
Prompt: []fantasy.Message{{
|
||||
Role: fantasy.MessageRoleUser,
|
||||
Content: []fantasy.MessagePart{fantasy.TextPart{Text: "hello"}},
|
||||
}},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_ = testutil.TryReceive(ctx, t, called)
|
||||
}
|
||||
|
||||
func TestMergeMissingProviderOptions_OpenRouterNested(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -48,7 +48,7 @@ func TestModelFromConfig_UserAgent(t *testing.T) {
|
||||
BaseURLByProvider: map[string]string{"openai": serverURL},
|
||||
}
|
||||
|
||||
model, err := chatprovider.ModelFromConfig("openai", "gpt-4", keys, expectedUA, nil, nil)
|
||||
model, err := chatprovider.ModelFromConfig("openai", "gpt-4", keys, expectedUA, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Make a real call so Fantasy sends an HTTP request to the
|
||||
|
||||
@@ -1,52 +0,0 @@
|
||||
package chattest
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"charm.land/fantasy"
|
||||
)
|
||||
|
||||
// FakeModel is a configurable test double for fantasy.LanguageModel.
|
||||
// When a method function is nil, the method returns a safe empty
|
||||
// response.
|
||||
type FakeModel struct {
|
||||
ProviderName string
|
||||
ModelName string
|
||||
GenerateFn func(context.Context, fantasy.Call) (*fantasy.Response, error)
|
||||
StreamFn func(context.Context, fantasy.Call) (fantasy.StreamResponse, error)
|
||||
GenerateObjectFn func(context.Context, fantasy.ObjectCall) (*fantasy.ObjectResponse, error)
|
||||
StreamObjectFn func(context.Context, fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error)
|
||||
}
|
||||
|
||||
var _ fantasy.LanguageModel = (*FakeModel)(nil)
|
||||
|
||||
func (m *FakeModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) {
|
||||
if m.GenerateFn == nil {
|
||||
return &fantasy.Response{}, nil
|
||||
}
|
||||
return m.GenerateFn(ctx, call)
|
||||
}
|
||||
|
||||
func (m *FakeModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
if m.StreamFn == nil {
|
||||
return fantasy.StreamResponse(func(func(fantasy.StreamPart) bool) {}), nil
|
||||
}
|
||||
return m.StreamFn(ctx, call)
|
||||
}
|
||||
|
||||
func (m *FakeModel) GenerateObject(ctx context.Context, call fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
|
||||
if m.GenerateObjectFn == nil {
|
||||
return &fantasy.ObjectResponse{}, nil
|
||||
}
|
||||
return m.GenerateObjectFn(ctx, call)
|
||||
}
|
||||
|
||||
func (m *FakeModel) StreamObject(ctx context.Context, call fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) {
|
||||
if m.StreamObjectFn == nil {
|
||||
return fantasy.ObjectStreamResponse(func(func(fantasy.ObjectStreamPart) bool) {}), nil
|
||||
}
|
||||
return m.StreamObjectFn(ctx, call)
|
||||
}
|
||||
|
||||
func (m *FakeModel) Provider() string { return m.ProviderName }
|
||||
func (m *FakeModel) Model() string { return m.ModelName }
|
||||
+11
-317
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -22,7 +21,6 @@ import (
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatdebug"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatprompt"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatprovider"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatretry"
|
||||
@@ -107,173 +105,35 @@ func (p *Server) maybeGenerateChatTitle(
|
||||
ctx context.Context,
|
||||
chat database.Chat,
|
||||
messages []database.ChatMessage,
|
||||
fallbackProvider string,
|
||||
fallbackModelName string,
|
||||
fallbackModel fantasy.LanguageModel,
|
||||
keys chatprovider.ProviderAPIKeys,
|
||||
generatedTitle *generatedChatTitle,
|
||||
logger slog.Logger,
|
||||
debugSvc *chatdebug.Service,
|
||||
) {
|
||||
input, ok := titleInput(chat, messages)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
debugEnabled := debugSvc != nil && debugSvc.IsEnabled(ctx, chat.ID, chat.OwnerID)
|
||||
|
||||
titleCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
type candidateDescriptor struct {
|
||||
provider string
|
||||
model string
|
||||
lm fantasy.LanguageModel
|
||||
}
|
||||
|
||||
// Build candidate list: preferred lightweight models first,
|
||||
// then the user's chat model as last resort.
|
||||
candidates := make([]candidateDescriptor, 0, len(preferredTitleModels)+1)
|
||||
candidates := make([]fantasy.LanguageModel, 0, len(preferredTitleModels)+1)
|
||||
for _, c := range preferredTitleModels {
|
||||
m, err := chatprovider.ModelFromConfig(
|
||||
c.provider, c.model, keys, chatprovider.UserAgent(),
|
||||
chatprovider.CoderHeaders(chat),
|
||||
nil,
|
||||
)
|
||||
if err == nil {
|
||||
candidates = append(candidates, candidateDescriptor{
|
||||
provider: c.provider,
|
||||
model: c.model,
|
||||
lm: m,
|
||||
})
|
||||
candidates = append(candidates, m)
|
||||
}
|
||||
}
|
||||
candidates = append(candidates, candidateDescriptor{
|
||||
provider: fallbackProvider,
|
||||
model: fallbackModelName,
|
||||
lm: fallbackModel,
|
||||
})
|
||||
|
||||
var historyTipMessageID int64
|
||||
if len(messages) > 0 {
|
||||
historyTipMessageID = messages[len(messages)-1].ID
|
||||
}
|
||||
|
||||
var triggerMessageID int64
|
||||
for _, message := range messages {
|
||||
if message.Role == database.ChatMessageRoleUser {
|
||||
triggerMessageID = message.ID
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
seedSummary := chatdebug.SeedSummary(
|
||||
chatdebug.TruncateLabel(input, chatdebug.MaxLabelLength),
|
||||
)
|
||||
|
||||
candidates = append(candidates, fallbackModel)
|
||||
var lastErr error
|
||||
for _, candidate := range candidates {
|
||||
candidateModel := candidate.lm
|
||||
candidateCtx := titleCtx
|
||||
var debugRun *database.ChatDebugRun
|
||||
if debugEnabled {
|
||||
run, err := debugSvc.CreateRun(titleCtx, chatdebug.CreateRunParams{
|
||||
ChatID: chat.ID,
|
||||
TriggerMessageID: triggerMessageID,
|
||||
HistoryTipMessageID: historyTipMessageID,
|
||||
Kind: chatdebug.KindTitleGeneration,
|
||||
Status: chatdebug.StatusInProgress,
|
||||
Provider: candidate.provider,
|
||||
Model: candidate.model,
|
||||
Summary: seedSummary,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Warn(ctx, "failed to create title debug run",
|
||||
slog.F("chat_id", chat.ID),
|
||||
slog.F("provider", candidate.provider),
|
||||
slog.F("model", candidate.model),
|
||||
slog.Error(err),
|
||||
)
|
||||
} else {
|
||||
debugRun = &run
|
||||
candidateCtx = chatdebug.ContextWithRun(
|
||||
candidateCtx,
|
||||
&chatdebug.RunContext{
|
||||
RunID: run.ID,
|
||||
ChatID: chat.ID,
|
||||
TriggerMessageID: triggerMessageID,
|
||||
HistoryTipMessageID: historyTipMessageID,
|
||||
Kind: chatdebug.KindTitleGeneration,
|
||||
Provider: candidate.provider,
|
||||
Model: candidate.model,
|
||||
},
|
||||
)
|
||||
debugModel, err := newQuickgenDebugModel(
|
||||
chat,
|
||||
keys,
|
||||
debugSvc,
|
||||
candidate.provider,
|
||||
candidate.model,
|
||||
)
|
||||
if err != nil {
|
||||
logger.Warn(ctx, "failed to build title debug model",
|
||||
slog.F("chat_id", chat.ID),
|
||||
slog.F("provider", candidate.provider),
|
||||
slog.F("model", candidate.model),
|
||||
slog.Error(err),
|
||||
)
|
||||
} else {
|
||||
candidateModel = debugModel
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
title, err := generateTitle(candidateCtx, candidateModel, input)
|
||||
if debugRun != nil {
|
||||
status := chatdebug.StatusCompleted
|
||||
switch {
|
||||
case err == nil:
|
||||
// keep completed
|
||||
case errors.Is(err, context.Canceled):
|
||||
status = chatdebug.StatusInterrupted
|
||||
default:
|
||||
status = chatdebug.StatusError
|
||||
}
|
||||
finalizeCtx, finalizeCancel := context.WithTimeout(
|
||||
context.WithoutCancel(ctx), 10*time.Second,
|
||||
)
|
||||
finalSummary := seedSummary
|
||||
if aggregated, aggErr := debugSvc.AggregateRunSummary(
|
||||
finalizeCtx,
|
||||
debugRun.ID,
|
||||
seedSummary,
|
||||
); aggErr != nil {
|
||||
logger.Warn(ctx, "failed to aggregate debug run summary",
|
||||
slog.F("chat_id", chat.ID),
|
||||
slog.F("run_id", debugRun.ID),
|
||||
slog.Error(aggErr),
|
||||
)
|
||||
} else {
|
||||
finalSummary = aggregated
|
||||
}
|
||||
if _, updateErr := debugSvc.UpdateRun(
|
||||
finalizeCtx,
|
||||
chatdebug.UpdateRunParams{
|
||||
ID: debugRun.ID,
|
||||
ChatID: chat.ID,
|
||||
Status: status,
|
||||
Summary: finalSummary,
|
||||
FinishedAt: time.Now(),
|
||||
},
|
||||
); updateErr != nil {
|
||||
logger.Warn(ctx, "failed to finalize title debug run",
|
||||
slog.F("chat_id", chat.ID),
|
||||
slog.F("run_id", debugRun.ID),
|
||||
slog.Error(updateErr),
|
||||
)
|
||||
}
|
||||
chatdebug.CleanupStepCounter(debugRun.ID)
|
||||
finalizeCancel()
|
||||
}
|
||||
for _, model := range candidates {
|
||||
title, err := generateTitle(titleCtx, model, input)
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
logger.Debug(ctx, "title model candidate failed",
|
||||
@@ -311,41 +171,6 @@ func (p *Server) maybeGenerateChatTitle(
|
||||
}
|
||||
}
|
||||
|
||||
func newQuickgenDebugModel(
|
||||
chat database.Chat,
|
||||
keys chatprovider.ProviderAPIKeys,
|
||||
debugSvc *chatdebug.Service,
|
||||
provider string,
|
||||
model string,
|
||||
) (fantasy.LanguageModel, error) {
|
||||
httpClient := &http.Client{Transport: &chatdebug.RecordingTransport{}}
|
||||
debugModel, err := chatprovider.ModelFromConfig(
|
||||
provider,
|
||||
model,
|
||||
keys,
|
||||
chatprovider.UserAgent(),
|
||||
chatprovider.CoderHeaders(chat),
|
||||
httpClient,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if debugModel == nil {
|
||||
return nil, xerrors.Errorf(
|
||||
"create model for %s/%s returned nil",
|
||||
provider,
|
||||
model,
|
||||
)
|
||||
}
|
||||
|
||||
return chatdebug.WrapModel(debugModel, debugSvc, chatdebug.RecorderOptions{
|
||||
ChatID: chat.ID,
|
||||
OwnerID: chat.OwnerID,
|
||||
Provider: provider,
|
||||
Model: model,
|
||||
}), nil
|
||||
}
|
||||
|
||||
// generateTitle calls the model with a title-generation system prompt
|
||||
// and returns the normalized result. It retries transient LLM errors
|
||||
// (rate limits, overloaded, etc.) with exponential backoff.
|
||||
@@ -746,160 +571,30 @@ func generatePushSummary(
|
||||
ctx context.Context,
|
||||
chat database.Chat,
|
||||
assistantText string,
|
||||
fallbackProvider string,
|
||||
fallbackModelName string,
|
||||
fallbackModel fantasy.LanguageModel,
|
||||
keys chatprovider.ProviderAPIKeys,
|
||||
logger slog.Logger,
|
||||
debugSvc *chatdebug.Service,
|
||||
triggerMessageID int64,
|
||||
historyTipMessageID int64,
|
||||
) string {
|
||||
debugEnabled := debugSvc != nil && debugSvc.IsEnabled(ctx, chat.ID, chat.OwnerID)
|
||||
|
||||
summaryCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
assistantText = truncateRunes(assistantText, maxConversationContextRunes)
|
||||
input := "Chat title: " + chat.Title + "\n\nAgent's last message:\n" + assistantText
|
||||
|
||||
type candidateDescriptor struct {
|
||||
provider string
|
||||
model string
|
||||
lm fantasy.LanguageModel
|
||||
}
|
||||
|
||||
candidates := make([]candidateDescriptor, 0, len(preferredTitleModels)+1)
|
||||
candidates := make([]fantasy.LanguageModel, 0, len(preferredTitleModels)+1)
|
||||
for _, c := range preferredTitleModels {
|
||||
m, err := chatprovider.ModelFromConfig(
|
||||
c.provider, c.model, keys, chatprovider.UserAgent(),
|
||||
chatprovider.CoderHeaders(chat),
|
||||
nil,
|
||||
)
|
||||
if err == nil {
|
||||
candidates = append(candidates, candidateDescriptor{
|
||||
provider: c.provider,
|
||||
model: c.model,
|
||||
lm: m,
|
||||
})
|
||||
candidates = append(candidates, m)
|
||||
}
|
||||
}
|
||||
candidates = append(candidates, candidateDescriptor{
|
||||
provider: fallbackProvider,
|
||||
model: fallbackModelName,
|
||||
lm: fallbackModel,
|
||||
})
|
||||
candidates = append(candidates, fallbackModel)
|
||||
|
||||
pushSeedSummary := chatdebug.SeedSummary("Push summary")
|
||||
|
||||
for _, candidate := range candidates {
|
||||
candidateModel := candidate.lm
|
||||
candidateCtx := summaryCtx
|
||||
var debugRun *database.ChatDebugRun
|
||||
if debugEnabled {
|
||||
run, err := debugSvc.CreateRun(summaryCtx, chatdebug.CreateRunParams{
|
||||
ChatID: chat.ID,
|
||||
TriggerMessageID: triggerMessageID,
|
||||
HistoryTipMessageID: historyTipMessageID,
|
||||
Kind: chatdebug.KindQuickgen,
|
||||
Status: chatdebug.StatusInProgress,
|
||||
Provider: candidate.provider,
|
||||
Model: candidate.model,
|
||||
Summary: pushSeedSummary,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Warn(ctx, "failed to create quickgen debug run",
|
||||
slog.F("chat_id", chat.ID),
|
||||
slog.F("provider", candidate.provider),
|
||||
slog.F("model", candidate.model),
|
||||
slog.Error(err),
|
||||
)
|
||||
} else {
|
||||
debugRun = &run
|
||||
candidateCtx = chatdebug.ContextWithRun(
|
||||
candidateCtx,
|
||||
&chatdebug.RunContext{
|
||||
RunID: run.ID,
|
||||
ChatID: chat.ID,
|
||||
TriggerMessageID: triggerMessageID,
|
||||
HistoryTipMessageID: historyTipMessageID,
|
||||
Kind: chatdebug.KindQuickgen,
|
||||
Provider: candidate.provider,
|
||||
Model: candidate.model,
|
||||
},
|
||||
)
|
||||
debugModel, err := newQuickgenDebugModel(
|
||||
chat,
|
||||
keys,
|
||||
debugSvc,
|
||||
candidate.provider,
|
||||
candidate.model,
|
||||
)
|
||||
if err != nil {
|
||||
logger.Warn(ctx, "failed to build quickgen debug model",
|
||||
slog.F("chat_id", chat.ID),
|
||||
slog.F("provider", candidate.provider),
|
||||
slog.F("model", candidate.model),
|
||||
slog.Error(err),
|
||||
)
|
||||
} else {
|
||||
candidateModel = debugModel
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
summary, err := generateShortText(
|
||||
candidateCtx,
|
||||
candidateModel,
|
||||
pushSummaryPrompt,
|
||||
input,
|
||||
)
|
||||
if debugRun != nil {
|
||||
status := chatdebug.StatusCompleted
|
||||
switch {
|
||||
case err == nil:
|
||||
// keep completed
|
||||
case errors.Is(err, context.Canceled):
|
||||
status = chatdebug.StatusInterrupted
|
||||
default:
|
||||
status = chatdebug.StatusError
|
||||
}
|
||||
finalizeCtx, finalizeCancel := context.WithTimeout(
|
||||
context.WithoutCancel(ctx), 10*time.Second,
|
||||
)
|
||||
finalSummary := pushSeedSummary
|
||||
if aggregated, aggErr := debugSvc.AggregateRunSummary(
|
||||
finalizeCtx,
|
||||
debugRun.ID,
|
||||
pushSeedSummary,
|
||||
); aggErr != nil {
|
||||
logger.Warn(ctx, "failed to aggregate debug run summary",
|
||||
slog.F("chat_id", chat.ID),
|
||||
slog.F("run_id", debugRun.ID),
|
||||
slog.Error(aggErr),
|
||||
)
|
||||
} else {
|
||||
finalSummary = aggregated
|
||||
}
|
||||
if _, updateErr := debugSvc.UpdateRun(
|
||||
finalizeCtx,
|
||||
chatdebug.UpdateRunParams{
|
||||
ID: debugRun.ID,
|
||||
ChatID: chat.ID,
|
||||
Status: status,
|
||||
Summary: finalSummary,
|
||||
FinishedAt: time.Now(),
|
||||
},
|
||||
); updateErr != nil {
|
||||
logger.Warn(ctx, "failed to finalize quickgen debug run",
|
||||
slog.F("chat_id", chat.ID),
|
||||
slog.F("run_id", debugRun.ID),
|
||||
slog.Error(updateErr),
|
||||
)
|
||||
}
|
||||
chatdebug.CleanupStepCounter(debugRun.ID)
|
||||
finalizeCancel()
|
||||
}
|
||||
for _, model := range candidates {
|
||||
summary, err := generateShortText(summaryCtx, model, pushSummaryPrompt, input)
|
||||
if err != nil {
|
||||
logger.Debug(ctx, "push summary model candidate failed",
|
||||
slog.Error(err),
|
||||
@@ -915,8 +610,7 @@ func generatePushSummary(
|
||||
|
||||
// generateShortText calls a model with a system prompt and user
|
||||
// input, returning a cleaned-up short text response. It reuses the
|
||||
// same retry logic as title generation. Retries can therefore
|
||||
// produce multiple debug steps for a single quickgen run.
|
||||
// same retry logic as title generation.
|
||||
func generateShortText(
|
||||
ctx context.Context,
|
||||
model fantasy.LanguageModel,
|
||||
|
||||
@@ -10,9 +10,9 @@ import (
|
||||
"charm.land/fantasy"
|
||||
"github.com/sqlc-dev/pqtype"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chattest"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
@@ -375,8 +375,8 @@ func Test_generateManualTitle_UsesTimeout(t *testing.T) {
|
||||
),
|
||||
}
|
||||
|
||||
model := &chattest.FakeModel{
|
||||
GenerateObjectFn: func(ctx context.Context, call fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
|
||||
model := &stubModel{
|
||||
generateObjectFn: func(ctx context.Context, call fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
|
||||
deadline, ok := ctx.Deadline()
|
||||
require.True(t, ok, "manual title generation should set a deadline")
|
||||
require.WithinDuration(
|
||||
@@ -413,8 +413,8 @@ func Test_generateManualTitle_TruncatesFirstUserInput(t *testing.T) {
|
||||
),
|
||||
}
|
||||
|
||||
model := &chattest.FakeModel{
|
||||
GenerateObjectFn: func(_ context.Context, call fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
|
||||
model := &stubModel{
|
||||
generateObjectFn: func(_ context.Context, call fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
|
||||
require.Len(t, call.Prompt, 2)
|
||||
systemText, ok := call.Prompt[0].Content[0].(fantasy.TextPart)
|
||||
require.True(t, ok)
|
||||
@@ -447,8 +447,8 @@ func Test_generateManualTitle_ReturnsUsageForEmptyNormalizedTitle(t *testing.T)
|
||||
),
|
||||
}
|
||||
|
||||
model := &chattest.FakeModel{
|
||||
GenerateObjectFn: func(_ context.Context, _ fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
|
||||
model := &stubModel{
|
||||
generateObjectFn: func(_ context.Context, _ fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
|
||||
return &fantasy.ObjectResponse{
|
||||
Object: map[string]any{"title": "\"\""},
|
||||
Usage: fantasy.Usage{
|
||||
@@ -504,8 +504,8 @@ func Test_selectPreferredConfiguredShortTextModelConfig(t *testing.T) {
|
||||
func Test_generateShortText_NormalizesQuotedOutput(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
model := &chattest.FakeModel{
|
||||
GenerateFn: func(_ context.Context, _ fantasy.Call) (*fantasy.Response, error) {
|
||||
model := &stubModel{
|
||||
generateFn: func(_ context.Context, _ fantasy.Call) (*fantasy.Response, error) {
|
||||
return &fantasy.Response{
|
||||
Content: fantasy.ResponseContent{
|
||||
fantasy.TextContent{Text: " \"Quoted summary\" "},
|
||||
@@ -520,6 +520,53 @@ func Test_generateShortText_NormalizesQuotedOutput(t *testing.T) {
|
||||
require.Equal(t, "Quoted summary", text)
|
||||
}
|
||||
|
||||
type stubModel struct {
|
||||
generateFn func(context.Context, fantasy.Call) (*fantasy.Response, error)
|
||||
generateObjectFn func(context.Context, fantasy.ObjectCall) (*fantasy.ObjectResponse, error)
|
||||
}
|
||||
|
||||
func (m *stubModel) Generate(
|
||||
ctx context.Context,
|
||||
call fantasy.Call,
|
||||
) (*fantasy.Response, error) {
|
||||
if m.generateFn == nil {
|
||||
return nil, xerrors.New("generate not implemented")
|
||||
}
|
||||
return m.generateFn(ctx, call)
|
||||
}
|
||||
|
||||
func (*stubModel) Stream(
|
||||
context.Context,
|
||||
fantasy.Call,
|
||||
) (fantasy.StreamResponse, error) {
|
||||
return nil, xerrors.New("stream not implemented")
|
||||
}
|
||||
|
||||
func (m *stubModel) GenerateObject(
|
||||
ctx context.Context,
|
||||
call fantasy.ObjectCall,
|
||||
) (*fantasy.ObjectResponse, error) {
|
||||
if m.generateObjectFn == nil {
|
||||
return nil, xerrors.New("generate object not implemented")
|
||||
}
|
||||
return m.generateObjectFn(ctx, call)
|
||||
}
|
||||
|
||||
func (*stubModel) StreamObject(
|
||||
context.Context,
|
||||
fantasy.ObjectCall,
|
||||
) (fantasy.ObjectStreamResponse, error) {
|
||||
return nil, xerrors.New("stream object not implemented")
|
||||
}
|
||||
|
||||
func (*stubModel) Provider() string {
|
||||
return "test"
|
||||
}
|
||||
|
||||
func (*stubModel) Model() string {
|
||||
return "test"
|
||||
}
|
||||
|
||||
func mustChatMessage(
|
||||
t *testing.T,
|
||||
role database.ChatMessageRole,
|
||||
|
||||
@@ -547,148 +547,6 @@ type UpdateChatDesktopEnabledRequest struct {
|
||||
EnableDesktop bool `json:"enable_desktop"`
|
||||
}
|
||||
|
||||
// ChatDebugLoggingAdminSettings describes the runtime admin setting
|
||||
// that allows users to opt into chat debug logging.
|
||||
type ChatDebugLoggingAdminSettings struct {
|
||||
AllowUsers bool `json:"allow_users"`
|
||||
ForcedByDeployment bool `json:"forced_by_deployment"`
|
||||
}
|
||||
|
||||
// UserChatDebugLoggingSettings describes whether debug logging is
|
||||
// active for the current user and whether the user may control it.
|
||||
type UserChatDebugLoggingSettings struct {
|
||||
DebugLoggingEnabled bool `json:"debug_logging_enabled"`
|
||||
UserToggleAllowed bool `json:"user_toggle_allowed"`
|
||||
ForcedByDeployment bool `json:"forced_by_deployment"`
|
||||
}
|
||||
|
||||
// UpdateChatDebugLoggingAllowUsersRequest is the admin request to
|
||||
// toggle whether users may opt into chat debug logging.
|
||||
type UpdateChatDebugLoggingAllowUsersRequest struct {
|
||||
AllowUsers bool `json:"allow_users"`
|
||||
}
|
||||
|
||||
// UpdateUserChatDebugLoggingRequest is the per-user request to
|
||||
// opt into or out of chat debug logging.
|
||||
type UpdateUserChatDebugLoggingRequest struct {
|
||||
DebugLoggingEnabled bool `json:"debug_logging_enabled"`
|
||||
}
|
||||
|
||||
// ChatDebugStatus enumerates the lifecycle states shared by debug
|
||||
// runs and steps. These values must match the literals used in
|
||||
// FinalizeStaleChatDebugRows and all insert/update callers.
|
||||
type ChatDebugStatus string
|
||||
|
||||
const (
|
||||
ChatDebugStatusInProgress ChatDebugStatus = "in_progress"
|
||||
ChatDebugStatusCompleted ChatDebugStatus = "completed"
|
||||
ChatDebugStatusError ChatDebugStatus = "error"
|
||||
ChatDebugStatusInterrupted ChatDebugStatus = "interrupted"
|
||||
)
|
||||
|
||||
// AllChatDebugStatuses contains every ChatDebugStatus value.
|
||||
// Update this when adding new constants above.
|
||||
var AllChatDebugStatuses = []ChatDebugStatus{
|
||||
ChatDebugStatusInProgress,
|
||||
ChatDebugStatusCompleted,
|
||||
ChatDebugStatusError,
|
||||
ChatDebugStatusInterrupted,
|
||||
}
|
||||
|
||||
// ChatDebugRunKind labels the operation that produced the debug
|
||||
// run. Each value corresponds to a distinct call-site in chatd.
|
||||
type ChatDebugRunKind string
|
||||
|
||||
const (
|
||||
ChatDebugRunKindChatTurn ChatDebugRunKind = "chat_turn"
|
||||
ChatDebugRunKindTitleGeneration ChatDebugRunKind = "title_generation"
|
||||
ChatDebugRunKindQuickgen ChatDebugRunKind = "quickgen"
|
||||
ChatDebugRunKindCompaction ChatDebugRunKind = "compaction"
|
||||
)
|
||||
|
||||
// AllChatDebugRunKinds contains every ChatDebugRunKind value.
|
||||
// Update this when adding new constants above.
|
||||
var AllChatDebugRunKinds = []ChatDebugRunKind{
|
||||
ChatDebugRunKindChatTurn,
|
||||
ChatDebugRunKindTitleGeneration,
|
||||
ChatDebugRunKindQuickgen,
|
||||
ChatDebugRunKindCompaction,
|
||||
}
|
||||
|
||||
// ChatDebugStepOperation labels the model interaction type for a
|
||||
// debug step.
|
||||
type ChatDebugStepOperation string
|
||||
|
||||
const (
|
||||
ChatDebugStepOperationStream ChatDebugStepOperation = "stream"
|
||||
ChatDebugStepOperationGenerate ChatDebugStepOperation = "generate"
|
||||
)
|
||||
|
||||
// AllChatDebugStepOperations contains every ChatDebugStepOperation
|
||||
// value. Update this when adding new constants above.
|
||||
var AllChatDebugStepOperations = []ChatDebugStepOperation{
|
||||
ChatDebugStepOperationStream,
|
||||
ChatDebugStepOperationGenerate,
|
||||
}
|
||||
|
||||
// ChatDebugRunSummary is a lightweight run entry for list endpoints.
|
||||
type ChatDebugRunSummary struct {
|
||||
ID uuid.UUID `json:"id" format:"uuid"`
|
||||
ChatID uuid.UUID `json:"chat_id" format:"uuid"`
|
||||
Kind ChatDebugRunKind `json:"kind"`
|
||||
Status ChatDebugStatus `json:"status"`
|
||||
Provider *string `json:"provider,omitempty"`
|
||||
Model *string `json:"model,omitempty"`
|
||||
Summary map[string]any `json:"summary"`
|
||||
StartedAt time.Time `json:"started_at" format:"date-time"`
|
||||
UpdatedAt time.Time `json:"updated_at" format:"date-time"`
|
||||
FinishedAt *time.Time `json:"finished_at,omitempty" format:"date-time"`
|
||||
}
|
||||
|
||||
// ChatDebugRun is the detailed run response including steps.
|
||||
// This type is consumed by the run-detail handler added in a later
|
||||
// PR in this stack; it is forward-declared here so that all SDK
|
||||
// types live in the same schema-layer commit.
|
||||
type ChatDebugRun struct {
|
||||
ID uuid.UUID `json:"id" format:"uuid"`
|
||||
ChatID uuid.UUID `json:"chat_id" format:"uuid"`
|
||||
RootChatID *uuid.UUID `json:"root_chat_id,omitempty" format:"uuid"`
|
||||
ParentChatID *uuid.UUID `json:"parent_chat_id,omitempty" format:"uuid"`
|
||||
ModelConfigID *uuid.UUID `json:"model_config_id,omitempty" format:"uuid"`
|
||||
TriggerMessageID *int64 `json:"trigger_message_id,omitempty"`
|
||||
HistoryTipMessageID *int64 `json:"history_tip_message_id,omitempty"`
|
||||
Kind ChatDebugRunKind `json:"kind"`
|
||||
Status ChatDebugStatus `json:"status"`
|
||||
Provider *string `json:"provider,omitempty"`
|
||||
Model *string `json:"model,omitempty"`
|
||||
Summary map[string]any `json:"summary"`
|
||||
StartedAt time.Time `json:"started_at" format:"date-time"`
|
||||
UpdatedAt time.Time `json:"updated_at" format:"date-time"`
|
||||
FinishedAt *time.Time `json:"finished_at,omitempty" format:"date-time"`
|
||||
Steps []ChatDebugStep `json:"steps"`
|
||||
}
|
||||
|
||||
// ChatDebugStep is a single step within a debug run.
|
||||
type ChatDebugStep struct {
|
||||
ID uuid.UUID `json:"id" format:"uuid"`
|
||||
RunID uuid.UUID `json:"run_id" format:"uuid"`
|
||||
ChatID uuid.UUID `json:"chat_id" format:"uuid"`
|
||||
StepNumber int32 `json:"step_number"`
|
||||
Operation ChatDebugStepOperation `json:"operation"`
|
||||
Status ChatDebugStatus `json:"status"`
|
||||
HistoryTipMessageID *int64 `json:"history_tip_message_id,omitempty"`
|
||||
AssistantMessageID *int64 `json:"assistant_message_id,omitempty"`
|
||||
NormalizedRequest map[string]any `json:"normalized_request"`
|
||||
NormalizedResponse map[string]any `json:"normalized_response,omitempty"`
|
||||
Usage map[string]any `json:"usage,omitempty"`
|
||||
Attempts []map[string]any `json:"attempts"`
|
||||
Error map[string]any `json:"error,omitempty"`
|
||||
Metadata map[string]any `json:"metadata"`
|
||||
StartedAt time.Time `json:"started_at" format:"date-time"`
|
||||
UpdatedAt time.Time `json:"updated_at" format:"date-time"`
|
||||
FinishedAt *time.Time `json:"finished_at,omitempty" format:"date-time"`
|
||||
}
|
||||
|
||||
// DefaultChatWorkspaceTTL is the default TTL for chat workspaces.
|
||||
// Zero means disabled — the template's own autostop setting applies.
|
||||
const DefaultChatWorkspaceTTL = 0
|
||||
@@ -2210,92 +2068,6 @@ func (c *ExperimentalClient) WatchChats(ctx context.Context) (<-chan ChatWatchEv
|
||||
}), nil
|
||||
}
|
||||
|
||||
// GetChatDebugLogging returns the runtime admin setting that allows
|
||||
// users to opt into chat debug logging.
|
||||
func (c *ExperimentalClient) GetChatDebugLogging(ctx context.Context) (ChatDebugLoggingAdminSettings, error) {
|
||||
res, err := c.Request(ctx, http.MethodGet, "/api/experimental/chats/config/debug-logging", nil)
|
||||
if err != nil {
|
||||
return ChatDebugLoggingAdminSettings{}, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusOK {
|
||||
return ChatDebugLoggingAdminSettings{}, ReadBodyAsError(res)
|
||||
}
|
||||
var resp ChatDebugLoggingAdminSettings
|
||||
return resp, json.NewDecoder(res.Body).Decode(&resp)
|
||||
}
|
||||
|
||||
// UpdateChatDebugLogging updates the runtime admin setting that allows
|
||||
// users to opt into chat debug logging.
|
||||
func (c *ExperimentalClient) UpdateChatDebugLogging(ctx context.Context, req UpdateChatDebugLoggingAllowUsersRequest) error {
|
||||
res, err := c.Request(ctx, http.MethodPut, "/api/experimental/chats/config/debug-logging", req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusNoContent {
|
||||
return ReadBodyAsError(res)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetUserChatDebugLogging returns whether chat debug logging is active
|
||||
// for the current user and whether the user may change it.
|
||||
func (c *ExperimentalClient) GetUserChatDebugLogging(ctx context.Context) (UserChatDebugLoggingSettings, error) {
|
||||
res, err := c.Request(ctx, http.MethodGet, "/api/experimental/chats/config/user-debug-logging", nil)
|
||||
if err != nil {
|
||||
return UserChatDebugLoggingSettings{}, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusOK {
|
||||
return UserChatDebugLoggingSettings{}, ReadBodyAsError(res)
|
||||
}
|
||||
var resp UserChatDebugLoggingSettings
|
||||
return resp, json.NewDecoder(res.Body).Decode(&resp)
|
||||
}
|
||||
|
||||
// UpdateUserChatDebugLogging updates the current user's chat debug
|
||||
// logging preference.
|
||||
func (c *ExperimentalClient) UpdateUserChatDebugLogging(ctx context.Context, req UpdateUserChatDebugLoggingRequest) error {
|
||||
res, err := c.Request(ctx, http.MethodPut, "/api/experimental/chats/config/user-debug-logging", req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusNoContent {
|
||||
return ReadBodyAsError(res)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetChatDebugRuns returns the debug runs for a chat.
|
||||
func (c *ExperimentalClient) GetChatDebugRuns(ctx context.Context, chatID uuid.UUID) ([]ChatDebugRunSummary, error) {
|
||||
res, err := c.Request(ctx, http.MethodGet, fmt.Sprintf("/api/experimental/chats/%s/debug/runs", chatID), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusOK {
|
||||
return nil, ReadBodyAsError(res)
|
||||
}
|
||||
var resp []ChatDebugRunSummary
|
||||
return resp, json.NewDecoder(res.Body).Decode(&resp)
|
||||
}
|
||||
|
||||
// GetChatDebugRun returns a debug run for a chat.
|
||||
func (c *ExperimentalClient) GetChatDebugRun(ctx context.Context, chatID uuid.UUID, runID uuid.UUID) (ChatDebugRun, error) {
|
||||
res, err := c.Request(ctx, http.MethodGet, fmt.Sprintf("/api/experimental/chats/%s/debug/runs/%s", chatID, runID), nil)
|
||||
if err != nil {
|
||||
return ChatDebugRun{}, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusOK {
|
||||
return ChatDebugRun{}, ReadBodyAsError(res)
|
||||
}
|
||||
var resp ChatDebugRun
|
||||
return resp, json.NewDecoder(res.Body).Decode(&resp)
|
||||
}
|
||||
|
||||
// GetChat returns a chat by ID.
|
||||
func (c *ExperimentalClient) GetChat(ctx context.Context, chatID uuid.UUID) (Chat, error) {
|
||||
res, err := c.Request(ctx, http.MethodGet, fmt.Sprintf("/api/experimental/chats/%s", chatID), nil)
|
||||
|
||||
+1
-12
@@ -3624,16 +3624,6 @@ Write out the current server config as YAML to stdout.`,
|
||||
YAML: "acquireBatchSize",
|
||||
Hidden: true, // Hidden because most operators should not need to modify this.
|
||||
},
|
||||
{
|
||||
Name: "Chat: Debug Logging Enabled",
|
||||
Description: "Force chat debug logging on for every chat, bypassing the runtime admin and user opt-in settings.",
|
||||
Flag: "chat-debug-logging-enabled",
|
||||
Env: "CODER_CHAT_DEBUG_LOGGING_ENABLED",
|
||||
Value: &c.AI.Chat.DebugLoggingEnabled,
|
||||
Default: "false",
|
||||
Group: &deploymentGroupChat,
|
||||
YAML: "debugLoggingEnabled",
|
||||
},
|
||||
// AI Bridge Options
|
||||
{
|
||||
Name: "AI Bridge Enabled",
|
||||
@@ -4100,8 +4090,7 @@ type AIBridgeProxyConfig struct {
|
||||
}
|
||||
|
||||
type ChatConfig struct {
|
||||
AcquireBatchSize serpent.Int64 `json:"acquire_batch_size" typescript:",notnull"`
|
||||
DebugLoggingEnabled serpent.Bool `json:"debug_logging_enabled" typescript:",notnull"`
|
||||
AcquireBatchSize serpent.Int64 `json:"acquire_batch_size" typescript:",notnull"`
|
||||
}
|
||||
|
||||
type AIConfig struct {
|
||||
|
||||
@@ -34,14 +34,16 @@ the most important.
|
||||
- [React](https://reactjs.org/) for the UI framework
|
||||
- [Typescript](https://www.typescriptlang.org/) to keep our sanity
|
||||
- [Vite](https://vitejs.dev/) to build the project
|
||||
- [Material V5](https://mui.com/material-ui/getting-started/) for UI components
|
||||
- [react-router](https://reactrouter.com/en/main) for routing
|
||||
- [TanStack Query](https://tanstack.com/query/v4/docs/react/overview) for
|
||||
- [TanStack Query v4](https://tanstack.com/query/v4/docs/react/overview) for
|
||||
fetching data
|
||||
- [Vitest](https://vitest.dev/) for integration testing
|
||||
- [axios](https://github.com/axios/axios) as fetching lib
|
||||
- [Playwright](https://playwright.dev/) for end-to-end (E2E) testing
|
||||
- [Jest](https://jestjs.io/) for integration testing
|
||||
- [Storybook](https://storybook.js.org/) and
|
||||
[Chromatic](https://www.chromatic.com/) for visual testing
|
||||
- [pnpm](https://pnpm.io/) as the package manager
|
||||
- [PNPM](https://pnpm.io/) as the package manager
|
||||
|
||||
## Structure
|
||||
|
||||
@@ -49,6 +51,7 @@ All UI-related code is in the `site` folder. Key directories include:
|
||||
|
||||
- **e2e** - End-to-end (E2E) tests
|
||||
- **src** - Source code
|
||||
- **mocks** - [Manual mocks](https://jestjs.io/docs/manual-mocks) used by Jest
|
||||
- **@types** - Custom types for dependencies that don't have defined types
|
||||
(largely code that has no server-side equivalent)
|
||||
- **api** - API function calls and types
|
||||
@@ -56,7 +59,7 @@ All UI-related code is in the `site` folder. Key directories include:
|
||||
- **components** - Reusable UI components without Coder specific business
|
||||
logic
|
||||
- **hooks** - Custom React hooks
|
||||
- **modules** - Coder specific logic and components related to multiple parts of the UI
|
||||
- **modules** - Coder-specific UI components
|
||||
- **pages** - Page-level components
|
||||
- **testHelpers** - Helper functions for integration testing
|
||||
- **theme** - theme configuration and color definitions
|
||||
@@ -283,9 +286,9 @@ local machine and forward the necessary ports to your workspace. At the end of
|
||||
the script, you will land _inside_ your workspace with environment variables set
|
||||
so you can simply execute the test (`pnpm run playwright:test`).
|
||||
|
||||
### Integration/Unit
|
||||
### Integration/Unit – Jest
|
||||
|
||||
We use unit and integration tests mostly for testing code that does _not_ pertain to React. Functions and classes that contain notable app logic, and which are well abstracted from React should have accompanying tests. If the logic is tightly coupled to a React component, a Storybook test or an E2E test is usually a better option.
|
||||
We use Jest mostly for testing code that does _not_ pertain to React. Functions and classes that contain notable app logic, and which are well abstracted from React should have accompanying tests. If the logic is tightly coupled to a React component, a Storybook test or an E2E test may be a better option depending on the scenario.
|
||||
|
||||
### Visual Testing – Storybook
|
||||
|
||||
|
||||
Generated
+1
-2
@@ -209,8 +209,7 @@ curl -X GET http://coder-server:8080/api/v2/deployment/config \
|
||||
"structured_logging": true
|
||||
},
|
||||
"chat": {
|
||||
"acquire_batch_size": 0,
|
||||
"debug_logging_enabled": true
|
||||
"acquire_batch_size": 0
|
||||
}
|
||||
},
|
||||
"allow_workspace_renames": true,
|
||||
|
||||
Generated
+7
-12
@@ -1240,8 +1240,7 @@
|
||||
"structured_logging": true
|
||||
},
|
||||
"chat": {
|
||||
"acquire_batch_size": 0,
|
||||
"debug_logging_enabled": true
|
||||
"acquire_batch_size": 0
|
||||
}
|
||||
}
|
||||
```
|
||||
@@ -2022,17 +2021,15 @@ AuthorizationObject can represent a "set" of objects, such as: all workspaces in
|
||||
|
||||
```json
|
||||
{
|
||||
"acquire_batch_size": 0,
|
||||
"debug_logging_enabled": true
|
||||
"acquire_batch_size": 0
|
||||
}
|
||||
```
|
||||
|
||||
### Properties
|
||||
|
||||
| Name | Type | Required | Restrictions | Description |
|
||||
|-------------------------|---------|----------|--------------|-------------|
|
||||
| `acquire_batch_size` | integer | false | | |
|
||||
| `debug_logging_enabled` | boolean | false | | |
|
||||
| Name | Type | Required | Restrictions | Description |
|
||||
|----------------------|---------|----------|--------------|-------------|
|
||||
| `acquire_batch_size` | integer | false | | |
|
||||
|
||||
## codersdk.ChatRetentionDaysResponse
|
||||
|
||||
@@ -3264,8 +3261,7 @@ CreateWorkspaceRequest provides options for creating a new workspace. Only one o
|
||||
"structured_logging": true
|
||||
},
|
||||
"chat": {
|
||||
"acquire_batch_size": 0,
|
||||
"debug_logging_enabled": true
|
||||
"acquire_batch_size": 0
|
||||
}
|
||||
},
|
||||
"allow_workspace_renames": true,
|
||||
@@ -3843,8 +3839,7 @@ CreateWorkspaceRequest provides options for creating a new workspace. Only one o
|
||||
"structured_logging": true
|
||||
},
|
||||
"chat": {
|
||||
"acquire_batch_size": 0,
|
||||
"debug_logging_enabled": true
|
||||
"acquire_batch_size": 0
|
||||
}
|
||||
},
|
||||
"allow_workspace_renames": true,
|
||||
|
||||
Generated
-11
@@ -1702,17 +1702,6 @@ How often to reconcile workspace prebuilds state.
|
||||
|
||||
Hide AI tasks from the dashboard.
|
||||
|
||||
### --chat-debug-logging-enabled
|
||||
|
||||
| | |
|
||||
|-------------|------------------------------------------------|
|
||||
| Type | <code>bool</code> |
|
||||
| Environment | <code>$CODER_CHAT_DEBUG_LOGGING_ENABLED</code> |
|
||||
| YAML | <code>chat.debugLoggingEnabled</code> |
|
||||
| Default | <code>false</code> |
|
||||
|
||||
Force chat debug logging on for every chat, bypassing the runtime admin and user opt-in settings.
|
||||
|
||||
### --aibridge-enabled
|
||||
|
||||
| | |
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# 1.93.1
|
||||
FROM rust:slim@sha256:cf09adf8c3ebaba10779e5c23ff7fe4df4cccdab8a91f199b0c142c53fef3e1a AS rust-utils
|
||||
FROM rust:slim@sha256:a08d20a404f947ed358dfb63d1ee7e0b88ecad3c45ba9682ccbf2cb09c98acca AS rust-utils
|
||||
# Install rust helper programs
|
||||
ENV CARGO_INSTALL_ROOT=/tmp/
|
||||
# Use more reliable mirrors for Debian packages
|
||||
|
||||
@@ -416,7 +416,7 @@ module "vscode-web" {
|
||||
module "jetbrains" {
|
||||
count = contains(jsondecode(data.coder_parameter.ide_choices.value), "jetbrains") ? data.coder_workspace.me.start_count : 0
|
||||
source = "dev.registry.coder.com/coder/jetbrains/coder"
|
||||
version = "1.4.0"
|
||||
version = "1.3.1"
|
||||
agent_id = coder_agent.dev.id
|
||||
agent_name = "dev"
|
||||
folder = local.repo_dir
|
||||
@@ -922,7 +922,7 @@ resource "coder_script" "boundary_config_setup" {
|
||||
module "claude-code" {
|
||||
count = data.coder_task.me.enabled ? data.coder_workspace.me.start_count : 0
|
||||
source = "dev.registry.coder.com/coder/claude-code/coder"
|
||||
version = "4.9.2"
|
||||
version = "4.9.1"
|
||||
enable_boundary = true
|
||||
agent_id = coder_agent.dev.id
|
||||
workdir = local.repo_dir
|
||||
|
||||
@@ -212,13 +212,6 @@ AI BRIDGE PROXY OPTIONS:
|
||||
certificates not trusted by the system. If not provided, the system
|
||||
certificate pool is used.
|
||||
|
||||
CHAT OPTIONS:
|
||||
Configure the background chat processing daemon.
|
||||
|
||||
--chat-debug-logging-enabled bool, $CODER_CHAT_DEBUG_LOGGING_ENABLED (default: false)
|
||||
Force chat debug logging on for every chat, bypassing the runtime
|
||||
admin and user opt-in settings.
|
||||
|
||||
CLIENT OPTIONS:
|
||||
These options change the behavior of how clients interact with the Coder.
|
||||
Clients include the Coder CLI, Coder Desktop, IDE extensions, and the web UI.
|
||||
|
||||
@@ -33,7 +33,7 @@ data "coder_task" "me" {}
|
||||
module "claude-code" {
|
||||
count = data.coder_workspace.me.start_count
|
||||
source = "registry.coder.com/coder/claude-code/coder"
|
||||
version = "4.9.2"
|
||||
version = "4.9.1"
|
||||
agent_id = coder_agent.main.id
|
||||
workdir = "/home/coder/projects"
|
||||
order = 999
|
||||
|
||||
@@ -130,7 +130,7 @@ require (
|
||||
github.com/coder/terraform-provider-coder/v2 v2.15.0
|
||||
github.com/coder/websocket v1.8.14
|
||||
github.com/coder/wgtunnel v0.2.0
|
||||
github.com/coreos/go-oidc/v3 v3.18.0
|
||||
github.com/coreos/go-oidc/v3 v3.17.0
|
||||
github.com/coreos/go-systemd v0.0.0-20191104093116-d3cd4ed1dbcf
|
||||
github.com/creack/pty v1.1.24
|
||||
github.com/dave/dst v0.27.2
|
||||
@@ -211,11 +211,11 @@ require (
|
||||
github.com/zclconf/go-cty-yaml v1.2.0
|
||||
go.mozilla.org/pkcs7 v0.9.0
|
||||
go.nhat.io/otelsql v0.16.0
|
||||
go.opentelemetry.io/otel v1.43.0
|
||||
go.opentelemetry.io/otel v1.42.0
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.40.0
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.40.0
|
||||
go.opentelemetry.io/otel/sdk v1.43.0
|
||||
go.opentelemetry.io/otel/trace v1.43.0
|
||||
go.opentelemetry.io/otel/sdk v1.42.0
|
||||
go.opentelemetry.io/otel/trace v1.42.0
|
||||
go.uber.org/atomic v1.11.0
|
||||
go.uber.org/goleak v1.3.1-0.20240429205332-517bace7cc29
|
||||
go.uber.org/mock v0.6.0
|
||||
@@ -231,7 +231,7 @@ require (
|
||||
golang.org/x/text v0.35.0
|
||||
golang.org/x/tools v0.43.0
|
||||
golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da
|
||||
google.golang.org/api v0.275.0
|
||||
google.golang.org/api v0.274.0
|
||||
google.golang.org/grpc v1.80.0
|
||||
google.golang.org/protobuf v1.36.11
|
||||
gopkg.in/DataDog/dd-trace-go.v1 v1.74.0
|
||||
@@ -244,7 +244,7 @@ require (
|
||||
)
|
||||
|
||||
require (
|
||||
cloud.google.com/go/auth v0.20.0 // indirect
|
||||
cloud.google.com/go/auth v0.18.2 // indirect
|
||||
cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect
|
||||
dario.cat/mergo v1.0.2 // indirect
|
||||
filippo.io/edwards25519 v1.1.1 // indirect
|
||||
@@ -345,7 +345,7 @@ require (
|
||||
github.com/google/s2a-go v0.1.9 // indirect
|
||||
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.14 // indirect
|
||||
github.com/googleapis/gax-go/v2 v2.21.0 // indirect
|
||||
github.com/googleapis/gax-go/v2 v2.19.0 // indirect
|
||||
github.com/gorilla/css v1.0.1 // indirect
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0 // indirect
|
||||
github.com/hashicorp/errwrap v1.1.0 // indirect
|
||||
@@ -458,8 +458,8 @@ require (
|
||||
go.opentelemetry.io/collector/pdata/pprofile v0.121.0 // indirect
|
||||
go.opentelemetry.io/collector/semconv v0.123.0 // indirect
|
||||
go.opentelemetry.io/contrib v1.19.0 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.68.0
|
||||
go.opentelemetry.io/otel/metric v1.43.0 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.67.0
|
||||
go.opentelemetry.io/otel/metric v1.42.0 // indirect
|
||||
go.opentelemetry.io/proto/otlp v1.9.0 // indirect
|
||||
go.uber.org/multierr v1.11.0 // indirect
|
||||
go.uber.org/zap v1.27.1 // indirect
|
||||
@@ -469,9 +469,9 @@ require (
|
||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 // indirect
|
||||
golang.zx2c4.com/wireguard/windows v0.5.3 // indirect
|
||||
google.golang.org/appengine v1.6.8 // indirect
|
||||
google.golang.org/genproto v0.0.0-20260319201613-d00831a3d3e7 // indirect
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20260319201613-d00831a3d3e7 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9 // indirect
|
||||
google.golang.org/genproto v0.0.0-20260316180232-0b37fe3546d5 // indirect
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20260316180232-0b37fe3546d5 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260319201613-d00831a3d3e7 // indirect
|
||||
gopkg.in/ini.v1 v1.67.1 // indirect
|
||||
howett.net/plist v1.0.0 // indirect
|
||||
kernel.org/pub/linux/libs/security/libcap/psx v1.2.77 // indirect
|
||||
@@ -628,7 +628,7 @@ require (
|
||||
github.com/zeebo/xxh3 v1.0.2 // indirect
|
||||
go.opentelemetry.io/contrib/detectors/gcp v1.40.0 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.67.0 // indirect
|
||||
go.opentelemetry.io/otel/sdk/metric v1.43.0 // indirect
|
||||
go.opentelemetry.io/otel/sdk/metric v1.42.0 // indirect
|
||||
go.yaml.in/yaml/v2 v2.4.3 // indirect
|
||||
go.yaml.in/yaml/v3 v3.0.4 // indirect
|
||||
go.yaml.in/yaml/v4 v4.0.0-rc.3 // indirect
|
||||
|
||||
@@ -4,8 +4,8 @@ cel.dev/expr v0.25.1 h1:1KrZg61W6TWSxuNZ37Xy49ps13NUovb66QLprthtwi4=
|
||||
cel.dev/expr v0.25.1/go.mod h1:hrXvqGP6G6gyx8UAHSHJ5RGk//1Oj5nXQ2NI02Nrsg4=
|
||||
cloud.google.com/go v0.123.0 h1:2NAUJwPR47q+E35uaJeYoNhuNEM9kM8SjgRgdeOJUSE=
|
||||
cloud.google.com/go v0.123.0/go.mod h1:xBoMV08QcqUGuPW65Qfm1o9Y4zKZBpGS+7bImXLTAZU=
|
||||
cloud.google.com/go/auth v0.20.0 h1:kXTssoVb4azsVDoUiF8KvxAqrsQcQtB53DcSgta74CA=
|
||||
cloud.google.com/go/auth v0.20.0/go.mod h1:942/yi/itH1SsmpyrbnTMDgGfdy2BUqIKyd0cyYLc5Q=
|
||||
cloud.google.com/go/auth v0.18.2 h1:+Nbt5Ev0xEqxlNjd6c+yYUeosQ5TtEUaNcN/3FozlaM=
|
||||
cloud.google.com/go/auth v0.18.2/go.mod h1:xD+oY7gcahcu7G2SG2DsBerfFxgPAJz17zz2joOFF3M=
|
||||
cloud.google.com/go/auth/oauth2adapt v0.2.8 h1:keo8NaayQZ6wimpNSmW5OPc283g65QNIiLpZnkHRbnc=
|
||||
cloud.google.com/go/auth/oauth2adapt v0.2.8/go.mod h1:XQ9y31RkqZCcwJWNSx2Xvric3RrU88hAYYbjDWYDL+c=
|
||||
cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdBtwLoEkH9Zs=
|
||||
@@ -378,8 +378,8 @@ github.com/containerd/stargz-snapshotter/estargz v0.18.1 h1:cy2/lpgBXDA3cDKSyEfN
|
||||
github.com/containerd/stargz-snapshotter/estargz v0.18.1/go.mod h1:ALIEqa7B6oVDsrF37GkGN20SuvG/pIMm7FwP7ZmRb0Q=
|
||||
github.com/coreos/go-iptables v0.6.0 h1:is9qnZMPYjLd8LYqmm/qlE+wwEgJIkTYdhV3rfZo4jk=
|
||||
github.com/coreos/go-iptables v0.6.0/go.mod h1:Qe8Bv2Xik5FyTXwgIbLAnv2sWSBmvWdFETJConOQ//Q=
|
||||
github.com/coreos/go-oidc/v3 v3.18.0 h1:V9orjXynvu5wiC9SemFTWnG4F45v403aIcjWo0d41+A=
|
||||
github.com/coreos/go-oidc/v3 v3.18.0/go.mod h1:DYCf24+ncYi+XkIH97GY1+dqoRlbaSI26KVTCI9SrY4=
|
||||
github.com/coreos/go-oidc/v3 v3.17.0 h1:hWBGaQfbi0iVviX4ibC7bk8OKT5qNr4klBaCHVNvehc=
|
||||
github.com/coreos/go-oidc/v3 v3.17.0/go.mod h1:wqPbKFrVnE90vty060SB40FCJ8fTHTxSwyXJqZH+sI8=
|
||||
github.com/coreos/go-systemd v0.0.0-20191104093116-d3cd4ed1dbcf h1:iW4rZ826su+pqaw19uhpSCzhj44qo35pNgKFGqzDKkU=
|
||||
github.com/coreos/go-systemd v0.0.0-20191104093116-d3cd4ed1dbcf/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4=
|
||||
github.com/cpuguy83/dockercfg v0.3.2 h1:DlJTyZGBDlXqUZ2Dk2Q3xHs/FtnooJJVaad2S9GKorA=
|
||||
@@ -677,8 +677,8 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.14 h1:yh8ncqsbUY4shRD5dA6RlzjJaT4hi3kII+zYw8wmLb8=
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.14/go.mod h1:vqVt9yG9480NtzREnTlmGSBmFrA+bzb0yl0TxoBQXOg=
|
||||
github.com/googleapis/gax-go/v2 v2.21.0 h1:h45NjjzEO3faG9Lg/cFrBh2PgegVVgzqKzuZl/wMbiI=
|
||||
github.com/googleapis/gax-go/v2 v2.21.0/go.mod h1:But/NJU6TnZsrLai/xBAQLLz+Hc7fHZJt/hsCz3Fih4=
|
||||
github.com/googleapis/gax-go/v2 v2.19.0 h1:fYQaUOiGwll0cGj7jmHT/0nPlcrZDFPrZRhTsoCr8hE=
|
||||
github.com/googleapis/gax-go/v2 v2.19.0/go.mod h1:w2ROXVdfGEVFXzmlciUU4EdjHgWvB5h2n6x/8XSTTJA=
|
||||
github.com/gorilla/css v1.0.1 h1:ntNaBIghp6JmvWnxbZKANoLyuXTPZ4cAMlo6RyhlbO8=
|
||||
github.com/gorilla/css v1.0.1/go.mod h1:BvnYkspnSzMmwRK+b8/xgNPLiIuNZr6vbZBTPQ2A3b0=
|
||||
github.com/gorilla/websocket v1.5.4-0.20250319132907-e064f32e3674 h1:JeSE6pjso5THxAzdVpqr6/geYxZytqFMBCOtn/ujyeo=
|
||||
@@ -1311,11 +1311,11 @@ go.opentelemetry.io/contrib/detectors/gcp v1.40.0 h1:Awaf8gmW99tZTOWqkLCOl6aw1/r
|
||||
go.opentelemetry.io/contrib/detectors/gcp v1.40.0/go.mod h1:99OY9ZCqyLkzJLTh5XhECpLRSxcZl+ZDKBEO+jMBFR4=
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.67.0 h1:yI1/OhfEPy7J9eoa6Sj051C7n5dvpj0QX8g4sRchg04=
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.67.0/go.mod h1:NoUCKYWK+3ecatC4HjkRktREheMeEtrXoQxrqYFeHSc=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.68.0 h1:CqXxU8VOmDefoh0+ztfGaymYbhdB/tT3zs79QaZTNGY=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.68.0/go.mod h1:BuhAPThV8PBHBvg8ZzZ/Ok3idOdhWIodywz2xEcRbJo=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.67.0 h1:OyrsyzuttWTSur2qN/Lm0m2a8yqyIjUVBZcxFPuXq2o=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.67.0/go.mod h1:C2NGBr+kAB4bk3xtMXfZ94gqFDtg/GkI7e9zqGh5Beg=
|
||||
go.opentelemetry.io/otel v1.3.0/go.mod h1:PWIKzi6JCp7sM0k9yZ43VX+T345uNbAkDKwHVjb2PTs=
|
||||
go.opentelemetry.io/otel v1.43.0 h1:mYIM03dnh5zfN7HautFE4ieIig9amkNANT+xcVxAj9I=
|
||||
go.opentelemetry.io/otel v1.43.0/go.mod h1:JuG+u74mvjvcm8vj8pI5XiHy1zDeoCS2LB1spIq7Ay0=
|
||||
go.opentelemetry.io/otel v1.42.0 h1:lSQGzTgVR3+sgJDAU/7/ZMjN9Z+vUip7leaqBKy4sho=
|
||||
go.opentelemetry.io/otel v1.42.0/go.mod h1:lJNsdRMxCUIWuMlVJWzecSMuNjE7dOYyWlqOXWkdqCc=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.40.0 h1:QKdN8ly8zEMrByybbQgv8cWBcdAarwmIPZ6FThrWXJs=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.40.0/go.mod h1:bTdK1nhqF76qiPoCCdyFIV+N/sRHYXYCTQc+3VCi3MI=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.40.0 h1:DvJDOPmSWQHWywQS6lKL+pb8s3gBLOZUtw4N+mavW1I=
|
||||
@@ -1326,16 +1326,16 @@ go.opentelemetry.io/otel/exporters/stdout/stdoutmetric v1.40.0 h1:ZrPRak/kS4xI3A
|
||||
go.opentelemetry.io/otel/exporters/stdout/stdoutmetric v1.40.0/go.mod h1:3y6kQCWztq6hyW8Z9YxQDDm0Je9AJoFar2G0yDcmhRk=
|
||||
go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.37.0 h1:SNhVp/9q4Go/XHBkQ1/d5u9P/U+L1yaGPoi0x+mStaI=
|
||||
go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.37.0/go.mod h1:tx8OOlGH6R4kLV67YaYO44GFXloEjGPZuMjEkaaqIp4=
|
||||
go.opentelemetry.io/otel/metric v1.43.0 h1:d7638QeInOnuwOONPp4JAOGfbCEpYb+K6DVWvdxGzgM=
|
||||
go.opentelemetry.io/otel/metric v1.43.0/go.mod h1:RDnPtIxvqlgO8GRW18W6Z/4P462ldprJtfxHxyKd2PY=
|
||||
go.opentelemetry.io/otel/metric v1.42.0 h1:2jXG+3oZLNXEPfNmnpxKDeZsFI5o4J+nz6xUlaFdF/4=
|
||||
go.opentelemetry.io/otel/metric v1.42.0/go.mod h1:RlUN/7vTU7Ao/diDkEpQpnz3/92J9ko05BIwxYa2SSI=
|
||||
go.opentelemetry.io/otel/sdk v1.3.0/go.mod h1:rIo4suHNhQwBIPg9axF8V9CA72Wz2mKF1teNrup8yzs=
|
||||
go.opentelemetry.io/otel/sdk v1.43.0 h1:pi5mE86i5rTeLXqoF/hhiBtUNcrAGHLKQdhg4h4V9Dg=
|
||||
go.opentelemetry.io/otel/sdk v1.43.0/go.mod h1:P+IkVU3iWukmiit/Yf9AWvpyRDlUeBaRg6Y+C58QHzg=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.43.0 h1:S88dyqXjJkuBNLeMcVPRFXpRw2fuwdvfCGLEo89fDkw=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.43.0/go.mod h1:C/RJtwSEJ5hzTiUz5pXF1kILHStzb9zFlIEe85bhj6A=
|
||||
go.opentelemetry.io/otel/sdk v1.42.0 h1:LyC8+jqk6UJwdrI/8VydAq/hvkFKNHZVIWuslJXYsDo=
|
||||
go.opentelemetry.io/otel/sdk v1.42.0/go.mod h1:rGHCAxd9DAph0joO4W6OPwxjNTYWghRWmkHuGbayMts=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.42.0 h1:D/1QR46Clz6ajyZ3G8SgNlTJKBdGp84q9RKCAZ3YGuA=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.42.0/go.mod h1:Ua6AAlDKdZ7tdvaQKfSmnFTdHx37+J4ba8MwVCYM5hc=
|
||||
go.opentelemetry.io/otel/trace v1.3.0/go.mod h1:c/VDhno8888bvQYmbYLqe41/Ldmr/KKunbvWM4/fEjk=
|
||||
go.opentelemetry.io/otel/trace v1.43.0 h1:BkNrHpup+4k4w+ZZ86CZoHHEkohws8AY+WTX09nk+3A=
|
||||
go.opentelemetry.io/otel/trace v1.43.0/go.mod h1:/QJhyVBUUswCphDVxq+8mld+AvhXZLhe+8WVFxiFff0=
|
||||
go.opentelemetry.io/otel/trace v1.42.0 h1:OUCgIPt+mzOnaUTpOQcBiM/PLQ/Op7oq6g4LenLmOYY=
|
||||
go.opentelemetry.io/otel/trace v1.42.0/go.mod h1:f3K9S+IFqnumBkKhRJMeaZeNk9epyhnCmQh/EysQCdc=
|
||||
go.opentelemetry.io/proto/otlp v1.9.0 h1:l706jCMITVouPOqEnii2fIAuO3IVGBRPV5ICjceRb/A=
|
||||
go.opentelemetry.io/proto/otlp v1.9.0/go.mod h1:xE+Cx5E/eEHw+ISFkwPLwCZefwVjY+pqKg1qcK03+/4=
|
||||
go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
|
||||
@@ -1514,19 +1514,19 @@ golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus
|
||||
golang.zx2c4.com/wireguard/windows v0.5.3/go.mod h1:9TEe8TJmtwyQebdFwAkEWOPr3prrtqm+REGFifP60hI=
|
||||
gonum.org/v1/gonum v0.17.0 h1:VbpOemQlsSMrYmn7T2OUvQ4dqxQXU+ouZFQsZOx50z4=
|
||||
gonum.org/v1/gonum v0.17.0/go.mod h1:El3tOrEuMpv2UdMrbNlKEh9vd86bmQ6vqIcDwxEOc1E=
|
||||
google.golang.org/api v0.275.0 h1:vfY5d9vFVJeWEZT65QDd9hbndr7FyZ2+6mIzGAh71NI=
|
||||
google.golang.org/api v0.275.0/go.mod h1:Fnag/EWUPIcJXuIkP1pjoTgS5vdxlk3eeemL7Do6bvw=
|
||||
google.golang.org/api v0.274.0 h1:aYhycS5QQCwxHLwfEHRRLf9yNsfvp1JadKKWBE54RFA=
|
||||
google.golang.org/api v0.274.0/go.mod h1:JbAt7mF+XVmWu6xNP8/+CTiGH30ofmCmk9nM8d8fHew=
|
||||
google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
|
||||
google.golang.org/appengine v1.6.8 h1:IhEN5q69dyKagZPYMSdIjS2HqprW324FRQZJcGqPAsM=
|
||||
google.golang.org/appengine v1.6.8/go.mod h1:1jJ3jBArFh5pcgW8gCtRJnepW8FzD1V44FJffLiz/Ds=
|
||||
google.golang.org/genai v1.51.0 h1:IZGuUqgfx40INv3hLFGCbOSGp0qFqm7LVmDghzNIYqg=
|
||||
google.golang.org/genai v1.51.0/go.mod h1:A3kkl0nyBjyFlNjgxIwKq70julKbIxpSxqKO5gw/gmk=
|
||||
google.golang.org/genproto v0.0.0-20260319201613-d00831a3d3e7 h1:XzmzkmB14QhVhgnawEVsOn6OFsnpyxNPRY9QV01dNB0=
|
||||
google.golang.org/genproto v0.0.0-20260319201613-d00831a3d3e7/go.mod h1:L43LFes82YgSonw6iTXTxXUX1OlULt4AQtkik4ULL/I=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20260319201613-d00831a3d3e7 h1:41r6JMbpzBMen0R/4TZeeAmGXSJC7DftGINUodzTkPI=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20260319201613-d00831a3d3e7/go.mod h1:EIQZ5bFCfRQDV4MhRle7+OgjNtZ6P1PiZBgAKuxXu/Y=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9 h1:m8qni9SQFH0tJc1X0vmnpw/0t+AImlSvp30sEupozUg=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8=
|
||||
google.golang.org/genproto v0.0.0-20260316180232-0b37fe3546d5 h1:JNfk58HZ8lfmXbYK2vx/UvsqIL59TzByCxPIX4TDmsE=
|
||||
google.golang.org/genproto v0.0.0-20260316180232-0b37fe3546d5/go.mod h1:x5julN69+ED4PcFk/XWayw35O0lf/nGa4aNgODCmNmw=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20260316180232-0b37fe3546d5 h1:CogIeEXn4qWYzzQU0QqvYBM8yDF9cFYzDq9ojSpv0Js=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20260316180232-0b37fe3546d5/go.mod h1:EIQZ5bFCfRQDV4MhRle7+OgjNtZ6P1PiZBgAKuxXu/Y=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260319201613-d00831a3d3e7 h1:ndE4FoJqsIceKP2oYSnUZqhTdYufCYYkqwtFzfrhI7w=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260319201613-d00831a3d3e7/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8=
|
||||
google.golang.org/grpc v1.80.0 h1:Xr6m2WmWZLETvUNvIUmeD5OAagMw3FiKmMlTdViWsHM=
|
||||
google.golang.org/grpc v1.80.0/go.mod h1:ho/dLnxwi3EDJA4Zghp7k2Ec1+c2jqup0bFkw07bwF4=
|
||||
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
|
||||
|
||||
+7
-7
@@ -1,17 +1,17 @@
|
||||
{
|
||||
"$schema": "https://unpkg.com/knip@5/schema.json",
|
||||
"entry": ["./src/index.tsx", "./src/serviceWorker.ts"],
|
||||
"project": [
|
||||
"./src/**/*.ts",
|
||||
"./src/**/*.tsx",
|
||||
"./test/**/*.ts",
|
||||
"./e2e/**/*.ts"
|
||||
],
|
||||
"project": ["./src/**/*.ts", "./src/**/*.tsx", "./e2e/**/*.ts"],
|
||||
"ignore": ["**/*Generated.ts", "src/api/chatModelOptions.ts"],
|
||||
"ignoreBinaries": ["protoc"],
|
||||
"ignoreDependencies": [
|
||||
"@babel/plugin-syntax-typescript",
|
||||
"@types/react-virtualized-auto-sizer",
|
||||
"babel-plugin-react-compiler",
|
||||
"jest_workaround",
|
||||
"ts-proto"
|
||||
]
|
||||
],
|
||||
"jest": {
|
||||
"entry": "./src/**/*.jest.{ts,tsx}"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,61 @@
|
||||
module.exports = {
|
||||
// Use a big timeout for CI.
|
||||
testTimeout: 20_000,
|
||||
maxWorkers: 8,
|
||||
projects: [
|
||||
{
|
||||
displayName: "test",
|
||||
roots: ["<rootDir>"],
|
||||
setupFiles: ["./jest.polyfills.js"],
|
||||
setupFilesAfterEnv: ["./jest.setup.ts"],
|
||||
extensionsToTreatAsEsm: [".ts"],
|
||||
transform: {
|
||||
"^.+\\.(t|j)sx?$": [
|
||||
"@swc/jest",
|
||||
{
|
||||
jsc: {
|
||||
transform: {
|
||||
react: {
|
||||
runtime: "automatic",
|
||||
importSource: "@emotion/react",
|
||||
},
|
||||
},
|
||||
experimental: {
|
||||
plugins: [["jest_workaround", {}]],
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
testEnvironment: "jest-fixed-jsdom",
|
||||
testEnvironmentOptions: {
|
||||
customExportConditions: [""],
|
||||
},
|
||||
testRegex: "(/__tests__/.*|(\\.|/)(jest))\\.tsx?$",
|
||||
testPathIgnorePatterns: ["/node_modules/", "/e2e/"],
|
||||
transformIgnorePatterns: [],
|
||||
moduleDirectories: ["node_modules"],
|
||||
moduleNameMapper: {
|
||||
"\\.css$": "<rootDir>/src/testHelpers/styleMock.ts",
|
||||
"^@fontsource": "<rootDir>/src/testHelpers/styleMock.ts",
|
||||
"^@pierre/diffs/react$":
|
||||
"<rootDir>/src/testHelpers/pierreDiffsReactMock.tsx",
|
||||
},
|
||||
},
|
||||
],
|
||||
collectCoverageFrom: [
|
||||
// included files
|
||||
"<rootDir>/**/*.ts",
|
||||
"<rootDir>/**/*.tsx",
|
||||
// excluded files
|
||||
"!<rootDir>/**/*.stories.tsx",
|
||||
"!<rootDir>/_jest/**/*.*",
|
||||
"!<rootDir>/api.ts",
|
||||
"!<rootDir>/coverage/**/*.*",
|
||||
"!<rootDir>/e2e/**/*.*",
|
||||
"!<rootDir>/jest-runner.eslint.config.js",
|
||||
"!<rootDir>/jest.config.js",
|
||||
"!<rootDir>/out/**/*.*",
|
||||
"!<rootDir>/storybook-static/**/*.*",
|
||||
],
|
||||
};
|
||||
@@ -0,0 +1,44 @@
|
||||
/**
|
||||
* Necessary for MSW
|
||||
*
|
||||
* @note The block below contains polyfills for Node.js globals
|
||||
* required for Jest to function when running JSDOM tests.
|
||||
* These HAVE to be require's and HAVE to be in this exact
|
||||
* order, since "undici" depends on the "TextEncoder" global API.
|
||||
*
|
||||
* Consider migrating to a more modern test runner if
|
||||
* you don't want to deal with this.
|
||||
*/
|
||||
const { TextDecoder, TextEncoder } = require("node:util");
|
||||
const { ReadableStream } = require("node:stream/web");
|
||||
|
||||
Object.defineProperties(globalThis, {
|
||||
TextDecoder: { value: TextDecoder },
|
||||
TextEncoder: { value: TextEncoder },
|
||||
ReadableStream: { value: ReadableStream },
|
||||
});
|
||||
|
||||
const { Blob, File } = require("node:buffer");
|
||||
const { fetch, Headers, FormData, Request, Response } = require("undici");
|
||||
|
||||
Object.defineProperties(globalThis, {
|
||||
fetch: { value: fetch, writable: true },
|
||||
Blob: { value: Blob },
|
||||
File: { value: File },
|
||||
Headers: { value: Headers },
|
||||
FormData: { value: FormData },
|
||||
Request: { value: Request },
|
||||
Response: { value: Response },
|
||||
matchMedia: {
|
||||
value: (query) => ({
|
||||
matches: false,
|
||||
media: query,
|
||||
onchange: null,
|
||||
addListener: jest.fn(),
|
||||
removeListener: jest.fn(),
|
||||
addEventListener: jest.fn(),
|
||||
removeEventListener: jest.fn(),
|
||||
dispatchEvent: jest.fn(),
|
||||
}),
|
||||
},
|
||||
});
|
||||
@@ -0,0 +1,80 @@
|
||||
import "@testing-library/jest-dom";
|
||||
import "jest-location-mock";
|
||||
import crypto from "node:crypto";
|
||||
import { cleanup } from "@testing-library/react";
|
||||
import { useMemo } from "react";
|
||||
import type { Region } from "#/api/typesGenerated";
|
||||
import type { ProxyLatencyReport } from "#/contexts/useProxyLatency";
|
||||
import { server } from "#/testHelpers/server";
|
||||
|
||||
// useProxyLatency does some http requests to determine latency.
|
||||
// This would fail unit testing, or at least make it very slow with
|
||||
// actual network requests. So just globally mock this hook.
|
||||
jest.mock("#/contexts/useProxyLatency", () => ({
|
||||
useProxyLatency: (proxies?: Region[]) => {
|
||||
// Must use `useMemo` here to avoid infinite loop.
|
||||
// Mocking the hook with a hook.
|
||||
const proxyLatencies = useMemo(() => {
|
||||
if (!proxies) {
|
||||
return {} as Record<string, ProxyLatencyReport>;
|
||||
}
|
||||
return proxies.reduce(
|
||||
(acc, proxy) => {
|
||||
acc[proxy.id] = {
|
||||
accurate: true,
|
||||
// Return a constant latency of 8ms.
|
||||
// If you make this random it could break stories.
|
||||
latencyMS: 8,
|
||||
at: new Date(),
|
||||
};
|
||||
return acc;
|
||||
},
|
||||
{} as Record<string, ProxyLatencyReport>,
|
||||
);
|
||||
}, [proxies]);
|
||||
|
||||
return { proxyLatencies, refetch: jest.fn() };
|
||||
},
|
||||
}));
|
||||
|
||||
global.scrollTo = jest.fn();
|
||||
|
||||
window.HTMLElement.prototype.scrollIntoView = jest.fn();
|
||||
// Polyfill pointer capture methods for JSDOM compatibility with Radix UI
|
||||
window.HTMLElement.prototype.hasPointerCapture = jest
|
||||
.fn()
|
||||
.mockReturnValue(false);
|
||||
window.HTMLElement.prototype.setPointerCapture = jest.fn();
|
||||
window.HTMLElement.prototype.releasePointerCapture = jest.fn();
|
||||
window.open = jest.fn();
|
||||
navigator.sendBeacon = jest.fn();
|
||||
|
||||
global.ResizeObserver = require("resize-observer-polyfill");
|
||||
|
||||
// Polyfill the getRandomValues that is used on utils/random.ts
|
||||
Object.defineProperty(global.self, "crypto", {
|
||||
value: {
|
||||
getRandomValues: crypto.randomFillSync,
|
||||
},
|
||||
});
|
||||
|
||||
// Establish API mocking before all tests through MSW.
|
||||
beforeAll(() =>
|
||||
server.listen({
|
||||
onUnhandledRequest: "warn",
|
||||
}),
|
||||
);
|
||||
|
||||
// Reset any request handlers that we may add during the tests,
|
||||
// so they don't affect other tests.
|
||||
afterEach(() => {
|
||||
cleanup();
|
||||
server.resetHandlers();
|
||||
jest.resetAllMocks();
|
||||
});
|
||||
|
||||
// Clean up after the tests are finished.
|
||||
afterAll(() => server.close());
|
||||
|
||||
// biome-ignore lint/complexity/noUselessEmptyExport: This is needed because we are compiling under `--isolatedModules`
|
||||
export {};
|
||||
+14
-3
@@ -28,10 +28,11 @@
|
||||
"storybook": "STORYBOOK=true storybook dev -p 6006",
|
||||
"storybook:build": "storybook build",
|
||||
"storybook:ci": "storybook build --test",
|
||||
"test": "vitest run --project=unit",
|
||||
"test": "vitest run --project=unit && jest",
|
||||
"test:storybook": "vitest --project=storybook",
|
||||
"test:ci": "vitest run --project=unit",
|
||||
"test:ci": "vitest run --project=unit && jest --silent",
|
||||
"test:watch": "vitest --project=unit",
|
||||
"test:watch-jest": "jest --watch",
|
||||
"stats": "STATS=true pnpm build && npx http-server ./stats -p 8081 -c-1",
|
||||
"update-emojis": "cp -rf ./node_modules/emoji-datasource-apple/img/apple/64/* ./static/emojis && cp -f ./node_modules/emoji-datasource-apple/img/apple/sheets-256/64.png ./static/emojis/spritesheet.png"
|
||||
},
|
||||
@@ -108,6 +109,7 @@
|
||||
"react-window": "1.8.11",
|
||||
"recharts": "2.15.4",
|
||||
"remark-gfm": "4.0.1",
|
||||
"resize-observer-polyfill": "1.5.1",
|
||||
"semver": "7.7.3",
|
||||
"sonner": "2.0.7",
|
||||
"streamdown": "2.5.0",
|
||||
@@ -116,6 +118,7 @@
|
||||
"tzdata": "1.0.46",
|
||||
"ua-parser-js": "1.0.41",
|
||||
"ufuzzy": "npm:@leeoniya/ufuzzy@1.0.10",
|
||||
"undici": "6.22.0",
|
||||
"unique-names-generator": "4.7.1",
|
||||
"uuid": "9.0.1",
|
||||
"websocket-ts": "2.2.1",
|
||||
@@ -135,6 +138,8 @@
|
||||
"@storybook/addon-themes": "10.3.3",
|
||||
"@storybook/addon-vitest": "10.3.3",
|
||||
"@storybook/react-vite": "10.3.3",
|
||||
"@swc/core": "1.3.38",
|
||||
"@swc/jest": "0.2.37",
|
||||
"@tailwindcss/typography": "0.5.19",
|
||||
"@testing-library/jest-dom": "6.9.1",
|
||||
"@testing-library/react": "14.3.1",
|
||||
@@ -144,6 +149,7 @@
|
||||
"@types/express": "4.17.17",
|
||||
"@types/file-saver": "2.0.7",
|
||||
"@types/humanize-duration": "3.27.4",
|
||||
"@types/jest": "29.5.14",
|
||||
"@types/lodash": "4.17.21",
|
||||
"@types/node": "20.19.25",
|
||||
"@types/novnc__novnc": "1.5.0",
|
||||
@@ -164,14 +170,18 @@
|
||||
"chromatic": "11.29.0",
|
||||
"dpdm": "3.14.0",
|
||||
"express": "4.21.2",
|
||||
"jest": "29.7.0",
|
||||
"jest-canvas-mock": "2.5.2",
|
||||
"jest-environment-jsdom": "29.5.0",
|
||||
"jest-fixed-jsdom": "0.0.11",
|
||||
"jest-location-mock": "2.0.0",
|
||||
"jest-websocket-mock": "2.5.0",
|
||||
"jest_workaround": "0.1.14",
|
||||
"jsdom": "27.2.0",
|
||||
"knip": "5.71.0",
|
||||
"msw": "2.4.8",
|
||||
"postcss": "8.5.6",
|
||||
"protobufjs": "7.5.4",
|
||||
"resize-observer-polyfill": "1.5.1",
|
||||
"rollup-plugin-visualizer": "7.0.1",
|
||||
"rxjs": "7.8.2",
|
||||
"ssh2": "1.17.0",
|
||||
@@ -214,6 +224,7 @@
|
||||
"storybook-addon-remix-react-router"
|
||||
],
|
||||
"onlyBuiltDependencies": [
|
||||
"@swc/core",
|
||||
"esbuild",
|
||||
"ssh2"
|
||||
]
|
||||
|
||||
Generated
+2316
-3
File diff suppressed because it is too large
Load Diff
@@ -1 +1 @@
|
||||
export default vi.fn();
|
||||
export default jest.fn();
|
||||
|
||||
Generated
-151
@@ -1236,7 +1236,6 @@ export const ChatCompactionThresholdKeyPrefix =
|
||||
// From codersdk/deployment.go
|
||||
export interface ChatConfig {
|
||||
readonly acquire_batch_size: number;
|
||||
readonly debug_logging_enabled: boolean;
|
||||
}
|
||||
|
||||
// From codersdk/chats.go
|
||||
@@ -1364,127 +1363,6 @@ export interface ChatCostUsersResponse {
|
||||
readonly users: readonly ChatCostUserRollup[];
|
||||
}
|
||||
|
||||
// From codersdk/chats.go
|
||||
/**
|
||||
* ChatDebugLoggingAdminSettings describes the runtime admin setting
|
||||
* that allows users to opt into chat debug logging.
|
||||
*/
|
||||
export interface ChatDebugLoggingAdminSettings {
|
||||
readonly allow_users: boolean;
|
||||
readonly forced_by_deployment: boolean;
|
||||
}
|
||||
|
||||
// From codersdk/chats.go
|
||||
/**
|
||||
* ChatDebugRun is the detailed run response including steps.
|
||||
* This type is consumed by the run-detail handler added in a later
|
||||
* PR in this stack; it is forward-declared here so that all SDK
|
||||
* types live in the same schema-layer commit.
|
||||
*/
|
||||
export interface ChatDebugRun {
|
||||
readonly id: string;
|
||||
readonly chat_id: string;
|
||||
readonly root_chat_id?: string;
|
||||
readonly parent_chat_id?: string;
|
||||
readonly model_config_id?: string;
|
||||
readonly trigger_message_id?: number;
|
||||
readonly history_tip_message_id?: number;
|
||||
readonly kind: ChatDebugRunKind;
|
||||
readonly status: ChatDebugStatus;
|
||||
readonly provider?: string;
|
||||
readonly model?: string;
|
||||
// empty interface{} type, falling back to unknown
|
||||
readonly summary: Record<string, unknown>;
|
||||
readonly started_at: string;
|
||||
readonly updated_at: string;
|
||||
readonly finished_at?: string;
|
||||
readonly steps: readonly ChatDebugStep[];
|
||||
}
|
||||
|
||||
// From codersdk/chats.go
|
||||
export type ChatDebugRunKind =
|
||||
| "chat_turn"
|
||||
| "compaction"
|
||||
| "quickgen"
|
||||
| "title_generation";
|
||||
|
||||
export const ChatDebugRunKinds: ChatDebugRunKind[] = [
|
||||
"chat_turn",
|
||||
"compaction",
|
||||
"quickgen",
|
||||
"title_generation",
|
||||
];
|
||||
|
||||
// From codersdk/chats.go
|
||||
/**
|
||||
* ChatDebugRunSummary is a lightweight run entry for list endpoints.
|
||||
*/
|
||||
export interface ChatDebugRunSummary {
|
||||
readonly id: string;
|
||||
readonly chat_id: string;
|
||||
readonly kind: ChatDebugRunKind;
|
||||
readonly status: ChatDebugStatus;
|
||||
readonly provider?: string;
|
||||
readonly model?: string;
|
||||
// empty interface{} type, falling back to unknown
|
||||
readonly summary: Record<string, unknown>;
|
||||
readonly started_at: string;
|
||||
readonly updated_at: string;
|
||||
readonly finished_at?: string;
|
||||
}
|
||||
|
||||
// From codersdk/chats.go
|
||||
export type ChatDebugStatus =
|
||||
| "completed"
|
||||
| "error"
|
||||
| "in_progress"
|
||||
| "interrupted";
|
||||
|
||||
export const ChatDebugStatuses: ChatDebugStatus[] = [
|
||||
"completed",
|
||||
"error",
|
||||
"in_progress",
|
||||
"interrupted",
|
||||
];
|
||||
|
||||
// From codersdk/chats.go
|
||||
/**
|
||||
* ChatDebugStep is a single step within a debug run.
|
||||
*/
|
||||
export interface ChatDebugStep {
|
||||
readonly id: string;
|
||||
readonly run_id: string;
|
||||
readonly chat_id: string;
|
||||
readonly step_number: number;
|
||||
readonly operation: ChatDebugStepOperation;
|
||||
readonly status: ChatDebugStatus;
|
||||
readonly history_tip_message_id?: number;
|
||||
readonly assistant_message_id?: number;
|
||||
// empty interface{} type, falling back to unknown
|
||||
readonly normalized_request: Record<string, unknown>;
|
||||
// empty interface{} type, falling back to unknown
|
||||
readonly normalized_response?: Record<string, unknown>;
|
||||
// empty interface{} type, falling back to unknown
|
||||
readonly usage?: Record<string, unknown>;
|
||||
// empty interface{} type, falling back to unknown
|
||||
readonly attempts: readonly Record<string, unknown>[];
|
||||
// empty interface{} type, falling back to unknown
|
||||
readonly error?: Record<string, unknown>;
|
||||
// empty interface{} type, falling back to unknown
|
||||
readonly metadata: Record<string, unknown>;
|
||||
readonly started_at: string;
|
||||
readonly updated_at: string;
|
||||
readonly finished_at?: string;
|
||||
}
|
||||
|
||||
// From codersdk/chats.go
|
||||
export type ChatDebugStepOperation = "generate" | "stream";
|
||||
|
||||
export const ChatDebugStepOperations: ChatDebugStepOperation[] = [
|
||||
"generate",
|
||||
"stream",
|
||||
];
|
||||
|
||||
// From codersdk/chats.go
|
||||
/**
|
||||
* ChatDesktopEnabledResponse is the response for getting the desktop setting.
|
||||
@@ -7482,15 +7360,6 @@ export interface UpdateAppearanceConfig {
|
||||
readonly announcement_banners: readonly BannerConfig[];
|
||||
}
|
||||
|
||||
// From codersdk/chats.go
|
||||
/**
|
||||
* UpdateChatDebugLoggingAllowUsersRequest is the admin request to
|
||||
* toggle whether users may opt into chat debug logging.
|
||||
*/
|
||||
export interface UpdateChatDebugLoggingAllowUsersRequest {
|
||||
readonly allow_users: boolean;
|
||||
}
|
||||
|
||||
// From codersdk/chats.go
|
||||
/**
|
||||
* UpdateChatDesktopEnabledRequest is the request to update the desktop setting.
|
||||
@@ -7794,15 +7663,6 @@ export interface UpdateUserChatCompactionThresholdRequest {
|
||||
readonly threshold_percent: number;
|
||||
}
|
||||
|
||||
// From codersdk/chats.go
|
||||
/**
|
||||
* UpdateUserChatDebugLoggingRequest is the per-user request to
|
||||
* opt into or out of chat debug logging.
|
||||
*/
|
||||
export interface UpdateUserChatDebugLoggingRequest {
|
||||
readonly debug_logging_enabled: boolean;
|
||||
}
|
||||
|
||||
// From codersdk/notifications.go
|
||||
export interface UpdateUserNotificationPreferences {
|
||||
readonly template_disabled_map: Record<string, boolean>;
|
||||
@@ -8102,17 +7962,6 @@ export interface UserChatCustomPrompt {
|
||||
readonly custom_prompt: string;
|
||||
}
|
||||
|
||||
// From codersdk/chats.go
|
||||
/**
|
||||
* UserChatDebugLoggingSettings describes whether debug logging is
|
||||
* active for the current user and whether the user may control it.
|
||||
*/
|
||||
export interface UserChatDebugLoggingSettings {
|
||||
readonly debug_logging_enabled: boolean;
|
||||
readonly user_toggle_allowed: boolean;
|
||||
readonly forced_by_deployment: boolean;
|
||||
}
|
||||
|
||||
// From codersdk/chats.go
|
||||
/**
|
||||
* UserChatProviderConfig is a summary of a provider that allows
|
||||
|
||||
@@ -37,7 +37,6 @@ const ComboboxWithHooks = ({
|
||||
optionsList?: SelectFilterOption[];
|
||||
}) => {
|
||||
const [value, setValue] = useState<string | undefined>(undefined);
|
||||
const [inputValue, setInputValue] = useState("");
|
||||
const selectedOption = optionsList.find((opt) => opt.value === value);
|
||||
|
||||
return (
|
||||
@@ -49,11 +48,7 @@ const ComboboxWithHooks = ({
|
||||
/>
|
||||
</ComboboxTrigger>
|
||||
<ComboboxContent className="w-60">
|
||||
<ComboboxInput
|
||||
placeholder="Search..."
|
||||
value={inputValue}
|
||||
onValueChange={setInputValue}
|
||||
/>
|
||||
<ComboboxInput placeholder="Search..." />
|
||||
<ComboboxList>
|
||||
{optionsList.map((option) => (
|
||||
<ComboboxItem key={option.value} value={option.value}>
|
||||
|
||||
@@ -127,7 +127,6 @@ export const ComboboxContent = ({
|
||||
};
|
||||
|
||||
export const ComboboxInput = CommandInput;
|
||||
|
||||
export const ComboboxList = CommandList;
|
||||
|
||||
export const ComboboxItem = ({
|
||||
|
||||
@@ -71,8 +71,6 @@ export const SelectFilter: FC<SelectFilterProps> = ({
|
||||
minWidth: width,
|
||||
}}
|
||||
align="end"
|
||||
// We want the backend to handle the filtering, not the client.
|
||||
shouldFilter={false}
|
||||
>
|
||||
{selectFilterSearch}
|
||||
<ComboboxList
|
||||
|
||||
@@ -1,9 +1,6 @@
|
||||
import { useMemo, useRef, useState } from "react";
|
||||
import { keepPreviousData, useQuery } from "react-query";
|
||||
import type { SelectFilterOption } from "#/components/Filter/SelectFilter";
|
||||
import { useDebouncedValue } from "#/hooks/debounce";
|
||||
|
||||
const FILTER_DEBOUNCE_MS = 300;
|
||||
|
||||
export type UseFilterMenuOptions = {
|
||||
id: string;
|
||||
@@ -28,7 +25,6 @@ export const useFilterMenu = ({
|
||||
{},
|
||||
);
|
||||
const [query, setQuery] = useState("");
|
||||
const debouncedQuery = useDebouncedValue(query, FILTER_DEBOUNCE_MS);
|
||||
const selectedOptionQuery = useQuery({
|
||||
queryKey: [id, "autocomplete", "selected", value],
|
||||
queryFn: () => {
|
||||
@@ -48,15 +44,11 @@ export const useFilterMenu = ({
|
||||
});
|
||||
const selectedOption = selectedOptionQuery.data;
|
||||
const searchOptionsQuery = useQuery({
|
||||
queryKey: [id, "autocomplete", "search", debouncedQuery],
|
||||
queryFn: () => getOptions(debouncedQuery),
|
||||
queryKey: [id, "autocomplete", "search", query],
|
||||
queryFn: () => getOptions(query),
|
||||
enabled,
|
||||
});
|
||||
const searchOptions = useMemo(() => {
|
||||
if (searchOptionsQuery.isFetching) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
const isDataLoaded =
|
||||
searchOptionsQuery.isFetched && selectedOptionQuery.isFetched;
|
||||
|
||||
@@ -85,7 +77,6 @@ export const useFilterMenu = ({
|
||||
query,
|
||||
searchOptionsQuery.data,
|
||||
searchOptionsQuery.isFetched,
|
||||
searchOptionsQuery.isFetching,
|
||||
selectedOption,
|
||||
]);
|
||||
|
||||
|
||||
@@ -97,8 +97,8 @@ export const IconField: FC<IconFieldProps> = ({
|
||||
Unfortunately, React doesn't provide an API to start warming a lazy component,
|
||||
so we just have to sneak it into the DOM, which is kind of annoying, but means
|
||||
that users shouldn't ever spend time waiting for it to load.
|
||||
- Except we don't do it when running tests, because it would make them
|
||||
slower anyway. */}
|
||||
- Except we don't do it when running tests, because Jest doesn't define
|
||||
`IntersectionObserver`, and it would make them slower anyway. */}
|
||||
{process.env.NODE_ENV !== "test" && (
|
||||
<div className="sr-only" aria-hidden="true">
|
||||
<Suspense>
|
||||
|
||||
@@ -49,7 +49,8 @@ export const ThemeProvider: FC<PropsWithChildren> = ({ children }) => {
|
||||
setPreferredColorScheme(event.matches ? "light" : "dark");
|
||||
};
|
||||
|
||||
// `addEventListener` here is a recent API that isn't mocked in tests.
|
||||
// `addEventListener` here is a recent API that only _very_ up-to-date
|
||||
// browsers support, and that isn't mocked in Jest.
|
||||
themeQuery.addEventListener?.("change", listener);
|
||||
return () => {
|
||||
themeQuery.removeEventListener?.("change", listener);
|
||||
|
||||
@@ -70,11 +70,11 @@ function setupMockClipboard(isSecure: boolean): SetupMockClipboardResult {
|
||||
// Don't need these other methods for any of the tests; read and write are
|
||||
// both synchronous and slower than the promise-based methods, so ideally
|
||||
// we won't ever need to call them in the hook logic
|
||||
addEventListener: vi.fn(),
|
||||
removeEventListener: vi.fn(),
|
||||
dispatchEvent: vi.fn(),
|
||||
read: vi.fn(),
|
||||
write: vi.fn(),
|
||||
addEventListener: jest.fn(),
|
||||
removeEventListener: jest.fn(),
|
||||
dispatchEvent: jest.fn(),
|
||||
read: jest.fn(),
|
||||
write: jest.fn(),
|
||||
};
|
||||
|
||||
return {
|
||||
@@ -145,19 +145,19 @@ describe.each(secureContextValues)("useClipboard - secure: %j", (isSecure) => {
|
||||
} = setupMockClipboard(isSecure);
|
||||
|
||||
beforeEach(() => {
|
||||
vi.useFakeTimers();
|
||||
jest.useFakeTimers();
|
||||
|
||||
// Can't use vi.spyOn here because there's no guarantee that the mock
|
||||
// Can't use jest.spyOn here because there's no guarantee that the mock
|
||||
// browser environment actually implements execCommand. Trying to spy on an
|
||||
// undefined value will throw an error
|
||||
global.document.execCommand = mockExecCommand;
|
||||
|
||||
vi.spyOn(window, "navigator", "get").mockImplementation(() => ({
|
||||
jest.spyOn(window, "navigator", "get").mockImplementation(() => ({
|
||||
...originalNavigator,
|
||||
clipboard: mockClipboard,
|
||||
}));
|
||||
|
||||
vi.spyOn(console, "error").mockImplementation((errorValue, ...rest) => {
|
||||
jest.spyOn(console, "error").mockImplementation((errorValue, ...rest) => {
|
||||
const canIgnore =
|
||||
errorValue instanceof Error &&
|
||||
errorValue.message === COPY_FAILED_MESSAGE;
|
||||
@@ -169,9 +169,9 @@ describe.each(secureContextValues)("useClipboard - secure: %j", (isSecure) => {
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.runAllTimers();
|
||||
vi.useRealTimers();
|
||||
vi.resetAllMocks();
|
||||
jest.runAllTimers();
|
||||
jest.useRealTimers();
|
||||
jest.resetAllMocks();
|
||||
global.document.execCommand = originalExecCommand;
|
||||
|
||||
// Still have to reset the mock clipboard state because the same mock values
|
||||
@@ -193,7 +193,7 @@ describe.each(secureContextValues)("useClipboard - secure: %j", (isSecure) => {
|
||||
// tests more annoying. Getting around that by waiting for all timeouts to
|
||||
// wrap up, but note that the value of showCopiedSuccess will become false
|
||||
// after runAllTimersAsync finishes
|
||||
await act(() => vi.runAllTimersAsync());
|
||||
await act(() => jest.runAllTimersAsync());
|
||||
|
||||
const clipboardText = getClipboardText();
|
||||
expect(clipboardText).toEqual(textToCopy);
|
||||
@@ -214,7 +214,7 @@ describe.each(secureContextValues)("useClipboard - secure: %j", (isSecure) => {
|
||||
|
||||
it("Should notify the user of an error using the provided callback", async () => {
|
||||
const textToCopy = "birds";
|
||||
const onError = vi.fn();
|
||||
const onError = jest.fn();
|
||||
const { result } = renderUseClipboard({ onError });
|
||||
|
||||
setSimulateFailure(true);
|
||||
@@ -223,7 +223,7 @@ describe.each(secureContextValues)("useClipboard - secure: %j", (isSecure) => {
|
||||
});
|
||||
|
||||
it("Should dispatch a new toast message to the global snackbar when errors happen while no error callback is provided to the hook", async () => {
|
||||
const toastErrorSpy = vi.spyOn(toast, "error");
|
||||
const toastErrorSpy = jest.spyOn(toast, "error");
|
||||
const textToCopy = "crow";
|
||||
const { result } = renderUseClipboard();
|
||||
|
||||
@@ -239,7 +239,7 @@ describe.each(secureContextValues)("useClipboard - secure: %j", (isSecure) => {
|
||||
// Snackbar state transitions that you might get if the hook uses the
|
||||
// default
|
||||
const textToCopy = "hamster";
|
||||
const { result } = renderUseClipboard({ onError: vi.fn() });
|
||||
const { result } = renderUseClipboard({ onError: jest.fn() });
|
||||
|
||||
setSimulateFailure(true);
|
||||
await act(() => result.current.copyToClipboard(textToCopy));
|
||||
@@ -264,7 +264,7 @@ describe.each(secureContextValues)("useClipboard - secure: %j", (isSecure) => {
|
||||
// inside of useEffect calls without having to think about dependencies too
|
||||
// much
|
||||
it("Ensures that the copyToClipboard function always maintains a stable reference across all re-renders", async () => {
|
||||
const initialOnError = vi.fn();
|
||||
const initialOnError = jest.fn();
|
||||
const { result, rerender } = renderUseClipboard({
|
||||
onError: initialOnError,
|
||||
clearErrorOnSuccess: true,
|
||||
@@ -278,7 +278,7 @@ describe.each(secureContextValues)("useClipboard - secure: %j", (isSecure) => {
|
||||
|
||||
// Re-render with new onError prop and then swap back to simplify
|
||||
// testing
|
||||
rerender({ onError: vi.fn() });
|
||||
rerender({ onError: jest.fn() });
|
||||
expect(result.current.copyToClipboard).toBe(initialCopy);
|
||||
rerender({ onError: initialOnError });
|
||||
|
||||
@@ -310,13 +310,13 @@ describe.each(secureContextValues)("useClipboard - secure: %j", (isSecure) => {
|
||||
});
|
||||
|
||||
it("Always uses the most up-to-date onError prop", async () => {
|
||||
const initialOnError = vi.fn();
|
||||
const initialOnError = jest.fn();
|
||||
const { result, rerender } = renderUseClipboard({
|
||||
onError: initialOnError,
|
||||
});
|
||||
setSimulateFailure(true);
|
||||
|
||||
const secondOnError = vi.fn();
|
||||
const secondOnError = jest.fn();
|
||||
rerender({ onError: secondOnError });
|
||||
await act(() => result.current.copyToClipboard("dummy-text"));
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import type { Meta, StoryObj } from "@storybook/react-vite";
|
||||
import { screen, spyOn, userEvent, within } from "storybook/test";
|
||||
import { expect, screen, spyOn, userEvent, within } from "storybook/test";
|
||||
import { API } from "#/api/api";
|
||||
import { getPreferredProxy } from "#/contexts/ProxyContext";
|
||||
import { chromatic } from "#/testHelpers/chromatic";
|
||||
@@ -57,6 +57,13 @@ export const HasError: Story = {
|
||||
agent: undefined,
|
||||
},
|
||||
},
|
||||
play: async ({ canvasElement }) => {
|
||||
const canvas = within(canvasElement);
|
||||
const moreActionsButton = canvas.getByRole("button", {
|
||||
name: "Dev Container actions",
|
||||
});
|
||||
expect(moreActionsButton).toBeVisible();
|
||||
},
|
||||
};
|
||||
|
||||
export const NoPorts: Story = {};
|
||||
@@ -123,6 +130,13 @@ export const NoContainerOrSubAgent: Story = {
|
||||
},
|
||||
subAgents: [],
|
||||
},
|
||||
play: async ({ canvasElement }) => {
|
||||
const canvas = within(canvasElement);
|
||||
const moreActionsButton = canvas.getByRole("button", {
|
||||
name: "Dev Container actions",
|
||||
});
|
||||
expect(moreActionsButton).toBeVisible();
|
||||
},
|
||||
};
|
||||
|
||||
export const NoContainerOrAgentOrName: Story = {
|
||||
|
||||
@@ -274,7 +274,7 @@ export const AgentDevcontainerCard: FC<AgentDevcontainerCardProps> = ({
|
||||
/>
|
||||
)}
|
||||
|
||||
{showDevcontainerControls && (
|
||||
{!isTransitioning && (
|
||||
<AgentDevcontainerMoreActions
|
||||
deleteDevContainer={deleteDevcontainerMutation.mutate}
|
||||
/>
|
||||
|
||||
@@ -5,7 +5,6 @@ import {
|
||||
EllipsisIcon,
|
||||
PlayIcon,
|
||||
SquareCheckBigIcon,
|
||||
TriangleAlertIcon,
|
||||
} from "lucide-react";
|
||||
import {
|
||||
type FC,
|
||||
@@ -26,7 +25,6 @@ import type {
|
||||
} from "#/api/typesGenerated";
|
||||
import { CheckIcon } from "#/components/AnimatedIcons/Check";
|
||||
import { ChevronDownIcon } from "#/components/AnimatedIcons/ChevronDown";
|
||||
import { Badge } from "#/components/Badge/Badge";
|
||||
import { Button } from "#/components/Button/Button";
|
||||
import {
|
||||
DropdownMenu,
|
||||
@@ -47,8 +45,6 @@ import { useKebabMenu } from "#/components/Tabs/utils/useKebabMenu";
|
||||
import { useProxy } from "#/contexts/ProxyContext";
|
||||
import { useClipboard } from "#/hooks/useClipboard";
|
||||
import { useFeatureVisibility } from "#/modules/dashboard/useFeatureVisibility";
|
||||
import { getAgentHealthIssues } from "#/modules/workspaces/health";
|
||||
import { AgentAlert } from "#/pages/WorkspacePage/AgentAlert";
|
||||
import { AppStatuses } from "#/pages/WorkspacePage/AppStatuses";
|
||||
import { cn } from "#/utils/cn";
|
||||
import { AgentApps, organizeAgentApps } from "./AgentApps/AgentApps";
|
||||
@@ -139,12 +135,9 @@ export const AgentRow: FC<AgentRowProps> = ({
|
||||
const showVSCode = hasVSCodeApp && !browser_only;
|
||||
|
||||
const hasStartupFeatures = Boolean(agent.logs_length);
|
||||
const healthIssues = getAgentHealthIssues(agent);
|
||||
const hasAgentIssues = healthIssues.length > 0;
|
||||
const { proxy } = useProxy();
|
||||
const [showLogs, setShowLogs] = useState(
|
||||
(["starting", "start_timeout"].includes(agent.lifecycle_state) ||
|
||||
hasAgentIssues) &&
|
||||
["starting", "start_timeout"].includes(agent.lifecycle_state) &&
|
||||
hasStartupFeatures,
|
||||
);
|
||||
const agentLogs = useAgentLogs({ agentId: agent.id, enabled: showLogs });
|
||||
@@ -153,11 +146,8 @@ export const AgentRow: FC<AgentRowProps> = ({
|
||||
const [bottomOfLogs, setBottomOfLogs] = useState(true);
|
||||
|
||||
useEffect(() => {
|
||||
setShowLogs(
|
||||
(agent.lifecycle_state !== "ready" || hasAgentIssues) &&
|
||||
hasStartupFeatures,
|
||||
);
|
||||
}, [agent.lifecycle_state, hasAgentIssues, hasStartupFeatures]);
|
||||
setShowLogs(agent.lifecycle_state !== "ready" && hasStartupFeatures);
|
||||
}, [agent.lifecycle_state, hasStartupFeatures]);
|
||||
|
||||
// This is a layout effect to remove flicker when we're scrolling to the bottom.
|
||||
// biome-ignore lint/correctness/useExhaustiveDependencies: consider refactoring
|
||||
@@ -218,6 +208,7 @@ export const AgentRow: FC<AgentRowProps> = ({
|
||||
agent,
|
||||
Boolean(hasDevcontainerErrors || shouldShowWildcardWarning),
|
||||
);
|
||||
|
||||
const [selectedLogTab, setSelectedLogTab] = useState("all");
|
||||
const sourceLogTabs = agent.log_sources
|
||||
.filter((logSource) => {
|
||||
@@ -468,37 +459,20 @@ export const AgentRow: FC<AgentRowProps> = ({
|
||||
<AgentMetadata initialMetadata={initialMetadata} agent={agent} />
|
||||
</div>
|
||||
|
||||
<section className="border-0 border-t border-solid border-border">
|
||||
<div className="px-4 py-2 relative">
|
||||
<Button
|
||||
variant="subtle"
|
||||
onClick={() => setShowLogs((v) => !v)}
|
||||
className="after:content-[''] after:absolute after:inset-0"
|
||||
>
|
||||
<ChevronDownIcon open={showLogs} />
|
||||
<span>Logs</span>
|
||||
{healthIssues.length > 0 && (
|
||||
<Badge variant="warning" size="xs" className="ml-1.5">
|
||||
<TriangleAlertIcon />
|
||||
<span>{healthIssues.length}</span>
|
||||
</Badge>
|
||||
)}
|
||||
</Button>
|
||||
</div>
|
||||
<Collapse in={showLogs || (!hasStartupFeatures && hasAgentIssues)}>
|
||||
<div className={cn("px-4", hasStartupFeatures ? "pb-4" : "py-4")}>
|
||||
{healthIssues.length > 0 && (
|
||||
<div className="mb-4 flex flex-col gap-3">
|
||||
{healthIssues.map((issue) => (
|
||||
<AgentAlert
|
||||
key={`${issue.title}-${issue.detail}`}
|
||||
{...issue}
|
||||
troubleshootingURL={agent.troubleshooting_url}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
{hasStartupFeatures && hasAnyLogs && (
|
||||
{hasStartupFeatures && (
|
||||
<section className="border-0 border-t border-solid border-border">
|
||||
<div className="px-4 py-2 relative">
|
||||
<Button
|
||||
variant="subtle"
|
||||
onClick={() => setShowLogs((v) => !v)}
|
||||
className="after:content-[''] after:absolute after:inset-0"
|
||||
>
|
||||
<ChevronDownIcon open={showLogs} />
|
||||
<span>Logs</span>
|
||||
</Button>
|
||||
</div>
|
||||
<Collapse in={showLogs}>
|
||||
<div className="px-4 pb-4">
|
||||
<div className="border border-solid rounded-md overflow-clip">
|
||||
<Tabs
|
||||
className="-mx-px -mt-px"
|
||||
@@ -611,10 +585,10 @@ export const AgentRow: FC<AgentRowProps> = ({
|
||||
</TabsContent>
|
||||
</Tabs>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</Collapse>
|
||||
</section>
|
||||
</div>
|
||||
</Collapse>
|
||||
</section>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import type { Meta, StoryObj } from "@storybook/react-vite";
|
||||
import { expect, fn, userEvent, waitFor, within } from "storybook/test";
|
||||
import type { Mock } from "vitest";
|
||||
import { agentLogsKey } from "#/api/queries/workspaces";
|
||||
import type { WorkspaceAgentLog } from "#/api/typesGenerated";
|
||||
import { MockWorkspaceAgent } from "#/testHelpers/entities";
|
||||
@@ -42,7 +41,7 @@ export const ClickOnDownload: Story = {
|
||||
`${MockWorkspaceAgent.name}-logs.txt`,
|
||||
),
|
||||
);
|
||||
const blob: Blob = (args.download as Mock).mock.calls[0][0];
|
||||
const blob: Blob = (args.download as jest.Mock).mock.calls[0][0];
|
||||
await expect(blob.type).toEqual("text/plain");
|
||||
},
|
||||
};
|
||||
|
||||
+10
-10
@@ -84,12 +84,12 @@ describe("useAgentContainers", () => {
|
||||
});
|
||||
|
||||
it("handles parsing errors from WebSocket", async () => {
|
||||
const toastErrorSpy = vi.spyOn(toast, "error");
|
||||
const watchAgentContainersSpy = vi.spyOn(API, "watchAgentContainers");
|
||||
const toastErrorSpy = jest.spyOn(toast, "error");
|
||||
const watchAgentContainersSpy = jest.spyOn(API, "watchAgentContainers");
|
||||
|
||||
const mockSocket = {
|
||||
addEventListener: vi.fn(),
|
||||
close: vi.fn(),
|
||||
addEventListener: jest.fn(),
|
||||
close: jest.fn(),
|
||||
};
|
||||
watchAgentContainersSpy.mockReturnValue(
|
||||
mockSocket as unknown as OneWayWebSocket<WorkspaceAgentListContainersResponse>,
|
||||
@@ -146,12 +146,12 @@ describe("useAgentContainers", () => {
|
||||
});
|
||||
|
||||
it("handles WebSocket errors", async () => {
|
||||
const toastErrorSpy = vi.spyOn(toast, "error");
|
||||
const watchAgentContainersSpy = vi.spyOn(API, "watchAgentContainers");
|
||||
const toastErrorSpy = jest.spyOn(toast, "error");
|
||||
const watchAgentContainersSpy = jest.spyOn(API, "watchAgentContainers");
|
||||
|
||||
const mockSocket = {
|
||||
addEventListener: vi.fn(),
|
||||
close: vi.fn(),
|
||||
addEventListener: jest.fn(),
|
||||
close: jest.fn(),
|
||||
};
|
||||
watchAgentContainersSpy.mockReturnValue(
|
||||
mockSocket as unknown as OneWayWebSocket<WorkspaceAgentListContainersResponse>,
|
||||
@@ -204,7 +204,7 @@ describe("useAgentContainers", () => {
|
||||
});
|
||||
|
||||
it("does not establish WebSocket connection when agent is not connected", () => {
|
||||
const watchAgentContainersSpy = vi.spyOn(API, "watchAgentContainers");
|
||||
const watchAgentContainersSpy = jest.spyOn(API, "watchAgentContainers");
|
||||
|
||||
const disconnectedAgent = {
|
||||
...MockWorkspaceAgent,
|
||||
@@ -222,7 +222,7 @@ describe("useAgentContainers", () => {
|
||||
});
|
||||
|
||||
it("does not establish WebSocket connection when dev container feature is not enabled", async () => {
|
||||
const watchAgentContainersSpy = vi.spyOn(API, "watchAgentContainers");
|
||||
const watchAgentContainersSpy = jest.spyOn(API, "watchAgentContainers");
|
||||
|
||||
server.use(
|
||||
http.get(
|
||||
+7
-8
@@ -1,7 +1,6 @@
|
||||
import { renderHook, waitFor } from "@testing-library/react";
|
||||
import { act } from "react";
|
||||
import { toast } from "sonner";
|
||||
import type { MockInstance } from "vitest";
|
||||
import * as apiModule from "#/api/api";
|
||||
import type { WorkspaceAgentLog } from "#/api/typesGenerated";
|
||||
import { MockWorkspaceAgent } from "#/testHelpers/entities";
|
||||
@@ -46,7 +45,7 @@ type MountHookOptions = Readonly<{
|
||||
type MountHookResult = Readonly<{
|
||||
serverResult: ServerResult;
|
||||
rerender: (props: { agentId: string; enabled: boolean }) => void;
|
||||
toastError: MockInstance;
|
||||
toastError: jest.SpyInstance;
|
||||
|
||||
// Note: the `current` property is only "halfway" readonly; the value is
|
||||
// readonly, but the key is still mutable
|
||||
@@ -57,8 +56,9 @@ function mountHook(options: MountHookOptions): MountHookResult {
|
||||
const { initialAgentId, enabled = true } = options;
|
||||
const serverResult: ServerResult = { current: undefined };
|
||||
|
||||
vi.spyOn(apiModule, "watchWorkspaceAgentLogs").mockImplementation(
|
||||
(agentId, params) => {
|
||||
jest
|
||||
.spyOn(apiModule, "watchWorkspaceAgentLogs")
|
||||
.mockImplementation((agentId, params) => {
|
||||
return new OneWayWebSocket({
|
||||
apiRoute: `/api/v2/workspaceagents/${agentId}/logs`,
|
||||
searchParams: new URLSearchParams({
|
||||
@@ -71,11 +71,10 @@ function mountHook(options: MountHookOptions): MountHookResult {
|
||||
return mockSocket;
|
||||
},
|
||||
});
|
||||
},
|
||||
);
|
||||
});
|
||||
|
||||
void vi.spyOn(console, "error").mockImplementation(() => {});
|
||||
const toastError = vi.spyOn(toast, "error");
|
||||
void jest.spyOn(console, "error").mockImplementation(() => {});
|
||||
const toastError = jest.spyOn(toast, "error");
|
||||
|
||||
const { result: hookResult, rerender } = renderHook(
|
||||
(props) => useAgentLogs(props),
|
||||
+5
-5
@@ -67,10 +67,10 @@ const mockRequiredParameter = createMockParameter({
|
||||
});
|
||||
|
||||
describe("DynamicParameter", () => {
|
||||
const mockOnChange = vi.fn();
|
||||
const mockOnChange = jest.fn();
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
describe("Input Parameter", () => {
|
||||
@@ -800,7 +800,7 @@ describe("DynamicParameter", () => {
|
||||
});
|
||||
|
||||
it("calls onChange when numeric value changes (debounced)", () => {
|
||||
vi.useFakeTimers();
|
||||
jest.useFakeTimers();
|
||||
render(
|
||||
<DynamicParameter
|
||||
parameter={mockNumberInputParameter}
|
||||
@@ -813,11 +813,11 @@ describe("DynamicParameter", () => {
|
||||
fireEvent.change(input, { target: { value: "7" } });
|
||||
|
||||
act(() => {
|
||||
vi.runAllTimers();
|
||||
jest.runAllTimers();
|
||||
});
|
||||
|
||||
expect(mockOnChange).toHaveBeenCalledWith("7");
|
||||
vi.useRealTimers();
|
||||
jest.useRealTimers();
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import type { Meta, StoryObj } from "@storybook/react-vite";
|
||||
import { expect, fn, userEvent, waitFor, within } from "storybook/test";
|
||||
import type { Mock } from "vitest";
|
||||
import { agentLogsKey, buildLogsKey } from "#/api/queries/workspaces";
|
||||
import { MockWorkspace, MockWorkspaceAgent } from "#/testHelpers/entities";
|
||||
import { withDesktopViewport } from "#/testHelpers/storybook";
|
||||
@@ -62,7 +61,7 @@ export const DownloadLogs: Story = {
|
||||
`${MockWorkspace.name}-logs.zip`,
|
||||
),
|
||||
);
|
||||
const blob: Blob = (args.download as Mock).mock.calls[0][0];
|
||||
const blob: Blob = (args.download as jest.Mock).mock.calls[0][0];
|
||||
await expect(blob.type).toEqual("application/zip");
|
||||
},
|
||||
};
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user