forked from ebhomengo/niki
104 lines
1.8 KiB
Go
104 lines
1.8 KiB
Go
package querier
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"sync"
|
|
"sync/atomic"
|
|
)
|
|
|
|
type contextKey string
|
|
|
|
const QuerierContextKey contextKey = "querier"
|
|
|
|
type conn interface {
|
|
Commit() error
|
|
Rollback() error
|
|
Begin() (*sql.Tx, error)
|
|
BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
|
|
QueryRow(query string, args ...any) *sql.Row
|
|
QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
|
|
Exec(query string, args ...any) (sql.Result, error)
|
|
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
|
|
}
|
|
|
|
type Querier struct {
|
|
txRequested atomic.Bool
|
|
initOnce sync.Once
|
|
conn conn
|
|
}
|
|
|
|
func GetQuerierFromContextOrNew(ctx context.Context) *Querier {
|
|
q, ok := ctx.Value(QuerierContextKey).(*Querier)
|
|
if !ok {
|
|
q = &Querier{
|
|
txRequested: atomic.Bool{},
|
|
initOnce: sync.Once{},
|
|
conn: nil,
|
|
}
|
|
}
|
|
|
|
return q
|
|
}
|
|
|
|
func (q *Querier) Begin() *Querier {
|
|
q.txRequested.Store(true)
|
|
|
|
return q
|
|
}
|
|
|
|
func (q *Querier) Continue(ctx context.Context, conn *SQLDB) (*Querier, error) {
|
|
var iErr error
|
|
q.initOnce.Do(func() {
|
|
if q.txRequested.Load() {
|
|
tx, bErr := conn.BeginTx(ctx, nil)
|
|
if bErr != nil {
|
|
iErr = bErr
|
|
|
|
return
|
|
}
|
|
q.conn = &SQLTx{tx}
|
|
} else {
|
|
q.conn = conn
|
|
}
|
|
})
|
|
|
|
return q, iErr
|
|
}
|
|
|
|
func (q *Querier) Commit() error {
|
|
return q.conn.Commit()
|
|
}
|
|
|
|
func (q *Querier) Rollback() error {
|
|
return q.conn.Rollback()
|
|
}
|
|
|
|
func (q *Querier) Conn() conn {
|
|
return q.conn
|
|
}
|
|
|
|
type SQLTx struct {
|
|
*sql.Tx
|
|
}
|
|
|
|
func (tx *SQLTx) Begin() (*sql.Tx, error) {
|
|
return &sql.Tx{}, nil
|
|
}
|
|
|
|
func (tx *SQLTx) BeginTx(_ context.Context, _ *sql.TxOptions) (*sql.Tx, error) {
|
|
return &sql.Tx{}, nil
|
|
}
|
|
|
|
type SQLDB struct {
|
|
*sql.DB
|
|
}
|
|
|
|
func (db *SQLDB) Commit() error {
|
|
return nil
|
|
}
|
|
|
|
func (db *SQLDB) Rollback() error {
|
|
return nil
|
|
}
|