Skip to content

Commit

Permalink
handle rollback error
Browse files Browse the repository at this point in the history
  • Loading branch information
yinloo-ola committed Dec 14, 2023
1 parent 273bc8a commit a18c33f
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 63 deletions.
35 changes: 20 additions & 15 deletions coredb/scan.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,8 @@ func RowToStruct[T any](row *sql.Row) (result *T, err error) {
result = new(T)
data := StrutForScan(result)
err = row.Scan(data...)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, err
if err == sql.ErrNoRows {
err = nil
}
return
}
Expand All @@ -74,45 +71,53 @@ func StrutForScan(u any) (pointers []any) {
}

func RowsToStructSliceReflect(rows *sql.Rows, out any) (err error) {
if rows == nil {
return
}
sliceValue := reflect.ValueOf(out)
if sliceValue.Kind() != reflect.Ptr || sliceValue.IsNil() {
return &InvalidScanError{Type: sliceValue.Type()}
panic(&InvalidScanError{Type: sliceValue.Type()})
}
sliceValue = sliceValue.Elem()
if sliceValue.Kind() != reflect.Slice {
return &InvalidScanError{Type: reflect.TypeOf(out)}
panic(&InvalidScanError{Type: reflect.TypeOf(out)})
}
elementType := sliceValue.Type().Elem()
if elementType.Kind() != reflect.Ptr {
return &InvalidScanError{Type: reflect.TypeOf(out)}
panic(&InvalidScanError{Type: reflect.TypeOf(out)})
}
elementType = elementType.Elem()

var elements []reflect.Value
for rows.Next() {
v := reflect.New(elementType)
data := StrutForScan(v.Interface())
err = rows.Scan(data...)
if err != nil {
return
}
sliceValue.Set(reflect.Append(sliceValue, v))
elements = append(elements, v.Elem())
}

sliceValue.Set(reflect.MakeSlice(sliceValue.Type(), len(elements), len(elements)))
for i, v := range elements {
sliceValue.Index(i).Set(v.Addr())
}

err = rows.Err()
return
}

func RowToStructReflect(row *sql.Row, v any) (err error) {
rv := reflect.ValueOf(v)
if rv.Kind() != reflect.Pointer || rv.IsNil() {
return &InvalidScanError{reflect.TypeOf(v)}
panic(&InvalidScanError{reflect.TypeOf(v)})
}

data := StrutForScan(v)
err = row.Scan(data...)
if err != nil {
if err == sql.ErrNoRows {
return nil
}
return err
if err == sql.ErrNoRows {
return nil
}
return
}
31 changes: 18 additions & 13 deletions coredb/tx.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"database/sql"
"errors"
"fmt"
"log"
)

// BeginTx returns a custom db.Tx based on opts. This method exists for flexibility.
Expand Down Expand Up @@ -144,30 +143,36 @@ func (t *TxProvider) acquireWithOpts(ctx context.Context, opts *sql.TxOptions) (
}

