Compare commits
1 Commits
pubsub-buf
...
fix/nested
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
843d708470 |
@@ -774,7 +774,7 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
|
||||
options.PrometheusRegistry.MustRegister(collectors.NewDBStatsCollector(sqlDB, ""))
|
||||
}
|
||||
|
||||
options.Database = database.New(sqlDB)
|
||||
options.Database = database.New(sqlDB, database.WithLogger(logger.Named("database")))
|
||||
ps, err := pubsub.New(ctx, logger.Named("pubsub"), sqlDB, dbURL)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create pubsub: %w", err)
|
||||
|
||||
@@ -15,6 +15,8 @@ import (
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
)
|
||||
|
||||
// Store contains all queryable database functions.
|
||||
@@ -54,6 +56,12 @@ func WithSerialRetryCount(count int) func(*sqlQuerier) {
|
||||
}
|
||||
}
|
||||
|
||||
func WithLogger(logger slog.Logger) func(*sqlQuerier) {
|
||||
return func(q *sqlQuerier) {
|
||||
q.logger = logger
|
||||
}
|
||||
}
|
||||
|
||||
// New creates a new database store using a SQL database connection.
|
||||
func New(sdb *sql.DB, opts ...func(*sqlQuerier)) Store {
|
||||
dbx := sqlx.NewDb(sdb, "postgres")
|
||||
@@ -117,9 +125,17 @@ type sqlQuerier struct {
|
||||
sdb *sqlx.DB
|
||||
db DBTX
|
||||
|
||||
// logger is used for critical logging when nested transactions
|
||||
// request stricter isolation levels than the outer transaction.
|
||||
logger slog.Logger
|
||||
|
||||
// serialRetryCount is the number of times to retry a transaction
|
||||
// if it fails with a serialization error.
|
||||
serialRetryCount int
|
||||
|
||||
// currentIsolation tracks the isolation level of the current
|
||||
// transaction, used to detect mismatched nested transactions.
|
||||
currentIsolation sql.IsolationLevel
|
||||
}
|
||||
|
||||
func (*sqlQuerier) Wrappers() []string {
|
||||
@@ -181,12 +197,25 @@ func (q *sqlQuerier) InTx(function func(Store) error, txOpts *TxOptions) error {
|
||||
return q.runTx(function, sqlOpts)
|
||||
}
|
||||
|
||||
// InTx performs database operations inside a transaction.
|
||||
// runTx performs database operations inside a transaction.
|
||||
func (q *sqlQuerier) runTx(function func(Store) error, txOpts *sql.TxOptions) error {
|
||||
if _, ok := q.db.(*sqlx.Tx); ok {
|
||||
// If the current inner "db" is already a transaction, we just reuse it.
|
||||
// We do not need to handle commit/rollback as the outer tx will handle
|
||||
// that.
|
||||
//
|
||||
// Check if the requested isolation level is stricter than the
|
||||
// current transaction's isolation level. If so, log a critical
|
||||
// error because the caller's correctness expectations cannot be
|
||||
// met by the weaker outer isolation level.
|
||||
if txOpts.Isolation > q.currentIsolation {
|
||||
q.logger.Critical(
|
||||
context.Background(),
|
||||
"nested transaction requested stricter isolation level than outer",
|
||||
slog.F("outer_isolation_level", q.currentIsolation.String()),
|
||||
slog.F("requested_isolation_level", txOpts.Isolation.String()),
|
||||
)
|
||||
}
|
||||
err := function(q)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("execute transaction: %w", err)
|
||||
@@ -207,7 +236,11 @@ func (q *sqlQuerier) runTx(function func(Store) error, txOpts *sql.TxOptions) er
|
||||
// couldn't roll back for some reason, extend returned error
|
||||
err = xerrors.Errorf("defer (%s): %w", rerr.Error(), err)
|
||||
}()
|
||||
err = function(&sqlQuerier{db: transaction})
|
||||
err = function(&sqlQuerier{
|
||||
db: transaction,
|
||||
logger: q.logger,
|
||||
currentIsolation: txOpts.Isolation,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("execute transaction: %w", err)
|
||||
}
|
||||
|
||||
@@ -3,12 +3,14 @@ package database_test
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/lib/pq"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtime"
|
||||
@@ -94,3 +96,92 @@ func testSQLDB(t testing.TB) *sql.DB {
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
func TestNestedInTxStricterIsolation(t *testing.T) {
|
||||
t.Parallel()
|
||||
if testing.Short() {
|
||||
t.SkipNow()
|
||||
}
|
||||
|
||||
t.Run("CriticalLogOnStricterNested", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
sqlDB := testSQLDB(t)
|
||||
sink := &capturedSink{}
|
||||
logger := slog.Make(sink)
|
||||
db := database.New(sqlDB, database.WithLogger(logger))
|
||||
|
||||
err := db.InTx(func(outer database.Store) error {
|
||||
return outer.InTx(func(_ database.Store) error {
|
||||
return nil
|
||||
}, &database.TxOptions{Isolation: sql.LevelSerializable})
|
||||
}, &database.TxOptions{Isolation: sql.LevelReadCommitted})
|
||||
require.NoError(t, err)
|
||||
|
||||
entries := sink.entries()
|
||||
require.Len(t, entries, 1, "expected exactly one critical log entry")
|
||||
require.Equal(t, slog.LevelCritical, entries[0].Level)
|
||||
require.Contains(t, entries[0].Message, "nested transaction requested stricter isolation level")
|
||||
})
|
||||
|
||||
t.Run("NoCriticalLogOnSameIsolation", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
sqlDB := testSQLDB(t)
|
||||
sink := &capturedSink{}
|
||||
logger := slog.Make(sink)
|
||||
db := database.New(sqlDB, database.WithLogger(logger))
|
||||
|
||||
err := db.InTx(func(outer database.Store) error {
|
||||
return outer.InTx(func(_ database.Store) error {
|
||||
return nil
|
||||
}, &database.TxOptions{Isolation: sql.LevelSerializable})
|
||||
}, &database.TxOptions{Isolation: sql.LevelSerializable})
|
||||
require.NoError(t, err)
|
||||
|
||||
entries := sink.entries()
|
||||
require.Empty(t, entries, "should not log when isolation levels match")
|
||||
})
|
||||
|
||||
t.Run("NoCriticalLogOnWeakerNested", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
sqlDB := testSQLDB(t)
|
||||
sink := &capturedSink{}
|
||||
logger := slog.Make(sink)
|
||||
db := database.New(sqlDB, database.WithLogger(logger))
|
||||
|
||||
err := db.InTx(func(outer database.Store) error {
|
||||
return outer.InTx(func(_ database.Store) error {
|
||||
return nil
|
||||
}, &database.TxOptions{Isolation: sql.LevelReadCommitted})
|
||||
}, &database.TxOptions{Isolation: sql.LevelSerializable})
|
||||
require.NoError(t, err)
|
||||
|
||||
entries := sink.entries()
|
||||
require.Empty(t, entries, "should not log when nested requests weaker isolation")
|
||||
})
|
||||
}
|
||||
|
||||
// capturedSink is a slog.Sink that captures log entries for assertion
|
||||
// in tests.
|
||||
type capturedSink struct {
|
||||
mu sync.Mutex
|
||||
logs []slog.SinkEntry
|
||||
}
|
||||
|
||||
func (s *capturedSink) LogEntry(_ context.Context, e slog.SinkEntry) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.logs = append(s.logs, e)
|
||||
}
|
||||
|
||||
func (s *capturedSink) Sync() {}
|
||||
|
||||
func (s *capturedSink) entries() []slog.SinkEntry {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
dst := make([]slog.SinkEntry, len(s.logs))
|
||||
copy(dst, s.logs)
|
||||
return dst
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user