feat: track ai seat usage (#22682)

When a user uses an AI feature, we record them in the `ai_seat_state` as consuming a seat. 

Added in debouching to prevent excessive writes to the db for this feature. There is no need for frequent updates.
This commit is contained in:
Steven Masley
2026-03-16 12:36:26 -05:00
committed by GitHub
parent cabb611fd9
commit abf59ee7a6
10 changed files with 262 additions and 7 deletions
+38
View File
@@ -0,0 +1,38 @@
// Package aiseats is the AGPL version the package.
// The actual implementation is in `enterprise/aiseats`.
package aiseats
import (
"context"
"github.com/google/uuid"
"github.com/coder/coder/v2/coderd/database"
)
type Reason struct {
EventType database.AiSeatUsageReason
Description string
}
// ReasonAIBridge constructs a reason for usage originating from AI Bridge.
func ReasonAIBridge(description string) Reason {
return Reason{EventType: database.AiSeatUsageReasonAibridge, Description: description}
}
// ReasonTask constructs a reason for usage originating from tasks.
func ReasonTask(description string) Reason {
return Reason{EventType: database.AiSeatUsageReasonTask, Description: description}
}
// SeatTracker records AI seat consumption state.
type SeatTracker interface {
// RecordUsage does not return an error to prevent blocking the user from using
// AI features. This method is used to record usage, not enforce it.
RecordUsage(ctx context.Context, userID uuid.UUID, reason Reason)
}
// Noop is an AGPL seat tracker that does nothing.
type Noop struct{}
func (Noop) RecordUsage(context.Context, uuid.UUID, Reason) {}
+6
View File
@@ -44,6 +44,7 @@ import (
"github.com/coder/coder/v2/buildinfo"
"github.com/coder/coder/v2/coderd/agentapi"
"github.com/coder/coder/v2/coderd/agentapi/metadatabatcher"
"github.com/coder/coder/v2/coderd/aiseats"
_ "github.com/coder/coder/v2/coderd/apidoc" // Used for swagger docs.
"github.com/coder/coder/v2/coderd/appearance"
"github.com/coder/coder/v2/coderd/audit"
@@ -630,6 +631,8 @@ func New(options *Options) *API {
dbRolluper: options.DatabaseRolluper,
ProfileCollector: defaultProfileCollector{},
}
api.AISeatTracker = aiseats.Noop{}
api.WorkspaceAppsProvider = workspaceapps.NewDBTokenProvider(
ctx,
options.Logger.Named("workspaceapps"),
@@ -2033,6 +2036,8 @@ type API struct {
dbRolluper *dbrollup.Rolluper
// chatDaemon handles background processing of pending chats.
chatDaemon *chatd.Server
// AISeatTracker records AI seat usage.
AISeatTracker aiseats.SeatTracker
// gitSyncWorker refreshes stale chat diff statuses in the
// background.
gitSyncWorker *gitsync.Worker
@@ -2245,6 +2250,7 @@ func (api *API) CreateInMemoryTaggedProvisionerDaemon(dialCtx context.Context, n
provisionerdserver.Options{
OIDCConfig: api.OIDCConfig,
ExternalAuthConfigs: api.ExternalAuthConfigs,
AISeatTracker: api.AISeatTracker,
Clock: api.Clock,
HeartbeatFn: options.heartbeatFn,
},
@@ -28,6 +28,7 @@ import (
protobuf "google.golang.org/protobuf/proto"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/coderd/aiseats"
"github.com/coder/coder/v2/coderd/apikey"
"github.com/coder/coder/v2/coderd/audit"
"github.com/coder/coder/v2/coderd/database"
@@ -76,6 +77,7 @@ const (
type Options struct {
OIDCConfig promoauth.OAuth2Config
ExternalAuthConfigs []*externalauth.Config
AISeatTracker aiseats.SeatTracker
// Clock for testing
Clock quartz.Clock
@@ -120,6 +122,7 @@ type server struct {
NotificationsEnqueuer notifications.Enqueuer
PrebuildsOrchestrator *atomic.Pointer[prebuilds.ReconciliationOrchestrator]
UsageInserter *atomic.Pointer[usage.Inserter]
AISeatTracker aiseats.SeatTracker
Experiments codersdk.Experiments
OIDCConfig promoauth.OAuth2Config
@@ -215,6 +218,9 @@ func NewServer(
if err := tags.Valid(); err != nil {
return nil, xerrors.Errorf("invalid tags: %w", err)
}
if options.AISeatTracker == nil {
options.AISeatTracker = aiseats.Noop{}
}
if options.AcquireJobLongPollDur == 0 {
options.AcquireJobLongPollDur = DefaultAcquireJobLongPollDur
}
@@ -253,6 +259,7 @@ func NewServer(
heartbeatFn: options.HeartbeatFn,
PrebuildsOrchestrator: prebuildsOrchestrator,
UsageInserter: usageInserter,
AISeatTracker: options.AISeatTracker,
metrics: metrics,
Experiments: experiments,
}
@@ -2437,6 +2444,12 @@ func (s *server) completeWorkspaceBuildJob(ctx context.Context, job database.Pro
})
}
// Record AI seat usage for successful task workspace builds.
if workspaceBuild.Transition == database.WorkspaceTransitionStart && workspace.TaskID.Valid {
s.AISeatTracker.RecordUsage(ctx, workspace.OwnerID,
aiseats.ReasonTask("task workspace build succeeded"))
}
if s.PrebuildsOrchestrator != nil && input.PrebuiltWorkspaceBuildStage == sdkproto.PrebuiltWorkspaceBuildStage_CLAIM {
// Track resource replacements, if there are any.
orchestrator := s.PrebuildsOrchestrator.Load()
@@ -15,6 +15,7 @@ import (
"google.golang.org/protobuf/types/known/structpb"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/coderd/aiseats"
"github.com/coder/coder/v2/coderd/apikey"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
@@ -81,10 +82,12 @@ type Server struct {
coderMCPConfig *proto.MCPServerConfig // may be nil if not available
structuredLogging bool
aiSeatTracker aiseats.SeatTracker
}
func NewServer(lifecycleCtx context.Context, store store, logger slog.Logger, accessURL string,
bridgeCfg codersdk.AIBridgeConfig, externalAuthConfigs []*externalauth.Config, experiments codersdk.Experiments,
aiSeatTracker aiseats.SeatTracker,
) (*Server, error) {
eac := make(map[string]*externalauth.Config, len(externalAuthConfigs))
@@ -102,6 +105,7 @@ func NewServer(lifecycleCtx context.Context, store store, logger slog.Logger, ac
logger: logger,
externalAuthConfigs: eac,
structuredLogging: bridgeCfg.StructuredLogging.Value(),
aiSeatTracker: aiSeatTracker,
}
if bridgeCfg.InjectCoderMCPTools {
@@ -184,6 +188,8 @@ func (s *Server) RecordInterception(ctx context.Context, in *proto.RecordInterce
return nil, xerrors.Errorf("start interception: %w", err)
}
reason := aiseats.ReasonAIBridge("provider=" + in.Provider + ", model=" + in.Model)
s.aiSeatTracker.RecordUsage(ctx, initID, reason)
return &proto.RecordInterceptionResponse{}, nil
}
@@ -24,6 +24,7 @@ import (
"cdr.dev/slog/v3"
"cdr.dev/slog/v3/sloggers/slogjson"
"github.com/coder/coder/v2/coderd/aiseats"
"github.com/coder/coder/v2/coderd/apikey"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbgen"
@@ -176,7 +177,7 @@ func TestAuthorization(t *testing.T) {
tc.mocksFn(db, apiKey, user)
}
srv, err := aibridgedserver.NewServer(t.Context(), db, logger, "/", codersdk.AIBridgeConfig{}, nil, requiredExperiments)
srv, err := aibridgedserver.NewServer(t.Context(), db, logger, "/", codersdk.AIBridgeConfig{}, nil, requiredExperiments, aiseats.Noop{})
require.NoError(t, err)
require.NotNil(t, srv)
@@ -268,7 +269,7 @@ func TestGetMCPServerConfigs(t *testing.T) {
accessURL := "https://my-cool-deployment.com"
srv, err := aibridgedserver.NewServer(t.Context(), db, logger, accessURL, codersdk.AIBridgeConfig{
InjectCoderMCPTools: serpent.Bool(!tc.disableCoderMCPInjection),
}, tc.externalAuthConfigs, tc.experiments)
}, tc.externalAuthConfigs, tc.experiments, aiseats.Noop{})
require.NoError(t, err)
require.NotNil(t, srv)
@@ -318,7 +319,7 @@ func TestGetMCPServerAccessTokensBatch(t *testing.T) {
{
ID: "3",
},
}, requiredExperiments)
}, requiredExperiments, aiseats.Noop{})
require.NoError(t, err)
require.NotNil(t, srv)
@@ -1014,7 +1015,7 @@ func testRecordMethod[Req any, Resp any](
}
ctx := testutil.Context(t, testutil.WaitLong)
srv, err := aibridgedserver.NewServer(ctx, db, logger, "/", codersdk.AIBridgeConfig{}, nil, requiredExperiments)
srv, err := aibridgedserver.NewServer(ctx, db, logger, "/", codersdk.AIBridgeConfig{}, nil, requiredExperiments, aiseats.Noop{})
require.NoError(t, err)
resp, err := callMethod(srv, ctx, tc.request)
@@ -1309,7 +1310,7 @@ func TestStructuredLogging(t *testing.T) {
ctx := testutil.Context(t, testutil.WaitLong)
srv, err := aibridgedserver.NewServer(ctx, db, logger, "/", codersdk.AIBridgeConfig{
StructuredLogging: serpent.Bool(tc.structuredLogging),
}, nil, requiredExperiments)
}, nil, requiredExperiments, aiseats.Noop{})
require.NoError(t, err)
err = tc.recordFn(srv, ctx, interceptionID)
@@ -1351,7 +1352,7 @@ func TestInferredThreadsByToolCalls(t *testing.T) {
user := dbgen.User(t, db, database.User{})
srv, err := aibridgedserver.NewServer(ctx, db, logger, "/", codersdk.AIBridgeConfig{}, nil, requiredExperiments)
srv, err := aibridgedserver.NewServer(ctx, db, logger, "/", codersdk.AIBridgeConfig{}, nil, requiredExperiments, aiseats.Noop{})
require.NoError(t, err)
aID := uuid.New()
+91
View File
@@ -0,0 +1,91 @@
package aiseats
import (
"context"
"sync"
"time"
"github.com/google/uuid"
"cdr.dev/slog/v3"
agplaiseats "github.com/coder/coder/v2/coderd/aiseats"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/quartz"
)
type store interface {
UpsertAISeatState(ctx context.Context, arg database.UpsertAISeatStateParams) error
}
// throttleInterval is the minimum time between DB writes for the same user. This
// is to prevent ai seat tracking from consuming more db resources.
//
// These events are not critical to be recorded in real time, so we can afford to
// skip almost all of them. The first write is the most important, as it
// indicates a seat is consumed. Subsequent writes are purely informative and has
// no functional impact.
const (
throttleInterval = 6 * time.Hour
// failedThrottleInterval exists to prevent a transient error from causing no
// usage to be recorded. Still debounce.
failedThrottleInterval = 30 * time.Minute
)
// SeatTracker records current AI seat state for users.
type SeatTracker struct {
db store
logger slog.Logger
clock quartz.Clock
mu sync.RWMutex
retryAfter map[uuid.UUID]time.Time
}
func New(db store, logger slog.Logger, clock quartz.Clock) *SeatTracker {
if clock == nil {
clock = quartz.NewReal()
}
return &SeatTracker{db: db, logger: logger, clock: clock, retryAfter: make(map[uuid.UUID]time.Time)}
}
// skipRecord returns true when the user is still in the retry cooldown
// window and we should skip a DB write attempt.
func (t *SeatTracker) skipRecord(userID uuid.UUID, now time.Time) bool {
t.mu.RLock()
defer t.mu.RUnlock()
retryAfter, ok := t.retryAfter[userID]
return ok && now.Before(retryAfter)
}
// recordThrottle sets the next time when DB writes for this user are allowed.
func (t *SeatTracker) recordThrottle(userID uuid.UUID, now time.Time, d time.Duration) {
t.mu.Lock()
defer t.mu.Unlock()
t.retryAfter[userID] = now.Add(d)
}
// RecordUsage will record the AI seat usage for the user. There is a race condition between
// checking if the user should be recorded or throttled and actually recording. This is fine, as
// it just means we record the usage twice.
// The throttle just exists to prevent excessive database queries.
func (t *SeatTracker) RecordUsage(ctx context.Context, userID uuid.UUID, reason agplaiseats.Reason) {
now := t.clock.Now()
if t.skipRecord(userID, now) {
return
}
err := t.db.UpsertAISeatState(ctx, database.UpsertAISeatStateParams{
UserID: userID,
FirstUsedAt: now,
LastEventType: reason.EventType,
LastEventDescription: reason.Description,
})
if err != nil {
t.logger.Warn(ctx, "upsert AI seat state", slog.Error(err), slog.F("user_id", userID), slog.F("event_type", reason.EventType))
t.recordThrottle(userID, now, failedThrottleInterval)
return
}
t.recordThrottle(userID, now, throttleInterval)
}
+94
View File
@@ -0,0 +1,94 @@
package aiseats_test
import (
"testing"
"time"
"github.com/stretchr/testify/require"
agplaiseats "github.com/coder/coder/v2/coderd/aiseats"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/coderd/database/dbtime"
enterpriseaiseats "github.com/coder/coder/v2/enterprise/aiseats"
"github.com/coder/coder/v2/testutil"
"github.com/coder/quartz"
)
func TestSeatTrackerDB(t *testing.T) {
t.Parallel()
t.Run("ActiveUserRecorded", func(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
ctx := testutil.Context(t, testutil.WaitShort)
clock := quartz.NewMock(t)
tracker := enterpriseaiseats.New(db, testutil.Logger(t), clock)
user := dbgen.User(t, db, database.User{Status: database.UserStatusActive})
tracker.RecordUsage(ctx, user.ID, agplaiseats.ReasonAIBridge("active user event"))
count, err := db.GetActiveAISeatCount(ctx)
require.NoError(t, err)
require.EqualValues(t, 1, count)
})
t.Run("InactiveUsersExcluded", func(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
ctx := testutil.Context(t, testutil.WaitShort)
tracker := enterpriseaiseats.New(db, testutil.Logger(t), quartz.NewMock(t))
dormantUser := dbgen.User(t, db, database.User{Status: database.UserStatusDormant})
tracker.RecordUsage(ctx, dormantUser.ID, agplaiseats.ReasonTask("dormant user event"))
suspendedUser := dbgen.User(t, db, database.User{Status: database.UserStatusSuspended})
tracker.RecordUsage(ctx, suspendedUser.ID, agplaiseats.ReasonTask("suspended user event"))
count, err := db.GetActiveAISeatCount(ctx)
require.NoError(t, err)
require.EqualValues(t, 0, count)
})
t.Run("StatusTransitions", func(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
ctx := testutil.Context(t, testutil.WaitShort)
tracker := enterpriseaiseats.New(db, testutil.Logger(t), quartz.NewMock(t))
user := dbgen.User(t, db, database.User{Status: database.UserStatusActive})
tracker.RecordUsage(ctx, user.ID, agplaiseats.ReasonAIBridge("status transition"))
count, err := db.GetActiveAISeatCount(ctx)
require.NoError(t, err)
require.EqualValues(t, 1, count)
_, err = db.UpdateUserStatus(ctx, database.UpdateUserStatusParams{
ID: user.ID,
Status: database.UserStatusDormant,
UpdatedAt: dbtime.Now(),
UserIsSeen: false,
})
require.NoError(t, err)
count, err = db.GetActiveAISeatCount(ctx)
require.NoError(t, err)
require.EqualValues(t, 0, count)
_, err = db.UpdateUserStatus(ctx, database.UpdateUserStatusParams{
ID: user.ID,
Status: database.UserStatusActive,
UpdatedAt: dbtime.Now().Add(time.Second),
UserIsSeen: false,
})
require.NoError(t, err)
count, err = db.GetActiveAISeatCount(ctx)
require.NoError(t, err)
require.EqualValues(t, 1, count)
})
}
+1 -1
View File
@@ -48,7 +48,7 @@ func (api *API) CreateInMemoryAIBridgeServer(dialCtx context.Context) (client ai
mux := drpcmux.New()
srv, err := aibridgedserver.NewServer(api.ctx, api.Database, api.Logger.Named("aibridgedserver"),
api.AccessURL.String(), api.DeploymentValues.AI.BridgeConfig, api.ExternalAuthConfigs, api.AGPL.Experiments)
api.AccessURL.String(), api.DeploymentValues.AI.BridgeConfig, api.ExternalAuthConfigs, api.AGPL.Experiments, api.aiSeatTracker)
if err != nil {
return nil, err
}
+5
View File
@@ -45,6 +45,7 @@ import (
agplusage "github.com/coder/coder/v2/coderd/usage"
"github.com/coder/coder/v2/coderd/wsbuilder"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/enterprise/aiseats"
entchatd "github.com/coder/coder/v2/enterprise/coderd/chatd"
"github.com/coder/coder/v2/enterprise/coderd/connectionlog"
"github.com/coder/coder/v2/enterprise/coderd/dbauthz"
@@ -217,7 +218,10 @@ func New(ctx context.Context, options *Options) (_ *API, err error) {
},
})
api.aiSeatTracker = aiseats.New(options.Database, api.Logger.Named("aiseats"), quartz.NewReal())
api.AGPL = coderd.New(options.Options)
api.AGPL.AISeatTracker = api.aiSeatTracker
defer func() {
if err != nil {
_ = api.Close()
@@ -785,6 +789,7 @@ type API struct {
aibridgedHandler http.Handler
aibridgeproxydHandler http.Handler
aiSeatTracker *aiseats.SeatTracker
}
// writeEntitlementWarningsHeader writes the entitlement warnings to the response header
+1
View File
@@ -356,6 +356,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
provisionerdserver.Options{
ExternalAuthConfigs: api.ExternalAuthConfigs,
OIDCConfig: api.OIDCConfig,
AISeatTracker: api.AGPL.AISeatTracker,
Clock: api.Clock,
},
api.NotificationsEnqueuer,