Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| d0e9e4595e | |||
| 3fd76a1bc9 | |||
| ce6e383f5c | |||
| 9bea4e098d | |||
| a0519fa1e9 |
@@ -9,23 +9,24 @@ import (
|
||||
|
||||
// Returns cost in micros -- millionths of a dollar, rounded up to the next
|
||||
// whole microdollar.
|
||||
// Returns nil when pricing is not configured or when all priced usage fields
|
||||
// are nil, allowing callers to distinguish "zero cost" from "unpriced".
|
||||
// Returns valid=false when pricing is not configured or when all priced usage
|
||||
// fields are nil, allowing callers to distinguish "zero cost" from
|
||||
// "unpriced".
|
||||
func CalculateTotalCostMicros(
|
||||
usage codersdk.ChatMessageUsage,
|
||||
cost *codersdk.ModelCostConfig,
|
||||
) *int64 {
|
||||
) (micros int64, valid bool) {
|
||||
if cost == nil {
|
||||
return nil
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// A cost config with no prices set means pricing is effectively
|
||||
// unconfigured — return nil (unpriced) rather than zero.
|
||||
// unconfigured — return valid=false (unpriced) rather than zero.
|
||||
if cost.InputPricePerMillionTokens == nil &&
|
||||
cost.OutputPricePerMillionTokens == nil &&
|
||||
cost.CacheReadPricePerMillionTokens == nil &&
|
||||
cost.CacheWritePricePerMillionTokens == nil {
|
||||
return nil
|
||||
return 0, false
|
||||
}
|
||||
|
||||
if usage.InputTokens == nil &&
|
||||
@@ -33,7 +34,7 @@ func CalculateTotalCostMicros(
|
||||
usage.ReasoningTokens == nil &&
|
||||
usage.CacheCreationTokens == nil &&
|
||||
usage.CacheReadTokens == nil {
|
||||
return nil
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// OutputTokens already includes reasoning tokens per provider
|
||||
@@ -41,14 +42,15 @@ func CalculateTotalCostMicros(
|
||||
// reasoning_tokens). Adding ReasoningTokens here would
|
||||
// double-count.
|
||||
|
||||
// Preserve nil when usage exists only in categories without configured
|
||||
// pricing, so callers can distinguish "unpriced" from "priced at zero".
|
||||
// Preserve valid=false when usage exists only in categories without
|
||||
// configured pricing, so callers can distinguish "unpriced" from
|
||||
// "priced at zero".
|
||||
hasMatchingPrice := (usage.InputTokens != nil && cost.InputPricePerMillionTokens != nil) ||
|
||||
(usage.OutputTokens != nil && cost.OutputPricePerMillionTokens != nil) ||
|
||||
(usage.CacheReadTokens != nil && cost.CacheReadPricePerMillionTokens != nil) ||
|
||||
(usage.CacheCreationTokens != nil && cost.CacheWritePricePerMillionTokens != nil)
|
||||
if !hasMatchingPrice {
|
||||
return nil
|
||||
return 0, false
|
||||
}
|
||||
|
||||
inputMicros := calcCost(usage.InputTokens, cost.InputPricePerMillionTokens)
|
||||
@@ -60,8 +62,8 @@ func CalculateTotalCostMicros(
|
||||
Add(outputMicros).
|
||||
Add(cacheReadMicros).
|
||||
Add(cacheWriteMicros)
|
||||
rounded := total.Ceil().IntPart()
|
||||
return &rounded
|
||||
|
||||
return total.Ceil().IntPart(), true
|
||||
}
|
||||
|
||||
// calcCost returns the cost in fractional microdollars (millionths of a USD)
|
||||
|
||||
@@ -15,19 +15,21 @@ func TestCalculateTotalCostMicros(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
usage codersdk.ChatMessageUsage
|
||||
cost *codersdk.ModelCostConfig
|
||||
want *int64
|
||||
name string
|
||||
usage codersdk.ChatMessageUsage
|
||||
cost *codersdk.ModelCostConfig
|
||||
wantMicros int64
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "nil cost returns nil",
|
||||
usage: codersdk.ChatMessageUsage{InputTokens: ptr.Ref[int64](1000)},
|
||||
cost: nil,
|
||||
want: nil,
|
||||
name: "nil cost returns unpriced",
|
||||
usage: codersdk.ChatMessageUsage{InputTokens: ptr.Ref[int64](1000)},
|
||||
cost: nil,
|
||||
wantMicros: 0,
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "all priced usage fields nil returns nil",
|
||||
name: "all priced usage fields nil returns unpriced",
|
||||
usage: codersdk.ChatMessageUsage{
|
||||
TotalTokens: ptr.Ref[int64](1234),
|
||||
ContextLimit: ptr.Ref[int64](8192),
|
||||
@@ -35,31 +37,29 @@ func TestCalculateTotalCostMicros(t *testing.T) {
|
||||
cost: &codersdk.ModelCostConfig{
|
||||
InputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("3")),
|
||||
},
|
||||
want: nil,
|
||||
wantMicros: 0,
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "sub-micro total rounds up to 1",
|
||||
usage: codersdk.ChatMessageUsage{InputTokens: ptr.Ref[int64](1)},
|
||||
cost: &codersdk.ModelCostConfig{
|
||||
InputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("0.01")),
|
||||
},
|
||||
want: ptr.Ref[int64](1),
|
||||
name: "sub-micro total rounds up to 1",
|
||||
usage: codersdk.ChatMessageUsage{InputTokens: ptr.Ref[int64](1)},
|
||||
cost: &codersdk.ModelCostConfig{InputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("0.01"))},
|
||||
wantMicros: 1,
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "simple input only",
|
||||
usage: codersdk.ChatMessageUsage{InputTokens: ptr.Ref[int64](1000)},
|
||||
cost: &codersdk.ModelCostConfig{
|
||||
InputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("3")),
|
||||
},
|
||||
want: ptr.Ref[int64](3000),
|
||||
name: "simple input only",
|
||||
usage: codersdk.ChatMessageUsage{InputTokens: ptr.Ref[int64](1000)},
|
||||
cost: &codersdk.ModelCostConfig{InputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("3"))},
|
||||
wantMicros: 3000,
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "simple output only",
|
||||
usage: codersdk.ChatMessageUsage{OutputTokens: ptr.Ref[int64](500)},
|
||||
cost: &codersdk.ModelCostConfig{
|
||||
OutputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("15")),
|
||||
},
|
||||
want: ptr.Ref[int64](7500),
|
||||
name: "simple output only",
|
||||
usage: codersdk.ChatMessageUsage{OutputTokens: ptr.Ref[int64](500)},
|
||||
cost: &codersdk.ModelCostConfig{OutputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("15"))},
|
||||
wantMicros: 7500,
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "reasoning tokens included in output total",
|
||||
@@ -67,26 +67,23 @@ func TestCalculateTotalCostMicros(t *testing.T) {
|
||||
OutputTokens: ptr.Ref[int64](500),
|
||||
ReasoningTokens: ptr.Ref[int64](200),
|
||||
},
|
||||
cost: &codersdk.ModelCostConfig{
|
||||
OutputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("15")),
|
||||
},
|
||||
want: ptr.Ref[int64](7500),
|
||||
cost: &codersdk.ModelCostConfig{OutputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("15"))},
|
||||
wantMicros: 7500,
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "cache read tokens",
|
||||
usage: codersdk.ChatMessageUsage{CacheReadTokens: ptr.Ref[int64](10000)},
|
||||
cost: &codersdk.ModelCostConfig{
|
||||
CacheReadPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("0.3")),
|
||||
},
|
||||
want: ptr.Ref[int64](3000),
|
||||
name: "cache read tokens",
|
||||
usage: codersdk.ChatMessageUsage{CacheReadTokens: ptr.Ref[int64](10000)},
|
||||
cost: &codersdk.ModelCostConfig{CacheReadPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("0.3"))},
|
||||
wantMicros: 3000,
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "cache creation tokens",
|
||||
usage: codersdk.ChatMessageUsage{CacheCreationTokens: ptr.Ref[int64](5000)},
|
||||
cost: &codersdk.ModelCostConfig{
|
||||
CacheWritePricePerMillionTokens: ptr.Ref(decimal.RequireFromString("3.75")),
|
||||
},
|
||||
want: ptr.Ref[int64](18750),
|
||||
name: "cache creation tokens",
|
||||
usage: codersdk.ChatMessageUsage{CacheCreationTokens: ptr.Ref[int64](5000)},
|
||||
cost: &codersdk.ModelCostConfig{CacheWritePricePerMillionTokens: ptr.Ref(decimal.RequireFromString("3.75"))},
|
||||
wantMicros: 18750,
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "full mixed usage totals all components exactly",
|
||||
@@ -105,7 +102,8 @@ func TestCalculateTotalCostMicros(t *testing.T) {
|
||||
CacheReadPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("0.7")),
|
||||
CacheWritePricePerMillionTokens: ptr.Ref(decimal.RequireFromString("7.89")),
|
||||
},
|
||||
want: ptr.Ref[int64](2005),
|
||||
wantMicros: 2005,
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "partial pricing only input contributes",
|
||||
@@ -116,32 +114,30 @@ func TestCalculateTotalCostMicros(t *testing.T) {
|
||||
CacheReadTokens: ptr.Ref[int64](500),
|
||||
CacheCreationTokens: ptr.Ref[int64](250),
|
||||
},
|
||||
cost: &codersdk.ModelCostConfig{
|
||||
InputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("2.5")),
|
||||
},
|
||||
want: ptr.Ref[int64](3085),
|
||||
cost: &codersdk.ModelCostConfig{InputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("2.5"))},
|
||||
wantMicros: 3085,
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "zero tokens with pricing returns zero pointer",
|
||||
usage: codersdk.ChatMessageUsage{InputTokens: ptr.Ref[int64](0)},
|
||||
cost: &codersdk.ModelCostConfig{
|
||||
InputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("3")),
|
||||
},
|
||||
want: ptr.Ref[int64](0),
|
||||
name: "zero tokens with pricing returns zero cost",
|
||||
usage: codersdk.ChatMessageUsage{InputTokens: ptr.Ref[int64](0)},
|
||||
cost: &codersdk.ModelCostConfig{InputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("3"))},
|
||||
wantMicros: 0,
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "usage only in unpriced categories returns nil",
|
||||
usage: codersdk.ChatMessageUsage{InputTokens: ptr.Ref[int64](1000)},
|
||||
cost: &codersdk.ModelCostConfig{
|
||||
OutputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("15")),
|
||||
},
|
||||
want: nil,
|
||||
name: "usage only in unpriced categories returns unpriced",
|
||||
usage: codersdk.ChatMessageUsage{InputTokens: ptr.Ref[int64](1000)},
|
||||
cost: &codersdk.ModelCostConfig{OutputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("15"))},
|
||||
wantMicros: 0,
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "non nil usage with empty cost config returns nil",
|
||||
usage: codersdk.ChatMessageUsage{InputTokens: ptr.Ref[int64](42)},
|
||||
cost: &codersdk.ModelCostConfig{},
|
||||
want: nil,
|
||||
name: "non nil usage with empty cost config returns unpriced",
|
||||
usage: codersdk.ChatMessageUsage{InputTokens: ptr.Ref[int64](42)},
|
||||
cost: &codersdk.ModelCostConfig{},
|
||||
wantMicros: 0,
|
||||
wantValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -150,14 +146,10 @@ func TestCalculateTotalCostMicros(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got := chatcost.CalculateTotalCostMicros(tt.usage, tt.cost)
|
||||
micros, valid := chatcost.CalculateTotalCostMicros(tt.usage, tt.cost)
|
||||
|
||||
if tt.want == nil {
|
||||
require.Nil(t, got)
|
||||
} else {
|
||||
require.NotNil(t, got)
|
||||
require.Equal(t, *tt.want, *got)
|
||||
}
|
||||
require.Equal(t, tt.wantValid, valid)
|
||||
require.Equal(t, tt.wantMicros, micros)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
+23
-10
@@ -298,6 +298,7 @@ func (p *Server) CreateChat(ctx context.Context, opts CreateOptions) (database.C
|
||||
ContextLimit: sql.NullInt64{},
|
||||
Compressed: sql.NullBool{},
|
||||
TotalCostMicros: sql.NullInt64{},
|
||||
CostValid: sql.NullBool{},
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("insert system message: %w", err)
|
||||
@@ -327,6 +328,7 @@ func (p *Server) CreateChat(ctx context.Context, opts CreateOptions) (database.C
|
||||
CacheReadTokens: sql.NullInt64{},
|
||||
ContextLimit: sql.NullInt64{},
|
||||
TotalCostMicros: sql.NullInt64{},
|
||||
CostValid: sql.NullBool{},
|
||||
Compressed: sql.NullBool{},
|
||||
})
|
||||
if err != nil {
|
||||
@@ -917,6 +919,7 @@ func insertUserMessageAndSetPending(
|
||||
CacheReadTokens: sql.NullInt64{},
|
||||
ContextLimit: sql.NullInt64{},
|
||||
TotalCostMicros: sql.NullInt64{},
|
||||
CostValid: sql.NullBool{},
|
||||
Compressed: sql.NullBool{},
|
||||
})
|
||||
if err != nil {
|
||||
@@ -1977,6 +1980,7 @@ func (p *Server) processChat(ctx context.Context, chat database.Chat) {
|
||||
CacheReadTokens: sql.NullInt64{},
|
||||
ContextLimit: sql.NullInt64{},
|
||||
TotalCostMicros: sql.NullInt64{},
|
||||
CostValid: sql.NullBool{},
|
||||
Compressed: sql.NullBool{},
|
||||
})
|
||||
if insertErr != nil {
|
||||
@@ -2383,7 +2387,8 @@ func (p *Server) runChat(
|
||||
}
|
||||
}
|
||||
|
||||
totalCostMicros := chatcost.CalculateTotalCostMicros(usageForCost, callConfig.Cost)
|
||||
totalCostMicros, costValid := chatcost.CalculateTotalCostMicros(usageForCost, callConfig.Cost)
|
||||
totalCostMicrosValue, costValidValue := persistedMessageCost(totalCostMicros, costValid)
|
||||
|
||||
assistantMessage, insertErr := tx.InsertChatMessage(persistCtx, database.InsertChatMessageParams{
|
||||
ChatID: chat.ID,
|
||||
@@ -2407,11 +2412,12 @@ func (p *Server) runChat(
|
||||
CacheReadTokens: usageNullInt64(step.Usage.CacheReadTokens, hasUsage),
|
||||
ContextLimit: step.ContextLimit,
|
||||
Compressed: sql.NullBool{},
|
||||
// TotalCostMicros is nullable: NULL means "unpriced"
|
||||
// (pricing config was missing or no priced token
|
||||
// breakdown available), while 0 means "priced at
|
||||
// zero cost" (e.g., a free model).
|
||||
TotalCostMicros: usageNullInt64Ptr(totalCostMicros),
|
||||
// cost_valid=true means priced (including zero-cost),
|
||||
// false means unpriced. Keep total_cost_micros NULL for
|
||||
// unpriced writes so older readers still infer pricing
|
||||
// correctly during mixed-version rollout.
|
||||
TotalCostMicros: totalCostMicrosValue,
|
||||
CostValid: costValidValue,
|
||||
})
|
||||
if insertErr != nil {
|
||||
return xerrors.Errorf("insert assistant message: %w", insertErr)
|
||||
@@ -2442,6 +2448,7 @@ func (p *Server) runChat(
|
||||
CacheReadTokens: sql.NullInt64{},
|
||||
ContextLimit: sql.NullInt64{},
|
||||
TotalCostMicros: sql.NullInt64{},
|
||||
CostValid: sql.NullBool{},
|
||||
Compressed: sql.NullBool{},
|
||||
})
|
||||
if insertErr != nil {
|
||||
@@ -2817,6 +2824,7 @@ func (p *Server) persistChatContextSummary(
|
||||
CacheReadTokens: sql.NullInt64{},
|
||||
ContextLimit: sql.NullInt64{},
|
||||
TotalCostMicros: sql.NullInt64{},
|
||||
CostValid: sql.NullBool{},
|
||||
})
|
||||
if txErr != nil {
|
||||
return xerrors.Errorf("insert hidden summary message: %w", txErr)
|
||||
@@ -2842,6 +2850,7 @@ func (p *Server) persistChatContextSummary(
|
||||
CacheReadTokens: sql.NullInt64{},
|
||||
ContextLimit: sql.NullInt64{},
|
||||
TotalCostMicros: sql.NullInt64{},
|
||||
CostValid: sql.NullBool{},
|
||||
})
|
||||
if txErr != nil {
|
||||
return xerrors.Errorf("insert summary tool call message: %w", txErr)
|
||||
@@ -2868,6 +2877,7 @@ func (p *Server) persistChatContextSummary(
|
||||
CacheReadTokens: sql.NullInt64{},
|
||||
ContextLimit: sql.NullInt64{},
|
||||
TotalCostMicros: sql.NullInt64{},
|
||||
CostValid: sql.NullBool{},
|
||||
})
|
||||
if txErr != nil {
|
||||
return xerrors.Errorf("insert summary tool result message: %w", txErr)
|
||||
@@ -2995,11 +3005,14 @@ func usageNullInt64(value int64, valid bool) sql.NullInt64 {
|
||||
}
|
||||
}
|
||||
|
||||
func usageNullInt64Ptr(v *int64) sql.NullInt64 {
|
||||
if v == nil {
|
||||
return sql.NullInt64{}
|
||||
//nolint:revive // Boolean controls SQL NULL validity.
|
||||
func persistedMessageCost(totalCostMicros int64, costValid bool) (sql.NullInt64, sql.NullBool) {
|
||||
costValidValue := sql.NullBool{Bool: costValid, Valid: true}
|
||||
if !costValid {
|
||||
return sql.NullInt64{}, costValidValue
|
||||
}
|
||||
return sql.NullInt64{Int64: *v, Valid: true}
|
||||
|
||||
return sql.NullInt64{Int64: totalCostMicros, Valid: true}, costValidValue
|
||||
}
|
||||
|
||||
func refreshChatWorkspaceSnapshot(
|
||||
|
||||
@@ -588,6 +588,8 @@ func TestEditMessageRejectsNonUserMessage(t *testing.T) {
|
||||
CacheReadTokens: sql.NullInt64{},
|
||||
ContextLimit: sql.NullInt64{},
|
||||
Compressed: sql.NullBool{},
|
||||
TotalCostMicros: sql.NullInt64{},
|
||||
CostValid: sql.NullBool{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -936,6 +938,8 @@ func TestSubscribeAfterMessageID(t *testing.T) {
|
||||
CacheReadTokens: sql.NullInt64{},
|
||||
ContextLimit: sql.NullInt64{},
|
||||
Compressed: sql.NullBool{},
|
||||
TotalCostMicros: sql.NullInt64{},
|
||||
CostValid: sql.NullBool{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -959,6 +963,8 @@ func TestSubscribeAfterMessageID(t *testing.T) {
|
||||
CacheReadTokens: sql.NullInt64{},
|
||||
ContextLimit: sql.NullInt64{},
|
||||
Compressed: sql.NullBool{},
|
||||
TotalCostMicros: sql.NullInt64{},
|
||||
CostValid: sql.NullBool{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -1405,6 +1411,82 @@ func setOpenAIProviderBaseURL(
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestSuccessfulChatPersistsNullCostForUnpricedAssistantMessages(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
||||
if !req.Stream {
|
||||
return chattest.OpenAINonStreamingResponse("title")
|
||||
}
|
||||
return chattest.OpenAIStreamingResponse(
|
||||
chattest.OpenAITextChunks("done")...,
|
||||
)
|
||||
})
|
||||
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
server := chatd.New(chatd.Config{
|
||||
Logger: logger,
|
||||
Database: db,
|
||||
ReplicaID: uuid.New(),
|
||||
Pubsub: ps,
|
||||
PendingChatAcquireInterval: 10 * time.Millisecond,
|
||||
InFlightChatStaleAfter: testutil.WaitSuperLong,
|
||||
})
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, server.Close())
|
||||
})
|
||||
|
||||
user, model := seedChatDependencies(ctx, t, db)
|
||||
setOpenAIProviderBaseURL(ctx, t, db, openAIURL)
|
||||
|
||||
_, err := db.UpdateChatModelConfig(ctx, database.UpdateChatModelConfigParams{
|
||||
Provider: model.Provider,
|
||||
Model: model.Model,
|
||||
DisplayName: model.DisplayName,
|
||||
UpdatedBy: model.UpdatedBy,
|
||||
Enabled: model.Enabled,
|
||||
IsDefault: model.IsDefault,
|
||||
ContextLimit: model.ContextLimit,
|
||||
CompressionThreshold: model.CompressionThreshold,
|
||||
Options: json.RawMessage(`{"cost":{"input_price_per_million_tokens":"0.15"}}`),
|
||||
ID: model.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
|
||||
OwnerID: user.ID,
|
||||
Title: "unpriced-cost-write-test",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
||||
fromDB, dbErr := db.GetChatByID(ctx, chat.ID)
|
||||
if dbErr != nil {
|
||||
return false
|
||||
}
|
||||
return fromDB.Status == database.ChatStatusWaiting && !fromDB.WorkerID.Valid
|
||||
}, testutil.IntervalFast)
|
||||
|
||||
messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chat.ID,
|
||||
AfterID: 0,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, messages, 2)
|
||||
|
||||
assistantMessage := messages[1]
|
||||
require.Equal(t, database.ChatMessageRoleAssistant, assistantMessage.Role)
|
||||
require.False(t, assistantMessage.TotalCostMicros.Valid)
|
||||
require.Equal(t, int64(0), assistantMessage.TotalCostMicros.Int64)
|
||||
require.True(t, assistantMessage.CostValid.Valid)
|
||||
require.False(t, assistantMessage.CostValid.Bool)
|
||||
}
|
||||
|
||||
func TestInterruptChatDoesNotSendWebPushNotification(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -3510,6 +3510,7 @@ func seedChatCostFixture(t *testing.T) chatCostTestFixture {
|
||||
InputTokens: sql.NullInt64{Int64: 100, Valid: true},
|
||||
OutputTokens: sql.NullInt64{Int64: 50, Valid: true},
|
||||
TotalCostMicros: sql.NullInt64{Int64: 500, Valid: true},
|
||||
CostValid: sql.NullBool{Bool: true, Valid: true},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
@@ -3602,6 +3603,7 @@ func TestChatCostSummary_AdminDrilldown(t *testing.T) {
|
||||
InputTokens: sql.NullInt64{Int64: 200, Valid: true},
|
||||
OutputTokens: sql.NullInt64{Int64: 100, Valid: true},
|
||||
TotalCostMicros: sql.NullInt64{Int64: 750, Valid: true},
|
||||
CostValid: sql.NullBool{Bool: true, Valid: true},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -3652,6 +3654,7 @@ func TestChatCostUsers(t *testing.T) {
|
||||
InputTokens: sql.NullInt64{Int64: 100, Valid: true},
|
||||
OutputTokens: sql.NullInt64{Int64: 50, Valid: true},
|
||||
TotalCostMicros: sql.NullInt64{Int64: 300, Valid: true},
|
||||
CostValid: sql.NullBool{Bool: true, Valid: true},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -3669,6 +3672,7 @@ func TestChatCostUsers(t *testing.T) {
|
||||
InputTokens: sql.NullInt64{Int64: 200, Valid: true},
|
||||
OutputTokens: sql.NullInt64{Int64: 100, Valid: true},
|
||||
TotalCostMicros: sql.NullInt64{Int64: 800, Valid: true},
|
||||
CostValid: sql.NullBool{Bool: true, Valid: true},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -3743,6 +3747,7 @@ func TestChatCostSummary_DateRange(t *testing.T) {
|
||||
InputTokens: sql.NullInt64{Int64: 100, Valid: true},
|
||||
OutputTokens: sql.NullInt64{Int64: 50, Valid: true},
|
||||
TotalCostMicros: sql.NullInt64{Int64: 500, Valid: true},
|
||||
CostValid: sql.NullBool{Bool: true, Valid: true},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -3798,6 +3803,7 @@ func TestChatCostSummary_UnpricedMessages(t *testing.T) {
|
||||
InputTokens: sql.NullInt64{Int64: 100, Valid: true},
|
||||
OutputTokens: sql.NullInt64{Int64: 50, Valid: true},
|
||||
TotalCostMicros: sql.NullInt64{Int64: 500, Valid: true},
|
||||
CostValid: sql.NullBool{Bool: true, Valid: true},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -3809,6 +3815,7 @@ func TestChatCostSummary_UnpricedMessages(t *testing.T) {
|
||||
InputTokens: sql.NullInt64{Int64: 200, Valid: true},
|
||||
OutputTokens: sql.NullInt64{Int64: 75, Valid: true},
|
||||
TotalCostMicros: sql.NullInt64{},
|
||||
CostValid: sql.NullBool{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -3822,6 +3829,82 @@ func TestChatCostSummary_UnpricedMessages(t *testing.T) {
|
||||
require.Equal(t, int64(125), summary.TotalOutputTokens)
|
||||
}
|
||||
|
||||
func TestChatCostSummary_MixedVersionRows(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client, db := newChatClientWithDatabase(t)
|
||||
firstUser := coderdtest.CreateFirstUser(t, client)
|
||||
modelConfig := createChatModelConfig(t, client)
|
||||
|
||||
chat, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{
|
||||
OwnerID: firstUser.UserID,
|
||||
LastModelConfigID: modelConfig.ID,
|
||||
Title: "mixed version test",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = db.InsertChatMessage(dbauthz.AsSystemRestricted(ctx), database.InsertChatMessageParams{
|
||||
ChatID: chat.ID,
|
||||
ModelConfigID: uuid.NullUUID{UUID: modelConfig.ID, Valid: true},
|
||||
Role: "assistant",
|
||||
ContentVersion: 1,
|
||||
Visibility: database.ChatMessageVisibilityBoth,
|
||||
InputTokens: sql.NullInt64{Int64: 100, Valid: true},
|
||||
OutputTokens: sql.NullInt64{Int64: 50, Valid: true},
|
||||
TotalCostMicros: sql.NullInt64{Int64: 500, Valid: true},
|
||||
CostValid: sql.NullBool{Bool: true, Valid: true},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = db.InsertChatMessage(dbauthz.AsSystemRestricted(ctx), database.InsertChatMessageParams{
|
||||
ChatID: chat.ID,
|
||||
ModelConfigID: uuid.NullUUID{UUID: modelConfig.ID, Valid: true},
|
||||
Role: "assistant",
|
||||
ContentVersion: 1,
|
||||
Visibility: database.ChatMessageVisibilityBoth,
|
||||
InputTokens: sql.NullInt64{Int64: 200, Valid: true},
|
||||
OutputTokens: sql.NullInt64{Int64: 100, Valid: true},
|
||||
TotalCostMicros: sql.NullInt64{Int64: 300, Valid: true},
|
||||
CostValid: sql.NullBool{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = db.InsertChatMessage(dbauthz.AsSystemRestricted(ctx), database.InsertChatMessageParams{
|
||||
ChatID: chat.ID,
|
||||
ModelConfigID: uuid.NullUUID{UUID: modelConfig.ID, Valid: true},
|
||||
Role: "assistant",
|
||||
ContentVersion: 1,
|
||||
Visibility: database.ChatMessageVisibilityBoth,
|
||||
InputTokens: sql.NullInt64{Int64: 150, Valid: true},
|
||||
OutputTokens: sql.NullInt64{Int64: 75, Valid: true},
|
||||
TotalCostMicros: sql.NullInt64{},
|
||||
CostValid: sql.NullBool{Bool: false, Valid: true},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = db.InsertChatMessage(dbauthz.AsSystemRestricted(ctx), database.InsertChatMessageParams{
|
||||
ChatID: chat.ID,
|
||||
ModelConfigID: uuid.NullUUID{UUID: modelConfig.ID, Valid: true},
|
||||
Role: "assistant",
|
||||
ContentVersion: 1,
|
||||
Visibility: database.ChatMessageVisibilityBoth,
|
||||
InputTokens: sql.NullInt64{Int64: 25, Valid: true},
|
||||
OutputTokens: sql.NullInt64{Int64: 10, Valid: true},
|
||||
TotalCostMicros: sql.NullInt64{},
|
||||
CostValid: sql.NullBool{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
summary, err := client.GetChatCostSummary(ctx, "me", codersdk.ChatCostSummaryOptions{})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(800), summary.TotalCostMicros)
|
||||
require.Equal(t, int64(2), summary.PricedMessageCount)
|
||||
require.Equal(t, int64(2), summary.UnpricedMessageCount)
|
||||
require.Equal(t, int64(475), summary.TotalInputTokens)
|
||||
require.Equal(t, int64(235), summary.TotalOutputTokens)
|
||||
}
|
||||
|
||||
func requireChatModelPricing(
|
||||
t *testing.T,
|
||||
actual *codersdk.ChatModelCallConfig,
|
||||
|
||||
Generated
+2
-1
@@ -1239,7 +1239,8 @@ CREATE TABLE chat_messages (
|
||||
compressed boolean DEFAULT false NOT NULL,
|
||||
created_by uuid,
|
||||
content_version smallint NOT NULL,
|
||||
total_cost_micros bigint
|
||||
total_cost_micros bigint,
|
||||
cost_valid boolean
|
||||
);
|
||||
|
||||
CREATE SEQUENCE chat_messages_id_seq
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
-- Restore NULL cost for rows that new code marked as unpriced.
|
||||
UPDATE chat_messages SET total_cost_micros = NULL WHERE cost_valid = false;
|
||||
|
||||
-- Drop cost_valid column.
|
||||
ALTER TABLE chat_messages DROP COLUMN cost_valid;
|
||||
@@ -0,0 +1,9 @@
|
||||
-- Add cost_valid as a nullable column with no default for
|
||||
-- mixed-version rollout compatibility. Old writers that do not
|
||||
-- know about this column will insert NULL, which the summary
|
||||
-- query interprets via COALESCE as falling back to the
|
||||
-- total_cost_micros IS NOT NULL heuristic.
|
||||
ALTER TABLE chat_messages ADD COLUMN cost_valid boolean;
|
||||
|
||||
-- Backfill: mark existing rows based on whether they have a cost.
|
||||
UPDATE chat_messages SET cost_valid = (total_cost_micros IS NOT NULL);
|
||||
@@ -465,6 +465,180 @@ func TestMigration000362AggregateUsageEvents(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestMigration000439ChatCostRolloutSafety(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if testing.Short() {
|
||||
t.SkipNow()
|
||||
return
|
||||
}
|
||||
|
||||
const migrationVersion = 439
|
||||
|
||||
sqlDB := testSQLDB(t)
|
||||
|
||||
next, err := migrations.Stepper(sqlDB)
|
||||
require.NoError(t, err)
|
||||
for {
|
||||
version, more, err := next()
|
||||
require.NoError(t, err)
|
||||
if !more {
|
||||
t.Fatalf("migration %d not found", migrationVersion)
|
||||
}
|
||||
if version == migrationVersion-1 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitSuperLong)
|
||||
orgID := uuid.New()
|
||||
userID := uuid.New()
|
||||
chatProviderID := uuid.New()
|
||||
modelConfigID := uuid.New()
|
||||
chatID := uuid.New()
|
||||
|
||||
_, err = sqlDB.ExecContext(ctx, `
|
||||
INSERT INTO organizations (id, name, display_name, description, icon, created_at, updated_at, is_default)
|
||||
VALUES ($1, 'test-org', 'Test Org', '', '', NOW(), NOW(), false)
|
||||
`, orgID)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = sqlDB.ExecContext(ctx, `
|
||||
INSERT INTO users (id, email, username, name, hashed_password, created_at, updated_at, status, rbac_roles, login_type)
|
||||
VALUES ($1, 'test@test.com', 'testuser', 'Test User', 'xxx', NOW(), NOW(), 'active', '{}', 'password')
|
||||
`, userID)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = sqlDB.ExecContext(ctx, `
|
||||
INSERT INTO organization_members (organization_id, user_id, created_at, updated_at, roles)
|
||||
VALUES ($1, $2, NOW(), NOW(), '{}')
|
||||
`, orgID, userID)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = sqlDB.ExecContext(ctx, `
|
||||
INSERT INTO chat_providers (
|
||||
id,
|
||||
provider,
|
||||
display_name,
|
||||
api_key,
|
||||
api_key_key_id,
|
||||
enabled,
|
||||
created_at,
|
||||
updated_at
|
||||
) VALUES (
|
||||
$1,
|
||||
'openai',
|
||||
'OpenAI',
|
||||
'',
|
||||
NULL,
|
||||
true,
|
||||
NOW(),
|
||||
NOW()
|
||||
)
|
||||
`, chatProviderID)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = sqlDB.ExecContext(ctx, `
|
||||
INSERT INTO chat_model_configs (
|
||||
id,
|
||||
display_name,
|
||||
provider,
|
||||
model,
|
||||
enabled,
|
||||
context_limit,
|
||||
compression_threshold,
|
||||
created_at,
|
||||
updated_at
|
||||
) VALUES (
|
||||
$1,
|
||||
'test model',
|
||||
'openai',
|
||||
'gpt-4',
|
||||
true,
|
||||
200000,
|
||||
70,
|
||||
NOW(),
|
||||
NOW()
|
||||
)
|
||||
`, modelConfigID)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = sqlDB.ExecContext(ctx, `
|
||||
INSERT INTO chats (id, owner_id, last_model_config_id, title, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, 'test chat', NOW(), NOW())
|
||||
`, chatID, userID, modelConfigID)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = sqlDB.ExecContext(ctx, `
|
||||
INSERT INTO chat_messages (chat_id, role, content_version, visibility, total_cost_micros, created_at)
|
||||
VALUES ($1, 'assistant', 1, 'both', 500, NOW())
|
||||
`, chatID)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = sqlDB.ExecContext(ctx, `
|
||||
INSERT INTO chat_messages (chat_id, role, content_version, visibility, total_cost_micros, created_at)
|
||||
VALUES ($1, 'assistant', 1, 'both', NULL, NOW())
|
||||
`, chatID)
|
||||
require.NoError(t, err)
|
||||
|
||||
version, _, err := next()
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, migrationVersion, version)
|
||||
|
||||
rows, err := sqlDB.QueryContext(ctx, `
|
||||
SELECT total_cost_micros, cost_valid
|
||||
FROM chat_messages
|
||||
WHERE chat_id = $1
|
||||
ORDER BY id
|
||||
`, chatID)
|
||||
require.NoError(t, err)
|
||||
defer rows.Close()
|
||||
|
||||
var messages []struct {
|
||||
totalCost sql.NullInt64
|
||||
costValid sql.NullBool
|
||||
}
|
||||
for rows.Next() {
|
||||
var message struct {
|
||||
totalCost sql.NullInt64
|
||||
costValid sql.NullBool
|
||||
}
|
||||
err = rows.Scan(&message.totalCost, &message.costValid)
|
||||
require.NoError(t, err)
|
||||
messages = append(messages, message)
|
||||
}
|
||||
require.NoError(t, rows.Err())
|
||||
require.Len(t, messages, 2)
|
||||
|
||||
require.True(t, messages[0].totalCost.Valid)
|
||||
require.Equal(t, int64(500), messages[0].totalCost.Int64)
|
||||
require.True(t, messages[0].costValid.Valid)
|
||||
require.True(t, messages[0].costValid.Bool)
|
||||
|
||||
require.False(t, messages[1].totalCost.Valid)
|
||||
require.True(t, messages[1].costValid.Valid)
|
||||
require.False(t, messages[1].costValid.Bool)
|
||||
|
||||
_, err = sqlDB.ExecContext(ctx, `
|
||||
INSERT INTO chat_messages (chat_id, role, content_version, visibility, total_cost_micros, created_at)
|
||||
VALUES ($1, 'assistant', 1, 'both', NULL, NOW())
|
||||
`, chatID)
|
||||
require.NoError(t, err)
|
||||
|
||||
var oldStyleTotalCost sql.NullInt64
|
||||
var oldStyleCostValid sql.NullBool
|
||||
err = sqlDB.QueryRowContext(ctx, `
|
||||
SELECT total_cost_micros, cost_valid
|
||||
FROM chat_messages
|
||||
WHERE chat_id = $1
|
||||
ORDER BY id DESC
|
||||
LIMIT 1
|
||||
`, chatID).Scan(&oldStyleTotalCost, &oldStyleCostValid)
|
||||
require.NoError(t, err)
|
||||
require.False(t, oldStyleTotalCost.Valid)
|
||||
require.False(t, oldStyleCostValid.Valid)
|
||||
}
|
||||
|
||||
func TestMigration000387MigrateTaskWorkspaces(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -4085,6 +4085,7 @@ type ChatMessage struct {
|
||||
CreatedBy uuid.NullUUID `db:"created_by" json:"created_by"`
|
||||
ContentVersion int16 `db:"content_version" json:"content_version"`
|
||||
TotalCostMicros sql.NullInt64 `db:"total_cost_micros" json:"total_cost_micros"`
|
||||
CostValid sql.NullBool `db:"cost_valid" json:"cost_valid"`
|
||||
}
|
||||
|
||||
type ChatModelConfig struct {
|
||||
|
||||
@@ -9052,11 +9052,13 @@ func TestGetChatMessagesForPromptByChatID(t *testing.T) {
|
||||
) database.ChatMessage {
|
||||
t.Helper()
|
||||
msg, err := db.InsertChatMessage(ctx, database.InsertChatMessageParams{
|
||||
ChatID: chatID,
|
||||
Role: role,
|
||||
ContentVersion: chatprompt.CurrentContentVersion,
|
||||
Visibility: vis,
|
||||
Compressed: sql.NullBool{Bool: compressed, Valid: true},
|
||||
ChatID: chatID,
|
||||
Role: role,
|
||||
ContentVersion: chatprompt.CurrentContentVersion,
|
||||
Visibility: vis,
|
||||
Compressed: sql.NullBool{Bool: compressed, Valid: true},
|
||||
TotalCostMicros: sql.NullInt64{},
|
||||
CostValid: sql.NullBool{},
|
||||
Content: pqtype.NullRawMessage{
|
||||
RawMessage: json.RawMessage(`"` + content + `"`),
|
||||
Valid: true,
|
||||
|
||||
@@ -3563,10 +3563,10 @@ const getChatCostSummary = `-- name: GetChatCostSummary :one
|
||||
SELECT
|
||||
COALESCE(SUM(cm.total_cost_micros), 0)::bigint AS total_cost_micros,
|
||||
COUNT(*) FILTER (
|
||||
WHERE cm.total_cost_micros IS NOT NULL
|
||||
WHERE COALESCE(cm.cost_valid, cm.total_cost_micros IS NOT NULL)
|
||||
)::bigint AS priced_message_count,
|
||||
COUNT(*) FILTER (
|
||||
WHERE cm.total_cost_micros IS NULL
|
||||
WHERE NOT COALESCE(cm.cost_valid, cm.total_cost_micros IS NOT NULL)
|
||||
AND (
|
||||
cm.input_tokens IS NOT NULL
|
||||
OR cm.output_tokens IS NOT NULL
|
||||
@@ -3721,7 +3721,7 @@ func (q *sqlQuerier) GetChatDiffStatusesByChatIDs(ctx context.Context, chatIds [
|
||||
|
||||
const getChatMessageByID = `-- name: GetChatMessageByID :one
|
||||
SELECT
|
||||
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros
|
||||
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros, cost_valid
|
||||
FROM
|
||||
chat_messages
|
||||
WHERE
|
||||
@@ -3750,13 +3750,14 @@ func (q *sqlQuerier) GetChatMessageByID(ctx context.Context, id int64) (ChatMess
|
||||
&i.CreatedBy,
|
||||
&i.ContentVersion,
|
||||
&i.TotalCostMicros,
|
||||
&i.CostValid,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getChatMessagesByChatID = `-- name: GetChatMessagesByChatID :many
|
||||
SELECT
|
||||
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros
|
||||
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros, cost_valid
|
||||
FROM
|
||||
chat_messages
|
||||
WHERE
|
||||
@@ -3800,6 +3801,7 @@ func (q *sqlQuerier) GetChatMessagesByChatID(ctx context.Context, arg GetChatMes
|
||||
&i.CreatedBy,
|
||||
&i.ContentVersion,
|
||||
&i.TotalCostMicros,
|
||||
&i.CostValid,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -3831,7 +3833,7 @@ WITH latest_compressed_summary AS (
|
||||
1
|
||||
)
|
||||
SELECT
|
||||
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros
|
||||
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros, cost_valid
|
||||
FROM
|
||||
chat_messages
|
||||
WHERE
|
||||
@@ -3899,6 +3901,7 @@ func (q *sqlQuerier) GetChatMessagesForPromptByChatID(ctx context.Context, chatI
|
||||
&i.CreatedBy,
|
||||
&i.ContentVersion,
|
||||
&i.TotalCostMicros,
|
||||
&i.CostValid,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -4043,7 +4046,7 @@ func (q *sqlQuerier) GetChatsByOwnerID(ctx context.Context, arg GetChatsByOwnerI
|
||||
|
||||
const getLastChatMessageByRole = `-- name: GetLastChatMessageByRole :one
|
||||
SELECT
|
||||
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros
|
||||
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros, cost_valid
|
||||
FROM
|
||||
chat_messages
|
||||
WHERE
|
||||
@@ -4082,6 +4085,7 @@ func (q *sqlQuerier) GetLastChatMessageByRole(ctx context.Context, arg GetLastCh
|
||||
&i.CreatedBy,
|
||||
&i.ContentVersion,
|
||||
&i.TotalCostMicros,
|
||||
&i.CostValid,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
@@ -4228,7 +4232,8 @@ INSERT INTO chat_messages (
|
||||
cache_read_tokens,
|
||||
context_limit,
|
||||
compressed,
|
||||
total_cost_micros
|
||||
total_cost_micros,
|
||||
cost_valid
|
||||
) VALUES (
|
||||
$1::uuid,
|
||||
$2::uuid,
|
||||
@@ -4245,10 +4250,11 @@ INSERT INTO chat_messages (
|
||||
$13::bigint,
|
||||
$14::bigint,
|
||||
COALESCE($15::boolean, FALSE),
|
||||
$16::bigint
|
||||
$16::bigint,
|
||||
$17::boolean
|
||||
)
|
||||
RETURNING
|
||||
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros
|
||||
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros, cost_valid
|
||||
`
|
||||
|
||||
type InsertChatMessageParams struct {
|
||||
@@ -4268,6 +4274,7 @@ type InsertChatMessageParams struct {
|
||||
ContextLimit sql.NullInt64 `db:"context_limit" json:"context_limit"`
|
||||
Compressed sql.NullBool `db:"compressed" json:"compressed"`
|
||||
TotalCostMicros sql.NullInt64 `db:"total_cost_micros" json:"total_cost_micros"`
|
||||
CostValid sql.NullBool `db:"cost_valid" json:"cost_valid"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) InsertChatMessage(ctx context.Context, arg InsertChatMessageParams) (ChatMessage, error) {
|
||||
@@ -4288,6 +4295,7 @@ func (q *sqlQuerier) InsertChatMessage(ctx context.Context, arg InsertChatMessag
|
||||
arg.ContextLimit,
|
||||
arg.Compressed,
|
||||
arg.TotalCostMicros,
|
||||
arg.CostValid,
|
||||
)
|
||||
var i ChatMessage
|
||||
err := row.Scan(
|
||||
@@ -4309,6 +4317,7 @@ func (q *sqlQuerier) InsertChatMessage(ctx context.Context, arg InsertChatMessag
|
||||
&i.CreatedBy,
|
||||
&i.ContentVersion,
|
||||
&i.TotalCostMicros,
|
||||
&i.CostValid,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
@@ -4444,7 +4453,7 @@ SET
|
||||
WHERE
|
||||
id = $3::bigint
|
||||
RETURNING
|
||||
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros
|
||||
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros, cost_valid
|
||||
`
|
||||
|
||||
type UpdateChatMessageByIDParams struct {
|
||||
@@ -4475,6 +4484,7 @@ func (q *sqlQuerier) UpdateChatMessageByID(ctx context.Context, arg UpdateChatMe
|
||||
&i.CreatedBy,
|
||||
&i.ContentVersion,
|
||||
&i.TotalCostMicros,
|
||||
&i.CostValid,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
@@ -182,7 +182,8 @@ INSERT INTO chat_messages (
|
||||
cache_read_tokens,
|
||||
context_limit,
|
||||
compressed,
|
||||
total_cost_micros
|
||||
total_cost_micros,
|
||||
cost_valid
|
||||
) VALUES (
|
||||
@chat_id::uuid,
|
||||
sqlc.narg('created_by')::uuid,
|
||||
@@ -199,7 +200,8 @@ INSERT INTO chat_messages (
|
||||
sqlc.narg('cache_read_tokens')::bigint,
|
||||
sqlc.narg('context_limit')::bigint,
|
||||
COALESCE(sqlc.narg('compressed')::boolean, FALSE),
|
||||
sqlc.narg('total_cost_micros')::bigint
|
||||
sqlc.narg('total_cost_micros')::bigint,
|
||||
sqlc.narg('cost_valid')::boolean
|
||||
)
|
||||
RETURNING
|
||||
*;
|
||||
@@ -516,10 +518,10 @@ WHERE
|
||||
SELECT
|
||||
COALESCE(SUM(cm.total_cost_micros), 0)::bigint AS total_cost_micros,
|
||||
COUNT(*) FILTER (
|
||||
WHERE cm.total_cost_micros IS NOT NULL
|
||||
WHERE COALESCE(cm.cost_valid, cm.total_cost_micros IS NOT NULL)
|
||||
)::bigint AS priced_message_count,
|
||||
COUNT(*) FILTER (
|
||||
WHERE cm.total_cost_micros IS NULL
|
||||
WHERE NOT COALESCE(cm.cost_valid, cm.total_cost_micros IS NOT NULL)
|
||||
AND (
|
||||
cm.input_tokens IS NOT NULL
|
||||
OR cm.output_tokens IS NOT NULL
|
||||
|
||||
Reference in New Issue
Block a user