package querier import ( "context" "database/sql" "sync" "sync/atomic" ) type contextKey string const QuerierContextKey contextKey = "querier" type conn interface { Commit() error Rollback() error Begin() (*sql.Tx, error) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) QueryRow(query string, args ...any) *sql.Row QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row Exec(query string, args ...any) (sql.Result, error) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) } type Querier struct { txRequested atomic.Bool initOnce sync.Once conn conn } func GetQuerierFromContextOrNew(ctx context.Context) *Querier { q, ok := ctx.Value(QuerierContextKey).(*Querier) if !ok { q = &Querier{ txRequested: atomic.Bool{}, initOnce: sync.Once{}, conn: nil, } } return q } func (q *Querier) Begin() *Querier { q.txRequested.Store(true) return q } func (q *Querier) Continue(ctx context.Context, conn *SQLDB) (*Querier, error) { var iErr error q.initOnce.Do(func() { if q.txRequested.Load() { tx, bErr := conn.BeginTx(ctx, nil) if bErr != nil { iErr = bErr return } q.conn = &SQLTx{tx} } else { q.conn = conn } }) return q, iErr } func (q *Querier) Commit() error { return q.conn.Commit() } func (q *Querier) Rollback() error { return q.conn.Rollback() } func (q *Querier) Conn() conn { return q.conn } type SQLTx struct { *sql.Tx } func (tx *SQLTx) Begin() (*sql.Tx, error) { return &sql.Tx{}, nil } func (tx *SQLTx) BeginTx(_ context.Context, _ *sql.TxOptions) (*sql.Tx, error) { return &sql.Tx{}, nil } type SQLDB struct { *sql.DB } func (db *SQLDB) Commit() error { return nil } func (db *SQLDB) Rollback() error { return nil }