feat: batch connection logs to avoid DB lock contention (#23727)

- Running 30k connections was generating a ton of lock contention in the
DB
This commit is contained in:
Jon Ayers
2026-04-03 15:47:26 -05:00
committed by GitHub
parent 333503f74e
commit a1d51f0dab
21 changed files with 2168 additions and 426 deletions
+1 -1
View File
@@ -85,7 +85,7 @@ func (a *ConnLogAPI) ReportConnection(ctx context.Context, req *agentproto.Repor
AgentName: a.AgentName,
Type: connectionType,
Code: code,
Ip: logIP,
IP: logIP,
ConnectionID: uuid.NullUUID{
UUID: connectionID,
Valid: true,
+1 -1
View File
@@ -152,7 +152,7 @@ func TestConnectionLog(t *testing.T) {
Int32: tt.status,
Valid: *tt.action == agentproto.Connection_DISCONNECT,
},
Ip: expectedIP,
IP: expectedIP,
Type: agentProtoConnectionTypeToConnectionLog(t, *tt.typ),
DisconnectReason: sql.NullString{
String: tt.reason,
+2 -2
View File
@@ -90,8 +90,8 @@ func (m *FakeConnectionLogger) Contains(t testing.TB, expected database.UpsertCo
t.Logf("connection log %d: expected Code %d, got %d", idx+1, expected.Code.Int32, cl.Code.Int32)
continue
}
if expected.Ip.Valid && cl.Ip.IPNet.String() != expected.Ip.IPNet.String() {
t.Logf("connection log %d: expected IP %s, got %s", idx+1, expected.Ip.IPNet, cl.Ip.IPNet)
if expected.IP.Valid && cl.IP.IPNet.String() != expected.IP.IPNet.String() {
t.Logf("connection log %d: expected IP %s, got %s", idx+1, expected.IP.IPNet, cl.IP.IPNet)
continue
}
if expected.UserAgent.Valid && cl.UserAgent.String != expected.UserAgent.String {
+7 -7
View File
@@ -1627,6 +1627,13 @@ func (q *querier) BatchUpdateWorkspaceNextStartAt(ctx context.Context, arg datab
return q.db.BatchUpdateWorkspaceNextStartAt(ctx, arg)
}
func (q *querier) BatchUpsertConnectionLogs(ctx context.Context, arg database.BatchUpsertConnectionLogsParams) error {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceConnectionLog); err != nil {
return err
}
return q.db.BatchUpsertConnectionLogs(ctx, arg)
}
func (q *querier) BulkMarkNotificationMessagesFailed(ctx context.Context, arg database.BulkMarkNotificationMessagesFailedParams) (int64, error) {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceNotificationMessage); err != nil {
return 0, err
@@ -7065,13 +7072,6 @@ func (q *querier) UpsertChatWorkspaceTTL(ctx context.Context, workspaceTtl strin
return q.db.UpsertChatWorkspaceTTL(ctx, workspaceTtl)
}
func (q *querier) UpsertConnectionLog(ctx context.Context, arg database.UpsertConnectionLogParams) (database.ConnectionLog, error) {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceConnectionLog); err != nil {
return database.ConnectionLog{}, err
}
return q.db.UpsertConnectionLog(ctx, arg)
}
func (q *querier) UpsertDefaultProxy(ctx context.Context, arg database.UpsertDefaultProxyParams) error {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceSystem); err != nil {
return err
+3 -4
View File
@@ -338,10 +338,9 @@ func (s *MethodTestSuite) TestAuditLogs() {
}
func (s *MethodTestSuite) TestConnectionLogs() {
s.Run("UpsertConnectionLog", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
ws := testutil.Fake(s.T(), faker, database.WorkspaceTable{})
arg := database.UpsertConnectionLogParams{Ip: defaultIPAddress(), Type: database.ConnectionTypeSsh, WorkspaceID: ws.ID, OrganizationID: ws.OrganizationID, ConnectionStatus: database.ConnectionStatusConnected, WorkspaceOwnerID: ws.OwnerID}
dbm.EXPECT().UpsertConnectionLog(gomock.Any(), arg).Return(database.ConnectionLog{}, nil).AnyTimes()
s.Run("BatchUpsertConnectionLogs", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
arg := database.BatchUpsertConnectionLogsParams{}
dbm.EXPECT().BatchUpsertConnectionLogs(gomock.Any(), arg).Return(nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceConnectionLog, policy.ActionUpdate)
}))
s.Run("GetConnectionLogsOffset", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
+47 -3
View File
@@ -76,7 +76,7 @@ func AuditLog(t testing.TB, db database.Store, seed database.AuditLog) database.
}
func ConnectionLog(t testing.TB, db database.Store, seed database.UpsertConnectionLogParams) database.ConnectionLog {
log, err := db.UpsertConnectionLog(genCtx, database.UpsertConnectionLogParams{
arg := database.UpsertConnectionLogParams{
ID: takeFirst(seed.ID, uuid.New()),
Time: takeFirst(seed.Time, dbtime.Now()),
OrganizationID: takeFirst(seed.OrganizationID, uuid.New()),
@@ -89,7 +89,7 @@ func ConnectionLog(t testing.TB, db database.Store, seed database.UpsertConnecti
Int32: takeFirst(seed.Code.Int32, 0),
Valid: takeFirst(seed.Code.Valid, false),
},
Ip: pqtype.Inet{
IP: pqtype.Inet{
IPNet: net.IPNet{
IP: net.IPv4(127, 0, 0, 1),
Mask: net.IPv4Mask(255, 255, 255, 255),
@@ -117,9 +117,53 @@ func ConnectionLog(t testing.TB, db database.Store, seed database.UpsertConnecti
Valid: takeFirst(seed.DisconnectReason.Valid, false),
},
ConnectionStatus: takeFirst(seed.ConnectionStatus, database.ConnectionStatusConnected),
}
var disconnectTime sql.NullTime
if arg.ConnectionStatus == database.ConnectionStatusDisconnected {
disconnectTime = sql.NullTime{Time: arg.Time, Valid: true}
}
err := db.BatchUpsertConnectionLogs(genCtx, database.BatchUpsertConnectionLogsParams{
ID: []uuid.UUID{arg.ID},
ConnectTime: []time.Time{arg.Time},
OrganizationID: []uuid.UUID{arg.OrganizationID},
WorkspaceOwnerID: []uuid.UUID{arg.WorkspaceOwnerID},
WorkspaceID: []uuid.UUID{arg.WorkspaceID},
WorkspaceName: []string{arg.WorkspaceName},
AgentName: []string{arg.AgentName},
Type: []database.ConnectionType{arg.Type},
Code: []int32{arg.Code.Int32},
CodeValid: []bool{arg.Code.Valid},
Ip: []pqtype.Inet{arg.IP},
UserAgent: []string{arg.UserAgent.String},
UserID: []uuid.UUID{arg.UserID.UUID},
SlugOrPort: []string{arg.SlugOrPort.String},
ConnectionID: []uuid.UUID{arg.ConnectionID.UUID},
DisconnectReason: []string{arg.DisconnectReason.String},
DisconnectTime: []time.Time{disconnectTime.Time},
})
require.NoError(t, err, "insert connection log")
return log
// Query back the actual row from the database. On upsert
// conflict the DB keeps the original row's ID, so we can't
// rely on arg.ID. Match on the conflict key for rows with a
// connection_id, or by primary key for NULL connection_id.
rows, err := db.GetConnectionLogsOffset(genCtx, database.GetConnectionLogsOffsetParams{})
require.NoError(t, err, "query connection logs")
for _, row := range rows {
if arg.ConnectionID.Valid {
if row.ConnectionLog.ConnectionID == arg.ConnectionID &&
row.ConnectionLog.WorkspaceID == arg.WorkspaceID &&
row.ConnectionLog.AgentName == arg.AgentName {
return row.ConnectionLog
}
} else if row.ConnectionLog.ID == arg.ID {
return row.ConnectionLog
}
}
require.Failf(t, "connection log not found", "id=%s", arg.ID)
return database.ConnectionLog{} // unreachable
}
func Template(t testing.TB, db database.Store, seed database.Template) database.Template {
+8 -8
View File
@@ -208,6 +208,14 @@ func (m queryMetricsStore) BatchUpdateWorkspaceNextStartAt(ctx context.Context,
return r0
}
func (m queryMetricsStore) BatchUpsertConnectionLogs(ctx context.Context, arg database.BatchUpsertConnectionLogsParams) error {
start := time.Now()
r0 := m.s.BatchUpsertConnectionLogs(ctx, arg)
m.queryLatencies.WithLabelValues("BatchUpsertConnectionLogs").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "BatchUpsertConnectionLogs").Inc()
return r0
}
func (m queryMetricsStore) BulkMarkNotificationMessagesFailed(ctx context.Context, arg database.BulkMarkNotificationMessagesFailedParams) (int64, error) {
start := time.Now()
r0, r1 := m.s.BulkMarkNotificationMessagesFailed(ctx, arg)
@@ -5024,14 +5032,6 @@ func (m queryMetricsStore) UpsertChatWorkspaceTTL(ctx context.Context, workspace
return r0
}
func (m queryMetricsStore) UpsertConnectionLog(ctx context.Context, arg database.UpsertConnectionLogParams) (database.ConnectionLog, error) {
start := time.Now()
r0, r1 := m.s.UpsertConnectionLog(ctx, arg)
m.queryLatencies.WithLabelValues("UpsertConnectionLog").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertConnectionLog").Inc()
return r0, r1
}
func (m queryMetricsStore) UpsertDefaultProxy(ctx context.Context, arg database.UpsertDefaultProxyParams) error {
start := time.Now()
r0 := m.s.UpsertDefaultProxy(ctx, arg)
+14 -15
View File
@@ -233,6 +233,20 @@ func (mr *MockStoreMockRecorder) BatchUpdateWorkspaceNextStartAt(ctx, arg any) *
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BatchUpdateWorkspaceNextStartAt", reflect.TypeOf((*MockStore)(nil).BatchUpdateWorkspaceNextStartAt), ctx, arg)
}
// BatchUpsertConnectionLogs mocks base method.
func (m *MockStore) BatchUpsertConnectionLogs(ctx context.Context, arg database.BatchUpsertConnectionLogsParams) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "BatchUpsertConnectionLogs", ctx, arg)
ret0, _ := ret[0].(error)
return ret0
}
// BatchUpsertConnectionLogs indicates an expected call of BatchUpsertConnectionLogs.
func (mr *MockStoreMockRecorder) BatchUpsertConnectionLogs(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BatchUpsertConnectionLogs", reflect.TypeOf((*MockStore)(nil).BatchUpsertConnectionLogs), ctx, arg)
}
// BulkMarkNotificationMessagesFailed mocks base method.
func (m *MockStore) BulkMarkNotificationMessagesFailed(ctx context.Context, arg database.BulkMarkNotificationMessagesFailedParams) (int64, error) {
m.ctrl.T.Helper()
@@ -9442,21 +9456,6 @@ func (mr *MockStoreMockRecorder) UpsertChatWorkspaceTTL(ctx, workspaceTtl any) *
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatWorkspaceTTL", reflect.TypeOf((*MockStore)(nil).UpsertChatWorkspaceTTL), ctx, workspaceTtl)
}
// UpsertConnectionLog mocks base method.
func (m *MockStore) UpsertConnectionLog(ctx context.Context, arg database.UpsertConnectionLogParams) (database.ConnectionLog, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpsertConnectionLog", ctx, arg)
ret0, _ := ret[0].(database.ConnectionLog)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpsertConnectionLog indicates an expected call of UpsertConnectionLog.
func (mr *MockStoreMockRecorder) UpsertConnectionLog(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertConnectionLog", reflect.TypeOf((*MockStore)(nil).UpsertConnectionLog), ctx, arg)
}
// UpsertDefaultProxy mocks base method.
func (m *MockStore) UpsertDefaultProxy(ctx context.Context, arg database.UpsertDefaultProxyParams) error {
m.ctrl.T.Helper()
+26
View File
@@ -10,6 +10,7 @@ import (
"time"
"github.com/google/uuid"
"github.com/sqlc-dev/pqtype"
"golang.org/x/exp/maps"
"golang.org/x/oauth2"
"golang.org/x/xerrors"
@@ -923,3 +924,28 @@ func WorkspaceIdentityFromWorkspace(w Workspace) WorkspaceIdentity {
func (r GetWorkspaceAgentAndWorkspaceByIDRow) RBACObject() rbac.Object {
return r.WorkspaceTable.RBACObject()
}
// UpsertConnectionLogParams contains the parameters for upserting a
// connection log entry. This struct is hand-maintained (not generated
// by sqlc) because the single-row UpsertConnectionLog query was
// removed in favor of BatchUpsertConnectionLogs, but the struct is
// still used as the canonical connection log event type throughout
// the codebase.
type UpsertConnectionLogParams struct {
ID uuid.UUID `db:"id" json:"id"`
OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"`
WorkspaceOwnerID uuid.UUID `db:"workspace_owner_id" json:"workspace_owner_id"`
WorkspaceID uuid.UUID `db:"workspace_id" json:"workspace_id"`
WorkspaceName string `db:"workspace_name" json:"workspace_name"`
AgentName string `db:"agent_name" json:"agent_name"`
Type ConnectionType `db:"type" json:"type"`
Code sql.NullInt32 `db:"code" json:"code"`
IP pqtype.Inet `db:"ip" json:"ip"`
UserAgent sql.NullString `db:"user_agent" json:"user_agent"`
UserID uuid.NullUUID `db:"user_id" json:"user_id"`
SlugOrPort sql.NullString `db:"slug_or_port" json:"slug_or_port"`
ConnectionID uuid.NullUUID `db:"connection_id" json:"connection_id"`
DisconnectReason sql.NullString `db:"disconnect_reason" json:"disconnect_reason"`
Time time.Time `db:"time" json:"time"`
ConnectionStatus ConnectionStatus `db:"connection_status" json:"connection_status"`
}
+1 -1
View File
@@ -65,6 +65,7 @@ type sqlcQuerier interface {
BatchUpdateWorkspaceAgentMetadata(ctx context.Context, arg BatchUpdateWorkspaceAgentMetadataParams) error
BatchUpdateWorkspaceLastUsedAt(ctx context.Context, arg BatchUpdateWorkspaceLastUsedAtParams) error
BatchUpdateWorkspaceNextStartAt(ctx context.Context, arg BatchUpdateWorkspaceNextStartAtParams) error
BatchUpsertConnectionLogs(ctx context.Context, arg BatchUpsertConnectionLogsParams) error
BulkMarkNotificationMessagesFailed(ctx context.Context, arg BulkMarkNotificationMessagesFailedParams) (int64, error)
BulkMarkNotificationMessagesSent(ctx context.Context, arg BulkMarkNotificationMessagesSentParams) (int64, error)
// Calculates the telemetry summary for a given provider, model, and client
@@ -991,7 +992,6 @@ type sqlcQuerier interface {
UpsertChatUsageLimitGroupOverride(ctx context.Context, arg UpsertChatUsageLimitGroupOverrideParams) (UpsertChatUsageLimitGroupOverrideRow, error)
UpsertChatUsageLimitUserOverride(ctx context.Context, arg UpsertChatUsageLimitUserOverrideParams) (UpsertChatUsageLimitUserOverrideRow, error)
UpsertChatWorkspaceTTL(ctx context.Context, workspaceTtl string) error
UpsertConnectionLog(ctx context.Context, arg UpsertConnectionLogParams) (ConnectionLog, error)
// The default proxy is implied and not actually stored in the database.
// So we need to store it's configuration here for display purposes.
// The functional values are immutable and controlled implicitly.
+482 -197
View File
@@ -3566,9 +3566,11 @@ func connectionOnlyIDs[T database.ConnectionLog | database.GetConnectionLogsOffs
return ids
}
func TestUpsertConnectionLog(t *testing.T) {
func TestBatchUpsertConnectionLogs(t *testing.T) {
t.Parallel()
createWorkspace := func(t *testing.T, db database.Store) database.WorkspaceTable {
t.Helper()
u := dbgen.User(t, db, database.User{})
o := dbgen.Organization(t, db, database.Organization{})
tpl := dbgen.Template(t, db, database.Template{
@@ -3584,253 +3586,536 @@ func TestUpsertConnectionLog(t *testing.T) {
})
}
// zeroTime is the sentinel value that the SQL treats as "no
// connect/disconnect time provided".
zeroTime := time.Time{}
defaultIP := pqtype.Inet{
IPNet: net.IPNet{
IP: net.IPv4(127, 0, 0, 1),
Mask: net.IPv4Mask(255, 255, 255, 255),
},
Valid: true,
}
t.Run("SingleConnect", func(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
ctx := context.Background()
ws := createWorkspace(t, db)
connID := uuid.New()
connectTime := dbtime.Now()
err := db.BatchUpsertConnectionLogs(ctx, database.BatchUpsertConnectionLogsParams{
ID: []uuid.UUID{uuid.New()},
ConnectTime: []time.Time{connectTime},
OrganizationID: []uuid.UUID{ws.OrganizationID},
WorkspaceOwnerID: []uuid.UUID{ws.OwnerID},
WorkspaceID: []uuid.UUID{ws.ID},
WorkspaceName: []string{ws.Name},
AgentName: []string{"agent"},
Type: []database.ConnectionType{database.ConnectionTypeSsh},
Code: []int32{0},
CodeValid: []bool{false},
Ip: []pqtype.Inet{defaultIP},
UserAgent: []string{""},
UserID: []uuid.UUID{uuid.Nil},
SlugOrPort: []string{""},
ConnectionID: []uuid.UUID{connID},
DisconnectReason: []string{""},
DisconnectTime: []time.Time{zeroTime},
})
require.NoError(t, err)
rows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{LimitOpt: 10})
require.NoError(t, err)
require.Len(t, rows, 1)
require.True(t, connectTime.Equal(rows[0].ConnectionLog.ConnectTime))
require.False(t, rows[0].ConnectionLog.DisconnectTime.Valid,
"disconnect_time should be NULL for a connect-only event")
})
t.Run("ConnectThenDisconnect", func(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
ctx := context.Background()
ws := createWorkspace(t, db)
connectionID := uuid.New()
agentName := "test-agent"
// 1. Insert a 'connect' event.
connID := uuid.New()
connectTime := dbtime.Now()
connectParams := database.UpsertConnectionLogParams{
ID: uuid.New(),
Time: connectTime,
OrganizationID: ws.OrganizationID,
WorkspaceOwnerID: ws.OwnerID,
WorkspaceID: ws.ID,
WorkspaceName: ws.Name,
AgentName: agentName,
Type: database.ConnectionTypeSsh,
ConnectionID: uuid.NullUUID{UUID: connectionID, Valid: true},
ConnectionStatus: database.ConnectionStatusConnected,
Ip: pqtype.Inet{
IPNet: net.IPNet{
IP: net.IPv4(127, 0, 0, 1),
Mask: net.IPv4Mask(255, 255, 255, 255),
},
Valid: true,
},
}
log1, err := db.UpsertConnectionLog(ctx, connectParams)
// Insert connect.
err := db.BatchUpsertConnectionLogs(ctx, database.BatchUpsertConnectionLogsParams{
ID: []uuid.UUID{uuid.New()},
ConnectTime: []time.Time{connectTime},
OrganizationID: []uuid.UUID{ws.OrganizationID},
WorkspaceOwnerID: []uuid.UUID{ws.OwnerID},
WorkspaceID: []uuid.UUID{ws.ID},
WorkspaceName: []string{ws.Name},
AgentName: []string{"agent"},
Type: []database.ConnectionType{database.ConnectionTypeSsh},
Code: []int32{0},
CodeValid: []bool{false},
Ip: []pqtype.Inet{defaultIP},
UserAgent: []string{""},
UserID: []uuid.UUID{uuid.Nil},
SlugOrPort: []string{""},
ConnectionID: []uuid.UUID{connID},
DisconnectReason: []string{""},
DisconnectTime: []time.Time{zeroTime},
})
require.NoError(t, err)
// Insert disconnect for same connection.
disconnectTime := connectTime.Add(time.Second)
err = db.BatchUpsertConnectionLogs(ctx, database.BatchUpsertConnectionLogsParams{
ID: []uuid.UUID{uuid.New()},
ConnectTime: []time.Time{zeroTime},
OrganizationID: []uuid.UUID{ws.OrganizationID},
WorkspaceOwnerID: []uuid.UUID{ws.OwnerID},
WorkspaceID: []uuid.UUID{ws.ID},
WorkspaceName: []string{ws.Name},
AgentName: []string{"agent"},
Type: []database.ConnectionType{database.ConnectionTypeSsh},
Code: []int32{1},
CodeValid: []bool{true},
Ip: []pqtype.Inet{defaultIP},
UserAgent: []string{""},
UserID: []uuid.UUID{uuid.Nil},
SlugOrPort: []string{""},
ConnectionID: []uuid.UUID{connID},
DisconnectReason: []string{"test disconnect"},
DisconnectTime: []time.Time{disconnectTime},
})
require.NoError(t, err)
require.Equal(t, connectParams.ID, log1.ID)
require.False(t, log1.DisconnectTime.Valid, "DisconnectTime should not be set on connect")
// Check that one row exists.
rows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{LimitOpt: 10})
require.NoError(t, err)
require.Len(t, rows, 1)
// 2. Insert a 'disconnected' event for the same connection.
disconnectTime := connectTime.Add(time.Second)
disconnectParams := database.UpsertConnectionLogParams{
ConnectionID: uuid.NullUUID{UUID: connectionID, Valid: true},
WorkspaceID: ws.ID,
AgentName: agentName,
ConnectionStatus: database.ConnectionStatusDisconnected,
// Updated to:
Time: disconnectTime,
DisconnectReason: sql.NullString{String: "test disconnect", Valid: true},
Code: sql.NullInt32{Int32: 1, Valid: true},
// Ignored
ID: uuid.New(),
OrganizationID: ws.OrganizationID,
WorkspaceOwnerID: ws.OwnerID,
WorkspaceName: ws.Name,
Type: database.ConnectionTypeSsh,
Ip: pqtype.Inet{
IPNet: net.IPNet{
IP: net.IPv4(127, 0, 0, 1),
Mask: net.IPv4Mask(255, 255, 255, 254),
},
Valid: true,
},
}
log2, err := db.UpsertConnectionLog(ctx, disconnectParams)
require.NoError(t, err)
// Updated
require.Equal(t, log1.ID, log2.ID)
require.True(t, log2.DisconnectTime.Valid)
require.True(t, disconnectTime.Equal(log2.DisconnectTime.Time))
require.Equal(t, disconnectParams.DisconnectReason.String, log2.DisconnectReason.String)
rows, err = db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{})
require.NoError(t, err)
require.Len(t, rows, 1)
row := rows[0].ConnectionLog
require.True(t, connectTime.Equal(row.ConnectTime))
require.True(t, row.DisconnectTime.Valid)
require.True(t, disconnectTime.Equal(row.DisconnectTime.Time))
require.Equal(t, "test disconnect", row.DisconnectReason.String)
require.Equal(t, int32(1), row.Code.Int32)
})
t.Run("ConnectDoesNotUpdate", func(t *testing.T) {
t.Run("DuplicateConnectIsNoOp", func(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
ctx := context.Background()
ws := createWorkspace(t, db)
connectionID := uuid.New()
agentName := "test-agent"
// 1. Insert a 'connect' event.
connID := uuid.New()
connectTime := dbtime.Now()
connectParams := database.UpsertConnectionLogParams{
ID: uuid.New(),
Time: connectTime,
OrganizationID: ws.OrganizationID,
WorkspaceOwnerID: ws.OwnerID,
WorkspaceID: ws.ID,
WorkspaceName: ws.Name,
AgentName: agentName,
Type: database.ConnectionTypeSsh,
ConnectionID: uuid.NullUUID{UUID: connectionID, Valid: true},
ConnectionStatus: database.ConnectionStatusConnected,
Ip: pqtype.Inet{
IPNet: net.IPNet{
IP: net.IPv4(127, 0, 0, 1),
Mask: net.IPv4Mask(255, 255, 255, 255),
},
Valid: true,
},
mkParams := func(ct time.Time, ip pqtype.Inet) database.BatchUpsertConnectionLogsParams {
return database.BatchUpsertConnectionLogsParams{
ID: []uuid.UUID{uuid.New()},
ConnectTime: []time.Time{ct},
OrganizationID: []uuid.UUID{ws.OrganizationID},
WorkspaceOwnerID: []uuid.UUID{ws.OwnerID},
WorkspaceID: []uuid.UUID{ws.ID},
WorkspaceName: []string{ws.Name},
AgentName: []string{"agent"},
Type: []database.ConnectionType{database.ConnectionTypeSsh},
Code: []int32{0},
CodeValid: []bool{false},
Ip: []pqtype.Inet{ip},
UserAgent: []string{""},
UserID: []uuid.UUID{uuid.Nil},
SlugOrPort: []string{""},
ConnectionID: []uuid.UUID{connID},
DisconnectReason: []string{""},
DisconnectTime: []time.Time{zeroTime},
}
}
log, err := db.UpsertConnectionLog(ctx, connectParams)
err := db.BatchUpsertConnectionLogs(ctx, mkParams(connectTime, defaultIP))
require.NoError(t, err)
// 2. Insert another 'connect' event for the same connection.
connectTime2 := connectTime.Add(time.Second)
connectParams2 := database.UpsertConnectionLogParams{
ConnectionID: uuid.NullUUID{UUID: connectionID, Valid: true},
WorkspaceID: ws.ID,
AgentName: agentName,
ConnectionStatus: database.ConnectionStatusConnected,
rows1, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{LimitOpt: 10})
require.NoError(t, err)
require.Len(t, rows1, 1)
// Ignored
ID: uuid.New(),
Time: connectTime2,
OrganizationID: ws.OrganizationID,
WorkspaceOwnerID: ws.OwnerID,
WorkspaceName: ws.Name,
Type: database.ConnectionTypeSsh,
Code: sql.NullInt32{Int32: 0, Valid: false},
Ip: pqtype.Inet{
IPNet: net.IPNet{
IP: net.IPv4(127, 0, 0, 1),
Mask: net.IPv4Mask(255, 255, 255, 254),
},
Valid: true,
// Second connect with later time and different IP.
otherIP := pqtype.Inet{
IPNet: net.IPNet{
IP: net.IPv4(10, 0, 0, 1),
Mask: net.IPv4Mask(255, 255, 255, 255),
},
Valid: true,
}
origLog, err := db.UpsertConnectionLog(ctx, connectParams2)
err = db.BatchUpsertConnectionLogs(ctx, mkParams(connectTime.Add(time.Second), otherIP))
require.NoError(t, err)
require.Equal(t, log, origLog, "connect update should be a no-op")
// Check that still only one row exists.
rows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{})
rows2, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{LimitOpt: 10})
require.NoError(t, err)
require.Len(t, rows, 1)
require.Equal(t, log, rows[0].ConnectionLog)
require.Len(t, rows2, 1)
// The LEAST logic should pick the earlier connect_time; IP and
// other fields are not updated on conflict.
require.True(t, connectTime.Equal(rows2[0].ConnectionLog.ConnectTime),
"connect_time should remain the original (earlier) value")
})
t.Run("DisconnectThenConnect", func(t *testing.T) {
t.Run("OrderIndependentConnectTime", func(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
ctx := context.Background()
ws := createWorkspace(t, db)
connectionID := uuid.New()
agentName := "test-agent"
// Insert just a 'disconect' event
connID := uuid.New()
disconnectTime := dbtime.Now()
disconnectParams := database.UpsertConnectionLogParams{
ID: uuid.New(),
Time: disconnectTime,
OrganizationID: ws.OrganizationID,
WorkspaceOwnerID: ws.OwnerID,
WorkspaceID: ws.ID,
WorkspaceName: ws.Name,
AgentName: agentName,
Type: database.ConnectionTypeSsh,
ConnectionID: uuid.NullUUID{UUID: connectionID, Valid: true},
ConnectionStatus: database.ConnectionStatusDisconnected,
DisconnectReason: sql.NullString{String: "server shutting down", Valid: true},
Ip: pqtype.Inet{
IPNet: net.IPNet{
IP: net.IPv4(127, 0, 0, 1),
Mask: net.IPv4Mask(255, 255, 255, 255),
},
Valid: true,
},
connectTime := disconnectTime.Add(-5 * time.Second)
// Disconnect arrives first.
err := db.BatchUpsertConnectionLogs(ctx, database.BatchUpsertConnectionLogsParams{
ID: []uuid.UUID{uuid.New()},
ConnectTime: []time.Time{disconnectTime},
OrganizationID: []uuid.UUID{ws.OrganizationID},
WorkspaceOwnerID: []uuid.UUID{ws.OwnerID},
WorkspaceID: []uuid.UUID{ws.ID},
WorkspaceName: []string{ws.Name},
AgentName: []string{"agent"},
Type: []database.ConnectionType{database.ConnectionTypeSsh},
Code: []int32{0},
CodeValid: []bool{true},
Ip: []pqtype.Inet{defaultIP},
UserAgent: []string{""},
UserID: []uuid.UUID{uuid.Nil},
SlugOrPort: []string{""},
ConnectionID: []uuid.UUID{connID},
DisconnectReason: []string{"bye"},
DisconnectTime: []time.Time{disconnectTime},
})
require.NoError(t, err)
// Connect arrives second with the real (earlier) connect_time.
err = db.BatchUpsertConnectionLogs(ctx, database.BatchUpsertConnectionLogsParams{
ID: []uuid.UUID{uuid.New()},
ConnectTime: []time.Time{connectTime},
OrganizationID: []uuid.UUID{ws.OrganizationID},
WorkspaceOwnerID: []uuid.UUID{ws.OwnerID},
WorkspaceID: []uuid.UUID{ws.ID},
WorkspaceName: []string{ws.Name},
AgentName: []string{"agent"},
Type: []database.ConnectionType{database.ConnectionTypeSsh},
Code: []int32{0},
CodeValid: []bool{false},
Ip: []pqtype.Inet{defaultIP},
UserAgent: []string{""},
UserID: []uuid.UUID{uuid.Nil},
SlugOrPort: []string{""},
ConnectionID: []uuid.UUID{connID},
DisconnectReason: []string{""},
DisconnectTime: []time.Time{zeroTime},
})
require.NoError(t, err)
rows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{LimitOpt: 10})
require.NoError(t, err)
require.Len(t, rows, 1)
require.True(t, connectTime.Equal(rows[0].ConnectionLog.ConnectTime),
"LEAST should pick the earlier connect_time")
})
t.Run("DisconnectFieldsAreWriteOnce", func(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
ctx := context.Background()
ws := createWorkspace(t, db)
connID := uuid.New()
disconnectTime := dbtime.Now()
mkDisconnect := func(reason string, code int32) database.BatchUpsertConnectionLogsParams {
return database.BatchUpsertConnectionLogsParams{
ID: []uuid.UUID{uuid.New()},
ConnectTime: []time.Time{disconnectTime},
OrganizationID: []uuid.UUID{ws.OrganizationID},
WorkspaceOwnerID: []uuid.UUID{ws.OwnerID},
WorkspaceID: []uuid.UUID{ws.ID},
WorkspaceName: []string{ws.Name},
AgentName: []string{"agent"},
Type: []database.ConnectionType{database.ConnectionTypeSsh},
Code: []int32{code},
CodeValid: []bool{true},
Ip: []pqtype.Inet{defaultIP},
UserAgent: []string{""},
UserID: []uuid.UUID{uuid.Nil},
SlugOrPort: []string{""},
ConnectionID: []uuid.UUID{connID},
DisconnectReason: []string{reason},
DisconnectTime: []time.Time{disconnectTime},
}
}
_, err := db.UpsertConnectionLog(ctx, disconnectParams)
err := db.BatchUpsertConnectionLogs(ctx, mkDisconnect("first reason", 1))
require.NoError(t, err)
firstRows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{})
// Second disconnect with different reason and code.
err = db.BatchUpsertConnectionLogs(ctx, mkDisconnect("second reason", 2))
require.NoError(t, err)
require.Len(t, firstRows, 1)
// We expect the connection event to be marked as closed with the start
// and close time being the same.
require.True(t, firstRows[0].ConnectionLog.DisconnectTime.Valid)
require.Equal(t, disconnectTime, firstRows[0].ConnectionLog.DisconnectTime.Time.UTC())
require.Equal(t, firstRows[0].ConnectionLog.ConnectTime.UTC(), firstRows[0].ConnectionLog.DisconnectTime.Time.UTC())
rows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{LimitOpt: 10})
require.NoError(t, err)
require.Len(t, rows, 1)
row := rows[0].ConnectionLog
require.Equal(t, "first reason", row.DisconnectReason.String,
"disconnect_reason should not be overwritten")
require.Equal(t, int32(1), row.Code.Int32,
"code should not be overwritten")
})
// Now insert a 'connect' event for the same connection.
// This should be a no op
connectTime := disconnectTime.Add(time.Second)
connectParams := database.UpsertConnectionLogParams{
ID: uuid.New(),
Time: connectTime,
OrganizationID: ws.OrganizationID,
WorkspaceOwnerID: ws.OwnerID,
WorkspaceID: ws.ID,
WorkspaceName: ws.Name,
AgentName: agentName,
Type: database.ConnectionTypeSsh,
ConnectionID: uuid.NullUUID{UUID: connectionID, Valid: true},
ConnectionStatus: database.ConnectionStatusConnected,
DisconnectReason: sql.NullString{String: "reconnected", Valid: true},
Code: sql.NullInt32{Int32: 0, Valid: false},
Ip: pqtype.Inet{
IPNet: net.IPNet{
IP: net.IPv4(127, 0, 0, 1),
Mask: net.IPv4Mask(255, 255, 255, 255),
},
Valid: true,
},
t.Run("ConnectAfterDisconnectIsNoOp", func(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
ctx := context.Background()
ws := createWorkspace(t, db)
connID := uuid.New()
disconnectTime := dbtime.Now()
// Insert disconnect first.
err := db.BatchUpsertConnectionLogs(ctx, database.BatchUpsertConnectionLogsParams{
ID: []uuid.UUID{uuid.New()},
ConnectTime: []time.Time{disconnectTime},
OrganizationID: []uuid.UUID{ws.OrganizationID},
WorkspaceOwnerID: []uuid.UUID{ws.OwnerID},
WorkspaceID: []uuid.UUID{ws.ID},
WorkspaceName: []string{ws.Name},
AgentName: []string{"agent"},
Type: []database.ConnectionType{database.ConnectionTypeSsh},
Code: []int32{42},
CodeValid: []bool{true},
Ip: []pqtype.Inet{defaultIP},
UserAgent: []string{""},
UserID: []uuid.UUID{uuid.Nil},
SlugOrPort: []string{""},
ConnectionID: []uuid.UUID{connID},
DisconnectReason: []string{"server shutdown"},
DisconnectTime: []time.Time{disconnectTime},
})
require.NoError(t, err)
rows1, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{LimitOpt: 10})
require.NoError(t, err)
require.Len(t, rows1, 1)
require.True(t, rows1[0].ConnectionLog.DisconnectTime.Valid)
require.Equal(t, "server shutdown", rows1[0].ConnectionLog.DisconnectReason.String)
require.Equal(t, int32(42), rows1[0].ConnectionLog.Code.Int32)
// Insert connect for same connection_id.
err = db.BatchUpsertConnectionLogs(ctx, database.BatchUpsertConnectionLogsParams{
ID: []uuid.UUID{uuid.New()},
ConnectTime: []time.Time{disconnectTime.Add(time.Second)},
OrganizationID: []uuid.UUID{ws.OrganizationID},
WorkspaceOwnerID: []uuid.UUID{ws.OwnerID},
WorkspaceID: []uuid.UUID{ws.ID},
WorkspaceName: []string{ws.Name},
AgentName: []string{"agent"},
Type: []database.ConnectionType{database.ConnectionTypeSsh},
Code: []int32{0},
CodeValid: []bool{false},
Ip: []pqtype.Inet{defaultIP},
UserAgent: []string{""},
UserID: []uuid.UUID{uuid.Nil},
SlugOrPort: []string{""},
ConnectionID: []uuid.UUID{connID},
DisconnectReason: []string{""},
DisconnectTime: []time.Time{zeroTime},
})
require.NoError(t, err)
rows2, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{LimitOpt: 10})
require.NoError(t, err)
require.Len(t, rows2, 1)
row := rows2[0].ConnectionLog
require.True(t, row.DisconnectTime.Valid,
"disconnect_time should not be cleared by a later connect")
require.Equal(t, "server shutdown", row.DisconnectReason.String,
"disconnect_reason should not be cleared")
require.Equal(t, int32(42), row.Code.Int32,
"code should not be cleared")
})
t.Run("CodeZeroPreserved", func(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
ctx := context.Background()
ws := createWorkspace(t, db)
connID := uuid.New()
now := dbtime.Now()
err := db.BatchUpsertConnectionLogs(ctx, database.BatchUpsertConnectionLogsParams{
ID: []uuid.UUID{uuid.New()},
ConnectTime: []time.Time{now},
OrganizationID: []uuid.UUID{ws.OrganizationID},
WorkspaceOwnerID: []uuid.UUID{ws.OwnerID},
WorkspaceID: []uuid.UUID{ws.ID},
WorkspaceName: []string{ws.Name},
AgentName: []string{"agent"},
Type: []database.ConnectionType{database.ConnectionTypeSsh},
Code: []int32{0},
CodeValid: []bool{true},
Ip: []pqtype.Inet{defaultIP},
UserAgent: []string{""},
UserID: []uuid.UUID{uuid.Nil},
SlugOrPort: []string{""},
ConnectionID: []uuid.UUID{connID},
DisconnectReason: []string{"normal"},
DisconnectTime: []time.Time{now},
})
require.NoError(t, err)
rows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{LimitOpt: 10})
require.NoError(t, err)
require.Len(t, rows, 1)
require.True(t, rows[0].ConnectionLog.Code.Valid, "code should be non-NULL")
require.Equal(t, int32(0), rows[0].ConnectionLog.Code.Int32,
"code=0 should be preserved, not treated as NULL")
})
t.Run("CodeNullWhenInvalid", func(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
ctx := context.Background()
ws := createWorkspace(t, db)
connID := uuid.New()
now := dbtime.Now()
err := db.BatchUpsertConnectionLogs(ctx, database.BatchUpsertConnectionLogsParams{
ID: []uuid.UUID{uuid.New()},
ConnectTime: []time.Time{now},
OrganizationID: []uuid.UUID{ws.OrganizationID},
WorkspaceOwnerID: []uuid.UUID{ws.OwnerID},
WorkspaceID: []uuid.UUID{ws.ID},
WorkspaceName: []string{ws.Name},
AgentName: []string{"agent"},
Type: []database.ConnectionType{database.ConnectionTypeSsh},
Code: []int32{99},
CodeValid: []bool{false},
Ip: []pqtype.Inet{defaultIP},
UserAgent: []string{""},
UserID: []uuid.UUID{uuid.Nil},
SlugOrPort: []string{""},
ConnectionID: []uuid.UUID{connID},
DisconnectReason: []string{""},
DisconnectTime: []time.Time{zeroTime},
})
require.NoError(t, err)
rows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{LimitOpt: 10})
require.NoError(t, err)
require.Len(t, rows, 1)
require.False(t, rows[0].ConnectionLog.Code.Valid,
"code should be NULL when code_valid is false")
})
t.Run("NullConnectionIDEvents", func(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
ctx := context.Background()
ws := createWorkspace(t, db)
now := dbtime.Now()
// Insert two web events with NULL connection_id (uuid.Nil →
// NULL via NULLIF) for the same workspace/agent.
for i := range 2 {
err := db.BatchUpsertConnectionLogs(ctx, database.BatchUpsertConnectionLogsParams{
ID: []uuid.UUID{uuid.New()},
ConnectTime: []time.Time{now.Add(time.Duration(i) * time.Second)},
OrganizationID: []uuid.UUID{ws.OrganizationID},
WorkspaceOwnerID: []uuid.UUID{ws.OwnerID},
WorkspaceID: []uuid.UUID{ws.ID},
WorkspaceName: []string{ws.Name},
AgentName: []string{"agent"},
Type: []database.ConnectionType{database.ConnectionTypeSsh},
Code: []int32{200},
CodeValid: []bool{true},
Ip: []pqtype.Inet{defaultIP},
UserAgent: []string{"Mozilla/5.0"},
UserID: []uuid.UUID{uuid.Nil},
SlugOrPort: []string{"web-terminal"},
ConnectionID: []uuid.UUID{uuid.Nil},
DisconnectReason: []string{""},
DisconnectTime: []time.Time{zeroTime},
})
require.NoError(t, err)
}
_, err = db.UpsertConnectionLog(ctx, connectParams)
rows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{LimitOpt: 10})
require.NoError(t, err)
require.Len(t, rows, 2,
"NULL connection_id rows should not conflict with each other")
})
secondRows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{})
require.NoError(t, err)
require.Len(t, secondRows, 1)
require.Equal(t, firstRows, secondRows)
t.Run("MultipleIndependentConnections", func(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
ctx := context.Background()
ws := createWorkspace(t, db)
now := dbtime.Now()
// Upsert a disconnection, which should also be a no op
disconnectParams.DisconnectReason = sql.NullString{
String: "updated close reason",
Valid: true,
n := 5
ids := make([]uuid.UUID, n)
connectTimes := make([]time.Time, n)
orgIDs := make([]uuid.UUID, n)
ownerIDs := make([]uuid.UUID, n)
wsIDs := make([]uuid.UUID, n)
wsNames := make([]string, n)
agentNames := make([]string, n)
types := make([]database.ConnectionType, n)
codes := make([]int32, n)
codeValids := make([]bool, n)
ips := make([]pqtype.Inet, n)
userAgents := make([]string, n)
userIDs := make([]uuid.UUID, n)
slugOrPorts := make([]string, n)
connIDs := make([]uuid.UUID, n)
disconnectReasons := make([]string, n)
disconnectTimes := make([]time.Time, n)
for i := range n {
ids[i] = uuid.New()
connectTimes[i] = now.Add(time.Duration(i) * time.Second)
orgIDs[i] = ws.OrganizationID
ownerIDs[i] = ws.OwnerID
wsIDs[i] = ws.ID
wsNames[i] = ws.Name
agentNames[i] = "agent"
types[i] = database.ConnectionTypeSsh
codes[i] = 0
codeValids[i] = false
ips[i] = defaultIP
userAgents[i] = ""
userIDs[i] = uuid.Nil
slugOrPorts[i] = ""
connIDs[i] = uuid.New()
disconnectReasons[i] = ""
disconnectTimes[i] = zeroTime
}
_, err = db.UpsertConnectionLog(ctx, disconnectParams)
err := db.BatchUpsertConnectionLogs(ctx, database.BatchUpsertConnectionLogsParams{
ID: ids,
ConnectTime: connectTimes,
OrganizationID: orgIDs,
WorkspaceOwnerID: ownerIDs,
WorkspaceID: wsIDs,
WorkspaceName: wsNames,
AgentName: agentNames,
Type: types,
Code: codes,
CodeValid: codeValids,
Ip: ips,
UserAgent: userAgents,
UserID: userIDs,
SlugOrPort: slugOrPorts,
ConnectionID: connIDs,
DisconnectReason: disconnectReasons,
DisconnectTime: disconnectTimes,
})
require.NoError(t, err)
thirdRows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{})
rows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{LimitOpt: 10})
require.NoError(t, err)
require.Len(t, secondRows, 1)
// The close reason shouldn't be updated
require.Equal(t, secondRows, thirdRows)
require.Len(t, rows, n, "each unique connection_id should produce its own row")
})
}
+117 -114
View File
@@ -7338,6 +7338,123 @@ func (q *sqlQuerier) UpsertChatUsageLimitUserOverride(ctx context.Context, arg U
return i, err
}
const batchUpsertConnectionLogs = `-- name: BatchUpsertConnectionLogs :exec
INSERT INTO connection_logs (
id, connect_time, organization_id, workspace_owner_id, workspace_id,
workspace_name, agent_name, type, code, ip, user_agent, user_id,
slug_or_port, connection_id, disconnect_reason, disconnect_time
)
SELECT
u.id,
u.connect_time,
u.organization_id,
u.workspace_owner_id,
u.workspace_id,
u.workspace_name,
u.agent_name,
u.type,
-- Use the validity flag to distinguish "no code" (NULL) from a
-- legitimate zero exit code.
CASE WHEN u.code_valid THEN u.code ELSE NULL END,
u.ip,
NULLIF(u.user_agent, ''),
NULLIF(u.user_id, '00000000-0000-0000-0000-000000000000'::uuid),
NULLIF(u.slug_or_port, ''),
NULLIF(u.connection_id, '00000000-0000-0000-0000-000000000000'::uuid),
NULLIF(u.disconnect_reason, ''),
NULLIF(u.disconnect_time, '0001-01-01 00:00:00Z'::timestamptz)
FROM (
SELECT
unnest($1::uuid[]) AS id,
unnest($2::timestamptz[]) AS connect_time,
unnest($3::uuid[]) AS organization_id,
unnest($4::uuid[]) AS workspace_owner_id,
unnest($5::uuid[]) AS workspace_id,
unnest($6::text[]) AS workspace_name,
unnest($7::text[]) AS agent_name,
unnest($8::connection_type[]) AS type,
unnest($9::int4[]) AS code,
unnest($10::bool[]) AS code_valid,
unnest($11::inet[]) AS ip,
unnest($12::text[]) AS user_agent,
unnest($13::uuid[]) AS user_id,
unnest($14::text[]) AS slug_or_port,
unnest($15::uuid[]) AS connection_id,
unnest($16::text[]) AS disconnect_reason,
unnest($17::timestamptz[]) AS disconnect_time
) AS u
ON CONFLICT (connection_id, workspace_id, agent_name)
DO UPDATE SET
-- Pick the earliest real connect_time. The zero sentinel
-- ('0001-01-01') means the batch didn't know the connect_time
-- (e.g. a pure disconnect event), so we keep the existing value.
connect_time = CASE
WHEN EXCLUDED.connect_time = '0001-01-01 00:00:00Z'::timestamptz
THEN connection_logs.connect_time
WHEN connection_logs.connect_time = '0001-01-01 00:00:00Z'::timestamptz
THEN EXCLUDED.connect_time
ELSE LEAST(connection_logs.connect_time, EXCLUDED.connect_time)
END,
disconnect_time = CASE
WHEN connection_logs.disconnect_time IS NULL
THEN EXCLUDED.disconnect_time
ELSE connection_logs.disconnect_time
END,
disconnect_reason = CASE
WHEN connection_logs.disconnect_reason IS NULL
THEN EXCLUDED.disconnect_reason
ELSE connection_logs.disconnect_reason
END,
code = CASE
WHEN connection_logs.code IS NULL
THEN EXCLUDED.code
ELSE connection_logs.code
END
`
type BatchUpsertConnectionLogsParams struct {
ID []uuid.UUID `db:"id" json:"id"`
ConnectTime []time.Time `db:"connect_time" json:"connect_time"`
OrganizationID []uuid.UUID `db:"organization_id" json:"organization_id"`
WorkspaceOwnerID []uuid.UUID `db:"workspace_owner_id" json:"workspace_owner_id"`
WorkspaceID []uuid.UUID `db:"workspace_id" json:"workspace_id"`
WorkspaceName []string `db:"workspace_name" json:"workspace_name"`
AgentName []string `db:"agent_name" json:"agent_name"`
Type []ConnectionType `db:"type" json:"type"`
Code []int32 `db:"code" json:"code"`
CodeValid []bool `db:"code_valid" json:"code_valid"`
Ip []pqtype.Inet `db:"ip" json:"ip"`
UserAgent []string `db:"user_agent" json:"user_agent"`
UserID []uuid.UUID `db:"user_id" json:"user_id"`
SlugOrPort []string `db:"slug_or_port" json:"slug_or_port"`
ConnectionID []uuid.UUID `db:"connection_id" json:"connection_id"`
DisconnectReason []string `db:"disconnect_reason" json:"disconnect_reason"`
DisconnectTime []time.Time `db:"disconnect_time" json:"disconnect_time"`
}
func (q *sqlQuerier) BatchUpsertConnectionLogs(ctx context.Context, arg BatchUpsertConnectionLogsParams) error {
_, err := q.db.ExecContext(ctx, batchUpsertConnectionLogs,
pq.Array(arg.ID),
pq.Array(arg.ConnectTime),
pq.Array(arg.OrganizationID),
pq.Array(arg.WorkspaceOwnerID),
pq.Array(arg.WorkspaceID),
pq.Array(arg.WorkspaceName),
pq.Array(arg.AgentName),
pq.Array(arg.Type),
pq.Array(arg.Code),
pq.Array(arg.CodeValid),
pq.Array(arg.Ip),
pq.Array(arg.UserAgent),
pq.Array(arg.UserID),
pq.Array(arg.SlugOrPort),
pq.Array(arg.ConnectionID),
pq.Array(arg.DisconnectReason),
pq.Array(arg.DisconnectTime),
)
return err
}
const countConnectionLogs = `-- name: CountConnectionLogs :one
SELECT
COUNT(*) AS count
@@ -7753,120 +7870,6 @@ func (q *sqlQuerier) GetConnectionLogsOffset(ctx context.Context, arg GetConnect
return items, nil
}
const upsertConnectionLog = `-- name: UpsertConnectionLog :one
INSERT INTO connection_logs (
id,
connect_time,
organization_id,
workspace_owner_id,
workspace_id,
workspace_name,
agent_name,
type,
code,
ip,
user_agent,
user_id,
slug_or_port,
connection_id,
disconnect_reason,
disconnect_time
) VALUES
($1, $15, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14,
-- If we've only received a disconnect event, mark the event as immediately
-- closed.
CASE
WHEN $16::connection_status = 'disconnected'
THEN $15 :: timestamp with time zone
ELSE NULL
END)
ON CONFLICT (connection_id, workspace_id, agent_name)
DO UPDATE SET
-- No-op if the connection is still open.
disconnect_time = CASE
WHEN $16::connection_status = 'disconnected'
-- Can only be set once
AND connection_logs.disconnect_time IS NULL
THEN EXCLUDED.connect_time
ELSE connection_logs.disconnect_time
END,
disconnect_reason = CASE
WHEN $16::connection_status = 'disconnected'
-- Can only be set once
AND connection_logs.disconnect_reason IS NULL
THEN EXCLUDED.disconnect_reason
ELSE connection_logs.disconnect_reason
END,
code = CASE
WHEN $16::connection_status = 'disconnected'
-- Can only be set once
AND connection_logs.code IS NULL
THEN EXCLUDED.code
ELSE connection_logs.code
END
RETURNING id, connect_time, organization_id, workspace_owner_id, workspace_id, workspace_name, agent_name, type, ip, code, user_agent, user_id, slug_or_port, connection_id, disconnect_time, disconnect_reason
`
type UpsertConnectionLogParams struct {
ID uuid.UUID `db:"id" json:"id"`
OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"`
WorkspaceOwnerID uuid.UUID `db:"workspace_owner_id" json:"workspace_owner_id"`
WorkspaceID uuid.UUID `db:"workspace_id" json:"workspace_id"`
WorkspaceName string `db:"workspace_name" json:"workspace_name"`
AgentName string `db:"agent_name" json:"agent_name"`
Type ConnectionType `db:"type" json:"type"`
Code sql.NullInt32 `db:"code" json:"code"`
Ip pqtype.Inet `db:"ip" json:"ip"`
UserAgent sql.NullString `db:"user_agent" json:"user_agent"`
UserID uuid.NullUUID `db:"user_id" json:"user_id"`
SlugOrPort sql.NullString `db:"slug_or_port" json:"slug_or_port"`
ConnectionID uuid.NullUUID `db:"connection_id" json:"connection_id"`
DisconnectReason sql.NullString `db:"disconnect_reason" json:"disconnect_reason"`
Time time.Time `db:"time" json:"time"`
ConnectionStatus ConnectionStatus `db:"connection_status" json:"connection_status"`
}
func (q *sqlQuerier) UpsertConnectionLog(ctx context.Context, arg UpsertConnectionLogParams) (ConnectionLog, error) {
row := q.db.QueryRowContext(ctx, upsertConnectionLog,
arg.ID,
arg.OrganizationID,
arg.WorkspaceOwnerID,
arg.WorkspaceID,
arg.WorkspaceName,
arg.AgentName,
arg.Type,
arg.Code,
arg.Ip,
arg.UserAgent,
arg.UserID,
arg.SlugOrPort,
arg.ConnectionID,
arg.DisconnectReason,
arg.Time,
arg.ConnectionStatus,
)
var i ConnectionLog
err := row.Scan(
&i.ID,
&i.ConnectTime,
&i.OrganizationID,
&i.WorkspaceOwnerID,
&i.WorkspaceID,
&i.WorkspaceName,
&i.AgentName,
&i.Type,
&i.Ip,
&i.Code,
&i.UserAgent,
&i.UserID,
&i.SlugOrPort,
&i.ConnectionID,
&i.DisconnectTime,
&i.DisconnectReason,
)
return i, err
}
const deleteCryptoKey = `-- name: DeleteCryptoKey :one
UPDATE crypto_keys
SET secret = NULL, secret_key_id = NULL
+69 -49
View File
@@ -251,55 +251,75 @@ DELETE FROM connection_logs
USING old_logs
WHERE connection_logs.id = old_logs.id;
-- name: UpsertConnectionLog :one
-- name: BatchUpsertConnectionLogs :exec
INSERT INTO connection_logs (
id,
connect_time,
organization_id,
workspace_owner_id,
workspace_id,
workspace_name,
agent_name,
type,
code,
ip,
user_agent,
user_id,
slug_or_port,
connection_id,
disconnect_reason,
disconnect_time
) VALUES
($1, @time, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14,
-- If we've only received a disconnect event, mark the event as immediately
-- closed.
CASE
WHEN @connection_status::connection_status = 'disconnected'
THEN @time :: timestamp with time zone
ELSE NULL
END)
id, connect_time, organization_id, workspace_owner_id, workspace_id,
workspace_name, agent_name, type, code, ip, user_agent, user_id,
slug_or_port, connection_id, disconnect_reason, disconnect_time
)
SELECT
u.id,
u.connect_time,
u.organization_id,
u.workspace_owner_id,
u.workspace_id,
u.workspace_name,
u.agent_name,
u.type,
-- Use the validity flag to distinguish "no code" (NULL) from a
-- legitimate zero exit code.
CASE WHEN u.code_valid THEN u.code ELSE NULL END,
u.ip,
NULLIF(u.user_agent, ''),
NULLIF(u.user_id, '00000000-0000-0000-0000-000000000000'::uuid),
NULLIF(u.slug_or_port, ''),
NULLIF(u.connection_id, '00000000-0000-0000-0000-000000000000'::uuid),
NULLIF(u.disconnect_reason, ''),
NULLIF(u.disconnect_time, '0001-01-01 00:00:00Z'::timestamptz)
FROM (
SELECT
unnest(sqlc.arg('id')::uuid[]) AS id,
unnest(sqlc.arg('connect_time')::timestamptz[]) AS connect_time,
unnest(sqlc.arg('organization_id')::uuid[]) AS organization_id,
unnest(sqlc.arg('workspace_owner_id')::uuid[]) AS workspace_owner_id,
unnest(sqlc.arg('workspace_id')::uuid[]) AS workspace_id,
unnest(sqlc.arg('workspace_name')::text[]) AS workspace_name,
unnest(sqlc.arg('agent_name')::text[]) AS agent_name,
unnest(sqlc.arg('type')::connection_type[]) AS type,
unnest(sqlc.arg('code')::int4[]) AS code,
unnest(sqlc.arg('code_valid')::bool[]) AS code_valid,
unnest(sqlc.arg('ip')::inet[]) AS ip,
unnest(sqlc.arg('user_agent')::text[]) AS user_agent,
unnest(sqlc.arg('user_id')::uuid[]) AS user_id,
unnest(sqlc.arg('slug_or_port')::text[]) AS slug_or_port,
unnest(sqlc.arg('connection_id')::uuid[]) AS connection_id,
unnest(sqlc.arg('disconnect_reason')::text[]) AS disconnect_reason,
unnest(sqlc.arg('disconnect_time')::timestamptz[]) AS disconnect_time
) AS u
ON CONFLICT (connection_id, workspace_id, agent_name)
DO UPDATE SET
-- No-op if the connection is still open.
disconnect_time = CASE
WHEN @connection_status::connection_status = 'disconnected'
-- Can only be set once
AND connection_logs.disconnect_time IS NULL
THEN EXCLUDED.connect_time
ELSE connection_logs.disconnect_time
END,
disconnect_reason = CASE
WHEN @connection_status::connection_status = 'disconnected'
-- Can only be set once
AND connection_logs.disconnect_reason IS NULL
THEN EXCLUDED.disconnect_reason
ELSE connection_logs.disconnect_reason
END,
code = CASE
WHEN @connection_status::connection_status = 'disconnected'
-- Can only be set once
AND connection_logs.code IS NULL
THEN EXCLUDED.code
ELSE connection_logs.code
END
RETURNING *;
-- Pick the earliest real connect_time. The zero sentinel
-- ('0001-01-01') means the batch didn't know the connect_time
-- (e.g. a pure disconnect event), so we keep the existing value.
connect_time = CASE
WHEN EXCLUDED.connect_time = '0001-01-01 00:00:00Z'::timestamptz
THEN connection_logs.connect_time
WHEN connection_logs.connect_time = '0001-01-01 00:00:00Z'::timestamptz
THEN EXCLUDED.connect_time
ELSE LEAST(connection_logs.connect_time, EXCLUDED.connect_time)
END,
disconnect_time = CASE
WHEN connection_logs.disconnect_time IS NULL
THEN EXCLUDED.disconnect_time
ELSE connection_logs.disconnect_time
END,
disconnect_reason = CASE
WHEN connection_logs.disconnect_reason IS NULL
THEN EXCLUDED.disconnect_reason
ELSE connection_logs.disconnect_reason
END,
code = CASE
WHEN connection_logs.code IS NULL
THEN EXCLUDED.code
ELSE connection_logs.code
END;
+1 -1
View File
@@ -535,7 +535,7 @@ func (p *DBTokenProvider) connLogInitRequest(w http.ResponseWriter, r *http.Requ
Int32: statusCode,
Valid: true,
},
Ip: database.ParseIP(ip),
IP: database.ParseIP(ip),
UserAgent: sql.NullString{Valid: userAgent != "", String: userAgent},
UserID: uuid.NullUUID{
UUID: userID,
+1 -1
View File
@@ -1281,7 +1281,7 @@ func assertConnLogContains(t *testing.T, rr *httptest.ResponseRecorder, r *http.
WorkspaceName: workspace.Name,
AgentName: agentName,
Type: typ,
Ip: database.ParseIP(r.RemoteAddr),
IP: database.ParseIP(r.RemoteAddr),
UserAgent: sql.NullString{Valid: r.UserAgent() != "", String: r.UserAgent()},
Code: sql.NullInt32{
Int32: int32(resp.StatusCode), // nolint:gosec
+10 -2
View File
@@ -5,6 +5,7 @@ import (
"crypto/ed25519"
"crypto/tls"
"fmt"
"io"
"math"
"net/http"
"net/url"
@@ -144,10 +145,11 @@ func New(ctx context.Context, options *Options) (_ *API, err error) {
}
if options.ConnectionLogger == nil {
options.ConnectionLogger = connectionlog.NewConnectionLogger(
connectionlog.NewDBBackend(options.Database),
connLogger := connectionlog.New(
connectionlog.NewDBBatcher(ctx, options.Database, options.Logger),
connectionlog.NewSlogBackend(options.Logger),
)
options.ConnectionLogger = connLogger
}
meshTLSConfig, err := replicasync.CreateDERPMeshTLSConfig(options.AccessURL.Hostname(), options.TLSCertificates)
@@ -822,6 +824,12 @@ func (api *API) Close() error {
api.Options.CheckInactiveUsersCancelFunc()
}
// Close the connection logger to flush any remaining batched
// entries before shutting down the database connection.
if cl, ok := api.Options.ConnectionLogger.(io.Closer); ok {
_ = cl.Close()
}
return api.AGPL.Close()
}
+474 -16
View File
@@ -2,31 +2,70 @@ package connectionlog
import (
"context"
"io"
"sync"
"time"
"github.com/google/uuid"
"github.com/hashicorp/go-multierror"
"github.com/sqlc-dev/pqtype"
"cdr.dev/slog/v3"
agpl "github.com/coder/coder/v2/coderd/connectionlog"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
auditbackends "github.com/coder/coder/v2/enterprise/audit/backends"
"github.com/coder/quartz"
)
const (
// defaultBatchSize is the maximum number of connection log entries
// to batch before forcing a flush.
defaultBatchSize = 1000
// defaultFlushInterval is how frequently to flush batched connection
// log entries to the database. Five seconds balances near-real-time
// audit visibility with write efficiency.
defaultFlushInterval = 5 * time.Second
// retryQueueSize is the capacity of the bounded retry channel.
// Failed batches beyond this limit are dropped.
retryQueueSize = 10
// shutdownWriteTimeout bounds how long a final write attempt
// can take during shutdown when the batcher context is already
// canceled.
shutdownWriteTimeout = 10 * time.Second
// maxRetries is the number of times to retry a failed batch
// write before dropping it and moving on.
maxRetries = 3
// retryInterval is the fixed delay between retry attempts.
retryInterval = time.Second
)
// Backend is a destination for connection log events. Backends that
// also implement io.Closer will be closed when the ConnectionLogger
// is closed.
type Backend interface {
Upsert(ctx context.Context, clog database.UpsertConnectionLogParams) error
}
func NewConnectionLogger(backends ...Backend) agpl.ConnectionLogger {
return &connectionLogger{
// ConnectionLogger fans out each connection log event to every
// registered backend.
type ConnectionLogger struct {
backends []Backend
}
// New creates a ConnectionLogger that dispatches to the given
// backends.
func New(backends ...Backend) *ConnectionLogger {
return &ConnectionLogger{
backends: backends,
}
}
type connectionLogger struct {
backends []Backend
}
func (c *connectionLogger) Upsert(ctx context.Context, clog database.UpsertConnectionLogParams) error {
func (c *ConnectionLogger) Upsert(ctx context.Context, clog database.UpsertConnectionLogParams) error {
var errs error
for _, backend := range c.backends {
err := backend.Upsert(ctx, clog)
@@ -37,24 +76,443 @@ func (c *connectionLogger) Upsert(ctx context.Context, clog database.UpsertConne
return errs
}
type dbBackend struct {
db database.Store
// Close closes all backends that implement io.Closer.
func (c *ConnectionLogger) Close() error {
var errs error
for _, backend := range c.backends {
if closer, ok := backend.(io.Closer); ok {
if err := closer.Close(); err != nil {
errs = multierror.Append(errs, err)
}
}
}
return errs
}
func NewDBBackend(db database.Store) Backend {
return &dbBackend{db: db}
// DBBatcherOption is a functional option for configuring a DBBatcher.
type DBBatcherOption func(b *DBBatcher)
// WithBatchSize sets the maximum number of entries to accumulate
// before forcing a flush.
func WithBatchSize(size int) DBBatcherOption {
return func(b *DBBatcher) {
b.maxBatchSize = size
}
}
func (b *dbBackend) Upsert(ctx context.Context, clog database.UpsertConnectionLogParams) error {
//nolint:gocritic // This is the Connection Logger
_, err := b.db.UpsertConnectionLog(dbauthz.AsConnectionLogger(ctx), clog)
return err
// WithFlushInterval sets how frequently the batcher flushes to the
// database.
func WithFlushInterval(d time.Duration) DBBatcherOption {
return func(b *DBBatcher) {
b.interval = d
}
}
// WithClock sets the clock, useful for testing.
func WithClock(clock quartz.Clock) DBBatcherOption {
return func(b *DBBatcher) {
b.clock = clock
}
}
// DBBatcher batches connection log upserts and periodically flushes
// them to the database to reduce per-event write pressure.
type DBBatcher struct {
store database.Store
log slog.Logger
itemCh chan database.UpsertConnectionLogParams
// dedupedBatch holds entries keyed by connection ID so that
// PostgreSQL never sees the same row twice in one INSERT …
// ON CONFLICT DO UPDATE. Connection IDs are globally unique
// (each new session gets a fresh UUID). Entries with a NULL
// connection_id (web events) go into nullConnIDBatch instead
// because NULL != NULL in SQL unique constraints.
dedupedBatch map[uuid.UUID]batchEntry
nullConnIDBatch []batchEntry
maxBatchSize int
// retryCh is a bounded channel of failed batches awaiting
// retry. A single retry worker goroutine processes this
// channel, retrying each batch up to maxRetries times before
// dropping it. If the channel is full, new failures are
// dropped immediately.
retryCh chan database.BatchUpsertConnectionLogsParams
clock quartz.Clock
timer *quartz.Timer
interval time.Duration
ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
}
// NewDBBatcher creates a DBBatcher that batches writes to the database
// and starts its background processing loop. Close must be called to
// flush remaining entries on shutdown.
func NewDBBatcher(ctx context.Context, store database.Store, log slog.Logger, opts ...DBBatcherOption) *DBBatcher {
b := &DBBatcher{
store: store,
log: log,
clock: quartz.NewReal(),
}
for _, opt := range opts {
opt(b)
}
if b.interval == 0 {
b.interval = defaultFlushInterval
}
if b.maxBatchSize == 0 {
b.maxBatchSize = defaultBatchSize
}
b.timer = b.clock.NewTimer(b.interval)
b.itemCh = make(chan database.UpsertConnectionLogParams, b.maxBatchSize)
b.dedupedBatch = make(map[uuid.UUID]batchEntry, b.maxBatchSize)
b.retryCh = make(chan database.BatchUpsertConnectionLogsParams, retryQueueSize)
b.ctx, b.cancel = context.WithCancel(ctx)
b.wg.Add(2)
go func() {
defer b.wg.Done()
b.run(b.ctx)
}()
go func() {
defer b.wg.Done()
b.retryLoop()
}()
return b
}
// Upsert enqueues a connection log entry for batched writing. It
// blocks if the internal buffer is full, ensuring no logs are dropped.
// It returns an error if the batcher or caller context is canceled.
func (b *DBBatcher) Upsert(ctx context.Context, clog database.UpsertConnectionLogParams) error {
if b.ctx.Err() != nil {
return b.ctx.Err()
}
select {
case b.itemCh <- clog:
return nil
case <-b.ctx.Done():
return b.ctx.Err()
case <-ctx.Done():
return ctx.Err()
}
}
// Close cancels the batcher context, waits for the run loop and
// retry worker to exit.
func (b *DBBatcher) Close() error {
b.cancel()
if b.timer != nil {
b.timer.Stop()
}
b.wg.Wait()
return nil
}
// addToBatch inserts an item into the batch, deduplicating by conflict
// key on the fly. For entries with the same key, disconnect events are
// preferred over connect events, and later events are preferred over
// earlier ones.
//
// This is safe because each new connection gets a fresh UUID (see
// agent/agent.go and agent/agentssh), so the only duplicate for the
// same (connection_id, workspace_id, agent_name) is a connect/disconnect
// pair for the same session. A "reconnect" always uses a new ID.
func (b *DBBatcher) addToBatch(item database.UpsertConnectionLogParams) {
entry := batchEntry{
UpsertConnectionLogParams: item,
}
if item.ConnectionStatus == database.ConnectionStatusDisconnected {
// For standalone disconnect events, use the disconnect
// time as both connect and disconnect time. This matches
// the single-row UpsertConnectionLog behavior which uses
// @time for connect_time regardless of status. The SQL
// LEAST logic will correct connect_time if the real
// connect event arrives in a later batch.
entry.connectTime = item.Time
entry.disconnectTime = item.Time
} else {
entry.connectTime = item.Time
}
if !item.ConnectionID.Valid {
b.nullConnIDBatch = append(b.nullConnIDBatch, entry)
return
}
connID := item.ConnectionID.UUID
existing, ok := b.dedupedBatch[connID]
if !ok {
b.dedupedBatch[connID] = entry
return
}
// When merging entries for the same connection, always preserve
// the earliest non-zero connect_time and latest disconnect_time
// so the row records the full session span.
if !existing.connectTime.IsZero() && existing.connectTime.Before(entry.connectTime) {
entry.connectTime = existing.connectTime
}
if existing.disconnectTime.After(entry.disconnectTime) {
entry.disconnectTime = existing.disconnectTime
}
// Prefer disconnect over connect (superset of info).
// If same status, prefer the later event.
if item.ConnectionStatus == database.ConnectionStatusDisconnected &&
existing.ConnectionStatus != database.ConnectionStatusDisconnected {
b.dedupedBatch[connID] = entry
} else if item.Time.After(existing.Time) {
b.dedupedBatch[connID] = entry
}
}
// batchLen returns the total number of entries currently buffered.
func (b *DBBatcher) batchLen() int {
return len(b.dedupedBatch) + len(b.nullConnIDBatch)
}
func (b *DBBatcher) run(ctx context.Context) {
//nolint:gocritic // System-level batch operation for connection logs.
authCtx := dbauthz.AsConnectionLogger(ctx)
for ctx.Err() == nil {
select {
case item := <-b.itemCh:
b.addToBatch(item)
if b.batchLen() >= b.maxBatchSize {
b.flush(authCtx)
b.timer.Reset(b.interval, "connectionLogBatcher", "capacityFlush")
}
case <-b.timer.C:
b.flush(authCtx)
b.timer.Reset(b.interval, "connectionLogBatcher", "scheduledFlush")
case <-ctx.Done():
}
}
b.log.Debug(ctx, "context done, flushing before exit")
// Drain any remaining items from the channel.
for {
select {
case item := <-b.itemCh:
b.addToBatch(item)
default:
if b.batchLen() > 0 {
b.shutdownBatch(b.buildParams())
}
// Signal the retry worker to skip delays and close
// the channel so it exits after processing any
// remaining items.
// Mark the batcher as closed so that any subsequent
// Upsert calls fail immediately instead of sending
// into itemCh after the run loop has exited.
close(b.retryCh)
return
}
}
}
// batchEntry wraps a connection log event with explicit connect and
// disconnect times. When a connect and disconnect for the same session
// are merged into one entry, connectTime preserves the original
// session start while disconnectTime records when it ended.
type batchEntry struct {
database.UpsertConnectionLogParams
connectTime time.Time
disconnectTime time.Time
}
// flush builds the batch params, clears the in-memory batch, and
// writes to the database. On failure, the batch is queued for retry
// by the single retry worker goroutine. If the retry queue is full,
// the batch is dropped.
func (b *DBBatcher) flush(ctx context.Context) {
count := b.batchLen()
if count == 0 {
return
}
params := b.buildParams()
// Clear the batch before writing so the run loop can start
// accumulating new entries.
b.dedupedBatch = make(map[uuid.UUID]batchEntry, b.maxBatchSize)
b.nullConnIDBatch = nil
// Use the batcher's context for normal operation so Close()
// can cancel hung writes. During shutdown (ctx already canceled),
// fall back to a bounded timeout.
writeCtx := b.ctx
if writeCtx.Err() != nil {
var cancel context.CancelFunc
writeCtx, cancel = context.WithTimeout(context.Background(), shutdownWriteTimeout)
defer cancel()
}
//nolint:gocritic // System-level batch operation for connection logs.
err := b.store.BatchUpsertConnectionLogs(dbauthz.AsConnectionLogger(writeCtx), params)
if err == nil {
return
}
b.log.Error(ctx, "batch upsert failed, queueing for retry",
slog.Error(err), slog.F("count", count))
// Don't retry on shutdown.
if ctx.Err() != nil {
return
}
select {
case b.retryCh <- params:
default:
b.log.Error(ctx, "retry queue full, dropping batch",
slog.F("dropped", count))
}
}
func (b *DBBatcher) buildParams() database.BatchUpsertConnectionLogsParams {
count := b.batchLen()
var (
ids = make([]uuid.UUID, 0, count)
connectTime = make([]time.Time, 0, count)
organizationID = make([]uuid.UUID, 0, count)
workspaceOwnerID = make([]uuid.UUID, 0, count)
workspaceID = make([]uuid.UUID, 0, count)
workspaceName = make([]string, 0, count)
agentName = make([]string, 0, count)
connType = make([]database.ConnectionType, 0, count)
code = make([]int32, 0, count)
codeValid = make([]bool, 0, count)
ip = make([]pqtype.Inet, 0, count)
userAgent = make([]string, 0, count)
userID = make([]uuid.UUID, 0, count)
slugOrPort = make([]string, 0, count)
connectionID = make([]uuid.UUID, 0, count)
disconnectReason = make([]string, 0, count)
disconnectTime = make([]time.Time, 0, count)
)
appendEntry := func(e batchEntry) {
ids = append(ids, e.ID)
connectTime = append(connectTime, e.connectTime)
organizationID = append(organizationID, e.OrganizationID)
workspaceOwnerID = append(workspaceOwnerID, e.WorkspaceOwnerID)
workspaceID = append(workspaceID, e.WorkspaceID)
workspaceName = append(workspaceName, e.WorkspaceName)
agentName = append(agentName, e.AgentName)
connType = append(connType, e.Type)
code = append(code, e.Code.Int32)
codeValid = append(codeValid, e.Code.Valid)
ip = append(ip, e.IP)
userAgent = append(userAgent, e.UserAgent.String)
userID = append(userID, e.UserID.UUID)
slugOrPort = append(slugOrPort, e.SlugOrPort.String)
connectionID = append(connectionID, e.ConnectionID.UUID)
disconnectReason = append(disconnectReason, e.DisconnectReason.String)
disconnectTime = append(disconnectTime, e.disconnectTime)
}
for _, entry := range b.dedupedBatch {
appendEntry(entry)
}
for _, entry := range b.nullConnIDBatch {
appendEntry(entry)
}
return database.BatchUpsertConnectionLogsParams{
ID: ids,
ConnectTime: connectTime,
OrganizationID: organizationID,
WorkspaceOwnerID: workspaceOwnerID,
WorkspaceID: workspaceID,
WorkspaceName: workspaceName,
AgentName: agentName,
Type: connType,
Code: code,
CodeValid: codeValid,
Ip: ip,
UserAgent: userAgent,
UserID: userID,
SlugOrPort: slugOrPort,
ConnectionID: connectionID,
DisconnectReason: disconnectReason,
DisconnectTime: disconnectTime,
}
}
// retryLoop is a single background goroutine that processes failed
// batches from retryCh. Each batch is retried up to maxRetries times
// with a fixed delay between attempts. When draining is set (shutdown),
// batches get a single immediate write attempt instead. The loop exits
// when retryCh is closed by the run goroutine.
func (b *DBBatcher) retryLoop() {
for params := range b.retryCh {
b.retryBatch(params)
}
}
// retryBatch retries writing a batch up to maxRetries times with a
// fixed delay between attempts. If the batcher context is canceled
// during a wait, one final attempt is made before returning.
func (b *DBBatcher) retryBatch(params database.BatchUpsertConnectionLogsParams) {
count := len(params.ID)
for attempt := range maxRetries {
t := time.NewTimer(retryInterval)
select {
case <-b.ctx.Done():
b.shutdownBatch(params)
return
case <-t.C:
}
//nolint:gocritic // System-level batch operation for connection logs.
err := b.store.BatchUpsertConnectionLogs(dbauthz.AsConnectionLogger(b.ctx), params)
if err == nil {
return
}
b.log.Warn(b.ctx, "batch retry failed",
slog.Error(err),
slog.F("count", count),
slog.F("attempt", attempt+1),
slog.F("max_attempts", maxRetries),
)
}
b.log.Error(b.ctx, "batch retries exhausted, dropping batch",
slog.F("dropped", count))
}
// shutdownBatch makes a single write attempt during shutdown with a
// bounded timeout so it can't hang indefinitely.
func (b *DBBatcher) shutdownBatch(params database.BatchUpsertConnectionLogsParams) {
ctx, cancel := context.WithTimeout(context.Background(), shutdownWriteTimeout)
defer cancel()
//nolint:gocritic // System-level batch operation for connection logs.
err := b.store.BatchUpsertConnectionLogs(dbauthz.AsConnectionLogger(ctx), params)
if err != nil {
b.log.Error(b.ctx, "batch write failed on shutdown, dropping batch",
slog.Error(err), slog.F("dropped", len(params.ID)))
}
}
type connectionSlogBackend struct {
exporter *auditbackends.SlogExporter
}
// NewSlogBackend returns a Backend that logs connection events via
// the structured logger.
func NewSlogBackend(logger slog.Logger) Backend {
return &connectionSlogBackend{
exporter: auditbackends.NewSlogExporter(logger),
@@ -0,0 +1,529 @@
package connectionlog
import (
"context"
"database/sql"
"sync"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbmock"
"github.com/coder/coder/v2/testutil"
"github.com/coder/quartz"
)
func Test_addToBatch(t *testing.T) {
t.Parallel()
t.Run("ConnectThenDisconnect", func(t *testing.T) {
t.Parallel()
b := &DBBatcher{
maxBatchSize: 100,
dedupedBatch: make(map[uuid.UUID]batchEntry),
}
wsID := uuid.New()
connID := uuid.New()
connect := fakeConnectEvent(wsID, "agent1", connID)
disconnect := fakeDisconnectEvent(wsID, "agent1", connID)
b.addToBatch(connect)
b.addToBatch(disconnect)
require.Equal(t, 1, b.batchLen())
key := connID
got := b.dedupedBatch[key]
require.Equal(t, disconnect.ID, got.ID)
require.Equal(t, database.ConnectionStatusDisconnected, got.ConnectionStatus)
// The connect_time should be preserved from the original
// connect event, not overwritten by the disconnect's
// timestamp.
require.Equal(t, connect.Time, got.connectTime)
require.Equal(t, disconnect.Time, got.disconnectTime)
})
t.Run("DisconnectThenLaterConnect", func(t *testing.T) {
t.Parallel()
b := &DBBatcher{
maxBatchSize: 100,
dedupedBatch: make(map[uuid.UUID]batchEntry),
}
wsID := uuid.New()
connID := uuid.New()
disconnect := fakeDisconnectEvent(wsID, "agent1", connID)
connect := fakeConnectEvent(wsID, "agent1", connID)
connect.Time = disconnect.Time.Add(time.Second)
b.addToBatch(disconnect)
b.addToBatch(connect)
require.Equal(t, 1, b.batchLen())
key := connID
// The later event wins when the incoming item is not a
// disconnect. In practice, this case doesn't occur because
// connection IDs are never reused.
got := b.dedupedBatch[key]
require.Equal(t, connect.ID, got.ID)
// The disconnect's time should be preserved even though
// the connect event replaced it.
require.Equal(t, disconnect.Time, got.disconnectTime)
})
t.Run("DisconnectThenEarlierConnect", func(t *testing.T) {
t.Parallel()
b := &DBBatcher{
maxBatchSize: 100,
dedupedBatch: make(map[uuid.UUID]batchEntry),
}
wsID := uuid.New()
connID := uuid.New()
disconnect := fakeDisconnectEvent(wsID, "agent1", connID)
connect := fakeConnectEvent(wsID, "agent1", connID)
connect.Time = disconnect.Time.Add(-time.Second)
b.addToBatch(disconnect)
b.addToBatch(connect)
require.Equal(t, 1, b.batchLen())
key := connID
require.Equal(t, disconnect.ID, b.dedupedBatch[key].ID)
})
t.Run("SameStatusKeepsLater", func(t *testing.T) {
t.Parallel()
b := &DBBatcher{
maxBatchSize: 100,
dedupedBatch: make(map[uuid.UUID]batchEntry),
}
wsID := uuid.New()
connID := uuid.New()
early := fakeConnectEvent(wsID, "agent1", connID)
early.Time = time.Now()
late := fakeConnectEvent(wsID, "agent1", connID)
late.Time = early.Time.Add(time.Second)
b.addToBatch(early)
b.addToBatch(late)
require.Equal(t, 1, b.batchLen())
key := connID
require.Equal(t, late.ID, b.dedupedBatch[key].ID)
})
t.Run("NullConnIDsNeverDedup", func(t *testing.T) {
t.Parallel()
b := &DBBatcher{
maxBatchSize: 100,
dedupedBatch: make(map[uuid.UUID]batchEntry),
}
evt1 := fakeNullConnIDEvent()
evt2 := fakeNullConnIDEvent()
evt2.WorkspaceID = evt1.WorkspaceID
evt2.AgentName = evt1.AgentName
b.addToBatch(evt1)
b.addToBatch(evt2)
require.Equal(t, 2, b.batchLen())
require.Len(t, b.nullConnIDBatch, 2)
require.Empty(t, b.dedupedBatch)
})
t.Run("MixedNullAndNonNull", func(t *testing.T) {
t.Parallel()
b := &DBBatcher{
maxBatchSize: 100,
dedupedBatch: make(map[uuid.UUID]batchEntry),
}
wsID := uuid.New()
regular := fakeConnectEvent(wsID, "agent1", uuid.New())
nullEvt := fakeNullConnIDEvent()
nullEvt.WorkspaceID = wsID
nullEvt.AgentName = "agent1"
b.addToBatch(regular)
b.addToBatch(nullEvt)
require.Equal(t, 2, b.batchLen())
require.Len(t, b.dedupedBatch, 1)
require.Len(t, b.nullConnIDBatch, 1)
})
t.Run("StandaloneDisconnectUsesTimeAsConnectTime", func(t *testing.T) {
t.Parallel()
b := &DBBatcher{
maxBatchSize: 100,
dedupedBatch: make(map[uuid.UUID]batchEntry),
}
connID := uuid.New()
disconnect := fakeDisconnectEvent(uuid.New(), "agent1", connID)
b.addToBatch(disconnect)
got := b.dedupedBatch[connID]
// A standalone disconnect must not leave connectTime as
// zero — that would insert a year-0001 connect_time in
// the DB. It should use the disconnect's own timestamp,
// matching the single-row UpsertConnectionLog behavior.
require.False(t, got.connectTime.IsZero(),
"standalone disconnect must have non-zero connectTime")
require.Equal(t, disconnect.Time, got.connectTime)
require.Equal(t, disconnect.Time, got.disconnectTime)
})
t.Run("DuplicateDisconnectsPreserveConnectTime", func(t *testing.T) {
t.Parallel()
b := &DBBatcher{
maxBatchSize: 100,
dedupedBatch: make(map[uuid.UUID]batchEntry),
}
wsID := uuid.New()
connID := uuid.New()
connect := fakeConnectEvent(wsID, "agent1", connID)
disconnect1 := fakeDisconnectEvent(wsID, "agent1", connID)
disconnect2 := fakeDisconnectEvent(wsID, "agent1", connID)
disconnect2.Time = disconnect1.Time.Add(time.Second)
b.addToBatch(connect)
b.addToBatch(disconnect1)
b.addToBatch(disconnect2)
require.Equal(t, 1, b.batchLen())
got := b.dedupedBatch[connID]
// The second disconnect should win (later event) but the
// original connect_time from the connect event must be
// preserved, not regressed to the disconnect's timestamp.
require.Equal(t, disconnect2.ID, got.ID)
require.Equal(t, connect.Time, got.connectTime,
"connect_time must not regress to disconnect timestamp")
require.Equal(t, disconnect2.Time, got.disconnectTime)
})
}
func Test_batcherFlush(t *testing.T) {
t.Parallel()
t.Run("DeduplicatesConnectDisconnect", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
log := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
ctrl := gomock.NewController(t)
store := dbmock.NewMockStore(ctrl)
clock := quartz.NewMock(t)
b := NewDBBatcher(ctx, store, log, WithClock(clock), WithBatchSize(100))
wsID := uuid.New()
connID := uuid.New()
connect := fakeConnectEvent(wsID, "agent1", connID)
disconnect := fakeDisconnectEvent(wsID, "agent1", connID)
// Expect a single batch with only the disconnect event.
store.EXPECT().
BatchUpsertConnectionLogs(gomock.Any(), batchParamsMatcher{
expectedCount: 1,
mustContainIDs: []uuid.UUID{disconnect.ID},
mustNotContainIDs: []uuid.UUID{connect.ID},
}).
Return(nil).
Times(1)
require.NoError(t, b.Upsert(ctx, connect))
require.NoError(t, b.Upsert(ctx, disconnect))
require.NoError(t, b.Close())
})
t.Run("DoesNotDeduplicateNullConnIDs", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
log := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
ctrl := gomock.NewController(t)
store := dbmock.NewMockStore(ctrl)
clock := quartz.NewMock(t)
b := NewDBBatcher(ctx, store, log, WithClock(clock), WithBatchSize(100))
evt1 := fakeNullConnIDEvent()
evt2 := fakeNullConnIDEvent()
evt2.WorkspaceID = evt1.WorkspaceID
evt2.AgentName = evt1.AgentName
store.EXPECT().
BatchUpsertConnectionLogs(gomock.Any(), batchParamsMatcher{
expectedCount: 2,
mustContainIDs: []uuid.UUID{evt1.ID, evt2.ID},
}).
Return(nil).
Times(1)
require.NoError(t, b.Upsert(ctx, evt1))
require.NoError(t, b.Upsert(ctx, evt2))
require.NoError(t, b.Close())
})
t.Run("DoesNotDeduplicateDifferentConnectionIDs", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
log := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
ctrl := gomock.NewController(t)
store := dbmock.NewMockStore(ctrl)
clock := quartz.NewMock(t)
b := NewDBBatcher(ctx, store, log, WithClock(clock), WithBatchSize(100))
wsID := uuid.New()
evt1 := fakeConnectEvent(wsID, "agent1", uuid.New())
evt2 := fakeConnectEvent(wsID, "agent1", uuid.New())
store.EXPECT().
BatchUpsertConnectionLogs(gomock.Any(), batchParamsMatcher{
expectedCount: 2,
mustContainIDs: []uuid.UUID{evt1.ID, evt2.ID},
}).
Return(nil).
Times(1)
require.NoError(t, b.Upsert(ctx, evt1))
require.NoError(t, b.Upsert(ctx, evt2))
require.NoError(t, b.Close())
})
t.Run("CloseFlushesMultipleEvents", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
log := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
ctrl := gomock.NewController(t)
store := dbmock.NewMockStore(ctrl)
clock := quartz.NewMock(t)
b := NewDBBatcher(ctx, store, log, WithClock(clock), WithBatchSize(100))
evt1 := fakeConnectEvent(uuid.New(), "agent1", uuid.New())
evt2 := fakeConnectEvent(uuid.New(), "agent2", uuid.New())
store.EXPECT().
BatchUpsertConnectionLogs(gomock.Any(), batchParamsMatcher{
expectedCount: 2,
mustContainIDs: []uuid.UUID{evt1.ID, evt2.ID},
}).
Return(nil).
Times(1)
require.NoError(t, b.Upsert(ctx, evt1))
require.NoError(t, b.Upsert(ctx, evt2))
require.NoError(t, b.Close())
})
t.Run("RetriesOnTransientFailure", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
log := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
ctrl := gomock.NewController(t)
store := dbmock.NewMockStore(ctrl)
clock := quartz.NewMock(t)
scheduledTrap := clock.Trap().TimerReset("connectionLogBatcher", "scheduledFlush")
defer scheduledTrap.Close()
b := NewDBBatcher(ctx, store, log, WithClock(clock), WithBatchSize(100))
evt := fakeConnectEvent(uuid.New(), "agent1", uuid.New())
// First call (synchronous in flush) fails, then the
// retry worker retries after the backoff and succeeds.
gomock.InOrder(
store.EXPECT().
BatchUpsertConnectionLogs(gomock.Any(), gomock.Any()).
Return(xerrors.New("transient error")).
Times(1),
store.EXPECT().
BatchUpsertConnectionLogs(gomock.Any(), batchParamsMatcher{
expectedCount: 1,
mustContainIDs: []uuid.UUID{evt.ID},
}).
Return(nil).
Times(1),
)
require.NoError(t, b.Upsert(ctx, evt))
// Trigger a scheduled flush while the batcher is still
// running. The synchronous write fails and queues to
// retryCh. The retry worker picks it up after a real-
// time 1s delay and succeeds.
clock.Advance(defaultFlushInterval).MustWait(ctx)
scheduledTrap.MustWait(ctx).MustRelease(ctx)
// Wait for the retry to complete (real-time 1s delay).
require.NoError(t, b.Close())
})
t.Run("ShutdownDrainsRetryQueue", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
log := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
ctrl := gomock.NewController(t)
store := dbmock.NewMockStore(ctrl)
clock := quartz.NewMock(t)
scheduledTrap := clock.Trap().TimerReset("connectionLogBatcher", "scheduledFlush")
defer scheduledTrap.Close()
b := NewDBBatcher(ctx, store, log, WithClock(clock), WithBatchSize(100))
evt := fakeConnectEvent(uuid.New(), "agent1", uuid.New())
// Track all successfully written IDs.
var writtenIDs []uuid.UUID
var mu sync.Mutex
firstCall := true
store.EXPECT().
BatchUpsertConnectionLogs(gomock.Any(), gomock.Any()).
DoAndReturn(func(_ context.Context, p database.BatchUpsertConnectionLogsParams) error {
mu.Lock()
defer mu.Unlock()
// First call (synchronous flush) fails, queueing
// the batch for retry.
if firstCall {
firstCall = false
return xerrors.New("transient error")
}
// Drain/retry attempts succeed.
writtenIDs = append(writtenIDs, p.ID...)
return nil
}).
AnyTimes()
// Send event and trigger flush — fails, queues.
require.NoError(t, b.Upsert(ctx, evt))
clock.Advance(defaultFlushInterval).MustWait(ctx)
scheduledTrap.MustWait(ctx).MustRelease(ctx)
// Close triggers shutdown. The retry worker drains
// retryCh and writes the batch via writeBatch.
require.NoError(t, b.Close())
mu.Lock()
defer mu.Unlock()
require.Contains(t, writtenIDs, evt.ID,
"event should be written during shutdown drain")
})
}
// batchParamsMatcher validates BatchUpsertConnectionLogsParams by
// checking count and specific IDs.
type batchParamsMatcher struct {
expectedCount int
mustContainIDs []uuid.UUID
mustNotContainIDs []uuid.UUID
}
func (m batchParamsMatcher) Matches(x interface{}) bool {
params, ok := x.(database.BatchUpsertConnectionLogsParams)
if !ok {
return false
}
if m.expectedCount > 0 && len(params.ID) != m.expectedCount {
return false
}
idSet := make(map[uuid.UUID]struct{}, len(params.ID))
for _, id := range params.ID {
idSet[id] = struct{}{}
}
for _, id := range m.mustContainIDs {
if _, ok := idSet[id]; !ok {
return false
}
}
for _, id := range m.mustNotContainIDs {
if _, ok := idSet[id]; ok {
return false
}
}
return true
}
func (batchParamsMatcher) String() string {
return "batch upsert params matcher"
}
func fakeConnectEvent(workspaceID uuid.UUID, agentName string, connectionID uuid.UUID) database.UpsertConnectionLogParams {
return database.UpsertConnectionLogParams{
ID: uuid.New(),
Time: time.Now(),
OrganizationID: uuid.New(),
WorkspaceOwnerID: uuid.New(),
WorkspaceID: workspaceID,
WorkspaceName: "test-workspace",
AgentName: agentName,
Type: database.ConnectionTypeSsh,
ConnectionID: uuid.NullUUID{UUID: connectionID, Valid: true},
ConnectionStatus: database.ConnectionStatusConnected,
}
}
func fakeDisconnectEvent(workspaceID uuid.UUID, agentName string, connectionID uuid.UUID) database.UpsertConnectionLogParams {
return database.UpsertConnectionLogParams{
ID: uuid.New(),
Time: time.Now().Add(time.Second),
OrganizationID: uuid.New(),
WorkspaceOwnerID: uuid.New(),
WorkspaceID: workspaceID,
WorkspaceName: "test-workspace",
AgentName: agentName,
Type: database.ConnectionTypeSsh,
ConnectionID: uuid.NullUUID{UUID: connectionID, Valid: true},
ConnectionStatus: database.ConnectionStatusDisconnected,
Code: sql.NullInt32{Int32: 0, Valid: true},
DisconnectReason: sql.NullString{String: "normal", Valid: true},
}
}
func fakeNullConnIDEvent() database.UpsertConnectionLogParams {
return database.UpsertConnectionLogParams{
ID: uuid.New(),
Time: time.Now(),
OrganizationID: uuid.New(),
WorkspaceOwnerID: uuid.New(),
WorkspaceID: uuid.New(),
WorkspaceName: "test-workspace",
AgentName: "test-agent",
Type: database.ConnectionTypeWorkspaceApp,
ConnectionID: uuid.NullUUID{},
ConnectionStatus: database.ConnectionStatusConnected,
}
}
@@ -0,0 +1,371 @@
package connectionlog_test
import (
"database/sql"
"net"
"testing"
"time"
"github.com/google/uuid"
"github.com/sqlc-dev/pqtype"
"github.com/stretchr/testify/require"
"cdr.dev/slog/v3"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/enterprise/coderd/connectionlog"
"github.com/coder/coder/v2/testutil"
"github.com/coder/quartz"
)
func createWorkspace(t *testing.T, db database.Store) database.WorkspaceTable {
t.Helper()
u := dbgen.User(t, db, database.User{})
o := dbgen.Organization(t, db, database.Organization{})
tpl := dbgen.Template(t, db, database.Template{
OrganizationID: o.ID,
CreatedBy: u.ID,
})
return dbgen.Workspace(t, db, database.WorkspaceTable{
ID: uuid.New(),
OwnerID: u.ID,
OrganizationID: o.ID,
AutomaticUpdates: database.AutomaticUpdatesNever,
TemplateID: tpl.ID,
})
}
func testIP() pqtype.Inet {
return pqtype.Inet{
IPNet: net.IPNet{
IP: net.IPv4(127, 0, 0, 1),
Mask: net.IPv4Mask(255, 255, 255, 255),
},
Valid: true,
}
}
func TestDBBackendIntegration(t *testing.T) {
t.Parallel()
t.Run("SingleConnect", func(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
ctx := testutil.Context(t, testutil.WaitShort)
log := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
clock := quartz.NewMock(t)
ws := createWorkspace(t, db)
//nolint:gocritic // Test needs system context for the batcher.
backend := connectionlog.NewDBBatcher(
dbauthz.AsConnectionLogger(ctx), db, log,
connectionlog.WithClock(clock),
connectionlog.WithBatchSize(100),
)
connID := uuid.New()
connectTime := dbtime.Now()
err := backend.Upsert(ctx, database.UpsertConnectionLogParams{
ID: uuid.New(),
Time: connectTime,
OrganizationID: ws.OrganizationID,
WorkspaceOwnerID: ws.OwnerID,
WorkspaceID: ws.ID,
WorkspaceName: ws.Name,
AgentName: "main",
Type: database.ConnectionTypeSsh,
ConnectionID: uuid.NullUUID{UUID: connID, Valid: true},
ConnectionStatus: database.ConnectionStatusConnected,
IP: testIP(),
})
require.NoError(t, err)
err = backend.Close()
require.NoError(t, err)
rows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{
LimitOpt: 10,
})
require.NoError(t, err)
require.Len(t, rows, 1)
require.Equal(t, connID, rows[0].ConnectionLog.ConnectionID.UUID)
require.False(t, rows[0].ConnectionLog.DisconnectTime.Valid)
})
t.Run("ConnectThenDisconnectSeparateBatches", func(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
ctx := testutil.Context(t, testutil.WaitShort)
log := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
clock := quartz.NewMock(t)
ws := createWorkspace(t, db)
connID := uuid.New()
connectTime := dbtime.Now()
// First batcher: insert connect, close to flush.
//nolint:gocritic // Test needs system context for the batcher.
b1 := connectionlog.NewDBBatcher(
dbauthz.AsConnectionLogger(ctx), db, log,
connectionlog.WithClock(clock),
connectionlog.WithBatchSize(100),
)
err := b1.Upsert(ctx, database.UpsertConnectionLogParams{
ID: uuid.New(),
Time: connectTime,
OrganizationID: ws.OrganizationID,
WorkspaceOwnerID: ws.OwnerID,
WorkspaceID: ws.ID,
WorkspaceName: ws.Name,
AgentName: "main",
Type: database.ConnectionTypeSsh,
ConnectionID: uuid.NullUUID{UUID: connID, Valid: true},
ConnectionStatus: database.ConnectionStatusConnected,
IP: testIP(),
})
require.NoError(t, err)
require.NoError(t, b1.Close())
// Second batcher: insert disconnect, close to flush.
//nolint:gocritic // Test needs system context for the batcher.
b2 := connectionlog.NewDBBatcher(
dbauthz.AsConnectionLogger(ctx), db, log,
connectionlog.WithClock(clock),
connectionlog.WithBatchSize(100),
)
disconnectTime := connectTime.Add(5 * time.Second)
err = b2.Upsert(ctx, database.UpsertConnectionLogParams{
ID: uuid.New(),
Time: disconnectTime,
OrganizationID: ws.OrganizationID,
WorkspaceOwnerID: ws.OwnerID,
WorkspaceID: ws.ID,
WorkspaceName: ws.Name,
AgentName: "main",
Type: database.ConnectionTypeSsh,
ConnectionID: uuid.NullUUID{UUID: connID, Valid: true},
ConnectionStatus: database.ConnectionStatusDisconnected,
Code: sql.NullInt32{Int32: 0, Valid: true},
DisconnectReason: sql.NullString{String: "client left", Valid: true},
IP: testIP(),
})
require.NoError(t, err)
require.NoError(t, b2.Close())
rows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{
LimitOpt: 10,
})
require.NoError(t, err)
require.Len(t, rows, 1, "connect+disconnect should produce one row")
require.True(t, rows[0].ConnectionLog.DisconnectTime.Valid)
require.Equal(t, "client left", rows[0].ConnectionLog.DisconnectReason.String)
})
t.Run("ConnectAndDisconnectSameBatch", func(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
ctx := testutil.Context(t, testutil.WaitShort)
log := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
clock := quartz.NewMock(t)
ws := createWorkspace(t, db)
//nolint:gocritic // Test needs system context for the batcher.
backend := connectionlog.NewDBBatcher(
dbauthz.AsConnectionLogger(ctx), db, log,
connectionlog.WithClock(clock),
connectionlog.WithBatchSize(100),
)
connID := uuid.New()
connectTime := dbtime.Now()
disconnectTime := connectTime.Add(time.Second)
// Both events in the same batch window.
err := backend.Upsert(ctx, database.UpsertConnectionLogParams{
ID: uuid.New(),
Time: connectTime,
OrganizationID: ws.OrganizationID,
WorkspaceOwnerID: ws.OwnerID,
WorkspaceID: ws.ID,
WorkspaceName: ws.Name,
AgentName: "main",
Type: database.ConnectionTypeSsh,
ConnectionID: uuid.NullUUID{UUID: connID, Valid: true},
ConnectionStatus: database.ConnectionStatusConnected,
IP: testIP(),
})
require.NoError(t, err)
err = backend.Upsert(ctx, database.UpsertConnectionLogParams{
ID: uuid.New(),
Time: disconnectTime,
OrganizationID: ws.OrganizationID,
WorkspaceOwnerID: ws.OwnerID,
WorkspaceID: ws.ID,
WorkspaceName: ws.Name,
AgentName: "main",
Type: database.ConnectionTypeSsh,
ConnectionID: uuid.NullUUID{UUID: connID, Valid: true},
ConnectionStatus: database.ConnectionStatusDisconnected,
Code: sql.NullInt32{Int32: 0, Valid: true},
DisconnectReason: sql.NullString{String: "done", Valid: true},
IP: testIP(),
})
require.NoError(t, err)
// Close drains channel and flushes — dedup keeps disconnect.
err = backend.Close()
require.NoError(t, err)
rows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{
LimitOpt: 10,
})
require.NoError(t, err)
require.Len(t, rows, 1)
require.True(t, rows[0].ConnectionLog.DisconnectTime.Valid)
require.Equal(t, "done", rows[0].ConnectionLog.DisconnectReason.String)
})
t.Run("MultipleIndependentConnections", func(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
ctx := testutil.Context(t, testutil.WaitShort)
log := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
clock := quartz.NewMock(t)
ws := createWorkspace(t, db)
//nolint:gocritic // Test needs system context for the batcher.
backend := connectionlog.NewDBBatcher(
dbauthz.AsConnectionLogger(ctx), db, log,
connectionlog.WithClock(clock),
connectionlog.WithBatchSize(100),
)
now := dbtime.Now()
for i := 0; i < 5; i++ {
err := backend.Upsert(ctx, database.UpsertConnectionLogParams{
ID: uuid.New(),
Time: now,
OrganizationID: ws.OrganizationID,
WorkspaceOwnerID: ws.OwnerID,
WorkspaceID: ws.ID,
WorkspaceName: ws.Name,
AgentName: "main",
Type: database.ConnectionTypeSsh,
ConnectionID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
ConnectionStatus: database.ConnectionStatusConnected,
IP: testIP(),
})
require.NoError(t, err)
}
err := backend.Close()
require.NoError(t, err)
rows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{
LimitOpt: 10,
})
require.NoError(t, err)
require.Len(t, rows, 5)
})
t.Run("NullConnectionIDWebEvents", func(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
ctx := testutil.Context(t, testutil.WaitShort)
log := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
clock := quartz.NewMock(t)
ws := createWorkspace(t, db)
//nolint:gocritic // Test needs system context for the batcher.
backend := connectionlog.NewDBBatcher(
dbauthz.AsConnectionLogger(ctx), db, log,
connectionlog.WithClock(clock),
connectionlog.WithBatchSize(100),
)
now := dbtime.Now()
for i := 0; i < 2; i++ {
err := backend.Upsert(ctx, database.UpsertConnectionLogParams{
ID: uuid.New(),
Time: now,
OrganizationID: ws.OrganizationID,
WorkspaceOwnerID: ws.OwnerID,
WorkspaceID: ws.ID,
WorkspaceName: ws.Name,
AgentName: "main",
Type: database.ConnectionTypeWorkspaceApp,
ConnectionID: uuid.NullUUID{},
ConnectionStatus: database.ConnectionStatusConnected,
IP: testIP(),
})
require.NoError(t, err)
}
err := backend.Close()
require.NoError(t, err)
rows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{
LimitOpt: 10,
})
require.NoError(t, err)
require.Len(t, rows, 2, "null connection_id events should not be deduplicated")
})
t.Run("CloseFlushesToDB", func(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
ctx := testutil.Context(t, testutil.WaitShort)
log := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
clock := quartz.NewMock(t)
ws := createWorkspace(t, db)
//nolint:gocritic // Test needs system context for the batcher.
backend := connectionlog.NewDBBatcher(
dbauthz.AsConnectionLogger(ctx), db, log,
connectionlog.WithClock(clock),
connectionlog.WithBatchSize(100),
)
err := backend.Upsert(ctx, database.UpsertConnectionLogParams{
ID: uuid.New(),
Time: dbtime.Now(),
OrganizationID: ws.OrganizationID,
WorkspaceOwnerID: ws.OwnerID,
WorkspaceID: ws.ID,
WorkspaceName: ws.Name,
AgentName: "main",
Type: database.ConnectionTypeSsh,
ConnectionID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
ConnectionStatus: database.ConnectionStatusConnected,
IP: testIP(),
})
require.NoError(t, err)
// Close without advancing clock — final flush should write.
err = backend.Close()
require.NoError(t, err)
rows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{
LimitOpt: 10,
})
require.NoError(t, err)
require.Len(t, rows, 1)
})
}
+1 -1
View File
@@ -227,7 +227,7 @@ func TestConnectionLogs(t *testing.T) {
Int32: 0,
Valid: false,
},
Ip: pqtype.Inet{IPNet: net.IPNet{
IP: pqtype.Inet{IPNet: net.IPNet{
IP: net.ParseIP("192.168.0.1"),
Mask: net.CIDRMask(8, 32),
}, Valid: true},
+3 -3
View File
@@ -784,7 +784,7 @@ func TestIssueSignedAppToken(t *testing.T) {
require.NoError(t, err)
require.True(t, connectionLogger.Contains(t, database.UpsertConnectionLogParams{
Ip: parsedFakeClientIP,
IP: parsedFakeClientIP,
}))
})
@@ -812,7 +812,7 @@ func TestIssueSignedAppToken(t *testing.T) {
}
require.True(t, connectionLogger.Contains(t, database.UpsertConnectionLogParams{
Ip: parsedFakeClientIP,
IP: parsedFakeClientIP,
}))
})
}
@@ -1020,7 +1020,7 @@ func TestReconnectingPTYSignedToken(t *testing.T) {
// validate it here.
require.True(t, connectionLogger.Contains(t, database.UpsertConnectionLogParams{
Ip: pqtype.Inet{
IP: pqtype.Inet{
Valid: true, IPNet: net.IPNet{
IP: net.ParseIP("127.0.0.1"),
Mask: net.CIDRMask(32, 32),