feat: batch connection logs to avoid DB lock contention (#23727)
- Running 30k connections was generating a ton of lock contention in the DB
This commit is contained in:
@@ -85,7 +85,7 @@ func (a *ConnLogAPI) ReportConnection(ctx context.Context, req *agentproto.Repor
|
||||
AgentName: a.AgentName,
|
||||
Type: connectionType,
|
||||
Code: code,
|
||||
Ip: logIP,
|
||||
IP: logIP,
|
||||
ConnectionID: uuid.NullUUID{
|
||||
UUID: connectionID,
|
||||
Valid: true,
|
||||
|
||||
@@ -152,7 +152,7 @@ func TestConnectionLog(t *testing.T) {
|
||||
Int32: tt.status,
|
||||
Valid: *tt.action == agentproto.Connection_DISCONNECT,
|
||||
},
|
||||
Ip: expectedIP,
|
||||
IP: expectedIP,
|
||||
Type: agentProtoConnectionTypeToConnectionLog(t, *tt.typ),
|
||||
DisconnectReason: sql.NullString{
|
||||
String: tt.reason,
|
||||
|
||||
@@ -90,8 +90,8 @@ func (m *FakeConnectionLogger) Contains(t testing.TB, expected database.UpsertCo
|
||||
t.Logf("connection log %d: expected Code %d, got %d", idx+1, expected.Code.Int32, cl.Code.Int32)
|
||||
continue
|
||||
}
|
||||
if expected.Ip.Valid && cl.Ip.IPNet.String() != expected.Ip.IPNet.String() {
|
||||
t.Logf("connection log %d: expected IP %s, got %s", idx+1, expected.Ip.IPNet, cl.Ip.IPNet)
|
||||
if expected.IP.Valid && cl.IP.IPNet.String() != expected.IP.IPNet.String() {
|
||||
t.Logf("connection log %d: expected IP %s, got %s", idx+1, expected.IP.IPNet, cl.IP.IPNet)
|
||||
continue
|
||||
}
|
||||
if expected.UserAgent.Valid && cl.UserAgent.String != expected.UserAgent.String {
|
||||
|
||||
@@ -1627,6 +1627,13 @@ func (q *querier) BatchUpdateWorkspaceNextStartAt(ctx context.Context, arg datab
|
||||
return q.db.BatchUpdateWorkspaceNextStartAt(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) BatchUpsertConnectionLogs(ctx context.Context, arg database.BatchUpsertConnectionLogsParams) error {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceConnectionLog); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.BatchUpsertConnectionLogs(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) BulkMarkNotificationMessagesFailed(ctx context.Context, arg database.BulkMarkNotificationMessagesFailedParams) (int64, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceNotificationMessage); err != nil {
|
||||
return 0, err
|
||||
@@ -7065,13 +7072,6 @@ func (q *querier) UpsertChatWorkspaceTTL(ctx context.Context, workspaceTtl strin
|
||||
return q.db.UpsertChatWorkspaceTTL(ctx, workspaceTtl)
|
||||
}
|
||||
|
||||
func (q *querier) UpsertConnectionLog(ctx context.Context, arg database.UpsertConnectionLogParams) (database.ConnectionLog, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceConnectionLog); err != nil {
|
||||
return database.ConnectionLog{}, err
|
||||
}
|
||||
return q.db.UpsertConnectionLog(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpsertDefaultProxy(ctx context.Context, arg database.UpsertDefaultProxyParams) error {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceSystem); err != nil {
|
||||
return err
|
||||
|
||||
@@ -338,10 +338,9 @@ func (s *MethodTestSuite) TestAuditLogs() {
|
||||
}
|
||||
|
||||
func (s *MethodTestSuite) TestConnectionLogs() {
|
||||
s.Run("UpsertConnectionLog", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
ws := testutil.Fake(s.T(), faker, database.WorkspaceTable{})
|
||||
arg := database.UpsertConnectionLogParams{Ip: defaultIPAddress(), Type: database.ConnectionTypeSsh, WorkspaceID: ws.ID, OrganizationID: ws.OrganizationID, ConnectionStatus: database.ConnectionStatusConnected, WorkspaceOwnerID: ws.OwnerID}
|
||||
dbm.EXPECT().UpsertConnectionLog(gomock.Any(), arg).Return(database.ConnectionLog{}, nil).AnyTimes()
|
||||
s.Run("BatchUpsertConnectionLogs", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
arg := database.BatchUpsertConnectionLogsParams{}
|
||||
dbm.EXPECT().BatchUpsertConnectionLogs(gomock.Any(), arg).Return(nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceConnectionLog, policy.ActionUpdate)
|
||||
}))
|
||||
s.Run("GetConnectionLogsOffset", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
|
||||
@@ -76,7 +76,7 @@ func AuditLog(t testing.TB, db database.Store, seed database.AuditLog) database.
|
||||
}
|
||||
|
||||
func ConnectionLog(t testing.TB, db database.Store, seed database.UpsertConnectionLogParams) database.ConnectionLog {
|
||||
log, err := db.UpsertConnectionLog(genCtx, database.UpsertConnectionLogParams{
|
||||
arg := database.UpsertConnectionLogParams{
|
||||
ID: takeFirst(seed.ID, uuid.New()),
|
||||
Time: takeFirst(seed.Time, dbtime.Now()),
|
||||
OrganizationID: takeFirst(seed.OrganizationID, uuid.New()),
|
||||
@@ -89,7 +89,7 @@ func ConnectionLog(t testing.TB, db database.Store, seed database.UpsertConnecti
|
||||
Int32: takeFirst(seed.Code.Int32, 0),
|
||||
Valid: takeFirst(seed.Code.Valid, false),
|
||||
},
|
||||
Ip: pqtype.Inet{
|
||||
IP: pqtype.Inet{
|
||||
IPNet: net.IPNet{
|
||||
IP: net.IPv4(127, 0, 0, 1),
|
||||
Mask: net.IPv4Mask(255, 255, 255, 255),
|
||||
@@ -117,9 +117,53 @@ func ConnectionLog(t testing.TB, db database.Store, seed database.UpsertConnecti
|
||||
Valid: takeFirst(seed.DisconnectReason.Valid, false),
|
||||
},
|
||||
ConnectionStatus: takeFirst(seed.ConnectionStatus, database.ConnectionStatusConnected),
|
||||
}
|
||||
|
||||
var disconnectTime sql.NullTime
|
||||
if arg.ConnectionStatus == database.ConnectionStatusDisconnected {
|
||||
disconnectTime = sql.NullTime{Time: arg.Time, Valid: true}
|
||||
}
|
||||
|
||||
err := db.BatchUpsertConnectionLogs(genCtx, database.BatchUpsertConnectionLogsParams{
|
||||
ID: []uuid.UUID{arg.ID},
|
||||
ConnectTime: []time.Time{arg.Time},
|
||||
OrganizationID: []uuid.UUID{arg.OrganizationID},
|
||||
WorkspaceOwnerID: []uuid.UUID{arg.WorkspaceOwnerID},
|
||||
WorkspaceID: []uuid.UUID{arg.WorkspaceID},
|
||||
WorkspaceName: []string{arg.WorkspaceName},
|
||||
AgentName: []string{arg.AgentName},
|
||||
Type: []database.ConnectionType{arg.Type},
|
||||
Code: []int32{arg.Code.Int32},
|
||||
CodeValid: []bool{arg.Code.Valid},
|
||||
Ip: []pqtype.Inet{arg.IP},
|
||||
UserAgent: []string{arg.UserAgent.String},
|
||||
UserID: []uuid.UUID{arg.UserID.UUID},
|
||||
SlugOrPort: []string{arg.SlugOrPort.String},
|
||||
ConnectionID: []uuid.UUID{arg.ConnectionID.UUID},
|
||||
DisconnectReason: []string{arg.DisconnectReason.String},
|
||||
DisconnectTime: []time.Time{disconnectTime.Time},
|
||||
})
|
||||
require.NoError(t, err, "insert connection log")
|
||||
return log
|
||||
|
||||
// Query back the actual row from the database. On upsert
|
||||
// conflict the DB keeps the original row's ID, so we can't
|
||||
// rely on arg.ID. Match on the conflict key for rows with a
|
||||
// connection_id, or by primary key for NULL connection_id.
|
||||
rows, err := db.GetConnectionLogsOffset(genCtx, database.GetConnectionLogsOffsetParams{})
|
||||
require.NoError(t, err, "query connection logs")
|
||||
for _, row := range rows {
|
||||
if arg.ConnectionID.Valid {
|
||||
if row.ConnectionLog.ConnectionID == arg.ConnectionID &&
|
||||
row.ConnectionLog.WorkspaceID == arg.WorkspaceID &&
|
||||
row.ConnectionLog.AgentName == arg.AgentName {
|
||||
return row.ConnectionLog
|
||||
}
|
||||
} else if row.ConnectionLog.ID == arg.ID {
|
||||
return row.ConnectionLog
|
||||
}
|
||||
}
|
||||
require.Failf(t, "connection log not found", "id=%s", arg.ID)
|
||||
return database.ConnectionLog{} // unreachable
|
||||
}
|
||||
|
||||
func Template(t testing.TB, db database.Store, seed database.Template) database.Template {
|
||||
|
||||
@@ -208,6 +208,14 @@ func (m queryMetricsStore) BatchUpdateWorkspaceNextStartAt(ctx context.Context,
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) BatchUpsertConnectionLogs(ctx context.Context, arg database.BatchUpsertConnectionLogsParams) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.BatchUpsertConnectionLogs(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("BatchUpsertConnectionLogs").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "BatchUpsertConnectionLogs").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) BulkMarkNotificationMessagesFailed(ctx context.Context, arg database.BulkMarkNotificationMessagesFailedParams) (int64, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.BulkMarkNotificationMessagesFailed(ctx, arg)
|
||||
@@ -5024,14 +5032,6 @@ func (m queryMetricsStore) UpsertChatWorkspaceTTL(ctx context.Context, workspace
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpsertConnectionLog(ctx context.Context, arg database.UpsertConnectionLogParams) (database.ConnectionLog, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpsertConnectionLog(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpsertConnectionLog").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertConnectionLog").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpsertDefaultProxy(ctx context.Context, arg database.UpsertDefaultProxyParams) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.UpsertDefaultProxy(ctx, arg)
|
||||
|
||||
@@ -233,6 +233,20 @@ func (mr *MockStoreMockRecorder) BatchUpdateWorkspaceNextStartAt(ctx, arg any) *
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BatchUpdateWorkspaceNextStartAt", reflect.TypeOf((*MockStore)(nil).BatchUpdateWorkspaceNextStartAt), ctx, arg)
|
||||
}
|
||||
|
||||
// BatchUpsertConnectionLogs mocks base method.
|
||||
func (m *MockStore) BatchUpsertConnectionLogs(ctx context.Context, arg database.BatchUpsertConnectionLogsParams) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "BatchUpsertConnectionLogs", ctx, arg)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// BatchUpsertConnectionLogs indicates an expected call of BatchUpsertConnectionLogs.
|
||||
func (mr *MockStoreMockRecorder) BatchUpsertConnectionLogs(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BatchUpsertConnectionLogs", reflect.TypeOf((*MockStore)(nil).BatchUpsertConnectionLogs), ctx, arg)
|
||||
}
|
||||
|
||||
// BulkMarkNotificationMessagesFailed mocks base method.
|
||||
func (m *MockStore) BulkMarkNotificationMessagesFailed(ctx context.Context, arg database.BulkMarkNotificationMessagesFailedParams) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -9442,21 +9456,6 @@ func (mr *MockStoreMockRecorder) UpsertChatWorkspaceTTL(ctx, workspaceTtl any) *
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatWorkspaceTTL", reflect.TypeOf((*MockStore)(nil).UpsertChatWorkspaceTTL), ctx, workspaceTtl)
|
||||
}
|
||||
|
||||
// UpsertConnectionLog mocks base method.
|
||||
func (m *MockStore) UpsertConnectionLog(ctx context.Context, arg database.UpsertConnectionLogParams) (database.ConnectionLog, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpsertConnectionLog", ctx, arg)
|
||||
ret0, _ := ret[0].(database.ConnectionLog)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpsertConnectionLog indicates an expected call of UpsertConnectionLog.
|
||||
func (mr *MockStoreMockRecorder) UpsertConnectionLog(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertConnectionLog", reflect.TypeOf((*MockStore)(nil).UpsertConnectionLog), ctx, arg)
|
||||
}
|
||||
|
||||
// UpsertDefaultProxy mocks base method.
|
||||
func (m *MockStore) UpsertDefaultProxy(ctx context.Context, arg database.UpsertDefaultProxyParams) error {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/sqlc-dev/pqtype"
|
||||
"golang.org/x/exp/maps"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/xerrors"
|
||||
@@ -923,3 +924,28 @@ func WorkspaceIdentityFromWorkspace(w Workspace) WorkspaceIdentity {
|
||||
func (r GetWorkspaceAgentAndWorkspaceByIDRow) RBACObject() rbac.Object {
|
||||
return r.WorkspaceTable.RBACObject()
|
||||
}
|
||||
|
||||
// UpsertConnectionLogParams contains the parameters for upserting a
|
||||
// connection log entry. This struct is hand-maintained (not generated
|
||||
// by sqlc) because the single-row UpsertConnectionLog query was
|
||||
// removed in favor of BatchUpsertConnectionLogs, but the struct is
|
||||
// still used as the canonical connection log event type throughout
|
||||
// the codebase.
|
||||
type UpsertConnectionLogParams struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"`
|
||||
WorkspaceOwnerID uuid.UUID `db:"workspace_owner_id" json:"workspace_owner_id"`
|
||||
WorkspaceID uuid.UUID `db:"workspace_id" json:"workspace_id"`
|
||||
WorkspaceName string `db:"workspace_name" json:"workspace_name"`
|
||||
AgentName string `db:"agent_name" json:"agent_name"`
|
||||
Type ConnectionType `db:"type" json:"type"`
|
||||
Code sql.NullInt32 `db:"code" json:"code"`
|
||||
IP pqtype.Inet `db:"ip" json:"ip"`
|
||||
UserAgent sql.NullString `db:"user_agent" json:"user_agent"`
|
||||
UserID uuid.NullUUID `db:"user_id" json:"user_id"`
|
||||
SlugOrPort sql.NullString `db:"slug_or_port" json:"slug_or_port"`
|
||||
ConnectionID uuid.NullUUID `db:"connection_id" json:"connection_id"`
|
||||
DisconnectReason sql.NullString `db:"disconnect_reason" json:"disconnect_reason"`
|
||||
Time time.Time `db:"time" json:"time"`
|
||||
ConnectionStatus ConnectionStatus `db:"connection_status" json:"connection_status"`
|
||||
}
|
||||
|
||||
@@ -65,6 +65,7 @@ type sqlcQuerier interface {
|
||||
BatchUpdateWorkspaceAgentMetadata(ctx context.Context, arg BatchUpdateWorkspaceAgentMetadataParams) error
|
||||
BatchUpdateWorkspaceLastUsedAt(ctx context.Context, arg BatchUpdateWorkspaceLastUsedAtParams) error
|
||||
BatchUpdateWorkspaceNextStartAt(ctx context.Context, arg BatchUpdateWorkspaceNextStartAtParams) error
|
||||
BatchUpsertConnectionLogs(ctx context.Context, arg BatchUpsertConnectionLogsParams) error
|
||||
BulkMarkNotificationMessagesFailed(ctx context.Context, arg BulkMarkNotificationMessagesFailedParams) (int64, error)
|
||||
BulkMarkNotificationMessagesSent(ctx context.Context, arg BulkMarkNotificationMessagesSentParams) (int64, error)
|
||||
// Calculates the telemetry summary for a given provider, model, and client
|
||||
@@ -991,7 +992,6 @@ type sqlcQuerier interface {
|
||||
UpsertChatUsageLimitGroupOverride(ctx context.Context, arg UpsertChatUsageLimitGroupOverrideParams) (UpsertChatUsageLimitGroupOverrideRow, error)
|
||||
UpsertChatUsageLimitUserOverride(ctx context.Context, arg UpsertChatUsageLimitUserOverrideParams) (UpsertChatUsageLimitUserOverrideRow, error)
|
||||
UpsertChatWorkspaceTTL(ctx context.Context, workspaceTtl string) error
|
||||
UpsertConnectionLog(ctx context.Context, arg UpsertConnectionLogParams) (ConnectionLog, error)
|
||||
// The default proxy is implied and not actually stored in the database.
|
||||
// So we need to store it's configuration here for display purposes.
|
||||
// The functional values are immutable and controlled implicitly.
|
||||
|
||||
+482
-197
@@ -3566,9 +3566,11 @@ func connectionOnlyIDs[T database.ConnectionLog | database.GetConnectionLogsOffs
|
||||
return ids
|
||||
}
|
||||
|
||||
func TestUpsertConnectionLog(t *testing.T) {
|
||||
func TestBatchUpsertConnectionLogs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
createWorkspace := func(t *testing.T, db database.Store) database.WorkspaceTable {
|
||||
t.Helper()
|
||||
u := dbgen.User(t, db, database.User{})
|
||||
o := dbgen.Organization(t, db, database.Organization{})
|
||||
tpl := dbgen.Template(t, db, database.Template{
|
||||
@@ -3584,253 +3586,536 @@ func TestUpsertConnectionLog(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
// zeroTime is the sentinel value that the SQL treats as "no
|
||||
// connect/disconnect time provided".
|
||||
zeroTime := time.Time{}
|
||||
|
||||
defaultIP := pqtype.Inet{
|
||||
IPNet: net.IPNet{
|
||||
IP: net.IPv4(127, 0, 0, 1),
|
||||
Mask: net.IPv4Mask(255, 255, 255, 255),
|
||||
},
|
||||
Valid: true,
|
||||
}
|
||||
|
||||
t.Run("SingleConnect", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
ctx := context.Background()
|
||||
ws := createWorkspace(t, db)
|
||||
connID := uuid.New()
|
||||
connectTime := dbtime.Now()
|
||||
|
||||
err := db.BatchUpsertConnectionLogs(ctx, database.BatchUpsertConnectionLogsParams{
|
||||
ID: []uuid.UUID{uuid.New()},
|
||||
ConnectTime: []time.Time{connectTime},
|
||||
OrganizationID: []uuid.UUID{ws.OrganizationID},
|
||||
WorkspaceOwnerID: []uuid.UUID{ws.OwnerID},
|
||||
WorkspaceID: []uuid.UUID{ws.ID},
|
||||
WorkspaceName: []string{ws.Name},
|
||||
AgentName: []string{"agent"},
|
||||
Type: []database.ConnectionType{database.ConnectionTypeSsh},
|
||||
Code: []int32{0},
|
||||
CodeValid: []bool{false},
|
||||
Ip: []pqtype.Inet{defaultIP},
|
||||
UserAgent: []string{""},
|
||||
UserID: []uuid.UUID{uuid.Nil},
|
||||
SlugOrPort: []string{""},
|
||||
ConnectionID: []uuid.UUID{connID},
|
||||
DisconnectReason: []string{""},
|
||||
DisconnectTime: []time.Time{zeroTime},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
rows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{LimitOpt: 10})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, rows, 1)
|
||||
require.True(t, connectTime.Equal(rows[0].ConnectionLog.ConnectTime))
|
||||
require.False(t, rows[0].ConnectionLog.DisconnectTime.Valid,
|
||||
"disconnect_time should be NULL for a connect-only event")
|
||||
})
|
||||
|
||||
t.Run("ConnectThenDisconnect", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
ctx := context.Background()
|
||||
|
||||
ws := createWorkspace(t, db)
|
||||
|
||||
connectionID := uuid.New()
|
||||
agentName := "test-agent"
|
||||
|
||||
// 1. Insert a 'connect' event.
|
||||
connID := uuid.New()
|
||||
connectTime := dbtime.Now()
|
||||
connectParams := database.UpsertConnectionLogParams{
|
||||
ID: uuid.New(),
|
||||
Time: connectTime,
|
||||
OrganizationID: ws.OrganizationID,
|
||||
WorkspaceOwnerID: ws.OwnerID,
|
||||
WorkspaceID: ws.ID,
|
||||
WorkspaceName: ws.Name,
|
||||
AgentName: agentName,
|
||||
Type: database.ConnectionTypeSsh,
|
||||
ConnectionID: uuid.NullUUID{UUID: connectionID, Valid: true},
|
||||
ConnectionStatus: database.ConnectionStatusConnected,
|
||||
Ip: pqtype.Inet{
|
||||
IPNet: net.IPNet{
|
||||
IP: net.IPv4(127, 0, 0, 1),
|
||||
Mask: net.IPv4Mask(255, 255, 255, 255),
|
||||
},
|
||||
Valid: true,
|
||||
},
|
||||
}
|
||||
|
||||
log1, err := db.UpsertConnectionLog(ctx, connectParams)
|
||||
// Insert connect.
|
||||
err := db.BatchUpsertConnectionLogs(ctx, database.BatchUpsertConnectionLogsParams{
|
||||
ID: []uuid.UUID{uuid.New()},
|
||||
ConnectTime: []time.Time{connectTime},
|
||||
OrganizationID: []uuid.UUID{ws.OrganizationID},
|
||||
WorkspaceOwnerID: []uuid.UUID{ws.OwnerID},
|
||||
WorkspaceID: []uuid.UUID{ws.ID},
|
||||
WorkspaceName: []string{ws.Name},
|
||||
AgentName: []string{"agent"},
|
||||
Type: []database.ConnectionType{database.ConnectionTypeSsh},
|
||||
Code: []int32{0},
|
||||
CodeValid: []bool{false},
|
||||
Ip: []pqtype.Inet{defaultIP},
|
||||
UserAgent: []string{""},
|
||||
UserID: []uuid.UUID{uuid.Nil},
|
||||
SlugOrPort: []string{""},
|
||||
ConnectionID: []uuid.UUID{connID},
|
||||
DisconnectReason: []string{""},
|
||||
DisconnectTime: []time.Time{zeroTime},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Insert disconnect for same connection.
|
||||
disconnectTime := connectTime.Add(time.Second)
|
||||
err = db.BatchUpsertConnectionLogs(ctx, database.BatchUpsertConnectionLogsParams{
|
||||
ID: []uuid.UUID{uuid.New()},
|
||||
ConnectTime: []time.Time{zeroTime},
|
||||
OrganizationID: []uuid.UUID{ws.OrganizationID},
|
||||
WorkspaceOwnerID: []uuid.UUID{ws.OwnerID},
|
||||
WorkspaceID: []uuid.UUID{ws.ID},
|
||||
WorkspaceName: []string{ws.Name},
|
||||
AgentName: []string{"agent"},
|
||||
Type: []database.ConnectionType{database.ConnectionTypeSsh},
|
||||
Code: []int32{1},
|
||||
CodeValid: []bool{true},
|
||||
Ip: []pqtype.Inet{defaultIP},
|
||||
UserAgent: []string{""},
|
||||
UserID: []uuid.UUID{uuid.Nil},
|
||||
SlugOrPort: []string{""},
|
||||
ConnectionID: []uuid.UUID{connID},
|
||||
DisconnectReason: []string{"test disconnect"},
|
||||
DisconnectTime: []time.Time{disconnectTime},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, connectParams.ID, log1.ID)
|
||||
require.False(t, log1.DisconnectTime.Valid, "DisconnectTime should not be set on connect")
|
||||
|
||||
// Check that one row exists.
|
||||
rows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{LimitOpt: 10})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, rows, 1)
|
||||
|
||||
// 2. Insert a 'disconnected' event for the same connection.
|
||||
disconnectTime := connectTime.Add(time.Second)
|
||||
disconnectParams := database.UpsertConnectionLogParams{
|
||||
ConnectionID: uuid.NullUUID{UUID: connectionID, Valid: true},
|
||||
WorkspaceID: ws.ID,
|
||||
AgentName: agentName,
|
||||
ConnectionStatus: database.ConnectionStatusDisconnected,
|
||||
|
||||
// Updated to:
|
||||
Time: disconnectTime,
|
||||
DisconnectReason: sql.NullString{String: "test disconnect", Valid: true},
|
||||
Code: sql.NullInt32{Int32: 1, Valid: true},
|
||||
|
||||
// Ignored
|
||||
ID: uuid.New(),
|
||||
OrganizationID: ws.OrganizationID,
|
||||
WorkspaceOwnerID: ws.OwnerID,
|
||||
WorkspaceName: ws.Name,
|
||||
Type: database.ConnectionTypeSsh,
|
||||
Ip: pqtype.Inet{
|
||||
IPNet: net.IPNet{
|
||||
IP: net.IPv4(127, 0, 0, 1),
|
||||
Mask: net.IPv4Mask(255, 255, 255, 254),
|
||||
},
|
||||
Valid: true,
|
||||
},
|
||||
}
|
||||
|
||||
log2, err := db.UpsertConnectionLog(ctx, disconnectParams)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Updated
|
||||
require.Equal(t, log1.ID, log2.ID)
|
||||
require.True(t, log2.DisconnectTime.Valid)
|
||||
require.True(t, disconnectTime.Equal(log2.DisconnectTime.Time))
|
||||
require.Equal(t, disconnectParams.DisconnectReason.String, log2.DisconnectReason.String)
|
||||
|
||||
rows, err = db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, rows, 1)
|
||||
row := rows[0].ConnectionLog
|
||||
require.True(t, connectTime.Equal(row.ConnectTime))
|
||||
require.True(t, row.DisconnectTime.Valid)
|
||||
require.True(t, disconnectTime.Equal(row.DisconnectTime.Time))
|
||||
require.Equal(t, "test disconnect", row.DisconnectReason.String)
|
||||
require.Equal(t, int32(1), row.Code.Int32)
|
||||
})
|
||||
|
||||
t.Run("ConnectDoesNotUpdate", func(t *testing.T) {
|
||||
t.Run("DuplicateConnectIsNoOp", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
ctx := context.Background()
|
||||
|
||||
ws := createWorkspace(t, db)
|
||||
|
||||
connectionID := uuid.New()
|
||||
agentName := "test-agent"
|
||||
|
||||
// 1. Insert a 'connect' event.
|
||||
connID := uuid.New()
|
||||
connectTime := dbtime.Now()
|
||||
connectParams := database.UpsertConnectionLogParams{
|
||||
ID: uuid.New(),
|
||||
Time: connectTime,
|
||||
OrganizationID: ws.OrganizationID,
|
||||
WorkspaceOwnerID: ws.OwnerID,
|
||||
WorkspaceID: ws.ID,
|
||||
WorkspaceName: ws.Name,
|
||||
AgentName: agentName,
|
||||
Type: database.ConnectionTypeSsh,
|
||||
ConnectionID: uuid.NullUUID{UUID: connectionID, Valid: true},
|
||||
ConnectionStatus: database.ConnectionStatusConnected,
|
||||
Ip: pqtype.Inet{
|
||||
IPNet: net.IPNet{
|
||||
IP: net.IPv4(127, 0, 0, 1),
|
||||
Mask: net.IPv4Mask(255, 255, 255, 255),
|
||||
},
|
||||
Valid: true,
|
||||
},
|
||||
|
||||
mkParams := func(ct time.Time, ip pqtype.Inet) database.BatchUpsertConnectionLogsParams {
|
||||
return database.BatchUpsertConnectionLogsParams{
|
||||
ID: []uuid.UUID{uuid.New()},
|
||||
ConnectTime: []time.Time{ct},
|
||||
OrganizationID: []uuid.UUID{ws.OrganizationID},
|
||||
WorkspaceOwnerID: []uuid.UUID{ws.OwnerID},
|
||||
WorkspaceID: []uuid.UUID{ws.ID},
|
||||
WorkspaceName: []string{ws.Name},
|
||||
AgentName: []string{"agent"},
|
||||
Type: []database.ConnectionType{database.ConnectionTypeSsh},
|
||||
Code: []int32{0},
|
||||
CodeValid: []bool{false},
|
||||
Ip: []pqtype.Inet{ip},
|
||||
UserAgent: []string{""},
|
||||
UserID: []uuid.UUID{uuid.Nil},
|
||||
SlugOrPort: []string{""},
|
||||
ConnectionID: []uuid.UUID{connID},
|
||||
DisconnectReason: []string{""},
|
||||
DisconnectTime: []time.Time{zeroTime},
|
||||
}
|
||||
}
|
||||
|
||||
log, err := db.UpsertConnectionLog(ctx, connectParams)
|
||||
err := db.BatchUpsertConnectionLogs(ctx, mkParams(connectTime, defaultIP))
|
||||
require.NoError(t, err)
|
||||
|
||||
// 2. Insert another 'connect' event for the same connection.
|
||||
connectTime2 := connectTime.Add(time.Second)
|
||||
connectParams2 := database.UpsertConnectionLogParams{
|
||||
ConnectionID: uuid.NullUUID{UUID: connectionID, Valid: true},
|
||||
WorkspaceID: ws.ID,
|
||||
AgentName: agentName,
|
||||
ConnectionStatus: database.ConnectionStatusConnected,
|
||||
rows1, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{LimitOpt: 10})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, rows1, 1)
|
||||
|
||||
// Ignored
|
||||
ID: uuid.New(),
|
||||
Time: connectTime2,
|
||||
OrganizationID: ws.OrganizationID,
|
||||
WorkspaceOwnerID: ws.OwnerID,
|
||||
WorkspaceName: ws.Name,
|
||||
Type: database.ConnectionTypeSsh,
|
||||
Code: sql.NullInt32{Int32: 0, Valid: false},
|
||||
Ip: pqtype.Inet{
|
||||
IPNet: net.IPNet{
|
||||
IP: net.IPv4(127, 0, 0, 1),
|
||||
Mask: net.IPv4Mask(255, 255, 255, 254),
|
||||
},
|
||||
Valid: true,
|
||||
// Second connect with later time and different IP.
|
||||
otherIP := pqtype.Inet{
|
||||
IPNet: net.IPNet{
|
||||
IP: net.IPv4(10, 0, 0, 1),
|
||||
Mask: net.IPv4Mask(255, 255, 255, 255),
|
||||
},
|
||||
Valid: true,
|
||||
}
|
||||
|
||||
origLog, err := db.UpsertConnectionLog(ctx, connectParams2)
|
||||
err = db.BatchUpsertConnectionLogs(ctx, mkParams(connectTime.Add(time.Second), otherIP))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, log, origLog, "connect update should be a no-op")
|
||||
|
||||
// Check that still only one row exists.
|
||||
rows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{})
|
||||
rows2, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{LimitOpt: 10})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, rows, 1)
|
||||
require.Equal(t, log, rows[0].ConnectionLog)
|
||||
require.Len(t, rows2, 1)
|
||||
|
||||
// The LEAST logic should pick the earlier connect_time; IP and
|
||||
// other fields are not updated on conflict.
|
||||
require.True(t, connectTime.Equal(rows2[0].ConnectionLog.ConnectTime),
|
||||
"connect_time should remain the original (earlier) value")
|
||||
})
|
||||
|
||||
t.Run("DisconnectThenConnect", func(t *testing.T) {
|
||||
t.Run("OrderIndependentConnectTime", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
ctx := context.Background()
|
||||
|
||||
ws := createWorkspace(t, db)
|
||||
|
||||
connectionID := uuid.New()
|
||||
agentName := "test-agent"
|
||||
|
||||
// Insert just a 'disconect' event
|
||||
connID := uuid.New()
|
||||
disconnectTime := dbtime.Now()
|
||||
disconnectParams := database.UpsertConnectionLogParams{
|
||||
ID: uuid.New(),
|
||||
Time: disconnectTime,
|
||||
OrganizationID: ws.OrganizationID,
|
||||
WorkspaceOwnerID: ws.OwnerID,
|
||||
WorkspaceID: ws.ID,
|
||||
WorkspaceName: ws.Name,
|
||||
AgentName: agentName,
|
||||
Type: database.ConnectionTypeSsh,
|
||||
ConnectionID: uuid.NullUUID{UUID: connectionID, Valid: true},
|
||||
ConnectionStatus: database.ConnectionStatusDisconnected,
|
||||
DisconnectReason: sql.NullString{String: "server shutting down", Valid: true},
|
||||
Ip: pqtype.Inet{
|
||||
IPNet: net.IPNet{
|
||||
IP: net.IPv4(127, 0, 0, 1),
|
||||
Mask: net.IPv4Mask(255, 255, 255, 255),
|
||||
},
|
||||
Valid: true,
|
||||
},
|
||||
connectTime := disconnectTime.Add(-5 * time.Second)
|
||||
|
||||
// Disconnect arrives first.
|
||||
err := db.BatchUpsertConnectionLogs(ctx, database.BatchUpsertConnectionLogsParams{
|
||||
ID: []uuid.UUID{uuid.New()},
|
||||
ConnectTime: []time.Time{disconnectTime},
|
||||
OrganizationID: []uuid.UUID{ws.OrganizationID},
|
||||
WorkspaceOwnerID: []uuid.UUID{ws.OwnerID},
|
||||
WorkspaceID: []uuid.UUID{ws.ID},
|
||||
WorkspaceName: []string{ws.Name},
|
||||
AgentName: []string{"agent"},
|
||||
Type: []database.ConnectionType{database.ConnectionTypeSsh},
|
||||
Code: []int32{0},
|
||||
CodeValid: []bool{true},
|
||||
Ip: []pqtype.Inet{defaultIP},
|
||||
UserAgent: []string{""},
|
||||
UserID: []uuid.UUID{uuid.Nil},
|
||||
SlugOrPort: []string{""},
|
||||
ConnectionID: []uuid.UUID{connID},
|
||||
DisconnectReason: []string{"bye"},
|
||||
DisconnectTime: []time.Time{disconnectTime},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Connect arrives second with the real (earlier) connect_time.
|
||||
err = db.BatchUpsertConnectionLogs(ctx, database.BatchUpsertConnectionLogsParams{
|
||||
ID: []uuid.UUID{uuid.New()},
|
||||
ConnectTime: []time.Time{connectTime},
|
||||
OrganizationID: []uuid.UUID{ws.OrganizationID},
|
||||
WorkspaceOwnerID: []uuid.UUID{ws.OwnerID},
|
||||
WorkspaceID: []uuid.UUID{ws.ID},
|
||||
WorkspaceName: []string{ws.Name},
|
||||
AgentName: []string{"agent"},
|
||||
Type: []database.ConnectionType{database.ConnectionTypeSsh},
|
||||
Code: []int32{0},
|
||||
CodeValid: []bool{false},
|
||||
Ip: []pqtype.Inet{defaultIP},
|
||||
UserAgent: []string{""},
|
||||
UserID: []uuid.UUID{uuid.Nil},
|
||||
SlugOrPort: []string{""},
|
||||
ConnectionID: []uuid.UUID{connID},
|
||||
DisconnectReason: []string{""},
|
||||
DisconnectTime: []time.Time{zeroTime},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
rows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{LimitOpt: 10})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, rows, 1)
|
||||
require.True(t, connectTime.Equal(rows[0].ConnectionLog.ConnectTime),
|
||||
"LEAST should pick the earlier connect_time")
|
||||
})
|
||||
|
||||
t.Run("DisconnectFieldsAreWriteOnce", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
ctx := context.Background()
|
||||
ws := createWorkspace(t, db)
|
||||
connID := uuid.New()
|
||||
disconnectTime := dbtime.Now()
|
||||
|
||||
mkDisconnect := func(reason string, code int32) database.BatchUpsertConnectionLogsParams {
|
||||
return database.BatchUpsertConnectionLogsParams{
|
||||
ID: []uuid.UUID{uuid.New()},
|
||||
ConnectTime: []time.Time{disconnectTime},
|
||||
OrganizationID: []uuid.UUID{ws.OrganizationID},
|
||||
WorkspaceOwnerID: []uuid.UUID{ws.OwnerID},
|
||||
WorkspaceID: []uuid.UUID{ws.ID},
|
||||
WorkspaceName: []string{ws.Name},
|
||||
AgentName: []string{"agent"},
|
||||
Type: []database.ConnectionType{database.ConnectionTypeSsh},
|
||||
Code: []int32{code},
|
||||
CodeValid: []bool{true},
|
||||
Ip: []pqtype.Inet{defaultIP},
|
||||
UserAgent: []string{""},
|
||||
UserID: []uuid.UUID{uuid.Nil},
|
||||
SlugOrPort: []string{""},
|
||||
ConnectionID: []uuid.UUID{connID},
|
||||
DisconnectReason: []string{reason},
|
||||
DisconnectTime: []time.Time{disconnectTime},
|
||||
}
|
||||
}
|
||||
|
||||
_, err := db.UpsertConnectionLog(ctx, disconnectParams)
|
||||
err := db.BatchUpsertConnectionLogs(ctx, mkDisconnect("first reason", 1))
|
||||
require.NoError(t, err)
|
||||
|
||||
firstRows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{})
|
||||
// Second disconnect with different reason and code.
|
||||
err = db.BatchUpsertConnectionLogs(ctx, mkDisconnect("second reason", 2))
|
||||
require.NoError(t, err)
|
||||
require.Len(t, firstRows, 1)
|
||||
|
||||
// We expect the connection event to be marked as closed with the start
|
||||
// and close time being the same.
|
||||
require.True(t, firstRows[0].ConnectionLog.DisconnectTime.Valid)
|
||||
require.Equal(t, disconnectTime, firstRows[0].ConnectionLog.DisconnectTime.Time.UTC())
|
||||
require.Equal(t, firstRows[0].ConnectionLog.ConnectTime.UTC(), firstRows[0].ConnectionLog.DisconnectTime.Time.UTC())
|
||||
rows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{LimitOpt: 10})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, rows, 1)
|
||||
row := rows[0].ConnectionLog
|
||||
require.Equal(t, "first reason", row.DisconnectReason.String,
|
||||
"disconnect_reason should not be overwritten")
|
||||
require.Equal(t, int32(1), row.Code.Int32,
|
||||
"code should not be overwritten")
|
||||
})
|
||||
|
||||
// Now insert a 'connect' event for the same connection.
|
||||
// This should be a no op
|
||||
connectTime := disconnectTime.Add(time.Second)
|
||||
connectParams := database.UpsertConnectionLogParams{
|
||||
ID: uuid.New(),
|
||||
Time: connectTime,
|
||||
OrganizationID: ws.OrganizationID,
|
||||
WorkspaceOwnerID: ws.OwnerID,
|
||||
WorkspaceID: ws.ID,
|
||||
WorkspaceName: ws.Name,
|
||||
AgentName: agentName,
|
||||
Type: database.ConnectionTypeSsh,
|
||||
ConnectionID: uuid.NullUUID{UUID: connectionID, Valid: true},
|
||||
ConnectionStatus: database.ConnectionStatusConnected,
|
||||
DisconnectReason: sql.NullString{String: "reconnected", Valid: true},
|
||||
Code: sql.NullInt32{Int32: 0, Valid: false},
|
||||
Ip: pqtype.Inet{
|
||||
IPNet: net.IPNet{
|
||||
IP: net.IPv4(127, 0, 0, 1),
|
||||
Mask: net.IPv4Mask(255, 255, 255, 255),
|
||||
},
|
||||
Valid: true,
|
||||
},
|
||||
t.Run("ConnectAfterDisconnectIsNoOp", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
ctx := context.Background()
|
||||
ws := createWorkspace(t, db)
|
||||
connID := uuid.New()
|
||||
disconnectTime := dbtime.Now()
|
||||
|
||||
// Insert disconnect first.
|
||||
err := db.BatchUpsertConnectionLogs(ctx, database.BatchUpsertConnectionLogsParams{
|
||||
ID: []uuid.UUID{uuid.New()},
|
||||
ConnectTime: []time.Time{disconnectTime},
|
||||
OrganizationID: []uuid.UUID{ws.OrganizationID},
|
||||
WorkspaceOwnerID: []uuid.UUID{ws.OwnerID},
|
||||
WorkspaceID: []uuid.UUID{ws.ID},
|
||||
WorkspaceName: []string{ws.Name},
|
||||
AgentName: []string{"agent"},
|
||||
Type: []database.ConnectionType{database.ConnectionTypeSsh},
|
||||
Code: []int32{42},
|
||||
CodeValid: []bool{true},
|
||||
Ip: []pqtype.Inet{defaultIP},
|
||||
UserAgent: []string{""},
|
||||
UserID: []uuid.UUID{uuid.Nil},
|
||||
SlugOrPort: []string{""},
|
||||
ConnectionID: []uuid.UUID{connID},
|
||||
DisconnectReason: []string{"server shutdown"},
|
||||
DisconnectTime: []time.Time{disconnectTime},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
rows1, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{LimitOpt: 10})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, rows1, 1)
|
||||
require.True(t, rows1[0].ConnectionLog.DisconnectTime.Valid)
|
||||
require.Equal(t, "server shutdown", rows1[0].ConnectionLog.DisconnectReason.String)
|
||||
require.Equal(t, int32(42), rows1[0].ConnectionLog.Code.Int32)
|
||||
|
||||
// Insert connect for same connection_id.
|
||||
err = db.BatchUpsertConnectionLogs(ctx, database.BatchUpsertConnectionLogsParams{
|
||||
ID: []uuid.UUID{uuid.New()},
|
||||
ConnectTime: []time.Time{disconnectTime.Add(time.Second)},
|
||||
OrganizationID: []uuid.UUID{ws.OrganizationID},
|
||||
WorkspaceOwnerID: []uuid.UUID{ws.OwnerID},
|
||||
WorkspaceID: []uuid.UUID{ws.ID},
|
||||
WorkspaceName: []string{ws.Name},
|
||||
AgentName: []string{"agent"},
|
||||
Type: []database.ConnectionType{database.ConnectionTypeSsh},
|
||||
Code: []int32{0},
|
||||
CodeValid: []bool{false},
|
||||
Ip: []pqtype.Inet{defaultIP},
|
||||
UserAgent: []string{""},
|
||||
UserID: []uuid.UUID{uuid.Nil},
|
||||
SlugOrPort: []string{""},
|
||||
ConnectionID: []uuid.UUID{connID},
|
||||
DisconnectReason: []string{""},
|
||||
DisconnectTime: []time.Time{zeroTime},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
rows2, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{LimitOpt: 10})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, rows2, 1)
|
||||
row := rows2[0].ConnectionLog
|
||||
require.True(t, row.DisconnectTime.Valid,
|
||||
"disconnect_time should not be cleared by a later connect")
|
||||
require.Equal(t, "server shutdown", row.DisconnectReason.String,
|
||||
"disconnect_reason should not be cleared")
|
||||
require.Equal(t, int32(42), row.Code.Int32,
|
||||
"code should not be cleared")
|
||||
})
|
||||
|
||||
t.Run("CodeZeroPreserved", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
ctx := context.Background()
|
||||
ws := createWorkspace(t, db)
|
||||
connID := uuid.New()
|
||||
now := dbtime.Now()
|
||||
|
||||
err := db.BatchUpsertConnectionLogs(ctx, database.BatchUpsertConnectionLogsParams{
|
||||
ID: []uuid.UUID{uuid.New()},
|
||||
ConnectTime: []time.Time{now},
|
||||
OrganizationID: []uuid.UUID{ws.OrganizationID},
|
||||
WorkspaceOwnerID: []uuid.UUID{ws.OwnerID},
|
||||
WorkspaceID: []uuid.UUID{ws.ID},
|
||||
WorkspaceName: []string{ws.Name},
|
||||
AgentName: []string{"agent"},
|
||||
Type: []database.ConnectionType{database.ConnectionTypeSsh},
|
||||
Code: []int32{0},
|
||||
CodeValid: []bool{true},
|
||||
Ip: []pqtype.Inet{defaultIP},
|
||||
UserAgent: []string{""},
|
||||
UserID: []uuid.UUID{uuid.Nil},
|
||||
SlugOrPort: []string{""},
|
||||
ConnectionID: []uuid.UUID{connID},
|
||||
DisconnectReason: []string{"normal"},
|
||||
DisconnectTime: []time.Time{now},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
rows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{LimitOpt: 10})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, rows, 1)
|
||||
require.True(t, rows[0].ConnectionLog.Code.Valid, "code should be non-NULL")
|
||||
require.Equal(t, int32(0), rows[0].ConnectionLog.Code.Int32,
|
||||
"code=0 should be preserved, not treated as NULL")
|
||||
})
|
||||
|
||||
t.Run("CodeNullWhenInvalid", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
ctx := context.Background()
|
||||
ws := createWorkspace(t, db)
|
||||
connID := uuid.New()
|
||||
now := dbtime.Now()
|
||||
|
||||
err := db.BatchUpsertConnectionLogs(ctx, database.BatchUpsertConnectionLogsParams{
|
||||
ID: []uuid.UUID{uuid.New()},
|
||||
ConnectTime: []time.Time{now},
|
||||
OrganizationID: []uuid.UUID{ws.OrganizationID},
|
||||
WorkspaceOwnerID: []uuid.UUID{ws.OwnerID},
|
||||
WorkspaceID: []uuid.UUID{ws.ID},
|
||||
WorkspaceName: []string{ws.Name},
|
||||
AgentName: []string{"agent"},
|
||||
Type: []database.ConnectionType{database.ConnectionTypeSsh},
|
||||
Code: []int32{99},
|
||||
CodeValid: []bool{false},
|
||||
Ip: []pqtype.Inet{defaultIP},
|
||||
UserAgent: []string{""},
|
||||
UserID: []uuid.UUID{uuid.Nil},
|
||||
SlugOrPort: []string{""},
|
||||
ConnectionID: []uuid.UUID{connID},
|
||||
DisconnectReason: []string{""},
|
||||
DisconnectTime: []time.Time{zeroTime},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
rows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{LimitOpt: 10})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, rows, 1)
|
||||
require.False(t, rows[0].ConnectionLog.Code.Valid,
|
||||
"code should be NULL when code_valid is false")
|
||||
})
|
||||
|
||||
t.Run("NullConnectionIDEvents", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
ctx := context.Background()
|
||||
ws := createWorkspace(t, db)
|
||||
now := dbtime.Now()
|
||||
|
||||
// Insert two web events with NULL connection_id (uuid.Nil →
|
||||
// NULL via NULLIF) for the same workspace/agent.
|
||||
for i := range 2 {
|
||||
err := db.BatchUpsertConnectionLogs(ctx, database.BatchUpsertConnectionLogsParams{
|
||||
ID: []uuid.UUID{uuid.New()},
|
||||
ConnectTime: []time.Time{now.Add(time.Duration(i) * time.Second)},
|
||||
OrganizationID: []uuid.UUID{ws.OrganizationID},
|
||||
WorkspaceOwnerID: []uuid.UUID{ws.OwnerID},
|
||||
WorkspaceID: []uuid.UUID{ws.ID},
|
||||
WorkspaceName: []string{ws.Name},
|
||||
AgentName: []string{"agent"},
|
||||
Type: []database.ConnectionType{database.ConnectionTypeSsh},
|
||||
Code: []int32{200},
|
||||
CodeValid: []bool{true},
|
||||
Ip: []pqtype.Inet{defaultIP},
|
||||
UserAgent: []string{"Mozilla/5.0"},
|
||||
UserID: []uuid.UUID{uuid.Nil},
|
||||
SlugOrPort: []string{"web-terminal"},
|
||||
ConnectionID: []uuid.UUID{uuid.Nil},
|
||||
DisconnectReason: []string{""},
|
||||
DisconnectTime: []time.Time{zeroTime},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
_, err = db.UpsertConnectionLog(ctx, connectParams)
|
||||
rows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{LimitOpt: 10})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, rows, 2,
|
||||
"NULL connection_id rows should not conflict with each other")
|
||||
})
|
||||
|
||||
secondRows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, secondRows, 1)
|
||||
require.Equal(t, firstRows, secondRows)
|
||||
t.Run("MultipleIndependentConnections", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
ctx := context.Background()
|
||||
ws := createWorkspace(t, db)
|
||||
now := dbtime.Now()
|
||||
|
||||
// Upsert a disconnection, which should also be a no op
|
||||
disconnectParams.DisconnectReason = sql.NullString{
|
||||
String: "updated close reason",
|
||||
Valid: true,
|
||||
n := 5
|
||||
ids := make([]uuid.UUID, n)
|
||||
connectTimes := make([]time.Time, n)
|
||||
orgIDs := make([]uuid.UUID, n)
|
||||
ownerIDs := make([]uuid.UUID, n)
|
||||
wsIDs := make([]uuid.UUID, n)
|
||||
wsNames := make([]string, n)
|
||||
agentNames := make([]string, n)
|
||||
types := make([]database.ConnectionType, n)
|
||||
codes := make([]int32, n)
|
||||
codeValids := make([]bool, n)
|
||||
ips := make([]pqtype.Inet, n)
|
||||
userAgents := make([]string, n)
|
||||
userIDs := make([]uuid.UUID, n)
|
||||
slugOrPorts := make([]string, n)
|
||||
connIDs := make([]uuid.UUID, n)
|
||||
disconnectReasons := make([]string, n)
|
||||
disconnectTimes := make([]time.Time, n)
|
||||
|
||||
for i := range n {
|
||||
ids[i] = uuid.New()
|
||||
connectTimes[i] = now.Add(time.Duration(i) * time.Second)
|
||||
orgIDs[i] = ws.OrganizationID
|
||||
ownerIDs[i] = ws.OwnerID
|
||||
wsIDs[i] = ws.ID
|
||||
wsNames[i] = ws.Name
|
||||
agentNames[i] = "agent"
|
||||
types[i] = database.ConnectionTypeSsh
|
||||
codes[i] = 0
|
||||
codeValids[i] = false
|
||||
ips[i] = defaultIP
|
||||
userAgents[i] = ""
|
||||
userIDs[i] = uuid.Nil
|
||||
slugOrPorts[i] = ""
|
||||
connIDs[i] = uuid.New()
|
||||
disconnectReasons[i] = ""
|
||||
disconnectTimes[i] = zeroTime
|
||||
}
|
||||
_, err = db.UpsertConnectionLog(ctx, disconnectParams)
|
||||
|
||||
err := db.BatchUpsertConnectionLogs(ctx, database.BatchUpsertConnectionLogsParams{
|
||||
ID: ids,
|
||||
ConnectTime: connectTimes,
|
||||
OrganizationID: orgIDs,
|
||||
WorkspaceOwnerID: ownerIDs,
|
||||
WorkspaceID: wsIDs,
|
||||
WorkspaceName: wsNames,
|
||||
AgentName: agentNames,
|
||||
Type: types,
|
||||
Code: codes,
|
||||
CodeValid: codeValids,
|
||||
Ip: ips,
|
||||
UserAgent: userAgents,
|
||||
UserID: userIDs,
|
||||
SlugOrPort: slugOrPorts,
|
||||
ConnectionID: connIDs,
|
||||
DisconnectReason: disconnectReasons,
|
||||
DisconnectTime: disconnectTimes,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
thirdRows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{})
|
||||
|
||||
rows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{LimitOpt: 10})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, secondRows, 1)
|
||||
// The close reason shouldn't be updated
|
||||
require.Equal(t, secondRows, thirdRows)
|
||||
require.Len(t, rows, n, "each unique connection_id should produce its own row")
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
+117
-114
@@ -7338,6 +7338,123 @@ func (q *sqlQuerier) UpsertChatUsageLimitUserOverride(ctx context.Context, arg U
|
||||
return i, err
|
||||
}
|
||||
|
||||
const batchUpsertConnectionLogs = `-- name: BatchUpsertConnectionLogs :exec
|
||||
INSERT INTO connection_logs (
|
||||
id, connect_time, organization_id, workspace_owner_id, workspace_id,
|
||||
workspace_name, agent_name, type, code, ip, user_agent, user_id,
|
||||
slug_or_port, connection_id, disconnect_reason, disconnect_time
|
||||
)
|
||||
SELECT
|
||||
u.id,
|
||||
u.connect_time,
|
||||
u.organization_id,
|
||||
u.workspace_owner_id,
|
||||
u.workspace_id,
|
||||
u.workspace_name,
|
||||
u.agent_name,
|
||||
u.type,
|
||||
-- Use the validity flag to distinguish "no code" (NULL) from a
|
||||
-- legitimate zero exit code.
|
||||
CASE WHEN u.code_valid THEN u.code ELSE NULL END,
|
||||
u.ip,
|
||||
NULLIF(u.user_agent, ''),
|
||||
NULLIF(u.user_id, '00000000-0000-0000-0000-000000000000'::uuid),
|
||||
NULLIF(u.slug_or_port, ''),
|
||||
NULLIF(u.connection_id, '00000000-0000-0000-0000-000000000000'::uuid),
|
||||
NULLIF(u.disconnect_reason, ''),
|
||||
NULLIF(u.disconnect_time, '0001-01-01 00:00:00Z'::timestamptz)
|
||||
FROM (
|
||||
SELECT
|
||||
unnest($1::uuid[]) AS id,
|
||||
unnest($2::timestamptz[]) AS connect_time,
|
||||
unnest($3::uuid[]) AS organization_id,
|
||||
unnest($4::uuid[]) AS workspace_owner_id,
|
||||
unnest($5::uuid[]) AS workspace_id,
|
||||
unnest($6::text[]) AS workspace_name,
|
||||
unnest($7::text[]) AS agent_name,
|
||||
unnest($8::connection_type[]) AS type,
|
||||
unnest($9::int4[]) AS code,
|
||||
unnest($10::bool[]) AS code_valid,
|
||||
unnest($11::inet[]) AS ip,
|
||||
unnest($12::text[]) AS user_agent,
|
||||
unnest($13::uuid[]) AS user_id,
|
||||
unnest($14::text[]) AS slug_or_port,
|
||||
unnest($15::uuid[]) AS connection_id,
|
||||
unnest($16::text[]) AS disconnect_reason,
|
||||
unnest($17::timestamptz[]) AS disconnect_time
|
||||
) AS u
|
||||
ON CONFLICT (connection_id, workspace_id, agent_name)
|
||||
DO UPDATE SET
|
||||
-- Pick the earliest real connect_time. The zero sentinel
|
||||
-- ('0001-01-01') means the batch didn't know the connect_time
|
||||
-- (e.g. a pure disconnect event), so we keep the existing value.
|
||||
connect_time = CASE
|
||||
WHEN EXCLUDED.connect_time = '0001-01-01 00:00:00Z'::timestamptz
|
||||
THEN connection_logs.connect_time
|
||||
WHEN connection_logs.connect_time = '0001-01-01 00:00:00Z'::timestamptz
|
||||
THEN EXCLUDED.connect_time
|
||||
ELSE LEAST(connection_logs.connect_time, EXCLUDED.connect_time)
|
||||
END,
|
||||
disconnect_time = CASE
|
||||
WHEN connection_logs.disconnect_time IS NULL
|
||||
THEN EXCLUDED.disconnect_time
|
||||
ELSE connection_logs.disconnect_time
|
||||
END,
|
||||
disconnect_reason = CASE
|
||||
WHEN connection_logs.disconnect_reason IS NULL
|
||||
THEN EXCLUDED.disconnect_reason
|
||||
ELSE connection_logs.disconnect_reason
|
||||
END,
|
||||
code = CASE
|
||||
WHEN connection_logs.code IS NULL
|
||||
THEN EXCLUDED.code
|
||||
ELSE connection_logs.code
|
||||
END
|
||||
`
|
||||
|
||||
type BatchUpsertConnectionLogsParams struct {
|
||||
ID []uuid.UUID `db:"id" json:"id"`
|
||||
ConnectTime []time.Time `db:"connect_time" json:"connect_time"`
|
||||
OrganizationID []uuid.UUID `db:"organization_id" json:"organization_id"`
|
||||
WorkspaceOwnerID []uuid.UUID `db:"workspace_owner_id" json:"workspace_owner_id"`
|
||||
WorkspaceID []uuid.UUID `db:"workspace_id" json:"workspace_id"`
|
||||
WorkspaceName []string `db:"workspace_name" json:"workspace_name"`
|
||||
AgentName []string `db:"agent_name" json:"agent_name"`
|
||||
Type []ConnectionType `db:"type" json:"type"`
|
||||
Code []int32 `db:"code" json:"code"`
|
||||
CodeValid []bool `db:"code_valid" json:"code_valid"`
|
||||
Ip []pqtype.Inet `db:"ip" json:"ip"`
|
||||
UserAgent []string `db:"user_agent" json:"user_agent"`
|
||||
UserID []uuid.UUID `db:"user_id" json:"user_id"`
|
||||
SlugOrPort []string `db:"slug_or_port" json:"slug_or_port"`
|
||||
ConnectionID []uuid.UUID `db:"connection_id" json:"connection_id"`
|
||||
DisconnectReason []string `db:"disconnect_reason" json:"disconnect_reason"`
|
||||
DisconnectTime []time.Time `db:"disconnect_time" json:"disconnect_time"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) BatchUpsertConnectionLogs(ctx context.Context, arg BatchUpsertConnectionLogsParams) error {
|
||||
_, err := q.db.ExecContext(ctx, batchUpsertConnectionLogs,
|
||||
pq.Array(arg.ID),
|
||||
pq.Array(arg.ConnectTime),
|
||||
pq.Array(arg.OrganizationID),
|
||||
pq.Array(arg.WorkspaceOwnerID),
|
||||
pq.Array(arg.WorkspaceID),
|
||||
pq.Array(arg.WorkspaceName),
|
||||
pq.Array(arg.AgentName),
|
||||
pq.Array(arg.Type),
|
||||
pq.Array(arg.Code),
|
||||
pq.Array(arg.CodeValid),
|
||||
pq.Array(arg.Ip),
|
||||
pq.Array(arg.UserAgent),
|
||||
pq.Array(arg.UserID),
|
||||
pq.Array(arg.SlugOrPort),
|
||||
pq.Array(arg.ConnectionID),
|
||||
pq.Array(arg.DisconnectReason),
|
||||
pq.Array(arg.DisconnectTime),
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
const countConnectionLogs = `-- name: CountConnectionLogs :one
|
||||
SELECT
|
||||
COUNT(*) AS count
|
||||
@@ -7753,120 +7870,6 @@ func (q *sqlQuerier) GetConnectionLogsOffset(ctx context.Context, arg GetConnect
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const upsertConnectionLog = `-- name: UpsertConnectionLog :one
|
||||
INSERT INTO connection_logs (
|
||||
id,
|
||||
connect_time,
|
||||
organization_id,
|
||||
workspace_owner_id,
|
||||
workspace_id,
|
||||
workspace_name,
|
||||
agent_name,
|
||||
type,
|
||||
code,
|
||||
ip,
|
||||
user_agent,
|
||||
user_id,
|
||||
slug_or_port,
|
||||
connection_id,
|
||||
disconnect_reason,
|
||||
disconnect_time
|
||||
) VALUES
|
||||
($1, $15, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14,
|
||||
-- If we've only received a disconnect event, mark the event as immediately
|
||||
-- closed.
|
||||
CASE
|
||||
WHEN $16::connection_status = 'disconnected'
|
||||
THEN $15 :: timestamp with time zone
|
||||
ELSE NULL
|
||||
END)
|
||||
ON CONFLICT (connection_id, workspace_id, agent_name)
|
||||
DO UPDATE SET
|
||||
-- No-op if the connection is still open.
|
||||
disconnect_time = CASE
|
||||
WHEN $16::connection_status = 'disconnected'
|
||||
-- Can only be set once
|
||||
AND connection_logs.disconnect_time IS NULL
|
||||
THEN EXCLUDED.connect_time
|
||||
ELSE connection_logs.disconnect_time
|
||||
END,
|
||||
disconnect_reason = CASE
|
||||
WHEN $16::connection_status = 'disconnected'
|
||||
-- Can only be set once
|
||||
AND connection_logs.disconnect_reason IS NULL
|
||||
THEN EXCLUDED.disconnect_reason
|
||||
ELSE connection_logs.disconnect_reason
|
||||
END,
|
||||
code = CASE
|
||||
WHEN $16::connection_status = 'disconnected'
|
||||
-- Can only be set once
|
||||
AND connection_logs.code IS NULL
|
||||
THEN EXCLUDED.code
|
||||
ELSE connection_logs.code
|
||||
END
|
||||
RETURNING id, connect_time, organization_id, workspace_owner_id, workspace_id, workspace_name, agent_name, type, ip, code, user_agent, user_id, slug_or_port, connection_id, disconnect_time, disconnect_reason
|
||||
`
|
||||
|
||||
type UpsertConnectionLogParams struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"`
|
||||
WorkspaceOwnerID uuid.UUID `db:"workspace_owner_id" json:"workspace_owner_id"`
|
||||
WorkspaceID uuid.UUID `db:"workspace_id" json:"workspace_id"`
|
||||
WorkspaceName string `db:"workspace_name" json:"workspace_name"`
|
||||
AgentName string `db:"agent_name" json:"agent_name"`
|
||||
Type ConnectionType `db:"type" json:"type"`
|
||||
Code sql.NullInt32 `db:"code" json:"code"`
|
||||
Ip pqtype.Inet `db:"ip" json:"ip"`
|
||||
UserAgent sql.NullString `db:"user_agent" json:"user_agent"`
|
||||
UserID uuid.NullUUID `db:"user_id" json:"user_id"`
|
||||
SlugOrPort sql.NullString `db:"slug_or_port" json:"slug_or_port"`
|
||||
ConnectionID uuid.NullUUID `db:"connection_id" json:"connection_id"`
|
||||
DisconnectReason sql.NullString `db:"disconnect_reason" json:"disconnect_reason"`
|
||||
Time time.Time `db:"time" json:"time"`
|
||||
ConnectionStatus ConnectionStatus `db:"connection_status" json:"connection_status"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) UpsertConnectionLog(ctx context.Context, arg UpsertConnectionLogParams) (ConnectionLog, error) {
|
||||
row := q.db.QueryRowContext(ctx, upsertConnectionLog,
|
||||
arg.ID,
|
||||
arg.OrganizationID,
|
||||
arg.WorkspaceOwnerID,
|
||||
arg.WorkspaceID,
|
||||
arg.WorkspaceName,
|
||||
arg.AgentName,
|
||||
arg.Type,
|
||||
arg.Code,
|
||||
arg.Ip,
|
||||
arg.UserAgent,
|
||||
arg.UserID,
|
||||
arg.SlugOrPort,
|
||||
arg.ConnectionID,
|
||||
arg.DisconnectReason,
|
||||
arg.Time,
|
||||
arg.ConnectionStatus,
|
||||
)
|
||||
var i ConnectionLog
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.ConnectTime,
|
||||
&i.OrganizationID,
|
||||
&i.WorkspaceOwnerID,
|
||||
&i.WorkspaceID,
|
||||
&i.WorkspaceName,
|
||||
&i.AgentName,
|
||||
&i.Type,
|
||||
&i.Ip,
|
||||
&i.Code,
|
||||
&i.UserAgent,
|
||||
&i.UserID,
|
||||
&i.SlugOrPort,
|
||||
&i.ConnectionID,
|
||||
&i.DisconnectTime,
|
||||
&i.DisconnectReason,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const deleteCryptoKey = `-- name: DeleteCryptoKey :one
|
||||
UPDATE crypto_keys
|
||||
SET secret = NULL, secret_key_id = NULL
|
||||
|
||||
@@ -251,55 +251,75 @@ DELETE FROM connection_logs
|
||||
USING old_logs
|
||||
WHERE connection_logs.id = old_logs.id;
|
||||
|
||||
-- name: UpsertConnectionLog :one
|
||||
-- name: BatchUpsertConnectionLogs :exec
|
||||
INSERT INTO connection_logs (
|
||||
id,
|
||||
connect_time,
|
||||
organization_id,
|
||||
workspace_owner_id,
|
||||
workspace_id,
|
||||
workspace_name,
|
||||
agent_name,
|
||||
type,
|
||||
code,
|
||||
ip,
|
||||
user_agent,
|
||||
user_id,
|
||||
slug_or_port,
|
||||
connection_id,
|
||||
disconnect_reason,
|
||||
disconnect_time
|
||||
) VALUES
|
||||
($1, @time, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14,
|
||||
-- If we've only received a disconnect event, mark the event as immediately
|
||||
-- closed.
|
||||
CASE
|
||||
WHEN @connection_status::connection_status = 'disconnected'
|
||||
THEN @time :: timestamp with time zone
|
||||
ELSE NULL
|
||||
END)
|
||||
id, connect_time, organization_id, workspace_owner_id, workspace_id,
|
||||
workspace_name, agent_name, type, code, ip, user_agent, user_id,
|
||||
slug_or_port, connection_id, disconnect_reason, disconnect_time
|
||||
)
|
||||
SELECT
|
||||
u.id,
|
||||
u.connect_time,
|
||||
u.organization_id,
|
||||
u.workspace_owner_id,
|
||||
u.workspace_id,
|
||||
u.workspace_name,
|
||||
u.agent_name,
|
||||
u.type,
|
||||
-- Use the validity flag to distinguish "no code" (NULL) from a
|
||||
-- legitimate zero exit code.
|
||||
CASE WHEN u.code_valid THEN u.code ELSE NULL END,
|
||||
u.ip,
|
||||
NULLIF(u.user_agent, ''),
|
||||
NULLIF(u.user_id, '00000000-0000-0000-0000-000000000000'::uuid),
|
||||
NULLIF(u.slug_or_port, ''),
|
||||
NULLIF(u.connection_id, '00000000-0000-0000-0000-000000000000'::uuid),
|
||||
NULLIF(u.disconnect_reason, ''),
|
||||
NULLIF(u.disconnect_time, '0001-01-01 00:00:00Z'::timestamptz)
|
||||
FROM (
|
||||
SELECT
|
||||
unnest(sqlc.arg('id')::uuid[]) AS id,
|
||||
unnest(sqlc.arg('connect_time')::timestamptz[]) AS connect_time,
|
||||
unnest(sqlc.arg('organization_id')::uuid[]) AS organization_id,
|
||||
unnest(sqlc.arg('workspace_owner_id')::uuid[]) AS workspace_owner_id,
|
||||
unnest(sqlc.arg('workspace_id')::uuid[]) AS workspace_id,
|
||||
unnest(sqlc.arg('workspace_name')::text[]) AS workspace_name,
|
||||
unnest(sqlc.arg('agent_name')::text[]) AS agent_name,
|
||||
unnest(sqlc.arg('type')::connection_type[]) AS type,
|
||||
unnest(sqlc.arg('code')::int4[]) AS code,
|
||||
unnest(sqlc.arg('code_valid')::bool[]) AS code_valid,
|
||||
unnest(sqlc.arg('ip')::inet[]) AS ip,
|
||||
unnest(sqlc.arg('user_agent')::text[]) AS user_agent,
|
||||
unnest(sqlc.arg('user_id')::uuid[]) AS user_id,
|
||||
unnest(sqlc.arg('slug_or_port')::text[]) AS slug_or_port,
|
||||
unnest(sqlc.arg('connection_id')::uuid[]) AS connection_id,
|
||||
unnest(sqlc.arg('disconnect_reason')::text[]) AS disconnect_reason,
|
||||
unnest(sqlc.arg('disconnect_time')::timestamptz[]) AS disconnect_time
|
||||
) AS u
|
||||
ON CONFLICT (connection_id, workspace_id, agent_name)
|
||||
DO UPDATE SET
|
||||
-- No-op if the connection is still open.
|
||||
disconnect_time = CASE
|
||||
WHEN @connection_status::connection_status = 'disconnected'
|
||||
-- Can only be set once
|
||||
AND connection_logs.disconnect_time IS NULL
|
||||
THEN EXCLUDED.connect_time
|
||||
ELSE connection_logs.disconnect_time
|
||||
END,
|
||||
disconnect_reason = CASE
|
||||
WHEN @connection_status::connection_status = 'disconnected'
|
||||
-- Can only be set once
|
||||
AND connection_logs.disconnect_reason IS NULL
|
||||
THEN EXCLUDED.disconnect_reason
|
||||
ELSE connection_logs.disconnect_reason
|
||||
END,
|
||||
code = CASE
|
||||
WHEN @connection_status::connection_status = 'disconnected'
|
||||
-- Can only be set once
|
||||
AND connection_logs.code IS NULL
|
||||
THEN EXCLUDED.code
|
||||
ELSE connection_logs.code
|
||||
END
|
||||
RETURNING *;
|
||||
-- Pick the earliest real connect_time. The zero sentinel
|
||||
-- ('0001-01-01') means the batch didn't know the connect_time
|
||||
-- (e.g. a pure disconnect event), so we keep the existing value.
|
||||
connect_time = CASE
|
||||
WHEN EXCLUDED.connect_time = '0001-01-01 00:00:00Z'::timestamptz
|
||||
THEN connection_logs.connect_time
|
||||
WHEN connection_logs.connect_time = '0001-01-01 00:00:00Z'::timestamptz
|
||||
THEN EXCLUDED.connect_time
|
||||
ELSE LEAST(connection_logs.connect_time, EXCLUDED.connect_time)
|
||||
END,
|
||||
disconnect_time = CASE
|
||||
WHEN connection_logs.disconnect_time IS NULL
|
||||
THEN EXCLUDED.disconnect_time
|
||||
ELSE connection_logs.disconnect_time
|
||||
END,
|
||||
disconnect_reason = CASE
|
||||
WHEN connection_logs.disconnect_reason IS NULL
|
||||
THEN EXCLUDED.disconnect_reason
|
||||
ELSE connection_logs.disconnect_reason
|
||||
END,
|
||||
code = CASE
|
||||
WHEN connection_logs.code IS NULL
|
||||
THEN EXCLUDED.code
|
||||
ELSE connection_logs.code
|
||||
END;
|
||||
|
||||
@@ -535,7 +535,7 @@ func (p *DBTokenProvider) connLogInitRequest(w http.ResponseWriter, r *http.Requ
|
||||
Int32: statusCode,
|
||||
Valid: true,
|
||||
},
|
||||
Ip: database.ParseIP(ip),
|
||||
IP: database.ParseIP(ip),
|
||||
UserAgent: sql.NullString{Valid: userAgent != "", String: userAgent},
|
||||
UserID: uuid.NullUUID{
|
||||
UUID: userID,
|
||||
|
||||
@@ -1281,7 +1281,7 @@ func assertConnLogContains(t *testing.T, rr *httptest.ResponseRecorder, r *http.
|
||||
WorkspaceName: workspace.Name,
|
||||
AgentName: agentName,
|
||||
Type: typ,
|
||||
Ip: database.ParseIP(r.RemoteAddr),
|
||||
IP: database.ParseIP(r.RemoteAddr),
|
||||
UserAgent: sql.NullString{Valid: r.UserAgent() != "", String: r.UserAgent()},
|
||||
Code: sql.NullInt32{
|
||||
Int32: int32(resp.StatusCode), // nolint:gosec
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"crypto/ed25519"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"net/http"
|
||||
"net/url"
|
||||
@@ -144,10 +145,11 @@ func New(ctx context.Context, options *Options) (_ *API, err error) {
|
||||
}
|
||||
|
||||
if options.ConnectionLogger == nil {
|
||||
options.ConnectionLogger = connectionlog.NewConnectionLogger(
|
||||
connectionlog.NewDBBackend(options.Database),
|
||||
connLogger := connectionlog.New(
|
||||
connectionlog.NewDBBatcher(ctx, options.Database, options.Logger),
|
||||
connectionlog.NewSlogBackend(options.Logger),
|
||||
)
|
||||
options.ConnectionLogger = connLogger
|
||||
}
|
||||
|
||||
meshTLSConfig, err := replicasync.CreateDERPMeshTLSConfig(options.AccessURL.Hostname(), options.TLSCertificates)
|
||||
@@ -822,6 +824,12 @@ func (api *API) Close() error {
|
||||
api.Options.CheckInactiveUsersCancelFunc()
|
||||
}
|
||||
|
||||
// Close the connection logger to flush any remaining batched
|
||||
// entries before shutting down the database connection.
|
||||
if cl, ok := api.Options.ConnectionLogger.(io.Closer); ok {
|
||||
_ = cl.Close()
|
||||
}
|
||||
|
||||
return api.AGPL.Close()
|
||||
}
|
||||
|
||||
|
||||
@@ -2,31 +2,70 @@ package connectionlog
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/sqlc-dev/pqtype"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
agpl "github.com/coder/coder/v2/coderd/connectionlog"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
auditbackends "github.com/coder/coder/v2/enterprise/audit/backends"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
const (
|
||||
// defaultBatchSize is the maximum number of connection log entries
|
||||
// to batch before forcing a flush.
|
||||
defaultBatchSize = 1000
|
||||
|
||||
// defaultFlushInterval is how frequently to flush batched connection
|
||||
// log entries to the database. Five seconds balances near-real-time
|
||||
// audit visibility with write efficiency.
|
||||
defaultFlushInterval = 5 * time.Second
|
||||
|
||||
// retryQueueSize is the capacity of the bounded retry channel.
|
||||
// Failed batches beyond this limit are dropped.
|
||||
retryQueueSize = 10
|
||||
|
||||
// shutdownWriteTimeout bounds how long a final write attempt
|
||||
// can take during shutdown when the batcher context is already
|
||||
// canceled.
|
||||
shutdownWriteTimeout = 10 * time.Second
|
||||
|
||||
// maxRetries is the number of times to retry a failed batch
|
||||
// write before dropping it and moving on.
|
||||
maxRetries = 3
|
||||
|
||||
// retryInterval is the fixed delay between retry attempts.
|
||||
retryInterval = time.Second
|
||||
)
|
||||
|
||||
// Backend is a destination for connection log events. Backends that
|
||||
// also implement io.Closer will be closed when the ConnectionLogger
|
||||
// is closed.
|
||||
type Backend interface {
|
||||
Upsert(ctx context.Context, clog database.UpsertConnectionLogParams) error
|
||||
}
|
||||
|
||||
func NewConnectionLogger(backends ...Backend) agpl.ConnectionLogger {
|
||||
return &connectionLogger{
|
||||
// ConnectionLogger fans out each connection log event to every
|
||||
// registered backend.
|
||||
type ConnectionLogger struct {
|
||||
backends []Backend
|
||||
}
|
||||
|
||||
// New creates a ConnectionLogger that dispatches to the given
|
||||
// backends.
|
||||
func New(backends ...Backend) *ConnectionLogger {
|
||||
return &ConnectionLogger{
|
||||
backends: backends,
|
||||
}
|
||||
}
|
||||
|
||||
type connectionLogger struct {
|
||||
backends []Backend
|
||||
}
|
||||
|
||||
func (c *connectionLogger) Upsert(ctx context.Context, clog database.UpsertConnectionLogParams) error {
|
||||
func (c *ConnectionLogger) Upsert(ctx context.Context, clog database.UpsertConnectionLogParams) error {
|
||||
var errs error
|
||||
for _, backend := range c.backends {
|
||||
err := backend.Upsert(ctx, clog)
|
||||
@@ -37,24 +76,443 @@ func (c *connectionLogger) Upsert(ctx context.Context, clog database.UpsertConne
|
||||
return errs
|
||||
}
|
||||
|
||||
type dbBackend struct {
|
||||
db database.Store
|
||||
// Close closes all backends that implement io.Closer.
|
||||
func (c *ConnectionLogger) Close() error {
|
||||
var errs error
|
||||
for _, backend := range c.backends {
|
||||
if closer, ok := backend.(io.Closer); ok {
|
||||
if err := closer.Close(); err != nil {
|
||||
errs = multierror.Append(errs, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return errs
|
||||
}
|
||||
|
||||
func NewDBBackend(db database.Store) Backend {
|
||||
return &dbBackend{db: db}
|
||||
// DBBatcherOption is a functional option for configuring a DBBatcher.
|
||||
type DBBatcherOption func(b *DBBatcher)
|
||||
|
||||
// WithBatchSize sets the maximum number of entries to accumulate
|
||||
// before forcing a flush.
|
||||
func WithBatchSize(size int) DBBatcherOption {
|
||||
return func(b *DBBatcher) {
|
||||
b.maxBatchSize = size
|
||||
}
|
||||
}
|
||||
|
||||
func (b *dbBackend) Upsert(ctx context.Context, clog database.UpsertConnectionLogParams) error {
|
||||
//nolint:gocritic // This is the Connection Logger
|
||||
_, err := b.db.UpsertConnectionLog(dbauthz.AsConnectionLogger(ctx), clog)
|
||||
return err
|
||||
// WithFlushInterval sets how frequently the batcher flushes to the
|
||||
// database.
|
||||
func WithFlushInterval(d time.Duration) DBBatcherOption {
|
||||
return func(b *DBBatcher) {
|
||||
b.interval = d
|
||||
}
|
||||
}
|
||||
|
||||
// WithClock sets the clock, useful for testing.
|
||||
func WithClock(clock quartz.Clock) DBBatcherOption {
|
||||
return func(b *DBBatcher) {
|
||||
b.clock = clock
|
||||
}
|
||||
}
|
||||
|
||||
// DBBatcher batches connection log upserts and periodically flushes
|
||||
// them to the database to reduce per-event write pressure.
|
||||
type DBBatcher struct {
|
||||
store database.Store
|
||||
log slog.Logger
|
||||
|
||||
itemCh chan database.UpsertConnectionLogParams
|
||||
|
||||
// dedupedBatch holds entries keyed by connection ID so that
|
||||
// PostgreSQL never sees the same row twice in one INSERT …
|
||||
// ON CONFLICT DO UPDATE. Connection IDs are globally unique
|
||||
// (each new session gets a fresh UUID). Entries with a NULL
|
||||
// connection_id (web events) go into nullConnIDBatch instead
|
||||
// because NULL != NULL in SQL unique constraints.
|
||||
dedupedBatch map[uuid.UUID]batchEntry
|
||||
nullConnIDBatch []batchEntry
|
||||
maxBatchSize int
|
||||
|
||||
// retryCh is a bounded channel of failed batches awaiting
|
||||
// retry. A single retry worker goroutine processes this
|
||||
// channel, retrying each batch up to maxRetries times before
|
||||
// dropping it. If the channel is full, new failures are
|
||||
// dropped immediately.
|
||||
retryCh chan database.BatchUpsertConnectionLogsParams
|
||||
|
||||
clock quartz.Clock
|
||||
timer *quartz.Timer
|
||||
interval time.Duration
|
||||
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// NewDBBatcher creates a DBBatcher that batches writes to the database
|
||||
// and starts its background processing loop. Close must be called to
|
||||
// flush remaining entries on shutdown.
|
||||
func NewDBBatcher(ctx context.Context, store database.Store, log slog.Logger, opts ...DBBatcherOption) *DBBatcher {
|
||||
b := &DBBatcher{
|
||||
store: store,
|
||||
log: log,
|
||||
clock: quartz.NewReal(),
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(b)
|
||||
}
|
||||
|
||||
if b.interval == 0 {
|
||||
b.interval = defaultFlushInterval
|
||||
}
|
||||
if b.maxBatchSize == 0 {
|
||||
b.maxBatchSize = defaultBatchSize
|
||||
}
|
||||
|
||||
b.timer = b.clock.NewTimer(b.interval)
|
||||
b.itemCh = make(chan database.UpsertConnectionLogParams, b.maxBatchSize)
|
||||
b.dedupedBatch = make(map[uuid.UUID]batchEntry, b.maxBatchSize)
|
||||
b.retryCh = make(chan database.BatchUpsertConnectionLogsParams, retryQueueSize)
|
||||
|
||||
b.ctx, b.cancel = context.WithCancel(ctx)
|
||||
b.wg.Add(2)
|
||||
go func() {
|
||||
defer b.wg.Done()
|
||||
b.run(b.ctx)
|
||||
}()
|
||||
go func() {
|
||||
defer b.wg.Done()
|
||||
b.retryLoop()
|
||||
}()
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
// Upsert enqueues a connection log entry for batched writing. It
|
||||
// blocks if the internal buffer is full, ensuring no logs are dropped.
|
||||
// It returns an error if the batcher or caller context is canceled.
|
||||
func (b *DBBatcher) Upsert(ctx context.Context, clog database.UpsertConnectionLogParams) error {
|
||||
if b.ctx.Err() != nil {
|
||||
return b.ctx.Err()
|
||||
}
|
||||
|
||||
select {
|
||||
case b.itemCh <- clog:
|
||||
return nil
|
||||
case <-b.ctx.Done():
|
||||
return b.ctx.Err()
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// Close cancels the batcher context, waits for the run loop and
|
||||
// retry worker to exit.
|
||||
func (b *DBBatcher) Close() error {
|
||||
b.cancel()
|
||||
if b.timer != nil {
|
||||
b.timer.Stop()
|
||||
}
|
||||
b.wg.Wait()
|
||||
return nil
|
||||
}
|
||||
|
||||
// addToBatch inserts an item into the batch, deduplicating by conflict
|
||||
// key on the fly. For entries with the same key, disconnect events are
|
||||
// preferred over connect events, and later events are preferred over
|
||||
// earlier ones.
|
||||
//
|
||||
// This is safe because each new connection gets a fresh UUID (see
|
||||
// agent/agent.go and agent/agentssh), so the only duplicate for the
|
||||
// same (connection_id, workspace_id, agent_name) is a connect/disconnect
|
||||
// pair for the same session. A "reconnect" always uses a new ID.
|
||||
func (b *DBBatcher) addToBatch(item database.UpsertConnectionLogParams) {
|
||||
entry := batchEntry{
|
||||
UpsertConnectionLogParams: item,
|
||||
}
|
||||
if item.ConnectionStatus == database.ConnectionStatusDisconnected {
|
||||
// For standalone disconnect events, use the disconnect
|
||||
// time as both connect and disconnect time. This matches
|
||||
// the single-row UpsertConnectionLog behavior which uses
|
||||
// @time for connect_time regardless of status. The SQL
|
||||
// LEAST logic will correct connect_time if the real
|
||||
// connect event arrives in a later batch.
|
||||
entry.connectTime = item.Time
|
||||
entry.disconnectTime = item.Time
|
||||
} else {
|
||||
entry.connectTime = item.Time
|
||||
}
|
||||
|
||||
if !item.ConnectionID.Valid {
|
||||
b.nullConnIDBatch = append(b.nullConnIDBatch, entry)
|
||||
return
|
||||
}
|
||||
connID := item.ConnectionID.UUID
|
||||
existing, ok := b.dedupedBatch[connID]
|
||||
if !ok {
|
||||
b.dedupedBatch[connID] = entry
|
||||
return
|
||||
}
|
||||
// When merging entries for the same connection, always preserve
|
||||
// the earliest non-zero connect_time and latest disconnect_time
|
||||
// so the row records the full session span.
|
||||
if !existing.connectTime.IsZero() && existing.connectTime.Before(entry.connectTime) {
|
||||
entry.connectTime = existing.connectTime
|
||||
}
|
||||
if existing.disconnectTime.After(entry.disconnectTime) {
|
||||
entry.disconnectTime = existing.disconnectTime
|
||||
}
|
||||
|
||||
// Prefer disconnect over connect (superset of info).
|
||||
// If same status, prefer the later event.
|
||||
if item.ConnectionStatus == database.ConnectionStatusDisconnected &&
|
||||
existing.ConnectionStatus != database.ConnectionStatusDisconnected {
|
||||
b.dedupedBatch[connID] = entry
|
||||
} else if item.Time.After(existing.Time) {
|
||||
b.dedupedBatch[connID] = entry
|
||||
}
|
||||
}
|
||||
|
||||
// batchLen returns the total number of entries currently buffered.
|
||||
func (b *DBBatcher) batchLen() int {
|
||||
return len(b.dedupedBatch) + len(b.nullConnIDBatch)
|
||||
}
|
||||
|
||||
func (b *DBBatcher) run(ctx context.Context) {
|
||||
//nolint:gocritic // System-level batch operation for connection logs.
|
||||
authCtx := dbauthz.AsConnectionLogger(ctx)
|
||||
for ctx.Err() == nil {
|
||||
select {
|
||||
case item := <-b.itemCh:
|
||||
b.addToBatch(item)
|
||||
|
||||
if b.batchLen() >= b.maxBatchSize {
|
||||
b.flush(authCtx)
|
||||
b.timer.Reset(b.interval, "connectionLogBatcher", "capacityFlush")
|
||||
}
|
||||
|
||||
case <-b.timer.C:
|
||||
b.flush(authCtx)
|
||||
b.timer.Reset(b.interval, "connectionLogBatcher", "scheduledFlush")
|
||||
|
||||
case <-ctx.Done():
|
||||
}
|
||||
}
|
||||
|
||||
b.log.Debug(ctx, "context done, flushing before exit")
|
||||
|
||||
// Drain any remaining items from the channel.
|
||||
for {
|
||||
select {
|
||||
case item := <-b.itemCh:
|
||||
b.addToBatch(item)
|
||||
default:
|
||||
if b.batchLen() > 0 {
|
||||
b.shutdownBatch(b.buildParams())
|
||||
}
|
||||
// Signal the retry worker to skip delays and close
|
||||
// the channel so it exits after processing any
|
||||
// remaining items.
|
||||
// Mark the batcher as closed so that any subsequent
|
||||
// Upsert calls fail immediately instead of sending
|
||||
// into itemCh after the run loop has exited.
|
||||
close(b.retryCh)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// batchEntry wraps a connection log event with explicit connect and
|
||||
// disconnect times. When a connect and disconnect for the same session
|
||||
// are merged into one entry, connectTime preserves the original
|
||||
// session start while disconnectTime records when it ended.
|
||||
type batchEntry struct {
|
||||
database.UpsertConnectionLogParams
|
||||
connectTime time.Time
|
||||
disconnectTime time.Time
|
||||
}
|
||||
|
||||
// flush builds the batch params, clears the in-memory batch, and
|
||||
// writes to the database. On failure, the batch is queued for retry
|
||||
// by the single retry worker goroutine. If the retry queue is full,
|
||||
// the batch is dropped.
|
||||
func (b *DBBatcher) flush(ctx context.Context) {
|
||||
count := b.batchLen()
|
||||
if count == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
params := b.buildParams()
|
||||
|
||||
// Clear the batch before writing so the run loop can start
|
||||
// accumulating new entries.
|
||||
b.dedupedBatch = make(map[uuid.UUID]batchEntry, b.maxBatchSize)
|
||||
b.nullConnIDBatch = nil
|
||||
|
||||
// Use the batcher's context for normal operation so Close()
|
||||
// can cancel hung writes. During shutdown (ctx already canceled),
|
||||
// fall back to a bounded timeout.
|
||||
writeCtx := b.ctx
|
||||
if writeCtx.Err() != nil {
|
||||
var cancel context.CancelFunc
|
||||
writeCtx, cancel = context.WithTimeout(context.Background(), shutdownWriteTimeout)
|
||||
defer cancel()
|
||||
}
|
||||
//nolint:gocritic // System-level batch operation for connection logs.
|
||||
err := b.store.BatchUpsertConnectionLogs(dbauthz.AsConnectionLogger(writeCtx), params)
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
|
||||
b.log.Error(ctx, "batch upsert failed, queueing for retry",
|
||||
slog.Error(err), slog.F("count", count))
|
||||
|
||||
// Don't retry on shutdown.
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case b.retryCh <- params:
|
||||
default:
|
||||
b.log.Error(ctx, "retry queue full, dropping batch",
|
||||
slog.F("dropped", count))
|
||||
}
|
||||
}
|
||||
|
||||
func (b *DBBatcher) buildParams() database.BatchUpsertConnectionLogsParams {
|
||||
count := b.batchLen()
|
||||
var (
|
||||
ids = make([]uuid.UUID, 0, count)
|
||||
connectTime = make([]time.Time, 0, count)
|
||||
organizationID = make([]uuid.UUID, 0, count)
|
||||
workspaceOwnerID = make([]uuid.UUID, 0, count)
|
||||
workspaceID = make([]uuid.UUID, 0, count)
|
||||
workspaceName = make([]string, 0, count)
|
||||
agentName = make([]string, 0, count)
|
||||
connType = make([]database.ConnectionType, 0, count)
|
||||
code = make([]int32, 0, count)
|
||||
codeValid = make([]bool, 0, count)
|
||||
ip = make([]pqtype.Inet, 0, count)
|
||||
userAgent = make([]string, 0, count)
|
||||
userID = make([]uuid.UUID, 0, count)
|
||||
slugOrPort = make([]string, 0, count)
|
||||
connectionID = make([]uuid.UUID, 0, count)
|
||||
disconnectReason = make([]string, 0, count)
|
||||
disconnectTime = make([]time.Time, 0, count)
|
||||
)
|
||||
|
||||
appendEntry := func(e batchEntry) {
|
||||
ids = append(ids, e.ID)
|
||||
connectTime = append(connectTime, e.connectTime)
|
||||
organizationID = append(organizationID, e.OrganizationID)
|
||||
workspaceOwnerID = append(workspaceOwnerID, e.WorkspaceOwnerID)
|
||||
workspaceID = append(workspaceID, e.WorkspaceID)
|
||||
workspaceName = append(workspaceName, e.WorkspaceName)
|
||||
agentName = append(agentName, e.AgentName)
|
||||
connType = append(connType, e.Type)
|
||||
code = append(code, e.Code.Int32)
|
||||
codeValid = append(codeValid, e.Code.Valid)
|
||||
ip = append(ip, e.IP)
|
||||
userAgent = append(userAgent, e.UserAgent.String)
|
||||
userID = append(userID, e.UserID.UUID)
|
||||
slugOrPort = append(slugOrPort, e.SlugOrPort.String)
|
||||
connectionID = append(connectionID, e.ConnectionID.UUID)
|
||||
disconnectReason = append(disconnectReason, e.DisconnectReason.String)
|
||||
disconnectTime = append(disconnectTime, e.disconnectTime)
|
||||
}
|
||||
|
||||
for _, entry := range b.dedupedBatch {
|
||||
appendEntry(entry)
|
||||
}
|
||||
for _, entry := range b.nullConnIDBatch {
|
||||
appendEntry(entry)
|
||||
}
|
||||
|
||||
return database.BatchUpsertConnectionLogsParams{
|
||||
ID: ids,
|
||||
ConnectTime: connectTime,
|
||||
OrganizationID: organizationID,
|
||||
WorkspaceOwnerID: workspaceOwnerID,
|
||||
WorkspaceID: workspaceID,
|
||||
WorkspaceName: workspaceName,
|
||||
AgentName: agentName,
|
||||
Type: connType,
|
||||
Code: code,
|
||||
CodeValid: codeValid,
|
||||
Ip: ip,
|
||||
UserAgent: userAgent,
|
||||
UserID: userID,
|
||||
SlugOrPort: slugOrPort,
|
||||
ConnectionID: connectionID,
|
||||
DisconnectReason: disconnectReason,
|
||||
DisconnectTime: disconnectTime,
|
||||
}
|
||||
}
|
||||
|
||||
// retryLoop is a single background goroutine that processes failed
|
||||
// batches from retryCh. Each batch is retried up to maxRetries times
|
||||
// with a fixed delay between attempts. When draining is set (shutdown),
|
||||
// batches get a single immediate write attempt instead. The loop exits
|
||||
// when retryCh is closed by the run goroutine.
|
||||
func (b *DBBatcher) retryLoop() {
|
||||
for params := range b.retryCh {
|
||||
b.retryBatch(params)
|
||||
}
|
||||
}
|
||||
|
||||
// retryBatch retries writing a batch up to maxRetries times with a
|
||||
// fixed delay between attempts. If the batcher context is canceled
|
||||
// during a wait, one final attempt is made before returning.
|
||||
func (b *DBBatcher) retryBatch(params database.BatchUpsertConnectionLogsParams) {
|
||||
count := len(params.ID)
|
||||
for attempt := range maxRetries {
|
||||
t := time.NewTimer(retryInterval)
|
||||
select {
|
||||
case <-b.ctx.Done():
|
||||
b.shutdownBatch(params)
|
||||
return
|
||||
case <-t.C:
|
||||
}
|
||||
|
||||
//nolint:gocritic // System-level batch operation for connection logs.
|
||||
err := b.store.BatchUpsertConnectionLogs(dbauthz.AsConnectionLogger(b.ctx), params)
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
|
||||
b.log.Warn(b.ctx, "batch retry failed",
|
||||
slog.Error(err),
|
||||
slog.F("count", count),
|
||||
slog.F("attempt", attempt+1),
|
||||
slog.F("max_attempts", maxRetries),
|
||||
)
|
||||
}
|
||||
|
||||
b.log.Error(b.ctx, "batch retries exhausted, dropping batch",
|
||||
slog.F("dropped", count))
|
||||
}
|
||||
|
||||
// shutdownBatch makes a single write attempt during shutdown with a
|
||||
// bounded timeout so it can't hang indefinitely.
|
||||
func (b *DBBatcher) shutdownBatch(params database.BatchUpsertConnectionLogsParams) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), shutdownWriteTimeout)
|
||||
defer cancel()
|
||||
//nolint:gocritic // System-level batch operation for connection logs.
|
||||
err := b.store.BatchUpsertConnectionLogs(dbauthz.AsConnectionLogger(ctx), params)
|
||||
if err != nil {
|
||||
b.log.Error(b.ctx, "batch write failed on shutdown, dropping batch",
|
||||
slog.Error(err), slog.F("dropped", len(params.ID)))
|
||||
}
|
||||
}
|
||||
|
||||
type connectionSlogBackend struct {
|
||||
exporter *auditbackends.SlogExporter
|
||||
}
|
||||
|
||||
// NewSlogBackend returns a Backend that logs connection events via
|
||||
// the structured logger.
|
||||
func NewSlogBackend(logger slog.Logger) Backend {
|
||||
return &connectionSlogBackend{
|
||||
exporter: auditbackends.NewSlogExporter(logger),
|
||||
|
||||
@@ -0,0 +1,529 @@
|
||||
package connectionlog
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbmock"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
func Test_addToBatch(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("ConnectThenDisconnect", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
b := &DBBatcher{
|
||||
maxBatchSize: 100,
|
||||
dedupedBatch: make(map[uuid.UUID]batchEntry),
|
||||
}
|
||||
|
||||
wsID := uuid.New()
|
||||
connID := uuid.New()
|
||||
|
||||
connect := fakeConnectEvent(wsID, "agent1", connID)
|
||||
disconnect := fakeDisconnectEvent(wsID, "agent1", connID)
|
||||
|
||||
b.addToBatch(connect)
|
||||
b.addToBatch(disconnect)
|
||||
|
||||
require.Equal(t, 1, b.batchLen())
|
||||
key := connID
|
||||
got := b.dedupedBatch[key]
|
||||
require.Equal(t, disconnect.ID, got.ID)
|
||||
require.Equal(t, database.ConnectionStatusDisconnected, got.ConnectionStatus)
|
||||
// The connect_time should be preserved from the original
|
||||
// connect event, not overwritten by the disconnect's
|
||||
// timestamp.
|
||||
require.Equal(t, connect.Time, got.connectTime)
|
||||
require.Equal(t, disconnect.Time, got.disconnectTime)
|
||||
})
|
||||
|
||||
t.Run("DisconnectThenLaterConnect", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
b := &DBBatcher{
|
||||
maxBatchSize: 100,
|
||||
dedupedBatch: make(map[uuid.UUID]batchEntry),
|
||||
}
|
||||
|
||||
wsID := uuid.New()
|
||||
connID := uuid.New()
|
||||
|
||||
disconnect := fakeDisconnectEvent(wsID, "agent1", connID)
|
||||
connect := fakeConnectEvent(wsID, "agent1", connID)
|
||||
connect.Time = disconnect.Time.Add(time.Second)
|
||||
|
||||
b.addToBatch(disconnect)
|
||||
b.addToBatch(connect)
|
||||
|
||||
require.Equal(t, 1, b.batchLen())
|
||||
key := connID
|
||||
// The later event wins when the incoming item is not a
|
||||
// disconnect. In practice, this case doesn't occur because
|
||||
// connection IDs are never reused.
|
||||
got := b.dedupedBatch[key]
|
||||
require.Equal(t, connect.ID, got.ID)
|
||||
// The disconnect's time should be preserved even though
|
||||
// the connect event replaced it.
|
||||
require.Equal(t, disconnect.Time, got.disconnectTime)
|
||||
})
|
||||
|
||||
t.Run("DisconnectThenEarlierConnect", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
b := &DBBatcher{
|
||||
maxBatchSize: 100,
|
||||
dedupedBatch: make(map[uuid.UUID]batchEntry),
|
||||
}
|
||||
|
||||
wsID := uuid.New()
|
||||
connID := uuid.New()
|
||||
|
||||
disconnect := fakeDisconnectEvent(wsID, "agent1", connID)
|
||||
connect := fakeConnectEvent(wsID, "agent1", connID)
|
||||
connect.Time = disconnect.Time.Add(-time.Second)
|
||||
|
||||
b.addToBatch(disconnect)
|
||||
b.addToBatch(connect)
|
||||
|
||||
require.Equal(t, 1, b.batchLen())
|
||||
key := connID
|
||||
require.Equal(t, disconnect.ID, b.dedupedBatch[key].ID)
|
||||
})
|
||||
|
||||
t.Run("SameStatusKeepsLater", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
b := &DBBatcher{
|
||||
maxBatchSize: 100,
|
||||
dedupedBatch: make(map[uuid.UUID]batchEntry),
|
||||
}
|
||||
|
||||
wsID := uuid.New()
|
||||
connID := uuid.New()
|
||||
|
||||
early := fakeConnectEvent(wsID, "agent1", connID)
|
||||
early.Time = time.Now()
|
||||
late := fakeConnectEvent(wsID, "agent1", connID)
|
||||
late.Time = early.Time.Add(time.Second)
|
||||
|
||||
b.addToBatch(early)
|
||||
b.addToBatch(late)
|
||||
|
||||
require.Equal(t, 1, b.batchLen())
|
||||
key := connID
|
||||
require.Equal(t, late.ID, b.dedupedBatch[key].ID)
|
||||
})
|
||||
|
||||
t.Run("NullConnIDsNeverDedup", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
b := &DBBatcher{
|
||||
maxBatchSize: 100,
|
||||
dedupedBatch: make(map[uuid.UUID]batchEntry),
|
||||
}
|
||||
|
||||
evt1 := fakeNullConnIDEvent()
|
||||
evt2 := fakeNullConnIDEvent()
|
||||
evt2.WorkspaceID = evt1.WorkspaceID
|
||||
evt2.AgentName = evt1.AgentName
|
||||
|
||||
b.addToBatch(evt1)
|
||||
b.addToBatch(evt2)
|
||||
|
||||
require.Equal(t, 2, b.batchLen())
|
||||
require.Len(t, b.nullConnIDBatch, 2)
|
||||
require.Empty(t, b.dedupedBatch)
|
||||
})
|
||||
|
||||
t.Run("MixedNullAndNonNull", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
b := &DBBatcher{
|
||||
maxBatchSize: 100,
|
||||
dedupedBatch: make(map[uuid.UUID]batchEntry),
|
||||
}
|
||||
|
||||
wsID := uuid.New()
|
||||
regular := fakeConnectEvent(wsID, "agent1", uuid.New())
|
||||
nullEvt := fakeNullConnIDEvent()
|
||||
nullEvt.WorkspaceID = wsID
|
||||
nullEvt.AgentName = "agent1"
|
||||
|
||||
b.addToBatch(regular)
|
||||
b.addToBatch(nullEvt)
|
||||
|
||||
require.Equal(t, 2, b.batchLen())
|
||||
require.Len(t, b.dedupedBatch, 1)
|
||||
require.Len(t, b.nullConnIDBatch, 1)
|
||||
})
|
||||
|
||||
t.Run("StandaloneDisconnectUsesTimeAsConnectTime", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
b := &DBBatcher{
|
||||
maxBatchSize: 100,
|
||||
dedupedBatch: make(map[uuid.UUID]batchEntry),
|
||||
}
|
||||
|
||||
connID := uuid.New()
|
||||
disconnect := fakeDisconnectEvent(uuid.New(), "agent1", connID)
|
||||
|
||||
b.addToBatch(disconnect)
|
||||
|
||||
got := b.dedupedBatch[connID]
|
||||
// A standalone disconnect must not leave connectTime as
|
||||
// zero — that would insert a year-0001 connect_time in
|
||||
// the DB. It should use the disconnect's own timestamp,
|
||||
// matching the single-row UpsertConnectionLog behavior.
|
||||
require.False(t, got.connectTime.IsZero(),
|
||||
"standalone disconnect must have non-zero connectTime")
|
||||
require.Equal(t, disconnect.Time, got.connectTime)
|
||||
require.Equal(t, disconnect.Time, got.disconnectTime)
|
||||
})
|
||||
|
||||
t.Run("DuplicateDisconnectsPreserveConnectTime", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
b := &DBBatcher{
|
||||
maxBatchSize: 100,
|
||||
dedupedBatch: make(map[uuid.UUID]batchEntry),
|
||||
}
|
||||
|
||||
wsID := uuid.New()
|
||||
connID := uuid.New()
|
||||
|
||||
connect := fakeConnectEvent(wsID, "agent1", connID)
|
||||
disconnect1 := fakeDisconnectEvent(wsID, "agent1", connID)
|
||||
disconnect2 := fakeDisconnectEvent(wsID, "agent1", connID)
|
||||
disconnect2.Time = disconnect1.Time.Add(time.Second)
|
||||
|
||||
b.addToBatch(connect)
|
||||
b.addToBatch(disconnect1)
|
||||
b.addToBatch(disconnect2)
|
||||
|
||||
require.Equal(t, 1, b.batchLen())
|
||||
got := b.dedupedBatch[connID]
|
||||
// The second disconnect should win (later event) but the
|
||||
// original connect_time from the connect event must be
|
||||
// preserved, not regressed to the disconnect's timestamp.
|
||||
require.Equal(t, disconnect2.ID, got.ID)
|
||||
require.Equal(t, connect.Time, got.connectTime,
|
||||
"connect_time must not regress to disconnect timestamp")
|
||||
require.Equal(t, disconnect2.Time, got.disconnectTime)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_batcherFlush(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("DeduplicatesConnectDisconnect", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
log := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
||||
ctrl := gomock.NewController(t)
|
||||
store := dbmock.NewMockStore(ctrl)
|
||||
clock := quartz.NewMock(t)
|
||||
|
||||
b := NewDBBatcher(ctx, store, log, WithClock(clock), WithBatchSize(100))
|
||||
|
||||
wsID := uuid.New()
|
||||
connID := uuid.New()
|
||||
connect := fakeConnectEvent(wsID, "agent1", connID)
|
||||
disconnect := fakeDisconnectEvent(wsID, "agent1", connID)
|
||||
|
||||
// Expect a single batch with only the disconnect event.
|
||||
store.EXPECT().
|
||||
BatchUpsertConnectionLogs(gomock.Any(), batchParamsMatcher{
|
||||
expectedCount: 1,
|
||||
mustContainIDs: []uuid.UUID{disconnect.ID},
|
||||
mustNotContainIDs: []uuid.UUID{connect.ID},
|
||||
}).
|
||||
Return(nil).
|
||||
Times(1)
|
||||
|
||||
require.NoError(t, b.Upsert(ctx, connect))
|
||||
require.NoError(t, b.Upsert(ctx, disconnect))
|
||||
require.NoError(t, b.Close())
|
||||
})
|
||||
|
||||
t.Run("DoesNotDeduplicateNullConnIDs", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
log := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
||||
ctrl := gomock.NewController(t)
|
||||
store := dbmock.NewMockStore(ctrl)
|
||||
clock := quartz.NewMock(t)
|
||||
|
||||
b := NewDBBatcher(ctx, store, log, WithClock(clock), WithBatchSize(100))
|
||||
|
||||
evt1 := fakeNullConnIDEvent()
|
||||
evt2 := fakeNullConnIDEvent()
|
||||
evt2.WorkspaceID = evt1.WorkspaceID
|
||||
evt2.AgentName = evt1.AgentName
|
||||
|
||||
store.EXPECT().
|
||||
BatchUpsertConnectionLogs(gomock.Any(), batchParamsMatcher{
|
||||
expectedCount: 2,
|
||||
mustContainIDs: []uuid.UUID{evt1.ID, evt2.ID},
|
||||
}).
|
||||
Return(nil).
|
||||
Times(1)
|
||||
|
||||
require.NoError(t, b.Upsert(ctx, evt1))
|
||||
require.NoError(t, b.Upsert(ctx, evt2))
|
||||
require.NoError(t, b.Close())
|
||||
})
|
||||
|
||||
t.Run("DoesNotDeduplicateDifferentConnectionIDs", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
log := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
||||
ctrl := gomock.NewController(t)
|
||||
store := dbmock.NewMockStore(ctrl)
|
||||
clock := quartz.NewMock(t)
|
||||
|
||||
b := NewDBBatcher(ctx, store, log, WithClock(clock), WithBatchSize(100))
|
||||
|
||||
wsID := uuid.New()
|
||||
evt1 := fakeConnectEvent(wsID, "agent1", uuid.New())
|
||||
evt2 := fakeConnectEvent(wsID, "agent1", uuid.New())
|
||||
|
||||
store.EXPECT().
|
||||
BatchUpsertConnectionLogs(gomock.Any(), batchParamsMatcher{
|
||||
expectedCount: 2,
|
||||
mustContainIDs: []uuid.UUID{evt1.ID, evt2.ID},
|
||||
}).
|
||||
Return(nil).
|
||||
Times(1)
|
||||
|
||||
require.NoError(t, b.Upsert(ctx, evt1))
|
||||
require.NoError(t, b.Upsert(ctx, evt2))
|
||||
require.NoError(t, b.Close())
|
||||
})
|
||||
|
||||
t.Run("CloseFlushesMultipleEvents", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
log := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
||||
ctrl := gomock.NewController(t)
|
||||
store := dbmock.NewMockStore(ctrl)
|
||||
clock := quartz.NewMock(t)
|
||||
|
||||
b := NewDBBatcher(ctx, store, log, WithClock(clock), WithBatchSize(100))
|
||||
|
||||
evt1 := fakeConnectEvent(uuid.New(), "agent1", uuid.New())
|
||||
evt2 := fakeConnectEvent(uuid.New(), "agent2", uuid.New())
|
||||
|
||||
store.EXPECT().
|
||||
BatchUpsertConnectionLogs(gomock.Any(), batchParamsMatcher{
|
||||
expectedCount: 2,
|
||||
mustContainIDs: []uuid.UUID{evt1.ID, evt2.ID},
|
||||
}).
|
||||
Return(nil).
|
||||
Times(1)
|
||||
|
||||
require.NoError(t, b.Upsert(ctx, evt1))
|
||||
require.NoError(t, b.Upsert(ctx, evt2))
|
||||
require.NoError(t, b.Close())
|
||||
})
|
||||
|
||||
t.Run("RetriesOnTransientFailure", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
log := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
||||
ctrl := gomock.NewController(t)
|
||||
store := dbmock.NewMockStore(ctrl)
|
||||
clock := quartz.NewMock(t)
|
||||
|
||||
scheduledTrap := clock.Trap().TimerReset("connectionLogBatcher", "scheduledFlush")
|
||||
defer scheduledTrap.Close()
|
||||
|
||||
b := NewDBBatcher(ctx, store, log, WithClock(clock), WithBatchSize(100))
|
||||
|
||||
evt := fakeConnectEvent(uuid.New(), "agent1", uuid.New())
|
||||
|
||||
// First call (synchronous in flush) fails, then the
|
||||
// retry worker retries after the backoff and succeeds.
|
||||
gomock.InOrder(
|
||||
store.EXPECT().
|
||||
BatchUpsertConnectionLogs(gomock.Any(), gomock.Any()).
|
||||
Return(xerrors.New("transient error")).
|
||||
Times(1),
|
||||
store.EXPECT().
|
||||
BatchUpsertConnectionLogs(gomock.Any(), batchParamsMatcher{
|
||||
expectedCount: 1,
|
||||
mustContainIDs: []uuid.UUID{evt.ID},
|
||||
}).
|
||||
Return(nil).
|
||||
Times(1),
|
||||
)
|
||||
|
||||
require.NoError(t, b.Upsert(ctx, evt))
|
||||
|
||||
// Trigger a scheduled flush while the batcher is still
|
||||
// running. The synchronous write fails and queues to
|
||||
// retryCh. The retry worker picks it up after a real-
|
||||
// time 1s delay and succeeds.
|
||||
clock.Advance(defaultFlushInterval).MustWait(ctx)
|
||||
scheduledTrap.MustWait(ctx).MustRelease(ctx)
|
||||
|
||||
// Wait for the retry to complete (real-time 1s delay).
|
||||
require.NoError(t, b.Close())
|
||||
})
|
||||
|
||||
t.Run("ShutdownDrainsRetryQueue", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
log := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
||||
ctrl := gomock.NewController(t)
|
||||
store := dbmock.NewMockStore(ctrl)
|
||||
clock := quartz.NewMock(t)
|
||||
|
||||
scheduledTrap := clock.Trap().TimerReset("connectionLogBatcher", "scheduledFlush")
|
||||
defer scheduledTrap.Close()
|
||||
|
||||
b := NewDBBatcher(ctx, store, log, WithClock(clock), WithBatchSize(100))
|
||||
|
||||
evt := fakeConnectEvent(uuid.New(), "agent1", uuid.New())
|
||||
|
||||
// Track all successfully written IDs.
|
||||
var writtenIDs []uuid.UUID
|
||||
var mu sync.Mutex
|
||||
firstCall := true
|
||||
store.EXPECT().
|
||||
BatchUpsertConnectionLogs(gomock.Any(), gomock.Any()).
|
||||
DoAndReturn(func(_ context.Context, p database.BatchUpsertConnectionLogsParams) error {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
// First call (synchronous flush) fails, queueing
|
||||
// the batch for retry.
|
||||
if firstCall {
|
||||
firstCall = false
|
||||
return xerrors.New("transient error")
|
||||
}
|
||||
// Drain/retry attempts succeed.
|
||||
writtenIDs = append(writtenIDs, p.ID...)
|
||||
return nil
|
||||
}).
|
||||
AnyTimes()
|
||||
|
||||
// Send event and trigger flush — fails, queues.
|
||||
require.NoError(t, b.Upsert(ctx, evt))
|
||||
clock.Advance(defaultFlushInterval).MustWait(ctx)
|
||||
scheduledTrap.MustWait(ctx).MustRelease(ctx)
|
||||
|
||||
// Close triggers shutdown. The retry worker drains
|
||||
// retryCh and writes the batch via writeBatch.
|
||||
require.NoError(t, b.Close())
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
require.Contains(t, writtenIDs, evt.ID,
|
||||
"event should be written during shutdown drain")
|
||||
})
|
||||
}
|
||||
|
||||
// batchParamsMatcher validates BatchUpsertConnectionLogsParams by
|
||||
// checking count and specific IDs.
|
||||
type batchParamsMatcher struct {
|
||||
expectedCount int
|
||||
mustContainIDs []uuid.UUID
|
||||
mustNotContainIDs []uuid.UUID
|
||||
}
|
||||
|
||||
func (m batchParamsMatcher) Matches(x interface{}) bool {
|
||||
params, ok := x.(database.BatchUpsertConnectionLogsParams)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
if m.expectedCount > 0 && len(params.ID) != m.expectedCount {
|
||||
return false
|
||||
}
|
||||
idSet := make(map[uuid.UUID]struct{}, len(params.ID))
|
||||
for _, id := range params.ID {
|
||||
idSet[id] = struct{}{}
|
||||
}
|
||||
for _, id := range m.mustContainIDs {
|
||||
if _, ok := idSet[id]; !ok {
|
||||
return false
|
||||
}
|
||||
}
|
||||
for _, id := range m.mustNotContainIDs {
|
||||
if _, ok := idSet[id]; ok {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (batchParamsMatcher) String() string {
|
||||
return "batch upsert params matcher"
|
||||
}
|
||||
|
||||
func fakeConnectEvent(workspaceID uuid.UUID, agentName string, connectionID uuid.UUID) database.UpsertConnectionLogParams {
|
||||
return database.UpsertConnectionLogParams{
|
||||
ID: uuid.New(),
|
||||
Time: time.Now(),
|
||||
OrganizationID: uuid.New(),
|
||||
WorkspaceOwnerID: uuid.New(),
|
||||
WorkspaceID: workspaceID,
|
||||
WorkspaceName: "test-workspace",
|
||||
AgentName: agentName,
|
||||
Type: database.ConnectionTypeSsh,
|
||||
ConnectionID: uuid.NullUUID{UUID: connectionID, Valid: true},
|
||||
ConnectionStatus: database.ConnectionStatusConnected,
|
||||
}
|
||||
}
|
||||
|
||||
func fakeDisconnectEvent(workspaceID uuid.UUID, agentName string, connectionID uuid.UUID) database.UpsertConnectionLogParams {
|
||||
return database.UpsertConnectionLogParams{
|
||||
ID: uuid.New(),
|
||||
Time: time.Now().Add(time.Second),
|
||||
OrganizationID: uuid.New(),
|
||||
WorkspaceOwnerID: uuid.New(),
|
||||
WorkspaceID: workspaceID,
|
||||
WorkspaceName: "test-workspace",
|
||||
AgentName: agentName,
|
||||
Type: database.ConnectionTypeSsh,
|
||||
ConnectionID: uuid.NullUUID{UUID: connectionID, Valid: true},
|
||||
ConnectionStatus: database.ConnectionStatusDisconnected,
|
||||
Code: sql.NullInt32{Int32: 0, Valid: true},
|
||||
DisconnectReason: sql.NullString{String: "normal", Valid: true},
|
||||
}
|
||||
}
|
||||
|
||||
func fakeNullConnIDEvent() database.UpsertConnectionLogParams {
|
||||
return database.UpsertConnectionLogParams{
|
||||
ID: uuid.New(),
|
||||
Time: time.Now(),
|
||||
OrganizationID: uuid.New(),
|
||||
WorkspaceOwnerID: uuid.New(),
|
||||
WorkspaceID: uuid.New(),
|
||||
WorkspaceName: "test-workspace",
|
||||
AgentName: "test-agent",
|
||||
Type: database.ConnectionTypeWorkspaceApp,
|
||||
ConnectionID: uuid.NullUUID{},
|
||||
ConnectionStatus: database.ConnectionStatusConnected,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,371 @@
|
||||
package connectionlog_test
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/sqlc-dev/pqtype"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/database/dbgen"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtime"
|
||||
"github.com/coder/coder/v2/enterprise/coderd/connectionlog"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
func createWorkspace(t *testing.T, db database.Store) database.WorkspaceTable {
|
||||
t.Helper()
|
||||
u := dbgen.User(t, db, database.User{})
|
||||
o := dbgen.Organization(t, db, database.Organization{})
|
||||
tpl := dbgen.Template(t, db, database.Template{
|
||||
OrganizationID: o.ID,
|
||||
CreatedBy: u.ID,
|
||||
})
|
||||
return dbgen.Workspace(t, db, database.WorkspaceTable{
|
||||
ID: uuid.New(),
|
||||
OwnerID: u.ID,
|
||||
OrganizationID: o.ID,
|
||||
AutomaticUpdates: database.AutomaticUpdatesNever,
|
||||
TemplateID: tpl.ID,
|
||||
})
|
||||
}
|
||||
|
||||
func testIP() pqtype.Inet {
|
||||
return pqtype.Inet{
|
||||
IPNet: net.IPNet{
|
||||
IP: net.IPv4(127, 0, 0, 1),
|
||||
Mask: net.IPv4Mask(255, 255, 255, 255),
|
||||
},
|
||||
Valid: true,
|
||||
}
|
||||
}
|
||||
|
||||
func TestDBBackendIntegration(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("SingleConnect", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
log := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
||||
clock := quartz.NewMock(t)
|
||||
|
||||
ws := createWorkspace(t, db)
|
||||
|
||||
//nolint:gocritic // Test needs system context for the batcher.
|
||||
backend := connectionlog.NewDBBatcher(
|
||||
dbauthz.AsConnectionLogger(ctx), db, log,
|
||||
connectionlog.WithClock(clock),
|
||||
connectionlog.WithBatchSize(100),
|
||||
)
|
||||
|
||||
connID := uuid.New()
|
||||
connectTime := dbtime.Now()
|
||||
err := backend.Upsert(ctx, database.UpsertConnectionLogParams{
|
||||
ID: uuid.New(),
|
||||
Time: connectTime,
|
||||
OrganizationID: ws.OrganizationID,
|
||||
WorkspaceOwnerID: ws.OwnerID,
|
||||
WorkspaceID: ws.ID,
|
||||
WorkspaceName: ws.Name,
|
||||
AgentName: "main",
|
||||
Type: database.ConnectionTypeSsh,
|
||||
ConnectionID: uuid.NullUUID{UUID: connID, Valid: true},
|
||||
ConnectionStatus: database.ConnectionStatusConnected,
|
||||
IP: testIP(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = backend.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
rows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{
|
||||
LimitOpt: 10,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, rows, 1)
|
||||
require.Equal(t, connID, rows[0].ConnectionLog.ConnectionID.UUID)
|
||||
require.False(t, rows[0].ConnectionLog.DisconnectTime.Valid)
|
||||
})
|
||||
|
||||
t.Run("ConnectThenDisconnectSeparateBatches", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
log := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
||||
clock := quartz.NewMock(t)
|
||||
|
||||
ws := createWorkspace(t, db)
|
||||
|
||||
connID := uuid.New()
|
||||
connectTime := dbtime.Now()
|
||||
|
||||
// First batcher: insert connect, close to flush.
|
||||
//nolint:gocritic // Test needs system context for the batcher.
|
||||
b1 := connectionlog.NewDBBatcher(
|
||||
dbauthz.AsConnectionLogger(ctx), db, log,
|
||||
connectionlog.WithClock(clock),
|
||||
connectionlog.WithBatchSize(100),
|
||||
)
|
||||
err := b1.Upsert(ctx, database.UpsertConnectionLogParams{
|
||||
ID: uuid.New(),
|
||||
Time: connectTime,
|
||||
OrganizationID: ws.OrganizationID,
|
||||
WorkspaceOwnerID: ws.OwnerID,
|
||||
WorkspaceID: ws.ID,
|
||||
WorkspaceName: ws.Name,
|
||||
AgentName: "main",
|
||||
Type: database.ConnectionTypeSsh,
|
||||
ConnectionID: uuid.NullUUID{UUID: connID, Valid: true},
|
||||
ConnectionStatus: database.ConnectionStatusConnected,
|
||||
IP: testIP(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, b1.Close())
|
||||
|
||||
// Second batcher: insert disconnect, close to flush.
|
||||
//nolint:gocritic // Test needs system context for the batcher.
|
||||
b2 := connectionlog.NewDBBatcher(
|
||||
dbauthz.AsConnectionLogger(ctx), db, log,
|
||||
connectionlog.WithClock(clock),
|
||||
connectionlog.WithBatchSize(100),
|
||||
)
|
||||
disconnectTime := connectTime.Add(5 * time.Second)
|
||||
err = b2.Upsert(ctx, database.UpsertConnectionLogParams{
|
||||
ID: uuid.New(),
|
||||
Time: disconnectTime,
|
||||
OrganizationID: ws.OrganizationID,
|
||||
WorkspaceOwnerID: ws.OwnerID,
|
||||
WorkspaceID: ws.ID,
|
||||
WorkspaceName: ws.Name,
|
||||
AgentName: "main",
|
||||
Type: database.ConnectionTypeSsh,
|
||||
ConnectionID: uuid.NullUUID{UUID: connID, Valid: true},
|
||||
ConnectionStatus: database.ConnectionStatusDisconnected,
|
||||
Code: sql.NullInt32{Int32: 0, Valid: true},
|
||||
DisconnectReason: sql.NullString{String: "client left", Valid: true},
|
||||
IP: testIP(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, b2.Close())
|
||||
|
||||
rows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{
|
||||
LimitOpt: 10,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, rows, 1, "connect+disconnect should produce one row")
|
||||
require.True(t, rows[0].ConnectionLog.DisconnectTime.Valid)
|
||||
require.Equal(t, "client left", rows[0].ConnectionLog.DisconnectReason.String)
|
||||
})
|
||||
|
||||
t.Run("ConnectAndDisconnectSameBatch", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
log := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
||||
clock := quartz.NewMock(t)
|
||||
|
||||
ws := createWorkspace(t, db)
|
||||
|
||||
//nolint:gocritic // Test needs system context for the batcher.
|
||||
backend := connectionlog.NewDBBatcher(
|
||||
dbauthz.AsConnectionLogger(ctx), db, log,
|
||||
connectionlog.WithClock(clock),
|
||||
connectionlog.WithBatchSize(100),
|
||||
)
|
||||
|
||||
connID := uuid.New()
|
||||
connectTime := dbtime.Now()
|
||||
disconnectTime := connectTime.Add(time.Second)
|
||||
|
||||
// Both events in the same batch window.
|
||||
err := backend.Upsert(ctx, database.UpsertConnectionLogParams{
|
||||
ID: uuid.New(),
|
||||
Time: connectTime,
|
||||
OrganizationID: ws.OrganizationID,
|
||||
WorkspaceOwnerID: ws.OwnerID,
|
||||
WorkspaceID: ws.ID,
|
||||
WorkspaceName: ws.Name,
|
||||
AgentName: "main",
|
||||
Type: database.ConnectionTypeSsh,
|
||||
ConnectionID: uuid.NullUUID{UUID: connID, Valid: true},
|
||||
ConnectionStatus: database.ConnectionStatusConnected,
|
||||
IP: testIP(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = backend.Upsert(ctx, database.UpsertConnectionLogParams{
|
||||
ID: uuid.New(),
|
||||
Time: disconnectTime,
|
||||
OrganizationID: ws.OrganizationID,
|
||||
WorkspaceOwnerID: ws.OwnerID,
|
||||
WorkspaceID: ws.ID,
|
||||
WorkspaceName: ws.Name,
|
||||
AgentName: "main",
|
||||
Type: database.ConnectionTypeSsh,
|
||||
ConnectionID: uuid.NullUUID{UUID: connID, Valid: true},
|
||||
ConnectionStatus: database.ConnectionStatusDisconnected,
|
||||
Code: sql.NullInt32{Int32: 0, Valid: true},
|
||||
DisconnectReason: sql.NullString{String: "done", Valid: true},
|
||||
IP: testIP(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Close drains channel and flushes — dedup keeps disconnect.
|
||||
err = backend.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
rows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{
|
||||
LimitOpt: 10,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, rows, 1)
|
||||
require.True(t, rows[0].ConnectionLog.DisconnectTime.Valid)
|
||||
require.Equal(t, "done", rows[0].ConnectionLog.DisconnectReason.String)
|
||||
})
|
||||
|
||||
t.Run("MultipleIndependentConnections", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
log := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
||||
clock := quartz.NewMock(t)
|
||||
|
||||
ws := createWorkspace(t, db)
|
||||
|
||||
//nolint:gocritic // Test needs system context for the batcher.
|
||||
backend := connectionlog.NewDBBatcher(
|
||||
dbauthz.AsConnectionLogger(ctx), db, log,
|
||||
connectionlog.WithClock(clock),
|
||||
connectionlog.WithBatchSize(100),
|
||||
)
|
||||
|
||||
now := dbtime.Now()
|
||||
for i := 0; i < 5; i++ {
|
||||
err := backend.Upsert(ctx, database.UpsertConnectionLogParams{
|
||||
ID: uuid.New(),
|
||||
Time: now,
|
||||
OrganizationID: ws.OrganizationID,
|
||||
WorkspaceOwnerID: ws.OwnerID,
|
||||
WorkspaceID: ws.ID,
|
||||
WorkspaceName: ws.Name,
|
||||
AgentName: "main",
|
||||
Type: database.ConnectionTypeSsh,
|
||||
ConnectionID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
|
||||
ConnectionStatus: database.ConnectionStatusConnected,
|
||||
IP: testIP(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
err := backend.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
rows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{
|
||||
LimitOpt: 10,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, rows, 5)
|
||||
})
|
||||
|
||||
t.Run("NullConnectionIDWebEvents", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
log := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
||||
clock := quartz.NewMock(t)
|
||||
|
||||
ws := createWorkspace(t, db)
|
||||
|
||||
//nolint:gocritic // Test needs system context for the batcher.
|
||||
backend := connectionlog.NewDBBatcher(
|
||||
dbauthz.AsConnectionLogger(ctx), db, log,
|
||||
connectionlog.WithClock(clock),
|
||||
connectionlog.WithBatchSize(100),
|
||||
)
|
||||
|
||||
now := dbtime.Now()
|
||||
for i := 0; i < 2; i++ {
|
||||
err := backend.Upsert(ctx, database.UpsertConnectionLogParams{
|
||||
ID: uuid.New(),
|
||||
Time: now,
|
||||
OrganizationID: ws.OrganizationID,
|
||||
WorkspaceOwnerID: ws.OwnerID,
|
||||
WorkspaceID: ws.ID,
|
||||
WorkspaceName: ws.Name,
|
||||
AgentName: "main",
|
||||
Type: database.ConnectionTypeWorkspaceApp,
|
||||
ConnectionID: uuid.NullUUID{},
|
||||
ConnectionStatus: database.ConnectionStatusConnected,
|
||||
IP: testIP(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
err := backend.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
rows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{
|
||||
LimitOpt: 10,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, rows, 2, "null connection_id events should not be deduplicated")
|
||||
})
|
||||
|
||||
t.Run("CloseFlushesToDB", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
log := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
||||
clock := quartz.NewMock(t)
|
||||
|
||||
ws := createWorkspace(t, db)
|
||||
|
||||
//nolint:gocritic // Test needs system context for the batcher.
|
||||
backend := connectionlog.NewDBBatcher(
|
||||
dbauthz.AsConnectionLogger(ctx), db, log,
|
||||
connectionlog.WithClock(clock),
|
||||
connectionlog.WithBatchSize(100),
|
||||
)
|
||||
|
||||
err := backend.Upsert(ctx, database.UpsertConnectionLogParams{
|
||||
ID: uuid.New(),
|
||||
Time: dbtime.Now(),
|
||||
OrganizationID: ws.OrganizationID,
|
||||
WorkspaceOwnerID: ws.OwnerID,
|
||||
WorkspaceID: ws.ID,
|
||||
WorkspaceName: ws.Name,
|
||||
AgentName: "main",
|
||||
Type: database.ConnectionTypeSsh,
|
||||
ConnectionID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
|
||||
ConnectionStatus: database.ConnectionStatusConnected,
|
||||
IP: testIP(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Close without advancing clock — final flush should write.
|
||||
err = backend.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
rows, err := db.GetConnectionLogsOffset(ctx, database.GetConnectionLogsOffsetParams{
|
||||
LimitOpt: 10,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, rows, 1)
|
||||
})
|
||||
}
|
||||
@@ -227,7 +227,7 @@ func TestConnectionLogs(t *testing.T) {
|
||||
Int32: 0,
|
||||
Valid: false,
|
||||
},
|
||||
Ip: pqtype.Inet{IPNet: net.IPNet{
|
||||
IP: pqtype.Inet{IPNet: net.IPNet{
|
||||
IP: net.ParseIP("192.168.0.1"),
|
||||
Mask: net.CIDRMask(8, 32),
|
||||
}, Valid: true},
|
||||
|
||||
@@ -784,7 +784,7 @@ func TestIssueSignedAppToken(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
require.True(t, connectionLogger.Contains(t, database.UpsertConnectionLogParams{
|
||||
Ip: parsedFakeClientIP,
|
||||
IP: parsedFakeClientIP,
|
||||
}))
|
||||
})
|
||||
|
||||
@@ -812,7 +812,7 @@ func TestIssueSignedAppToken(t *testing.T) {
|
||||
}
|
||||
|
||||
require.True(t, connectionLogger.Contains(t, database.UpsertConnectionLogParams{
|
||||
Ip: parsedFakeClientIP,
|
||||
IP: parsedFakeClientIP,
|
||||
}))
|
||||
})
|
||||
}
|
||||
@@ -1020,7 +1020,7 @@ func TestReconnectingPTYSignedToken(t *testing.T) {
|
||||
// validate it here.
|
||||
|
||||
require.True(t, connectionLogger.Contains(t, database.UpsertConnectionLogParams{
|
||||
Ip: pqtype.Inet{
|
||||
IP: pqtype.Inet{
|
||||
Valid: true, IPNet: net.IPNet{
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
Mask: net.CIDRMask(32, 32),
|
||||
|
||||
Reference in New Issue
Block a user