// TxWithOpts ...
func (t *TxProvider) TxWithOpts(ctx context.Context, fn func(TxContext) error, opts *sql.TxOptions) error {
tx, err := t.acquireWithOpts(ctx, opts)
func (t *TxProvider) TxWithOpts(ctx context.Context, fn func(TxContext) error, opts *sql.TxOptions) (err error) {
var trx *tx
trx, err = t.acquireWithOpts(ctx, opts)
if err != nil {
return err
}

defer func() {
//nolint:gocritic
if r := recover(); r != nil {
log.Printf("Recovering from panic in TxWithOpts error is: %v \n", r)
_ = tx.Rollback()
err, _ = r.(error)
_ = trx.Rollback()
var ok bool
err, ok = r.(error)
if !ok {
err = fmt.Errorf("%v", r)
}
} else if err != nil {
err = tx.Rollback()
errRollback := trx.Rollback()
if errors.Is(errRollback, sql.ErrTxDone) && ctx.Err() != nil {
errRollback = nil
}
if errRollback != nil {
err = fmt.Errorf("%v encountered. but rollback failed: %w", err, errRollback)
}
} else {
err = tx.Commit()
}

if ctx.Err() != nil && errors.Is(err, context.DeadlineExceeded) {
log.Printf("query response time exceeded the configured timeout")
err = trx.Commit()
}
}()

err = fn(tx)
err = fn(trx)

return err
}
Expand Down
102 changes: 67 additions & 35 deletions tests/tx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,90 +4,122 @@ import (
"context"
"database/sql"
"errors"
"fmt"
"reflect"
"strings"
"testing"
"time"

_ "github.com/go-sql-driver/mysql"
"github.com/stretchr/testify/assert"

"github.com/olachat/gola/v2/coredb"
"github.com/olachat/gola/v2/golalib/testdata/worker"
)

func TestBeginTx(t *testing.T) {
as := assert.New(t)
func ExampleNewTxProvider() {

prov := coredb.NewTxProvider("newdb")
err := prov.Tx(context.Background(), func(tx coredb.TxContext) error {
_, err := tx.Exec("truncate table worker")
as.Nil(err)
panicOnErr(err)

var workers []*worker.Worker
err = tx.Find(&workers, "worker", "where id > ?", 0)
as.Nil(err)
as.Equal(0, len(workers))
panicOnErr(err)
mustEqual(0, len(workers))
fmt.Println("no of workers:", len(workers)) // uncomment to run test
// Output: no of workers: 0

_, err = tx.Exec("insert into worker (name,age) values (?, ?)", "peter", 18)
as.Nil(err)
panicOnErr(err)

_, err = tx.Exec("insert into worker (name,age) values (?, ?)", "john", 28)
as.Nil(err)
panicOnErr(err)
return err
})
as.Nil(err)
panicOnErr(err)

err = prov.Tx(context.Background(), func(tx coredb.TxContext) error {
var workers []*worker.Worker
err := tx.Find(&workers, "worker", "where id > ?", 0)
as.Nil(err)
as.Equal(2, len(workers))
as.Equal("peter", workers[0].GetName())
as.Equal(18, workers[0].GetAge())
as.Equal("john", workers[1].GetName())
as.Equal(28, workers[1].GetAge())
panicOnErr(err)
mustEqual(2, len(workers))
mustEqual("peter", workers[0].GetName())
mustEqual(18, workers[0].GetAge())
mustEqual("john", workers[1].GetName())
mustEqual(28, workers[1].GetAge())

var w worker.Worker
err = tx.FindOne(&w, "worker", "where id = ?", 1)
as.Nil(err)
as.Equal("peter", w.GetName())
as.Equal(18, w.GetAge())
panicOnErr(err)
mustEqual("peter", w.GetName())
mustEqual(18, w.GetAge())

r, err := tx.QueryInt("select count(1) from worker")
as.Nil(err)
as.Equal(2, r)
panicOnErr(err)
mustEqual(2, r)

var workers2 []*worker.Worker
err = tx.Query(&workers2, "select * from worker where id > ?", 0)
as.Nil(err)
as.Equal(2, len(workers2))
as.Equal("peter", workers2[0].GetName())
as.Equal(18, workers2[0].GetAge())
as.Equal("john", workers2[1].GetName())
as.Equal(28, workers2[1].GetAge())
panicOnErr(err)
mustEqual(2, len(workers2))
mustEqual("peter", workers2[0].GetName())
mustEqual(18, workers2[0].GetAge())
mustEqual("john", workers2[1].GetName())
mustEqual(28, workers2[1].GetAge())
return nil
})

prov.Tx(context.Background(), func(tx coredb.TxContext) error {
_, err := tx.Exec("insert into worker (name,age) values (?, ?)", "winson", 19)
as.Nil(err)
panicOnErr(err)

return errors.New("abort")
})

prov.Tx(context.Background(), func(tx coredb.TxContext) error {
var w []*worker.Worker
err := tx.Find(&w, "worker", "where id > ?", 0)
as.Nil(err)
as.Equal(2, len(w))
as.Equal("peter", w[0].GetName())
as.Equal(18, w[0].GetAge())
as.Equal("john", w[1].GetName())
as.Equal(28, w[1].GetAge())
panicOnErr(err)
mustEqual(2, len(w))
mustEqual("peter", w[0].GetName())
mustEqual(18, w[0].GetAge())
mustEqual("john", w[1].GetName())
mustEqual(28, w[1].GetAge())
return nil
})
as.Nil(err)
panicOnErr(err)

prov2 := coredb.NewTxProvider("newdb")
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
defer cancel()
err = prov2.Tx(ctx, func(tx coredb.TxContext) error {
_, err := tx.Exec("insert into worker (name,age) values (?, ?)", "winson", 19)
if err != nil {
return err
}
var w []*worker.Worker
time.Sleep(10 * time.Millisecond)
err = tx.Find(&w, "worker", "where age = ?", 28)
if err != nil {
return err
}
return nil
})
if !errors.Is(err, context.DeadlineExceeded) {
panic(err)
}

}

func panicOnErr(err error) {
if err != nil {
panic(err)
}
}
func mustEqual(a, b interface{}) {
if !reflect.DeepEqual(a, b) {
panic(fmt.Sprintf("%v != %v", a, b))
}
}

func open() (db *sql.DB, err error) {
Expand Down

0 comments on commit a18c33f

Please sign in to comment.