feat: track ai seat usage (#22682)
When a user uses an AI feature, we record them in the `ai_seat_state` as consuming a seat. Added in debouching to prevent excessive writes to the db for this feature. There is no need for frequent updates.
This commit is contained in:
@@ -0,0 +1,38 @@
|
||||
// Package aiseats is the AGPL version the package.
|
||||
// The actual implementation is in `enterprise/aiseats`.
|
||||
package aiseats
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
)
|
||||
|
||||
type Reason struct {
|
||||
EventType database.AiSeatUsageReason
|
||||
Description string
|
||||
}
|
||||
|
||||
// ReasonAIBridge constructs a reason for usage originating from AI Bridge.
|
||||
func ReasonAIBridge(description string) Reason {
|
||||
return Reason{EventType: database.AiSeatUsageReasonAibridge, Description: description}
|
||||
}
|
||||
|
||||
// ReasonTask constructs a reason for usage originating from tasks.
|
||||
func ReasonTask(description string) Reason {
|
||||
return Reason{EventType: database.AiSeatUsageReasonTask, Description: description}
|
||||
}
|
||||
|
||||
// SeatTracker records AI seat consumption state.
|
||||
type SeatTracker interface {
|
||||
// RecordUsage does not return an error to prevent blocking the user from using
|
||||
// AI features. This method is used to record usage, not enforce it.
|
||||
RecordUsage(ctx context.Context, userID uuid.UUID, reason Reason)
|
||||
}
|
||||
|
||||
// Noop is an AGPL seat tracker that does nothing.
|
||||
type Noop struct{}
|
||||
|
||||
func (Noop) RecordUsage(context.Context, uuid.UUID, Reason) {}
|
||||
@@ -44,6 +44,7 @@ import (
|
||||
"github.com/coder/coder/v2/buildinfo"
|
||||
"github.com/coder/coder/v2/coderd/agentapi"
|
||||
"github.com/coder/coder/v2/coderd/agentapi/metadatabatcher"
|
||||
"github.com/coder/coder/v2/coderd/aiseats"
|
||||
_ "github.com/coder/coder/v2/coderd/apidoc" // Used for swagger docs.
|
||||
"github.com/coder/coder/v2/coderd/appearance"
|
||||
"github.com/coder/coder/v2/coderd/audit"
|
||||
@@ -630,6 +631,8 @@ func New(options *Options) *API {
|
||||
dbRolluper: options.DatabaseRolluper,
|
||||
ProfileCollector: defaultProfileCollector{},
|
||||
}
|
||||
api.AISeatTracker = aiseats.Noop{}
|
||||
|
||||
api.WorkspaceAppsProvider = workspaceapps.NewDBTokenProvider(
|
||||
ctx,
|
||||
options.Logger.Named("workspaceapps"),
|
||||
@@ -2033,6 +2036,8 @@ type API struct {
|
||||
dbRolluper *dbrollup.Rolluper
|
||||
// chatDaemon handles background processing of pending chats.
|
||||
chatDaemon *chatd.Server
|
||||
// AISeatTracker records AI seat usage.
|
||||
AISeatTracker aiseats.SeatTracker
|
||||
// gitSyncWorker refreshes stale chat diff statuses in the
|
||||
// background.
|
||||
gitSyncWorker *gitsync.Worker
|
||||
@@ -2245,6 +2250,7 @@ func (api *API) CreateInMemoryTaggedProvisionerDaemon(dialCtx context.Context, n
|
||||
provisionerdserver.Options{
|
||||
OIDCConfig: api.OIDCConfig,
|
||||
ExternalAuthConfigs: api.ExternalAuthConfigs,
|
||||
AISeatTracker: api.AISeatTracker,
|
||||
Clock: api.Clock,
|
||||
HeartbeatFn: options.heartbeatFn,
|
||||
},
|
||||
|
||||
@@ -28,6 +28,7 @@ import (
|
||||
protobuf "google.golang.org/protobuf/proto"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/coderd/aiseats"
|
||||
"github.com/coder/coder/v2/coderd/apikey"
|
||||
"github.com/coder/coder/v2/coderd/audit"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
@@ -76,6 +77,7 @@ const (
|
||||
type Options struct {
|
||||
OIDCConfig promoauth.OAuth2Config
|
||||
ExternalAuthConfigs []*externalauth.Config
|
||||
AISeatTracker aiseats.SeatTracker
|
||||
|
||||
// Clock for testing
|
||||
Clock quartz.Clock
|
||||
@@ -120,6 +122,7 @@ type server struct {
|
||||
NotificationsEnqueuer notifications.Enqueuer
|
||||
PrebuildsOrchestrator *atomic.Pointer[prebuilds.ReconciliationOrchestrator]
|
||||
UsageInserter *atomic.Pointer[usage.Inserter]
|
||||
AISeatTracker aiseats.SeatTracker
|
||||
Experiments codersdk.Experiments
|
||||
|
||||
OIDCConfig promoauth.OAuth2Config
|
||||
@@ -215,6 +218,9 @@ func NewServer(
|
||||
if err := tags.Valid(); err != nil {
|
||||
return nil, xerrors.Errorf("invalid tags: %w", err)
|
||||
}
|
||||
if options.AISeatTracker == nil {
|
||||
options.AISeatTracker = aiseats.Noop{}
|
||||
}
|
||||
if options.AcquireJobLongPollDur == 0 {
|
||||
options.AcquireJobLongPollDur = DefaultAcquireJobLongPollDur
|
||||
}
|
||||
@@ -253,6 +259,7 @@ func NewServer(
|
||||
heartbeatFn: options.HeartbeatFn,
|
||||
PrebuildsOrchestrator: prebuildsOrchestrator,
|
||||
UsageInserter: usageInserter,
|
||||
AISeatTracker: options.AISeatTracker,
|
||||
metrics: metrics,
|
||||
Experiments: experiments,
|
||||
}
|
||||
@@ -2437,6 +2444,12 @@ func (s *server) completeWorkspaceBuildJob(ctx context.Context, job database.Pro
|
||||
})
|
||||
}
|
||||
|
||||
// Record AI seat usage for successful task workspace builds.
|
||||
if workspaceBuild.Transition == database.WorkspaceTransitionStart && workspace.TaskID.Valid {
|
||||
s.AISeatTracker.RecordUsage(ctx, workspace.OwnerID,
|
||||
aiseats.ReasonTask("task workspace build succeeded"))
|
||||
}
|
||||
|
||||
if s.PrebuildsOrchestrator != nil && input.PrebuiltWorkspaceBuildStage == sdkproto.PrebuiltWorkspaceBuildStage_CLAIM {
|
||||
// Track resource replacements, if there are any.
|
||||
orchestrator := s.PrebuildsOrchestrator.Load()
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"google.golang.org/protobuf/types/known/structpb"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/coderd/aiseats"
|
||||
"github.com/coder/coder/v2/coderd/apikey"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
@@ -81,10 +82,12 @@ type Server struct {
|
||||
|
||||
coderMCPConfig *proto.MCPServerConfig // may be nil if not available
|
||||
structuredLogging bool
|
||||
aiSeatTracker aiseats.SeatTracker
|
||||
}
|
||||
|
||||
func NewServer(lifecycleCtx context.Context, store store, logger slog.Logger, accessURL string,
|
||||
bridgeCfg codersdk.AIBridgeConfig, externalAuthConfigs []*externalauth.Config, experiments codersdk.Experiments,
|
||||
aiSeatTracker aiseats.SeatTracker,
|
||||
) (*Server, error) {
|
||||
eac := make(map[string]*externalauth.Config, len(externalAuthConfigs))
|
||||
|
||||
@@ -102,6 +105,7 @@ func NewServer(lifecycleCtx context.Context, store store, logger slog.Logger, ac
|
||||
logger: logger,
|
||||
externalAuthConfigs: eac,
|
||||
structuredLogging: bridgeCfg.StructuredLogging.Value(),
|
||||
aiSeatTracker: aiSeatTracker,
|
||||
}
|
||||
|
||||
if bridgeCfg.InjectCoderMCPTools {
|
||||
@@ -184,6 +188,8 @@ func (s *Server) RecordInterception(ctx context.Context, in *proto.RecordInterce
|
||||
return nil, xerrors.Errorf("start interception: %w", err)
|
||||
}
|
||||
|
||||
reason := aiseats.ReasonAIBridge("provider=" + in.Provider + ", model=" + in.Model)
|
||||
s.aiSeatTracker.RecordUsage(ctx, initID, reason)
|
||||
return &proto.RecordInterceptionResponse{}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -24,6 +24,7 @@ import (
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"cdr.dev/slog/v3/sloggers/slogjson"
|
||||
"github.com/coder/coder/v2/coderd/aiseats"
|
||||
"github.com/coder/coder/v2/coderd/apikey"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbgen"
|
||||
@@ -176,7 +177,7 @@ func TestAuthorization(t *testing.T) {
|
||||
tc.mocksFn(db, apiKey, user)
|
||||
}
|
||||
|
||||
srv, err := aibridgedserver.NewServer(t.Context(), db, logger, "/", codersdk.AIBridgeConfig{}, nil, requiredExperiments)
|
||||
srv, err := aibridgedserver.NewServer(t.Context(), db, logger, "/", codersdk.AIBridgeConfig{}, nil, requiredExperiments, aiseats.Noop{})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, srv)
|
||||
|
||||
@@ -268,7 +269,7 @@ func TestGetMCPServerConfigs(t *testing.T) {
|
||||
accessURL := "https://my-cool-deployment.com"
|
||||
srv, err := aibridgedserver.NewServer(t.Context(), db, logger, accessURL, codersdk.AIBridgeConfig{
|
||||
InjectCoderMCPTools: serpent.Bool(!tc.disableCoderMCPInjection),
|
||||
}, tc.externalAuthConfigs, tc.experiments)
|
||||
}, tc.externalAuthConfigs, tc.experiments, aiseats.Noop{})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, srv)
|
||||
|
||||
@@ -318,7 +319,7 @@ func TestGetMCPServerAccessTokensBatch(t *testing.T) {
|
||||
{
|
||||
ID: "3",
|
||||
},
|
||||
}, requiredExperiments)
|
||||
}, requiredExperiments, aiseats.Noop{})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, srv)
|
||||
|
||||
@@ -1014,7 +1015,7 @@ func testRecordMethod[Req any, Resp any](
|
||||
}
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
srv, err := aibridgedserver.NewServer(ctx, db, logger, "/", codersdk.AIBridgeConfig{}, nil, requiredExperiments)
|
||||
srv, err := aibridgedserver.NewServer(ctx, db, logger, "/", codersdk.AIBridgeConfig{}, nil, requiredExperiments, aiseats.Noop{})
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := callMethod(srv, ctx, tc.request)
|
||||
@@ -1309,7 +1310,7 @@ func TestStructuredLogging(t *testing.T) {
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
srv, err := aibridgedserver.NewServer(ctx, db, logger, "/", codersdk.AIBridgeConfig{
|
||||
StructuredLogging: serpent.Bool(tc.structuredLogging),
|
||||
}, nil, requiredExperiments)
|
||||
}, nil, requiredExperiments, aiseats.Noop{})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = tc.recordFn(srv, ctx, interceptionID)
|
||||
@@ -1351,7 +1352,7 @@ func TestInferredThreadsByToolCalls(t *testing.T) {
|
||||
|
||||
user := dbgen.User(t, db, database.User{})
|
||||
|
||||
srv, err := aibridgedserver.NewServer(ctx, db, logger, "/", codersdk.AIBridgeConfig{}, nil, requiredExperiments)
|
||||
srv, err := aibridgedserver.NewServer(ctx, db, logger, "/", codersdk.AIBridgeConfig{}, nil, requiredExperiments, aiseats.Noop{})
|
||||
require.NoError(t, err)
|
||||
|
||||
aID := uuid.New()
|
||||
|
||||
@@ -0,0 +1,91 @@
|
||||
package aiseats
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
agplaiseats "github.com/coder/coder/v2/coderd/aiseats"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
type store interface {
|
||||
UpsertAISeatState(ctx context.Context, arg database.UpsertAISeatStateParams) error
|
||||
}
|
||||
|
||||
// throttleInterval is the minimum time between DB writes for the same user. This
|
||||
// is to prevent ai seat tracking from consuming more db resources.
|
||||
//
|
||||
// These events are not critical to be recorded in real time, so we can afford to
|
||||
// skip almost all of them. The first write is the most important, as it
|
||||
// indicates a seat is consumed. Subsequent writes are purely informative and has
|
||||
// no functional impact.
|
||||
const (
|
||||
throttleInterval = 6 * time.Hour
|
||||
// failedThrottleInterval exists to prevent a transient error from causing no
|
||||
// usage to be recorded. Still debounce.
|
||||
failedThrottleInterval = 30 * time.Minute
|
||||
)
|
||||
|
||||
// SeatTracker records current AI seat state for users.
|
||||
type SeatTracker struct {
|
||||
db store
|
||||
logger slog.Logger
|
||||
clock quartz.Clock
|
||||
|
||||
mu sync.RWMutex
|
||||
retryAfter map[uuid.UUID]time.Time
|
||||
}
|
||||
|
||||
func New(db store, logger slog.Logger, clock quartz.Clock) *SeatTracker {
|
||||
if clock == nil {
|
||||
clock = quartz.NewReal()
|
||||
}
|
||||
return &SeatTracker{db: db, logger: logger, clock: clock, retryAfter: make(map[uuid.UUID]time.Time)}
|
||||
}
|
||||
|
||||
// skipRecord returns true when the user is still in the retry cooldown
|
||||
// window and we should skip a DB write attempt.
|
||||
func (t *SeatTracker) skipRecord(userID uuid.UUID, now time.Time) bool {
|
||||
t.mu.RLock()
|
||||
defer t.mu.RUnlock()
|
||||
|
||||
retryAfter, ok := t.retryAfter[userID]
|
||||
return ok && now.Before(retryAfter)
|
||||
}
|
||||
|
||||
// recordThrottle sets the next time when DB writes for this user are allowed.
|
||||
func (t *SeatTracker) recordThrottle(userID uuid.UUID, now time.Time, d time.Duration) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
t.retryAfter[userID] = now.Add(d)
|
||||
}
|
||||
|
||||
// RecordUsage will record the AI seat usage for the user. There is a race condition between
|
||||
// checking if the user should be recorded or throttled and actually recording. This is fine, as
|
||||
// it just means we record the usage twice.
|
||||
// The throttle just exists to prevent excessive database queries.
|
||||
func (t *SeatTracker) RecordUsage(ctx context.Context, userID uuid.UUID, reason agplaiseats.Reason) {
|
||||
now := t.clock.Now()
|
||||
if t.skipRecord(userID, now) {
|
||||
return
|
||||
}
|
||||
|
||||
err := t.db.UpsertAISeatState(ctx, database.UpsertAISeatStateParams{
|
||||
UserID: userID,
|
||||
FirstUsedAt: now,
|
||||
LastEventType: reason.EventType,
|
||||
LastEventDescription: reason.Description,
|
||||
})
|
||||
if err != nil {
|
||||
t.logger.Warn(ctx, "upsert AI seat state", slog.Error(err), slog.F("user_id", userID), slog.F("event_type", reason.EventType))
|
||||
t.recordThrottle(userID, now, failedThrottleInterval)
|
||||
return
|
||||
}
|
||||
|
||||
t.recordThrottle(userID, now, throttleInterval)
|
||||
}
|
||||
@@ -0,0 +1,94 @@
|
||||
package aiseats_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
agplaiseats "github.com/coder/coder/v2/coderd/aiseats"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbgen"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtime"
|
||||
enterpriseaiseats "github.com/coder/coder/v2/enterprise/aiseats"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
func TestSeatTrackerDB(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("ActiveUserRecorded", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
clock := quartz.NewMock(t)
|
||||
tracker := enterpriseaiseats.New(db, testutil.Logger(t), clock)
|
||||
|
||||
user := dbgen.User(t, db, database.User{Status: database.UserStatusActive})
|
||||
tracker.RecordUsage(ctx, user.ID, agplaiseats.ReasonAIBridge("active user event"))
|
||||
|
||||
count, err := db.GetActiveAISeatCount(ctx)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 1, count)
|
||||
})
|
||||
|
||||
t.Run("InactiveUsersExcluded", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
tracker := enterpriseaiseats.New(db, testutil.Logger(t), quartz.NewMock(t))
|
||||
|
||||
dormantUser := dbgen.User(t, db, database.User{Status: database.UserStatusDormant})
|
||||
tracker.RecordUsage(ctx, dormantUser.ID, agplaiseats.ReasonTask("dormant user event"))
|
||||
|
||||
suspendedUser := dbgen.User(t, db, database.User{Status: database.UserStatusSuspended})
|
||||
tracker.RecordUsage(ctx, suspendedUser.ID, agplaiseats.ReasonTask("suspended user event"))
|
||||
|
||||
count, err := db.GetActiveAISeatCount(ctx)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 0, count)
|
||||
})
|
||||
|
||||
t.Run("StatusTransitions", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
tracker := enterpriseaiseats.New(db, testutil.Logger(t), quartz.NewMock(t))
|
||||
|
||||
user := dbgen.User(t, db, database.User{Status: database.UserStatusActive})
|
||||
tracker.RecordUsage(ctx, user.ID, agplaiseats.ReasonAIBridge("status transition"))
|
||||
|
||||
count, err := db.GetActiveAISeatCount(ctx)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 1, count)
|
||||
|
||||
_, err = db.UpdateUserStatus(ctx, database.UpdateUserStatusParams{
|
||||
ID: user.ID,
|
||||
Status: database.UserStatusDormant,
|
||||
UpdatedAt: dbtime.Now(),
|
||||
UserIsSeen: false,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
count, err = db.GetActiveAISeatCount(ctx)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 0, count)
|
||||
|
||||
_, err = db.UpdateUserStatus(ctx, database.UpdateUserStatusParams{
|
||||
ID: user.ID,
|
||||
Status: database.UserStatusActive,
|
||||
UpdatedAt: dbtime.Now().Add(time.Second),
|
||||
UserIsSeen: false,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
count, err = db.GetActiveAISeatCount(ctx)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 1, count)
|
||||
})
|
||||
}
|
||||
@@ -48,7 +48,7 @@ func (api *API) CreateInMemoryAIBridgeServer(dialCtx context.Context) (client ai
|
||||
|
||||
mux := drpcmux.New()
|
||||
srv, err := aibridgedserver.NewServer(api.ctx, api.Database, api.Logger.Named("aibridgedserver"),
|
||||
api.AccessURL.String(), api.DeploymentValues.AI.BridgeConfig, api.ExternalAuthConfigs, api.AGPL.Experiments)
|
||||
api.AccessURL.String(), api.DeploymentValues.AI.BridgeConfig, api.ExternalAuthConfigs, api.AGPL.Experiments, api.aiSeatTracker)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -45,6 +45,7 @@ import (
|
||||
agplusage "github.com/coder/coder/v2/coderd/usage"
|
||||
"github.com/coder/coder/v2/coderd/wsbuilder"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/enterprise/aiseats"
|
||||
entchatd "github.com/coder/coder/v2/enterprise/coderd/chatd"
|
||||
"github.com/coder/coder/v2/enterprise/coderd/connectionlog"
|
||||
"github.com/coder/coder/v2/enterprise/coderd/dbauthz"
|
||||
@@ -217,7 +218,10 @@ func New(ctx context.Context, options *Options) (_ *API, err error) {
|
||||
},
|
||||
})
|
||||
|
||||
api.aiSeatTracker = aiseats.New(options.Database, api.Logger.Named("aiseats"), quartz.NewReal())
|
||||
|
||||
api.AGPL = coderd.New(options.Options)
|
||||
api.AGPL.AISeatTracker = api.aiSeatTracker
|
||||
defer func() {
|
||||
if err != nil {
|
||||
_ = api.Close()
|
||||
@@ -785,6 +789,7 @@ type API struct {
|
||||
|
||||
aibridgedHandler http.Handler
|
||||
aibridgeproxydHandler http.Handler
|
||||
aiSeatTracker *aiseats.SeatTracker
|
||||
}
|
||||
|
||||
// writeEntitlementWarningsHeader writes the entitlement warnings to the response header
|
||||
|
||||
@@ -356,6 +356,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
|
||||
provisionerdserver.Options{
|
||||
ExternalAuthConfigs: api.ExternalAuthConfigs,
|
||||
OIDCConfig: api.OIDCConfig,
|
||||
AISeatTracker: api.AGPL.AISeatTracker,
|
||||
Clock: api.Clock,
|
||||
},
|
||||
api.NotificationsEnqueuer,
|
||||
|
||||
Reference in New Issue
Block a user