forked from ebhomengo/niki
1
0
Fork 0
niki/pkg/query_transaction/sql/querier.go

104 lines
1.8 KiB
Go
Raw Normal View History

package querier
import (
"context"
"database/sql"
"sync"
"sync/atomic"
)
type contextKey string
const QuerierContextKey contextKey = "querier"
type conn interface {
Commit() error
Rollback() error
Begin() (*sql.Tx, error)
BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
QueryRow(query string, args ...any) *sql.Row
QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
Exec(query string, args ...any) (sql.Result, error)
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
}
type Querier struct {
txRequested atomic.Bool
initOnce sync.Once
conn conn
}
func GetQuerierFromContextOrNew(ctx context.Context) *Querier {
q, ok := ctx.Value(QuerierContextKey).(*Querier)
if !ok {
q = &Querier{
txRequested: atomic.Bool{},
initOnce: sync.Once{},
conn: nil,
}
}
return q
}
func (q *Querier) Begin() *Querier {
q.txRequested.Store(true)
return q
}
func (q *Querier) Continue(ctx context.Context, conn *SQLDB) (*Querier, error) {
var iErr error
q.initOnce.Do(func() {
if q.txRequested.Load() {
tx, bErr := conn.BeginTx(ctx, nil)
if bErr != nil {
iErr = bErr
return
}
q.conn = &SQLTx{tx}
} else {
2024-07-17 11:04:42 +00:00
q.conn = conn
}
})
return q, iErr
}
func (q *Querier) Commit() error {
return q.conn.Commit()
}
func (q *Querier) Rollback() error {
return q.conn.Rollback()
}
func (q *Querier) Conn() conn {
return q.conn
}
type SQLTx struct {
*sql.Tx
}
func (tx *SQLTx) Begin() (*sql.Tx, error) {
return &sql.Tx{}, nil
}
func (tx *SQLTx) BeginTx(_ context.Context, _ *sql.TxOptions) (*sql.Tx, error) {
return &sql.Tx{}, nil
}
type SQLDB struct {
*sql.DB
}
func (db *SQLDB) Commit() error {
return nil
}
func (db *SQLDB) Rollback() error {
return nil
}