Compare commits

...

5 Commits

Author SHA1 Message Date
Michael Suchacz d0e9e4595e fix(chat): preserve null cost writes 2026-03-16 01:19:59 +00:00
Michael Suchacz 3fd76a1bc9 test(chats): add rollout safety tests for migration 000439 2026-03-16 00:54:08 +00:00
Michael Suchacz ce6e383f5c fix(chats): make chat cost migration rollout-safe 2026-03-16 00:54:08 +00:00
Michael Suchacz 9bea4e098d fix(migrations): correct 000439 rollback ordering 2026-03-15 23:40:36 +00:00
Michael Suchacz a0519fa1e9 refactor(chat): simplify chat cost storage 2026-03-15 23:21:52 +00:00
13 changed files with 490 additions and 114 deletions
+14 -12
View File
@@ -9,23 +9,24 @@ import (
// Returns cost in micros -- millionths of a dollar, rounded up to the next
// whole microdollar.
// Returns nil when pricing is not configured or when all priced usage fields
// are nil, allowing callers to distinguish "zero cost" from "unpriced".
// Returns valid=false when pricing is not configured or when all priced usage
// fields are nil, allowing callers to distinguish "zero cost" from
// "unpriced".
func CalculateTotalCostMicros(
usage codersdk.ChatMessageUsage,
cost *codersdk.ModelCostConfig,
) *int64 {
) (micros int64, valid bool) {
if cost == nil {
return nil
return 0, false
}
// A cost config with no prices set means pricing is effectively
// unconfigured — return nil (unpriced) rather than zero.
// unconfigured — return valid=false (unpriced) rather than zero.
if cost.InputPricePerMillionTokens == nil &&
cost.OutputPricePerMillionTokens == nil &&
cost.CacheReadPricePerMillionTokens == nil &&
cost.CacheWritePricePerMillionTokens == nil {
return nil
return 0, false
}
if usage.InputTokens == nil &&
@@ -33,7 +34,7 @@ func CalculateTotalCostMicros(
usage.ReasoningTokens == nil &&
usage.CacheCreationTokens == nil &&
usage.CacheReadTokens == nil {
return nil
return 0, false
}
// OutputTokens already includes reasoning tokens per provider
@@ -41,14 +42,15 @@ func CalculateTotalCostMicros(
// reasoning_tokens). Adding ReasoningTokens here would
// double-count.
// Preserve nil when usage exists only in categories without configured
// pricing, so callers can distinguish "unpriced" from "priced at zero".
// Preserve valid=false when usage exists only in categories without
// configured pricing, so callers can distinguish "unpriced" from
// "priced at zero".
hasMatchingPrice := (usage.InputTokens != nil && cost.InputPricePerMillionTokens != nil) ||
(usage.OutputTokens != nil && cost.OutputPricePerMillionTokens != nil) ||
(usage.CacheReadTokens != nil && cost.CacheReadPricePerMillionTokens != nil) ||
(usage.CacheCreationTokens != nil && cost.CacheWritePricePerMillionTokens != nil)
if !hasMatchingPrice {
return nil
return 0, false
}
inputMicros := calcCost(usage.InputTokens, cost.InputPricePerMillionTokens)
@@ -60,8 +62,8 @@ func CalculateTotalCostMicros(
Add(outputMicros).
Add(cacheReadMicros).
Add(cacheWriteMicros)
rounded := total.Ceil().IntPart()
return &rounded
return total.Ceil().IntPart(), true
}
// calcCost returns the cost in fractional microdollars (millionths of a USD)
+64 -72
View File
@@ -15,19 +15,21 @@ func TestCalculateTotalCostMicros(t *testing.T) {
t.Parallel()
tests := []struct {
name string
usage codersdk.ChatMessageUsage
cost *codersdk.ModelCostConfig
want *int64
name string
usage codersdk.ChatMessageUsage
cost *codersdk.ModelCostConfig
wantMicros int64
wantValid bool
}{
{
name: "nil cost returns nil",
usage: codersdk.ChatMessageUsage{InputTokens: ptr.Ref[int64](1000)},
cost: nil,
want: nil,
name: "nil cost returns unpriced",
usage: codersdk.ChatMessageUsage{InputTokens: ptr.Ref[int64](1000)},
cost: nil,
wantMicros: 0,
wantValid: false,
},
{
name: "all priced usage fields nil returns nil",
name: "all priced usage fields nil returns unpriced",
usage: codersdk.ChatMessageUsage{
TotalTokens: ptr.Ref[int64](1234),
ContextLimit: ptr.Ref[int64](8192),
@@ -35,31 +37,29 @@ func TestCalculateTotalCostMicros(t *testing.T) {
cost: &codersdk.ModelCostConfig{
InputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("3")),
},
want: nil,
wantMicros: 0,
wantValid: false,
},
{
name: "sub-micro total rounds up to 1",
usage: codersdk.ChatMessageUsage{InputTokens: ptr.Ref[int64](1)},
cost: &codersdk.ModelCostConfig{
InputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("0.01")),
},
want: ptr.Ref[int64](1),
name: "sub-micro total rounds up to 1",
usage: codersdk.ChatMessageUsage{InputTokens: ptr.Ref[int64](1)},
cost: &codersdk.ModelCostConfig{InputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("0.01"))},
wantMicros: 1,
wantValid: true,
},
{
name: "simple input only",
usage: codersdk.ChatMessageUsage{InputTokens: ptr.Ref[int64](1000)},
cost: &codersdk.ModelCostConfig{
InputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("3")),
},
want: ptr.Ref[int64](3000),
name: "simple input only",
usage: codersdk.ChatMessageUsage{InputTokens: ptr.Ref[int64](1000)},
cost: &codersdk.ModelCostConfig{InputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("3"))},
wantMicros: 3000,
wantValid: true,
},
{
name: "simple output only",
usage: codersdk.ChatMessageUsage{OutputTokens: ptr.Ref[int64](500)},
cost: &codersdk.ModelCostConfig{
OutputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("15")),
},
want: ptr.Ref[int64](7500),
name: "simple output only",
usage: codersdk.ChatMessageUsage{OutputTokens: ptr.Ref[int64](500)},
cost: &codersdk.ModelCostConfig{OutputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("15"))},
wantMicros: 7500,
wantValid: true,
},
{
name: "reasoning tokens included in output total",
@@ -67,26 +67,23 @@ func TestCalculateTotalCostMicros(t *testing.T) {
OutputTokens: ptr.Ref[int64](500),
ReasoningTokens: ptr.Ref[int64](200),
},
cost: &codersdk.ModelCostConfig{
OutputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("15")),
},
want: ptr.Ref[int64](7500),
cost: &codersdk.ModelCostConfig{OutputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("15"))},
wantMicros: 7500,
wantValid: true,
},
{
name: "cache read tokens",
usage: codersdk.ChatMessageUsage{CacheReadTokens: ptr.Ref[int64](10000)},
cost: &codersdk.ModelCostConfig{
CacheReadPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("0.3")),
},
want: ptr.Ref[int64](3000),
name: "cache read tokens",
usage: codersdk.ChatMessageUsage{CacheReadTokens: ptr.Ref[int64](10000)},
cost: &codersdk.ModelCostConfig{CacheReadPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("0.3"))},
wantMicros: 3000,
wantValid: true,
},
{
name: "cache creation tokens",
usage: codersdk.ChatMessageUsage{CacheCreationTokens: ptr.Ref[int64](5000)},
cost: &codersdk.ModelCostConfig{
CacheWritePricePerMillionTokens: ptr.Ref(decimal.RequireFromString("3.75")),
},
want: ptr.Ref[int64](18750),
name: "cache creation tokens",
usage: codersdk.ChatMessageUsage{CacheCreationTokens: ptr.Ref[int64](5000)},
cost: &codersdk.ModelCostConfig{CacheWritePricePerMillionTokens: ptr.Ref(decimal.RequireFromString("3.75"))},
wantMicros: 18750,
wantValid: true,
},
{
name: "full mixed usage totals all components exactly",
@@ -105,7 +102,8 @@ func TestCalculateTotalCostMicros(t *testing.T) {
CacheReadPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("0.7")),
CacheWritePricePerMillionTokens: ptr.Ref(decimal.RequireFromString("7.89")),
},
want: ptr.Ref[int64](2005),
wantMicros: 2005,
wantValid: true,
},
{
name: "partial pricing only input contributes",
@@ -116,32 +114,30 @@ func TestCalculateTotalCostMicros(t *testing.T) {
CacheReadTokens: ptr.Ref[int64](500),
CacheCreationTokens: ptr.Ref[int64](250),
},
cost: &codersdk.ModelCostConfig{
InputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("2.5")),
},
want: ptr.Ref[int64](3085),
cost: &codersdk.ModelCostConfig{InputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("2.5"))},
wantMicros: 3085,
wantValid: true,
},
{
name: "zero tokens with pricing returns zero pointer",
usage: codersdk.ChatMessageUsage{InputTokens: ptr.Ref[int64](0)},
cost: &codersdk.ModelCostConfig{
InputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("3")),
},
want: ptr.Ref[int64](0),
name: "zero tokens with pricing returns zero cost",
usage: codersdk.ChatMessageUsage{InputTokens: ptr.Ref[int64](0)},
cost: &codersdk.ModelCostConfig{InputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("3"))},
wantMicros: 0,
wantValid: true,
},
{
name: "usage only in unpriced categories returns nil",
usage: codersdk.ChatMessageUsage{InputTokens: ptr.Ref[int64](1000)},
cost: &codersdk.ModelCostConfig{
OutputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("15")),
},
want: nil,
name: "usage only in unpriced categories returns unpriced",
usage: codersdk.ChatMessageUsage{InputTokens: ptr.Ref[int64](1000)},
cost: &codersdk.ModelCostConfig{OutputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("15"))},
wantMicros: 0,
wantValid: false,
},
{
name: "non nil usage with empty cost config returns nil",
usage: codersdk.ChatMessageUsage{InputTokens: ptr.Ref[int64](42)},
cost: &codersdk.ModelCostConfig{},
want: nil,
name: "non nil usage with empty cost config returns unpriced",
usage: codersdk.ChatMessageUsage{InputTokens: ptr.Ref[int64](42)},
cost: &codersdk.ModelCostConfig{},
wantMicros: 0,
wantValid: false,
},
}
@@ -150,14 +146,10 @@ func TestCalculateTotalCostMicros(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := chatcost.CalculateTotalCostMicros(tt.usage, tt.cost)
micros, valid := chatcost.CalculateTotalCostMicros(tt.usage, tt.cost)
if tt.want == nil {
require.Nil(t, got)
} else {
require.NotNil(t, got)
require.Equal(t, *tt.want, *got)
}
require.Equal(t, tt.wantValid, valid)
require.Equal(t, tt.wantMicros, micros)
})
}
}
+23 -10
View File
@@ -298,6 +298,7 @@ func (p *Server) CreateChat(ctx context.Context, opts CreateOptions) (database.C
ContextLimit: sql.NullInt64{},
Compressed: sql.NullBool{},
TotalCostMicros: sql.NullInt64{},
CostValid: sql.NullBool{},
})
if err != nil {
return xerrors.Errorf("insert system message: %w", err)
@@ -327,6 +328,7 @@ func (p *Server) CreateChat(ctx context.Context, opts CreateOptions) (database.C
CacheReadTokens: sql.NullInt64{},
ContextLimit: sql.NullInt64{},
TotalCostMicros: sql.NullInt64{},
CostValid: sql.NullBool{},
Compressed: sql.NullBool{},
})
if err != nil {
@@ -917,6 +919,7 @@ func insertUserMessageAndSetPending(
CacheReadTokens: sql.NullInt64{},
ContextLimit: sql.NullInt64{},
TotalCostMicros: sql.NullInt64{},
CostValid: sql.NullBool{},
Compressed: sql.NullBool{},
})
if err != nil {
@@ -1977,6 +1980,7 @@ func (p *Server) processChat(ctx context.Context, chat database.Chat) {
CacheReadTokens: sql.NullInt64{},
ContextLimit: sql.NullInt64{},
TotalCostMicros: sql.NullInt64{},
CostValid: sql.NullBool{},
Compressed: sql.NullBool{},
})
if insertErr != nil {
@@ -2383,7 +2387,8 @@ func (p *Server) runChat(
}
}
totalCostMicros := chatcost.CalculateTotalCostMicros(usageForCost, callConfig.Cost)
totalCostMicros, costValid := chatcost.CalculateTotalCostMicros(usageForCost, callConfig.Cost)
totalCostMicrosValue, costValidValue := persistedMessageCost(totalCostMicros, costValid)
assistantMessage, insertErr := tx.InsertChatMessage(persistCtx, database.InsertChatMessageParams{
ChatID: chat.ID,
@@ -2407,11 +2412,12 @@ func (p *Server) runChat(
CacheReadTokens: usageNullInt64(step.Usage.CacheReadTokens, hasUsage),
ContextLimit: step.ContextLimit,
Compressed: sql.NullBool{},
// TotalCostMicros is nullable: NULL means "unpriced"
// (pricing config was missing or no priced token
// breakdown available), while 0 means "priced at
// zero cost" (e.g., a free model).
TotalCostMicros: usageNullInt64Ptr(totalCostMicros),
// cost_valid=true means priced (including zero-cost),
// false means unpriced. Keep total_cost_micros NULL for
// unpriced writes so older readers still infer pricing
// correctly during mixed-version rollout.
TotalCostMicros: totalCostMicrosValue,
CostValid: costValidValue,
})
if insertErr != nil {
return xerrors.Errorf("insert assistant message: %w", insertErr)
@@ -2442,6 +2448,7 @@ func (p *Server) runChat(
CacheReadTokens: sql.NullInt64{},
ContextLimit: sql.NullInt64{},
TotalCostMicros: sql.NullInt64{},
CostValid: sql.NullBool{},
Compressed: sql.NullBool{},
})
if insertErr != nil {
@@ -2817,6 +2824,7 @@ func (p *Server) persistChatContextSummary(
CacheReadTokens: sql.NullInt64{},
ContextLimit: sql.NullInt64{},
TotalCostMicros: sql.NullInt64{},
CostValid: sql.NullBool{},
})
if txErr != nil {
return xerrors.Errorf("insert hidden summary message: %w", txErr)
@@ -2842,6 +2850,7 @@ func (p *Server) persistChatContextSummary(
CacheReadTokens: sql.NullInt64{},
ContextLimit: sql.NullInt64{},
TotalCostMicros: sql.NullInt64{},
CostValid: sql.NullBool{},
})
if txErr != nil {
return xerrors.Errorf("insert summary tool call message: %w", txErr)
@@ -2868,6 +2877,7 @@ func (p *Server) persistChatContextSummary(
CacheReadTokens: sql.NullInt64{},
ContextLimit: sql.NullInt64{},
TotalCostMicros: sql.NullInt64{},
CostValid: sql.NullBool{},
})
if txErr != nil {
return xerrors.Errorf("insert summary tool result message: %w", txErr)
@@ -2995,11 +3005,14 @@ func usageNullInt64(value int64, valid bool) sql.NullInt64 {
}
}
func usageNullInt64Ptr(v *int64) sql.NullInt64 {
if v == nil {
return sql.NullInt64{}
//nolint:revive // Boolean controls SQL NULL validity.
func persistedMessageCost(totalCostMicros int64, costValid bool) (sql.NullInt64, sql.NullBool) {
costValidValue := sql.NullBool{Bool: costValid, Valid: true}
if !costValid {
return sql.NullInt64{}, costValidValue
}
return sql.NullInt64{Int64: *v, Valid: true}
return sql.NullInt64{Int64: totalCostMicros, Valid: true}, costValidValue
}
func refreshChatWorkspaceSnapshot(
+82
View File
@@ -588,6 +588,8 @@ func TestEditMessageRejectsNonUserMessage(t *testing.T) {
CacheReadTokens: sql.NullInt64{},
ContextLimit: sql.NullInt64{},
Compressed: sql.NullBool{},
TotalCostMicros: sql.NullInt64{},
CostValid: sql.NullBool{},
})
require.NoError(t, err)
@@ -936,6 +938,8 @@ func TestSubscribeAfterMessageID(t *testing.T) {
CacheReadTokens: sql.NullInt64{},
ContextLimit: sql.NullInt64{},
Compressed: sql.NullBool{},
TotalCostMicros: sql.NullInt64{},
CostValid: sql.NullBool{},
})
require.NoError(t, err)
@@ -959,6 +963,8 @@ func TestSubscribeAfterMessageID(t *testing.T) {
CacheReadTokens: sql.NullInt64{},
ContextLimit: sql.NullInt64{},
Compressed: sql.NullBool{},
TotalCostMicros: sql.NullInt64{},
CostValid: sql.NullBool{},
})
require.NoError(t, err)
@@ -1405,6 +1411,82 @@ func setOpenAIProviderBaseURL(
require.NoError(t, err)
}
func TestSuccessfulChatPersistsNullCostForUnpricedAssistantMessages(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
ctx := testutil.Context(t, testutil.WaitLong)
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
if !req.Stream {
return chattest.OpenAINonStreamingResponse("title")
}
return chattest.OpenAIStreamingResponse(
chattest.OpenAITextChunks("done")...,
)
})
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
server := chatd.New(chatd.Config{
Logger: logger,
Database: db,
ReplicaID: uuid.New(),
Pubsub: ps,
PendingChatAcquireInterval: 10 * time.Millisecond,
InFlightChatStaleAfter: testutil.WaitSuperLong,
})
t.Cleanup(func() {
require.NoError(t, server.Close())
})
user, model := seedChatDependencies(ctx, t, db)
setOpenAIProviderBaseURL(ctx, t, db, openAIURL)
_, err := db.UpdateChatModelConfig(ctx, database.UpdateChatModelConfigParams{
Provider: model.Provider,
Model: model.Model,
DisplayName: model.DisplayName,
UpdatedBy: model.UpdatedBy,
Enabled: model.Enabled,
IsDefault: model.IsDefault,
ContextLimit: model.ContextLimit,
CompressionThreshold: model.CompressionThreshold,
Options: json.RawMessage(`{"cost":{"input_price_per_million_tokens":"0.15"}}`),
ID: model.ID,
})
require.NoError(t, err)
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
OwnerID: user.ID,
Title: "unpriced-cost-write-test",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
})
require.NoError(t, err)
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
fromDB, dbErr := db.GetChatByID(ctx, chat.ID)
if dbErr != nil {
return false
}
return fromDB.Status == database.ChatStatusWaiting && !fromDB.WorkerID.Valid
}, testutil.IntervalFast)
messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
ChatID: chat.ID,
AfterID: 0,
})
require.NoError(t, err)
require.Len(t, messages, 2)
assistantMessage := messages[1]
require.Equal(t, database.ChatMessageRoleAssistant, assistantMessage.Role)
require.False(t, assistantMessage.TotalCostMicros.Valid)
require.Equal(t, int64(0), assistantMessage.TotalCostMicros.Int64)
require.True(t, assistantMessage.CostValid.Valid)
require.False(t, assistantMessage.CostValid.Bool)
}
func TestInterruptChatDoesNotSendWebPushNotification(t *testing.T) {
t.Parallel()
+83
View File
@@ -3510,6 +3510,7 @@ func seedChatCostFixture(t *testing.T) chatCostTestFixture {
InputTokens: sql.NullInt64{Int64: 100, Valid: true},
OutputTokens: sql.NullInt64{Int64: 50, Valid: true},
TotalCostMicros: sql.NullInt64{Int64: 500, Valid: true},
CostValid: sql.NullBool{Bool: true, Valid: true},
})
require.NoError(t, err)
}
@@ -3602,6 +3603,7 @@ func TestChatCostSummary_AdminDrilldown(t *testing.T) {
InputTokens: sql.NullInt64{Int64: 200, Valid: true},
OutputTokens: sql.NullInt64{Int64: 100, Valid: true},
TotalCostMicros: sql.NullInt64{Int64: 750, Valid: true},
CostValid: sql.NullBool{Bool: true, Valid: true},
})
require.NoError(t, err)
@@ -3652,6 +3654,7 @@ func TestChatCostUsers(t *testing.T) {
InputTokens: sql.NullInt64{Int64: 100, Valid: true},
OutputTokens: sql.NullInt64{Int64: 50, Valid: true},
TotalCostMicros: sql.NullInt64{Int64: 300, Valid: true},
CostValid: sql.NullBool{Bool: true, Valid: true},
})
require.NoError(t, err)
@@ -3669,6 +3672,7 @@ func TestChatCostUsers(t *testing.T) {
InputTokens: sql.NullInt64{Int64: 200, Valid: true},
OutputTokens: sql.NullInt64{Int64: 100, Valid: true},
TotalCostMicros: sql.NullInt64{Int64: 800, Valid: true},
CostValid: sql.NullBool{Bool: true, Valid: true},
})
require.NoError(t, err)
@@ -3743,6 +3747,7 @@ func TestChatCostSummary_DateRange(t *testing.T) {
InputTokens: sql.NullInt64{Int64: 100, Valid: true},
OutputTokens: sql.NullInt64{Int64: 50, Valid: true},
TotalCostMicros: sql.NullInt64{Int64: 500, Valid: true},
CostValid: sql.NullBool{Bool: true, Valid: true},
})
require.NoError(t, err)
@@ -3798,6 +3803,7 @@ func TestChatCostSummary_UnpricedMessages(t *testing.T) {
InputTokens: sql.NullInt64{Int64: 100, Valid: true},
OutputTokens: sql.NullInt64{Int64: 50, Valid: true},
TotalCostMicros: sql.NullInt64{Int64: 500, Valid: true},
CostValid: sql.NullBool{Bool: true, Valid: true},
})
require.NoError(t, err)
@@ -3809,6 +3815,7 @@ func TestChatCostSummary_UnpricedMessages(t *testing.T) {
InputTokens: sql.NullInt64{Int64: 200, Valid: true},
OutputTokens: sql.NullInt64{Int64: 75, Valid: true},
TotalCostMicros: sql.NullInt64{},
CostValid: sql.NullBool{},
})
require.NoError(t, err)
@@ -3822,6 +3829,82 @@ func TestChatCostSummary_UnpricedMessages(t *testing.T) {
require.Equal(t, int64(125), summary.TotalOutputTokens)
}
func TestChatCostSummary_MixedVersionRows(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client, db := newChatClientWithDatabase(t)
firstUser := coderdtest.CreateFirstUser(t, client)
modelConfig := createChatModelConfig(t, client)
chat, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{
OwnerID: firstUser.UserID,
LastModelConfigID: modelConfig.ID,
Title: "mixed version test",
})
require.NoError(t, err)
_, err = db.InsertChatMessage(dbauthz.AsSystemRestricted(ctx), database.InsertChatMessageParams{
ChatID: chat.ID,
ModelConfigID: uuid.NullUUID{UUID: modelConfig.ID, Valid: true},
Role: "assistant",
ContentVersion: 1,
Visibility: database.ChatMessageVisibilityBoth,
InputTokens: sql.NullInt64{Int64: 100, Valid: true},
OutputTokens: sql.NullInt64{Int64: 50, Valid: true},
TotalCostMicros: sql.NullInt64{Int64: 500, Valid: true},
CostValid: sql.NullBool{Bool: true, Valid: true},
})
require.NoError(t, err)
_, err = db.InsertChatMessage(dbauthz.AsSystemRestricted(ctx), database.InsertChatMessageParams{
ChatID: chat.ID,
ModelConfigID: uuid.NullUUID{UUID: modelConfig.ID, Valid: true},
Role: "assistant",
ContentVersion: 1,
Visibility: database.ChatMessageVisibilityBoth,
InputTokens: sql.NullInt64{Int64: 200, Valid: true},
OutputTokens: sql.NullInt64{Int64: 100, Valid: true},
TotalCostMicros: sql.NullInt64{Int64: 300, Valid: true},
CostValid: sql.NullBool{},
})
require.NoError(t, err)
_, err = db.InsertChatMessage(dbauthz.AsSystemRestricted(ctx), database.InsertChatMessageParams{
ChatID: chat.ID,
ModelConfigID: uuid.NullUUID{UUID: modelConfig.ID, Valid: true},
Role: "assistant",
ContentVersion: 1,
Visibility: database.ChatMessageVisibilityBoth,
InputTokens: sql.NullInt64{Int64: 150, Valid: true},
OutputTokens: sql.NullInt64{Int64: 75, Valid: true},
TotalCostMicros: sql.NullInt64{},
CostValid: sql.NullBool{Bool: false, Valid: true},
})
require.NoError(t, err)
_, err = db.InsertChatMessage(dbauthz.AsSystemRestricted(ctx), database.InsertChatMessageParams{
ChatID: chat.ID,
ModelConfigID: uuid.NullUUID{UUID: modelConfig.ID, Valid: true},
Role: "assistant",
ContentVersion: 1,
Visibility: database.ChatMessageVisibilityBoth,
InputTokens: sql.NullInt64{Int64: 25, Valid: true},
OutputTokens: sql.NullInt64{Int64: 10, Valid: true},
TotalCostMicros: sql.NullInt64{},
CostValid: sql.NullBool{},
})
require.NoError(t, err)
summary, err := client.GetChatCostSummary(ctx, "me", codersdk.ChatCostSummaryOptions{})
require.NoError(t, err)
require.Equal(t, int64(800), summary.TotalCostMicros)
require.Equal(t, int64(2), summary.PricedMessageCount)
require.Equal(t, int64(2), summary.UnpricedMessageCount)
require.Equal(t, int64(475), summary.TotalInputTokens)
require.Equal(t, int64(235), summary.TotalOutputTokens)
}
func requireChatModelPricing(
t *testing.T,
actual *codersdk.ChatModelCallConfig,
+2 -1
View File
@@ -1239,7 +1239,8 @@ CREATE TABLE chat_messages (
compressed boolean DEFAULT false NOT NULL,
created_by uuid,
content_version smallint NOT NULL,
total_cost_micros bigint
total_cost_micros bigint,
cost_valid boolean
);
CREATE SEQUENCE chat_messages_id_seq
@@ -0,0 +1,5 @@
-- Restore NULL cost for rows that new code marked as unpriced.
UPDATE chat_messages SET total_cost_micros = NULL WHERE cost_valid = false;
-- Drop cost_valid column.
ALTER TABLE chat_messages DROP COLUMN cost_valid;
@@ -0,0 +1,9 @@
-- Add cost_valid as a nullable column with no default for
-- mixed-version rollout compatibility. Old writers that do not
-- know about this column will insert NULL, which the summary
-- query interprets via COALESCE as falling back to the
-- total_cost_micros IS NOT NULL heuristic.
ALTER TABLE chat_messages ADD COLUMN cost_valid boolean;
-- Backfill: mark existing rows based on whether they have a cost.
UPDATE chat_messages SET cost_valid = (total_cost_micros IS NOT NULL);
+174
View File
@@ -465,6 +465,180 @@ func TestMigration000362AggregateUsageEvents(t *testing.T) {
}
}
func TestMigration000439ChatCostRolloutSafety(t *testing.T) {
t.Parallel()
if testing.Short() {
t.SkipNow()
return
}
const migrationVersion = 439
sqlDB := testSQLDB(t)
next, err := migrations.Stepper(sqlDB)
require.NoError(t, err)
for {
version, more, err := next()
require.NoError(t, err)
if !more {
t.Fatalf("migration %d not found", migrationVersion)
}
if version == migrationVersion-1 {
break
}
}
ctx := testutil.Context(t, testutil.WaitSuperLong)
orgID := uuid.New()
userID := uuid.New()
chatProviderID := uuid.New()
modelConfigID := uuid.New()
chatID := uuid.New()
_, err = sqlDB.ExecContext(ctx, `
INSERT INTO organizations (id, name, display_name, description, icon, created_at, updated_at, is_default)
VALUES ($1, 'test-org', 'Test Org', '', '', NOW(), NOW(), false)
`, orgID)
require.NoError(t, err)
_, err = sqlDB.ExecContext(ctx, `
INSERT INTO users (id, email, username, name, hashed_password, created_at, updated_at, status, rbac_roles, login_type)
VALUES ($1, 'test@test.com', 'testuser', 'Test User', 'xxx', NOW(), NOW(), 'active', '{}', 'password')
`, userID)
require.NoError(t, err)
_, err = sqlDB.ExecContext(ctx, `
INSERT INTO organization_members (organization_id, user_id, created_at, updated_at, roles)
VALUES ($1, $2, NOW(), NOW(), '{}')
`, orgID, userID)
require.NoError(t, err)
_, err = sqlDB.ExecContext(ctx, `
INSERT INTO chat_providers (
id,
provider,
display_name,
api_key,
api_key_key_id,
enabled,
created_at,
updated_at
) VALUES (
$1,
'openai',
'OpenAI',
'',
NULL,
true,
NOW(),
NOW()
)
`, chatProviderID)
require.NoError(t, err)
_, err = sqlDB.ExecContext(ctx, `
INSERT INTO chat_model_configs (
id,
display_name,
provider,
model,
enabled,
context_limit,
compression_threshold,
created_at,
updated_at
) VALUES (
$1,
'test model',
'openai',
'gpt-4',
true,
200000,
70,
NOW(),
NOW()
)
`, modelConfigID)
require.NoError(t, err)
_, err = sqlDB.ExecContext(ctx, `
INSERT INTO chats (id, owner_id, last_model_config_id, title, created_at, updated_at)
VALUES ($1, $2, $3, 'test chat', NOW(), NOW())
`, chatID, userID, modelConfigID)
require.NoError(t, err)
_, err = sqlDB.ExecContext(ctx, `
INSERT INTO chat_messages (chat_id, role, content_version, visibility, total_cost_micros, created_at)
VALUES ($1, 'assistant', 1, 'both', 500, NOW())
`, chatID)
require.NoError(t, err)
_, err = sqlDB.ExecContext(ctx, `
INSERT INTO chat_messages (chat_id, role, content_version, visibility, total_cost_micros, created_at)
VALUES ($1, 'assistant', 1, 'both', NULL, NOW())
`, chatID)
require.NoError(t, err)
version, _, err := next()
require.NoError(t, err)
require.EqualValues(t, migrationVersion, version)
rows, err := sqlDB.QueryContext(ctx, `
SELECT total_cost_micros, cost_valid
FROM chat_messages
WHERE chat_id = $1
ORDER BY id
`, chatID)
require.NoError(t, err)
defer rows.Close()
var messages []struct {
totalCost sql.NullInt64
costValid sql.NullBool
}
for rows.Next() {
var message struct {
totalCost sql.NullInt64
costValid sql.NullBool
}
err = rows.Scan(&message.totalCost, &message.costValid)
require.NoError(t, err)
messages = append(messages, message)
}
require.NoError(t, rows.Err())
require.Len(t, messages, 2)
require.True(t, messages[0].totalCost.Valid)
require.Equal(t, int64(500), messages[0].totalCost.Int64)
require.True(t, messages[0].costValid.Valid)
require.True(t, messages[0].costValid.Bool)
require.False(t, messages[1].totalCost.Valid)
require.True(t, messages[1].costValid.Valid)
require.False(t, messages[1].costValid.Bool)
_, err = sqlDB.ExecContext(ctx, `
INSERT INTO chat_messages (chat_id, role, content_version, visibility, total_cost_micros, created_at)
VALUES ($1, 'assistant', 1, 'both', NULL, NOW())
`, chatID)
require.NoError(t, err)
var oldStyleTotalCost sql.NullInt64
var oldStyleCostValid sql.NullBool
err = sqlDB.QueryRowContext(ctx, `
SELECT total_cost_micros, cost_valid
FROM chat_messages
WHERE chat_id = $1
ORDER BY id DESC
LIMIT 1
`, chatID).Scan(&oldStyleTotalCost, &oldStyleCostValid)
require.NoError(t, err)
require.False(t, oldStyleTotalCost.Valid)
require.False(t, oldStyleCostValid.Valid)
}
func TestMigration000387MigrateTaskWorkspaces(t *testing.T) {
t.Parallel()
+1
View File
@@ -4085,6 +4085,7 @@ type ChatMessage struct {
CreatedBy uuid.NullUUID `db:"created_by" json:"created_by"`
ContentVersion int16 `db:"content_version" json:"content_version"`
TotalCostMicros sql.NullInt64 `db:"total_cost_micros" json:"total_cost_micros"`
CostValid sql.NullBool `db:"cost_valid" json:"cost_valid"`
}
type ChatModelConfig struct {
+7 -5
View File
@@ -9052,11 +9052,13 @@ func TestGetChatMessagesForPromptByChatID(t *testing.T) {
) database.ChatMessage {
t.Helper()
msg, err := db.InsertChatMessage(ctx, database.InsertChatMessageParams{
ChatID: chatID,
Role: role,
ContentVersion: chatprompt.CurrentContentVersion,
Visibility: vis,
Compressed: sql.NullBool{Bool: compressed, Valid: true},
ChatID: chatID,
Role: role,
ContentVersion: chatprompt.CurrentContentVersion,
Visibility: vis,
Compressed: sql.NullBool{Bool: compressed, Valid: true},
TotalCostMicros: sql.NullInt64{},
CostValid: sql.NullBool{},
Content: pqtype.NullRawMessage{
RawMessage: json.RawMessage(`"` + content + `"`),
Valid: true,
+20 -10
View File
@@ -3563,10 +3563,10 @@ const getChatCostSummary = `-- name: GetChatCostSummary :one
SELECT
COALESCE(SUM(cm.total_cost_micros), 0)::bigint AS total_cost_micros,
COUNT(*) FILTER (
WHERE cm.total_cost_micros IS NOT NULL
WHERE COALESCE(cm.cost_valid, cm.total_cost_micros IS NOT NULL)
)::bigint AS priced_message_count,
COUNT(*) FILTER (
WHERE cm.total_cost_micros IS NULL
WHERE NOT COALESCE(cm.cost_valid, cm.total_cost_micros IS NOT NULL)
AND (
cm.input_tokens IS NOT NULL
OR cm.output_tokens IS NOT NULL
@@ -3721,7 +3721,7 @@ func (q *sqlQuerier) GetChatDiffStatusesByChatIDs(ctx context.Context, chatIds [
const getChatMessageByID = `-- name: GetChatMessageByID :one
SELECT
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros, cost_valid
FROM
chat_messages
WHERE
@@ -3750,13 +3750,14 @@ func (q *sqlQuerier) GetChatMessageByID(ctx context.Context, id int64) (ChatMess
&i.CreatedBy,
&i.ContentVersion,
&i.TotalCostMicros,
&i.CostValid,
)
return i, err
}
const getChatMessagesByChatID = `-- name: GetChatMessagesByChatID :many
SELECT
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros, cost_valid
FROM
chat_messages
WHERE
@@ -3800,6 +3801,7 @@ func (q *sqlQuerier) GetChatMessagesByChatID(ctx context.Context, arg GetChatMes
&i.CreatedBy,
&i.ContentVersion,
&i.TotalCostMicros,
&i.CostValid,
); err != nil {
return nil, err
}
@@ -3831,7 +3833,7 @@ WITH latest_compressed_summary AS (
1
)
SELECT
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros, cost_valid
FROM
chat_messages
WHERE
@@ -3899,6 +3901,7 @@ func (q *sqlQuerier) GetChatMessagesForPromptByChatID(ctx context.Context, chatI
&i.CreatedBy,
&i.ContentVersion,
&i.TotalCostMicros,
&i.CostValid,
); err != nil {
return nil, err
}
@@ -4043,7 +4046,7 @@ func (q *sqlQuerier) GetChatsByOwnerID(ctx context.Context, arg GetChatsByOwnerI
const getLastChatMessageByRole = `-- name: GetLastChatMessageByRole :one
SELECT
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros, cost_valid
FROM
chat_messages
WHERE
@@ -4082,6 +4085,7 @@ func (q *sqlQuerier) GetLastChatMessageByRole(ctx context.Context, arg GetLastCh
&i.CreatedBy,
&i.ContentVersion,
&i.TotalCostMicros,
&i.CostValid,
)
return i, err
}
@@ -4228,7 +4232,8 @@ INSERT INTO chat_messages (
cache_read_tokens,
context_limit,
compressed,
total_cost_micros
total_cost_micros,
cost_valid
) VALUES (
$1::uuid,
$2::uuid,
@@ -4245,10 +4250,11 @@ INSERT INTO chat_messages (
$13::bigint,
$14::bigint,
COALESCE($15::boolean, FALSE),
$16::bigint
$16::bigint,
$17::boolean
)
RETURNING
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros, cost_valid
`
type InsertChatMessageParams struct {
@@ -4268,6 +4274,7 @@ type InsertChatMessageParams struct {
ContextLimit sql.NullInt64 `db:"context_limit" json:"context_limit"`
Compressed sql.NullBool `db:"compressed" json:"compressed"`
TotalCostMicros sql.NullInt64 `db:"total_cost_micros" json:"total_cost_micros"`
CostValid sql.NullBool `db:"cost_valid" json:"cost_valid"`
}
func (q *sqlQuerier) InsertChatMessage(ctx context.Context, arg InsertChatMessageParams) (ChatMessage, error) {
@@ -4288,6 +4295,7 @@ func (q *sqlQuerier) InsertChatMessage(ctx context.Context, arg InsertChatMessag
arg.ContextLimit,
arg.Compressed,
arg.TotalCostMicros,
arg.CostValid,
)
var i ChatMessage
err := row.Scan(
@@ -4309,6 +4317,7 @@ func (q *sqlQuerier) InsertChatMessage(ctx context.Context, arg InsertChatMessag
&i.CreatedBy,
&i.ContentVersion,
&i.TotalCostMicros,
&i.CostValid,
)
return i, err
}
@@ -4444,7 +4453,7 @@ SET
WHERE
id = $3::bigint
RETURNING
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros, cost_valid
`
type UpdateChatMessageByIDParams struct {
@@ -4475,6 +4484,7 @@ func (q *sqlQuerier) UpdateChatMessageByID(ctx context.Context, arg UpdateChatMe
&i.CreatedBy,
&i.ContentVersion,
&i.TotalCostMicros,
&i.CostValid,
)
return i, err
}
+6 -4
View File
@@ -182,7 +182,8 @@ INSERT INTO chat_messages (
cache_read_tokens,
context_limit,
compressed,
total_cost_micros
total_cost_micros,
cost_valid
) VALUES (
@chat_id::uuid,
sqlc.narg('created_by')::uuid,
@@ -199,7 +200,8 @@ INSERT INTO chat_messages (
sqlc.narg('cache_read_tokens')::bigint,
sqlc.narg('context_limit')::bigint,
COALESCE(sqlc.narg('compressed')::boolean, FALSE),
sqlc.narg('total_cost_micros')::bigint
sqlc.narg('total_cost_micros')::bigint,
sqlc.narg('cost_valid')::boolean
)
RETURNING
*;
@@ -516,10 +518,10 @@ WHERE
SELECT
COALESCE(SUM(cm.total_cost_micros), 0)::bigint AS total_cost_micros,
COUNT(*) FILTER (
WHERE cm.total_cost_micros IS NOT NULL
WHERE COALESCE(cm.cost_valid, cm.total_cost_micros IS NOT NULL)
)::bigint AS priced_message_count,
COUNT(*) FILTER (
WHERE cm.total_cost_micros IS NULL
WHERE NOT COALESCE(cm.cost_valid, cm.total_cost_micros IS NOT NULL)
AND (
cm.input_tokens IS NOT NULL
OR cm.output_tokens IS NOT NULL