From 8d5204a0529b6cefba6a802e141f4b5d921c9680 Mon Sep 17 00:00:00 2001 From: tanyinloo Date: Tue, 12 Dec 2023 17:19:13 +0800 Subject: [PATCH 1/8] deprecate properly --- coredb/engine.go | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/coredb/engine.go b/coredb/engine.go index 72b2743..954e4cd 100644 --- a/coredb/engine.go +++ b/coredb/engine.go @@ -43,6 +43,7 @@ func getDB(dbname string, mode DBMode) *sql.DB { } // FetchByPK returns a row of T type with given primary key value +// // Deprecated: use the function with context func FetchByPK[T any](dbname string, tableName string, pkName []string, val ...any) *T { sql := "WHERE `" + pkName[0] + "` = ?" @@ -54,6 +55,7 @@ func FetchByPK[T any](dbname string, tableName string, pkName []string, val ...a } // FetchByPKs returns rows of T type with given primary key values +// // Deprecated: use the function with context func FetchByPKs[T any](dbname string, tableName string, pkName string, vals []any) []*T { if len(vals) == 0 { @@ -71,6 +73,7 @@ func FetchByPKs[T any](dbname string, tableName string, pkName string, vals []an } // FetchByPKFromMaster returns a row of T type with given primary key value +// // Deprecated: use the function with context func FetchByPKFromMaster[T any](dbname string, tableName string, pkName []string, val ...any) *T { sql := "WHERE `" + pkName[0] + "` = ?" @@ -82,6 +85,7 @@ func FetchByPKFromMaster[T any](dbname string, tableName string, pkName []string } // FetchByPKsFromMaster returns rows of T type with given primary key values +// // Deprecated: use the function with context func FetchByPKsFromMaster[T any](dbname string, tableName string, pkName string, vals []any) []*T { if len(vals) == 0 { @@ -99,6 +103,7 @@ func FetchByPKsFromMaster[T any](dbname string, tableName string, pkName string, } // Exec given query with given db info & params +// // Deprecated: use the function with context func Exec(dbname string, query string, params ...any) (sql.Result, error) { mydb := getDB(dbname, DBModeWrite) @@ -106,6 +111,7 @@ func Exec(dbname string, query string, params ...any) (sql.Result, error) { } // FindOne returns a row from given table type with where query +// // Deprecated: use the function with context func FindOne[T any](dbname string, tableName string, where WhereQuery) *T { u := new(T) @@ -131,6 +137,7 @@ func FindOne[T any](dbname string, tableName string, where WhereQuery) *T { } // Find returns rows from given table type with where query +// // Deprecated: use the function with context func Find[T any](dbname string, tableName string, where WhereQuery) ([]*T, error) { columnsNames := GetColumnsNames[T]() @@ -142,6 +149,7 @@ func Find[T any](dbname string, tableName string, where WhereQuery) ([]*T, error } // FindOneFromMaster using master DB returns a row from given table type with where query +// // Deprecated: use the function with context func FindOneFromMaster[T any](dbname string, tableName string, where WhereQuery) *T { u := new(T) @@ -167,6 +175,7 @@ func FindOneFromMaster[T any](dbname string, tableName string, where WhereQuery) } // FindFromMaster using master DB returns rows from given table type with where query +// // Deprecated: use the function with context func FindFromMaster[T any](dbname string, tableName string, where WhereQuery) ([]*T, error) { columnsNames := GetColumnsNames[T]() @@ -178,6 +187,7 @@ func FindFromMaster[T any](dbname string, tableName string, where WhereQuery) ([ } // QueryInt single int result by query, handy for count(*) querys +// // Deprecated: use the function with context func QueryInt(dbname string, query string, params ...any) (result int, err error) { mydb := getDB(dbname, DBModeRead) @@ -186,6 +196,7 @@ func QueryInt(dbname string, query string, params ...any) (result int, err error } // QueryIntFromMaster single int result by query, handy for count(*) querys +// // Deprecated: use the function with context func QueryIntFromMaster(dbname string, query string, params ...any) (result int, err error) { mydb := getDB(dbname, DBModeReadFromWrite) @@ -194,6 +205,7 @@ func QueryIntFromMaster(dbname string, query string, params ...any) (result int, } // Query rows from given table type with where query & params +// // Deprecated: use the function with context func Query[T any](dbname string, query string, params ...any) (result []*T, err error) { mydb := getDB(dbname, DBModeRead) @@ -217,6 +229,7 @@ func Query[T any](dbname string, query string, params ...any) (result []*T, err } // Query rows from master DB from given table type with where query & params +// // Deprecated: use the function with context func QueryFromMaster[T any](dbname string, query string, params ...any) (result []*T, err error) { mydb := getDB(dbname, DBModeReadFromWrite) From d01ad804f9a361998e4a3d5304e7367bbc770bc1 Mon Sep 17 00:00:00 2001 From: tanyinloo Date: Tue, 12 Dec 2023 21:50:45 +0800 Subject: [PATCH 2/8] support transaction --- coredb/engine.go | 50 ++- coredb/engine_ctx.go | 4 +- coredb/scan.go | 122 ++++++ coredb/tx.go | 158 +++++++ golalib/golalib_test.go | 1 + golalib/testdata/worker.sql | 5 + golalib/testdata/worker/worker.go | 485 ++++++++++++++++++++++ golalib/testdata/worker/worker_ctx.go | 258 ++++++++++++ golalib/testdata/worker/worker_idx.go | 105 +++++ golalib/testdata/worker/worker_idx_ctx.go | 28 ++ tests/tx_test.go | 114 +++++ tests/user_test.go | 9 + 12 files changed, 1329 insertions(+), 10 deletions(-) create mode 100644 coredb/scan.go create mode 100644 coredb/tx.go create mode 100644 golalib/testdata/worker.sql create mode 100644 golalib/testdata/worker/worker.go create mode 100644 golalib/testdata/worker/worker_ctx.go create mode 100644 golalib/testdata/worker/worker_idx.go create mode 100644 golalib/testdata/worker/worker_idx_ctx.go create mode 100644 tests/tx_test.go diff --git a/coredb/engine.go b/coredb/engine.go index 954e4cd..5ff8a5d 100644 --- a/coredb/engine.go +++ b/coredb/engine.go @@ -191,7 +191,7 @@ func FindFromMaster[T any](dbname string, tableName string, where WhereQuery) ([ // Deprecated: use the function with context func QueryInt(dbname string, query string, params ...any) (result int, err error) { mydb := getDB(dbname, DBModeRead) - mydb.QueryRow(query, params...).Scan(&result) + err = mydb.QueryRow(query, params...).Scan(&result) return } @@ -200,7 +200,7 @@ func QueryInt(dbname string, query string, params ...any) (result int, err error // Deprecated: use the function with context func QueryIntFromMaster(dbname string, query string, params ...any) (result int, err error) { mydb := getDB(dbname, DBModeReadFromWrite) - mydb.QueryRow(query, params...).Scan(&result) + err = mydb.QueryRow(query, params...).Scan(&result) return } @@ -257,7 +257,8 @@ func GetColumnsNames[T any]() (joinedColumnNames string) { var o *T t := reflect.TypeOf(o) typeColumnNamesLock.RLock() - joinedColumnNames, ok := typeColumnNames[t] + var ok bool + joinedColumnNames, ok = typeColumnNames[t] typeColumnNamesLock.RUnlock() if ok { return @@ -282,15 +283,48 @@ func GetColumnsNames[T any]() (joinedColumnNames string) { return } -// StrutForScan returns value pointers of given obj -func StrutForScan[T any](u *T) (pointers []any) { - val := reflect.ValueOf(u).Elem() - pointers = make([]any, 0, val.NumField()) +// GetColumnsNamesReflect returns column names joined by `,` of given type +func GetColumnsNamesReflect(o any) (joinedColumnNames string) { + t := reflect.TypeOf(o) + elemType := t.Elem() + switch t.Kind() { + case reflect.Ptr: + // 是指针,尝试获取其指向的元素类型 + if elemType.Kind() == reflect.Slice { + elemType = elemType.Elem().Elem() // 如果指针指向切片,获取切片元素的类型 + } + case reflect.Slice: + // 是切片,获取切片元素的类型 + if elemType.Kind() == reflect.Ptr { + elemType = elemType.Elem() // 切片元素是指针,获取其指向的类型 + } + default: + // 既不是指针也不是切片,返回错误 + panic("coredb: o is neither a pointer nor a slice") + } + + typeColumnNamesLock.RLock() + var ok bool + joinedColumnNames, ok = typeColumnNames[elemType] + typeColumnNamesLock.RUnlock() + if ok { + return + } + + var columnNames []string + val := reflect.New(elemType).Elem() for i := 0; i < val.NumField(); i++ { valueField := val.Field(i) if f, ok := valueField.Addr().Interface().(ColumnType); ok { - pointers = append(pointers, f.GetValPointer()) + columnNames = append(columnNames, "`"+f.GetColumnName()+"`") } } + + joinedColumnNames = strings.Join(columnNames, ",") + + typeColumnNamesLock.Lock() + typeColumnNames[t] = joinedColumnNames + typeColumnNamesLock.Unlock() + return } diff --git a/coredb/engine_ctx.go b/coredb/engine_ctx.go index 49d44ee..c0b6efd 100644 --- a/coredb/engine_ctx.go +++ b/coredb/engine_ctx.go @@ -129,14 +129,14 @@ func FindFromMasterCtx[T any](ctx context.Context, dbname string, tableName stri // QueryIntCtx single int result by query, handy for count(*) querys func QueryIntCtx(ctx context.Context, dbname string, query string, params ...any) (result int, err error) { mydb := getDB(dbname, DBModeRead) - mydb.QueryRowContext(ctx, query, params...).Scan(&result) + err = mydb.QueryRowContext(ctx, query, params...).Scan(&result) return } // QueryIntFromMasterCtx single int result by query, handy for count(*) querys func QueryIntFromMasterCtx(ctx context.Context, dbname string, query string, params ...any) (result int, err error) { mydb := getDB(dbname, DBModeReadFromWrite) - mydb.QueryRowContext(ctx, query, params...).Scan(&result) + err = mydb.QueryRowContext(ctx, query, params...).Scan(&result) return } diff --git a/coredb/scan.go b/coredb/scan.go new file mode 100644 index 0000000..d6ba429 --- /dev/null +++ b/coredb/scan.go @@ -0,0 +1,122 @@ +package coredb + +import ( + "database/sql" + "reflect" +) + +// An InvalidScanError describes an invalid argument passed to Scan. +type InvalidScanError struct { + Type reflect.Type +} + +func (e *InvalidScanError) Error() string { + if e.Type == nil { + return "coredb: target is nil" + } + + if e.Type.Kind() != reflect.Pointer { + return "coredb: target must be a non-nil pointer, got " + e.Type.String() + } + return "coredb: nil " + e.Type.String() + ")" +} + +// RowsToStructSlice converts the rows of a SQL query result into a slice of structs. +// +// It takes a pointer to a sql.Rows object as input. +// The function also uses a generic type T, which represents the type of the struct. +// +// The function returns a slice of pointers to T structs and an error. +func RowsToStructSlice[T any](rows *sql.Rows) (result []*T, err error) { + var u *T + for rows.Next() { + u = new(T) + data := StrutForScan(u) + err = rows.Scan(data...) + if err != nil { + return + } + result = append(result, u) + } + return +} + +// RowToStruct converts a database row into a struct. +// +// It takes a pointer to a sql.Row and returns a pointer to the converted struct and an error. +func RowToStruct[T any](row *sql.Row) (result *T, err error) { + result = new(T) + data := StrutForScan(result) + err = row.Scan(data...) + if err != nil { + if err == sql.ErrNoRows { + return nil, nil + } + return nil, err + } + return +} + +// StrutForScan returns value pointers of given obj +func StrutForScan(u any) (pointers []any) { + val := reflect.ValueOf(u) + if val.Kind() != reflect.Pointer || val.IsNil() { + err := &InvalidScanError{reflect.TypeOf(u)} + panic(err) + } + + val = val.Elem() + + pointers = make([]any, 0, val.NumField()) + for i := 0; i < val.NumField(); i++ { + valueField := val.Field(i) + if f, ok := valueField.Addr().Interface().(ColumnType); ok { + pointers = append(pointers, f.GetValPointer()) + } + } + return +} + +func RowsToStructSliceReflect(rows *sql.Rows, out any) (err error) { + sliceValue := reflect.ValueOf(out) + if sliceValue.Kind() != reflect.Ptr || sliceValue.IsNil() { + return &InvalidScanError{Type: sliceValue.Type()} + } + sliceValue = sliceValue.Elem() + if sliceValue.Kind() != reflect.Slice { + return &InvalidScanError{Type: reflect.TypeOf(out)} + } + elementType := sliceValue.Type().Elem() + if elementType.Kind() != reflect.Ptr { + return &InvalidScanError{Type: reflect.TypeOf(out)} + } + elementType = elementType.Elem() + + for rows.Next() { + v := reflect.New(elementType) + data := StrutForScan(v.Interface()) + err = rows.Scan(data...) + if err != nil { + return + } + sliceValue.Set(reflect.Append(sliceValue, v)) + } + return +} + +func RowToStructReflect(row *sql.Row, v any) (err error) { + rv := reflect.ValueOf(v) + if rv.Kind() != reflect.Pointer || rv.IsNil() { + return &InvalidScanError{reflect.TypeOf(v)} + } + + data := StrutForScan(v) + err = row.Scan(data...) + if err != nil { + if err == sql.ErrNoRows { + return nil + } + return err + } + return +} diff --git a/coredb/tx.go b/coredb/tx.go new file mode 100644 index 0000000..6d2f1db --- /dev/null +++ b/coredb/tx.go @@ -0,0 +1,158 @@ +package coredb + +import ( + "context" + "database/sql" + "errors" + "fmt" + "log" +) + +// BeginTx returns a custom db.Tx based on opts. This method exists for flexibility. +// Make sure you call Commit or Rollback on the returned Tx. +// Refer to https://go.dev/doc/database/execute-transactions on how to use the returned Tx. +func BeginTx(ctx context.Context, dbname string, opts *sql.TxOptions) (tx *sql.Tx, err error) { + mydb := getDB(dbname, DBModeRead) + return mydb.BeginTx(ctx, opts) +} + +// DefaultTxOpts is package variable with default transaction level +var DefaultTxOpts = sql.TxOptions{ + Isolation: sql.LevelDefault, + ReadOnly: false, +} + +// TxContext interface for DAO operations with context. +type TxContext interface { + Exec(query string, args ...any) (sql.Result, error) + Query(results any, query string, params ...any) (err error) + QueryInt(query string, params ...any) (result int, err error) + FindOne(result any, tableName string, whereSQL string, params ...any) error + Find(results any, tableName string, whereSQL string, params ...any) error +} + +// tx represents transaction with context as inner object. +type tx struct { + ctx context.Context //nolint:containedctx + Tx *sql.Tx +} + +// Exec executes query with params. +func (t *tx) Exec(query string, params ...any) (sql.Result, error) { + return t.Tx.ExecContext(t.ctx, query, params...) +} + +// Query loads data from db. +func (t *tx) Query(results any, query string, params ...any) error { + rows, err := t.Tx.QueryContext(t.ctx, query, params...) + if err != nil { + return err + } + return RowsToStructSliceReflect(rows, results) +} + +func (t *tx) QueryInt(query string, params ...any) (result int, err error) { + err = t.Tx.QueryRowContext(t.ctx, query, params...).Scan(&result) + return +} + +func (t *tx) FindOne(result any, tableName string, whereSQL string, params ...any) error { + columnsNames := GetColumnsNamesReflect(result) + data := StrutForScan(result) + query := fmt.Sprintf("SELECT %s FROM `%s` %s", columnsNames, + tableName, whereSQL) + err2 := t.Tx.QueryRowContext(t.ctx, query, params...).Scan(data...) + + if err2 != nil { + // It's on purpose the hide the error + // But should re-consider later + if err2 == sql.ErrNoRows { + return nil + } + return err2 + } + + return nil +} + +func (t *tx) Find(results any, tableName string, whereSQL string, params ...any) error { + columnsNames := GetColumnsNamesReflect(results) + query := fmt.Sprintf("SELECT %s FROM `%s` %s", columnsNames, + tableName, whereSQL) + return t.Query(results, query, params...) +} + +// Commit this transaction. +func (t *tx) Commit() error { + return t.Tx.Commit() +} + +// Rollback cancel this transaction. +func (t *tx) Rollback() error { + return t.Tx.Rollback() +} + +// Connector for sql database. +type Connector interface { + BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) +} + +// TxProvider ... +type TxProvider struct { + conn Connector +} + +// NewTxProvider ... +func NewTxProvider(dbname string) *TxProvider { + mydb := getDB(dbname, DBModeWrite) + return &TxProvider{ + conn: mydb, + } +} + +// acquireWithOpts transaction from db +func (t *TxProvider) acquireWithOpts(ctx context.Context, opts *sql.TxOptions) (*tx, error) { + trx, err := t.conn.BeginTx(ctx, opts) + if err != nil { + return nil, err + } + + return &tx{ + ctx: ctx, + Tx: trx, + }, nil +} + +// TxWithOpts ... +func (t *TxProvider) TxWithOpts(ctx context.Context, fn func(TxContext) error, opts *sql.TxOptions) error { + tx, err := t.acquireWithOpts(ctx, opts) + if err != nil { + return err + } + + defer func() { + //nolint:gocritic + if r := recover(); r != nil { + log.Printf("Recovering from panic in TxWithOpts error is: %v \n", r) + _ = tx.Rollback() + err, _ = r.(error) + } else if err != nil { + err = tx.Rollback() + } else { + err = tx.Commit() + } + + if ctx.Err() != nil && errors.Is(err, context.DeadlineExceeded) { + log.Printf("query response time exceeded the configured timeout") + } + }() + + err = fn(tx) + + return err +} + +// Tx runs fn in transaction. +func (t *TxProvider) Tx(ctx context.Context, fn func(TxContext) error) error { + return t.TxWithOpts(ctx, fn, &DefaultTxOpts) +} diff --git a/golalib/golalib_test.go b/golalib/golalib_test.go index b9274a5..7cc6701 100644 --- a/golalib/golalib_test.go +++ b/golalib/golalib_test.go @@ -31,6 +31,7 @@ var testTables = []string{ "gifts", "gifts_nn", "gifts_with_default", "gifts_nn_with_default", "wallet", + "worker", } var testDataPath = "testdata" + string(filepath.Separator) diff --git a/golalib/testdata/worker.sql b/golalib/testdata/worker.sql new file mode 100644 index 0000000..369c5ab --- /dev/null +++ b/golalib/testdata/worker.sql @@ -0,0 +1,5 @@ +CREATE TABLE worker ( + id INT AUTO_INCREMENT PRIMARY KEY, + name VARCHAR(255) NOT NULL, + age INT NOT NULL +); diff --git a/golalib/testdata/worker/worker.go b/golalib/testdata/worker/worker.go new file mode 100644 index 0000000..e37f7d7 --- /dev/null +++ b/golalib/testdata/worker/worker.go @@ -0,0 +1,485 @@ +// Code generated by gola 0.1.1; DO NOT EDIT. + +package worker + +import ( + "database/sql" + "encoding/json" + "reflect" + "strings" + + "github.com/olachat/gola/v2/coredb" +) + +const DBName string = "testdata" +const TableName string = "worker" + +// Worker represents `worker` table +type Worker struct { + // int(11) + Id `json:"id"` + // varchar(255) + Name `json:"name"` + // int(11) + Age `json:"age"` +} + +type withPK interface { + GetId() int +} + +// FetchByPK returns a row from `worker` table with given primary key value +// +// Deprecated: use the function with context +func FetchByPK(val int) *Worker { + return coredb.FetchByPK[Worker](DBName, TableName, []string{"id"}, val) +} + +// FetchFieldsByPK returns a row with selected fields from worker table with given primary key value +// +// Deprecated: use the function with context +func FetchFieldsByPK[T any](val int) *T { + return coredb.FetchByPK[T](DBName, TableName, []string{"id"}, val) +} + +// FetchByPKs returns rows with from `worker` table with given primary key values +// +// Deprecated: use the function with context +func FetchByPKs(vals ...int) []*Worker { + pks := coredb.GetAnySlice(vals) + return coredb.FetchByPKs[Worker](DBName, TableName, "id", pks) +} + +// FetchFieldsByPKs returns rows with selected fields from `worker` table with given primary key values +// +// Deprecated: use the function with context +func FetchFieldsByPKs[T any](vals ...int) []*T { + pks := coredb.GetAnySlice(vals) + return coredb.FetchByPKs[T](DBName, TableName, "id", pks) +} + +// FindOne returns a row from `worker` table with arbitary where query +// whereSQL must start with "where ..." +// +// Deprecated: use the function with context +func FindOne(whereSQL string, params ...any) *Worker { + w := coredb.NewWhere(whereSQL, params...) + return coredb.FindOne[Worker](DBName, TableName, w) +} + +// FindOneFields returns a row with selected fields from `worker` table with arbitary where query +// whereSQL must start with "where ..." +// +// Deprecated: use the function with context +func FindOneFields[T any](whereSQL string, params ...any) *T { + w := coredb.NewWhere(whereSQL, params...) + return coredb.FindOne[T](DBName, TableName, w) +} + +// Find returns rows from `worker` table with arbitary where query +// whereSQL must start with "where ..." +// +// Deprecated: use the function with context +func Find(whereSQL string, params ...any) ([]*Worker, error) { + w := coredb.NewWhere(whereSQL, params...) + return coredb.Find[Worker](DBName, TableName, w) +} + +// FindFields returns rows with selected fields from `worker` table with arbitary where query +// whereSQL must start with "where ..." +// +// Deprecated: use the function with context +func FindFields[T any](whereSQL string, params ...any) ([]*T, error) { + w := coredb.NewWhere(whereSQL, params...) + return coredb.Find[T](DBName, TableName, w) +} + +// Count returns select count(*) with arbitary where query +// whereSQL must start with "where ..." +// +// Deprecated: use the function with context +func Count(whereSQL string, params ...any) (int, error) { + return coredb.QueryInt(DBName, "SELECT COUNT(*) FROM `worker` "+whereSQL, params...) +} + +// FetchByPK returns a row from `worker` table with given primary key value +// +// Deprecated: use the function with context +func FetchByPKFromMaster(val int) *Worker { + return coredb.FetchByPKFromMaster[Worker](DBName, TableName, []string{"id"}, val) +} + +// FetchFieldsByPK returns a row with selected fields from worker table with given primary key value +// +// Deprecated: use the function with context +func FetchFieldsByPKFromMaster[T any](val int) *T { + return coredb.FetchByPKFromMaster[T](DBName, TableName, []string{"id"}, val) +} + +// FetchByPKs returns rows with from `worker` table with given primary key values +// +// Deprecated: use the function with context +func FetchByPKsFromMaster(vals ...int) []*Worker { + pks := coredb.GetAnySlice(vals) + return coredb.FetchByPKsFromMaster[Worker](DBName, TableName, "id", pks) +} + +// FetchFieldsByPKs returns rows with selected fields from `worker` table with given primary key values +// +// Deprecated: use the function with context +func FetchFieldsByPKsFromMaster[T any](vals ...int) []*T { + pks := coredb.GetAnySlice(vals) + return coredb.FetchByPKsFromMaster[T](DBName, TableName, "id", pks) +} + +// FindOne returns a row from `worker` table with arbitary where query +// whereSQL must start with "where ..." +// +// Deprecated: use the function with context +func FindOneFromMaster(whereSQL string, params ...any) *Worker { + w := coredb.NewWhere(whereSQL, params...) + return coredb.FindOneFromMaster[Worker](DBName, TableName, w) +} + +// FindOneFields returns a row with selected fields from `worker` table with arbitary where query +// whereSQL must start with "where ..." +// +// Deprecated: use the function with context +func FindOneFieldsFromMaster[T any](whereSQL string, params ...any) *T { + w := coredb.NewWhere(whereSQL, params...) + return coredb.FindOneFromMaster[T](DBName, TableName, w) +} + +// Find returns rows from `worker` table with arbitary where query +// whereSQL must start with "where ..." +// +// Deprecated: use the function with context +func FindFromMaster(whereSQL string, params ...any) ([]*Worker, error) { + w := coredb.NewWhere(whereSQL, params...) + return coredb.FindFromMaster[Worker](DBName, TableName, w) +} + +// FindFields returns rows with selected fields from `worker` table with arbitary where query +// whereSQL must start with "where ..." +// +// Deprecated: use the function with context +func FindFieldsFromMaster[T any](whereSQL string, params ...any) ([]*T, error) { + w := coredb.NewWhere(whereSQL, params...) + return coredb.FindFromMaster[T](DBName, TableName, w) +} + +// Count returns select count(*) with arbitary where query +// whereSQL must start with "where ..." +// +// Deprecated: use the function with context +func CountFromMaster(whereSQL string, params ...any) (int, error) { + return coredb.QueryIntFromMaster(DBName, "SELECT COUNT(*) FROM `worker` "+whereSQL, params...) +} + +// Column types + +// Id field +type Id struct { + isAssigned bool + val int +} + +func (c *Id) GetId() int { + return c.val +} + +func (c *Id) GetColumnName() string { + return "id" +} + +func (c *Id) GetValPointer() any { + return &c.val +} + +func (c *Id) getIdForDB() int { + return c.val +} + +func (c *Id) MarshalJSON() ([]byte, error) { + return json.Marshal(&c.val) +} + +func (c *Id) UnmarshalJSON(data []byte) error { + if err := json.Unmarshal(data, &c.val); err != nil { + return err + } + + return nil +} + +// Name field +type Name struct { + _updated bool + val string +} + +func (c *Name) GetName() string { + return c.val +} + +func (c *Name) SetName(val string) bool { + if c.val == val { + return false + } + c._updated = true + c.val = val + return true +} + +func (c *Name) IsUpdated() bool { + return c._updated +} + +func (c *Name) resetUpdated() { + c._updated = false +} + +func (c *Name) GetColumnName() string { + return "name" +} + +func (c *Name) GetValPointer() any { + return &c.val +} + +func (c *Name) getNameForDB() string { + return c.val +} + +func (c *Name) MarshalJSON() ([]byte, error) { + return json.Marshal(&c.val) +} + +func (c *Name) UnmarshalJSON(data []byte) error { + if err := json.Unmarshal(data, &c.val); err != nil { + return err + } + + return nil +} + +// Age field +type Age struct { + _updated bool + val int +} + +func (c *Age) GetAge() int { + return c.val +} + +func (c *Age) SetAge(val int) bool { + if c.val == val { + return false + } + c._updated = true + c.val = val + return true +} + +func (c *Age) IsUpdated() bool { + return c._updated +} + +func (c *Age) resetUpdated() { + c._updated = false +} + +func (c *Age) GetColumnName() string { + return "age" +} + +func (c *Age) GetValPointer() any { + return &c.val +} + +func (c *Age) getAgeForDB() int { + return c.val +} + +func (c *Age) MarshalJSON() ([]byte, error) { + return json.Marshal(&c.val) +} + +func (c *Age) UnmarshalJSON(data []byte) error { + if err := json.Unmarshal(data, &c.val); err != nil { + return err + } + + return nil +} + +// New return new *Worker with default values +func New() *Worker { + return &Worker{ + Id{}, + Name{}, + Age{}, + } +} + +// NewWithPK takes "id" +// and returns new *Worker with given PK +func NewWithPK(val int) *Worker { + c := &Worker{ + Id{}, + Name{}, + Age{}, + } + c.Id.val = val + c.Id.isAssigned = true + return c +} + +const insertWithoutPK string = "INSERT INTO `worker` (`name`, `age`) values (?, ?)" +const insertWithPK string = "INSERT INTO `worker` (`id`, `name`, `age`) values (?, ?, ?)" + +// Insert Worker struct to `worker` table +// Deprecated: use the function with context +func (c *Worker) Insert() error { + var result sql.Result + var err error + if c.Id.isAssigned { + result, err = coredb.Exec(DBName, insertWithPK, c.getIdForDB(), c.getNameForDB(), c.getAgeForDB()) + if err != nil { + return err + } + } else { + result, err = coredb.Exec(DBName, insertWithoutPK, c.getNameForDB(), c.getAgeForDB()) + if err != nil { + return err + } + + id, err := result.LastInsertId() + if err != nil { + return err + } + c.Id.val = int(id) + } + + affectedRows, err := result.RowsAffected() + if err != nil { + return err + } + if affectedRows == 0 { + return coredb.ErrAvoidInsert + } + + c.resetUpdated() + return nil +} + +func (c *Worker) resetUpdated() { + c.Name.resetUpdated() + c.Age.resetUpdated() +} + +// Update Worker struct in `worker` table +// Deprecated: use the function with context +func (obj *Worker) Update() (bool, error) { + var updatedFields []string + var params []any + if obj.Name.IsUpdated() { + updatedFields = append(updatedFields, "`name` = ?") + params = append(params, obj.getNameForDB()) + } + if obj.Age.IsUpdated() { + updatedFields = append(updatedFields, "`age` = ?") + params = append(params, obj.getAgeForDB()) + } + + if len(updatedFields) == 0 { + return false, nil + } + + sql := "UPDATE `worker` SET " + sql = sql + strings.Join(updatedFields, ",") + " WHERE `id` = ?" + params = append(params, obj.GetId()) + + result, err := coredb.Exec(DBName, sql, params...) + if err != nil { + return false, err + } + + affectedRows, err := result.RowsAffected() + if err != nil { + return false, err + } + if affectedRows == 0 { + return false, coredb.ErrAvoidUpdate + } + + obj.resetUpdated() + return true, nil +} + +// Update Worker struct with given fields in `worker` table +// Deprecated: use the function with context +func Update(obj withPK) (bool, error) { + var updatedFields []string + var params []any + var resetFuncs []func() + + val := reflect.ValueOf(obj).Elem() + updatedFields = make([]string, 0, val.NumField()) + params = make([]any, 0, val.NumField()) + + for i := 0; i < val.NumField(); i++ { + col := val.Field(i).Addr().Interface() + + switch c := col.(type) { + case *Name: + if c.IsUpdated() { + updatedFields = append(updatedFields, "`name` = ?") + params = append(params, c.getNameForDB()) + resetFuncs = append(resetFuncs, c.resetUpdated) + } + case *Age: + if c.IsUpdated() { + updatedFields = append(updatedFields, "`age` = ?") + params = append(params, c.getAgeForDB()) + resetFuncs = append(resetFuncs, c.resetUpdated) + } + } + } + + if len(updatedFields) == 0 { + return false, nil + } + + sql := "UPDATE `worker` SET " + sql = sql + strings.Join(updatedFields, ",") + " WHERE `id` = ?" + params = append(params, obj.GetId()) + + result, err := coredb.Exec(DBName, sql, params...) + if err != nil { + return false, err + } + + affectedRows, err := result.RowsAffected() + if err != nil { + return false, err + } + if affectedRows == 0 { + return false, coredb.ErrAvoidUpdate + } + + for _, f := range resetFuncs { + f() + } + return true, nil +} + +const deleteSql string = "DELETE FROM `worker` WHERE `id` = ?" + +// DeleteByPK delete a row from worker table with given primary key value +// Deprecated: use the function with context +func DeleteByPK(val int) error { + _, err := coredb.Exec(DBName, deleteSql, val) + return err +} diff --git a/golalib/testdata/worker/worker_ctx.go b/golalib/testdata/worker/worker_ctx.go new file mode 100644 index 0000000..231d2b1 --- /dev/null +++ b/golalib/testdata/worker/worker_ctx.go @@ -0,0 +1,258 @@ +// Code generated by gola 0.1.1; DO NOT EDIT. + +package worker + +import ( + "context" + "database/sql" + "reflect" + "strings" + + "github.com/olachat/gola/v2/coredb" +) + +// FetchByPK returns a row from `worker` table with given primary key value +func FetchByPKCtx(ctx context.Context, val int) (*Worker, error) { + return coredb.FetchByPKCtx[Worker](ctx, DBName, TableName, []string{"id"}, val) +} + +// FetchFieldsByPK returns a row with selected fields from worker table with given primary key value +func FetchFieldsByPKCtx[T any](ctx context.Context, val int) (*T, error) { + return coredb.FetchByPKCtx[T](ctx, DBName, TableName, []string{"id"}, val) +} + +// FetchByPKs returns rows with from `worker` table with given primary key values +func FetchByPKsCtx(ctx context.Context, vals ...int) ([]*Worker, error) { + pks := coredb.GetAnySlice(vals) + return coredb.FetchByPKsCtx[Worker](ctx, DBName, TableName, "id", pks) +} + +// FetchFieldsByPKs returns rows with selected fields from `worker` table with given primary key values +func FetchFieldsByPKsCtx[T any](ctx context.Context, vals ...int) ([]*T, error) { + pks := coredb.GetAnySlice(vals) + return coredb.FetchByPKsCtx[T](ctx, DBName, TableName, "id", pks) +} + +// FindOne returns a row from `worker` table with arbitary where query +// whereSQL must start with "where ..." +func FindOneCtx(ctx context.Context, whereSQL string, params ...any) (*Worker, error) { + w := coredb.NewWhere(whereSQL, params...) + return coredb.FindOneCtx[Worker](ctx, DBName, TableName, w) +} + +// FindOneFields returns a row with selected fields from `worker` table with arbitary where query +// whereSQL must start with "where ..." +func FindOneFieldsCtx[T any](ctx context.Context, whereSQL string, params ...any) (*T, error) { + w := coredb.NewWhere(whereSQL, params...) + return coredb.FindOneCtx[T](ctx, DBName, TableName, w) +} + +// Find returns rows from `worker` table with arbitary where query +// whereSQL must start with "where ..." +func FindCtx(ctx context.Context, whereSQL string, params ...any) ([]*Worker, error) { + w := coredb.NewWhere(whereSQL, params...) + return coredb.FindCtx[Worker](ctx, DBName, TableName, w) +} + +// FindFields returns rows with selected fields from `worker` table with arbitary where query +// whereSQL must start with "where ..." +func FindFieldsCtx[T any](ctx context.Context, whereSQL string, params ...any) ([]*T, error) { + w := coredb.NewWhere(whereSQL, params...) + return coredb.FindCtx[T](ctx, DBName, TableName, w) +} + +// Count returns select count(*) with arbitary where query +// whereSQL must start with "where ..." +func CountCtx(ctx context.Context, whereSQL string, params ...any) (int, error) { + return coredb.QueryIntCtx(ctx, DBName, "SELECT COUNT(*) FROM `worker` "+whereSQL, params...) +} + +// FetchByPK returns a row from `worker` table with given primary key value +func FetchByPKFromMasterCtx(ctx context.Context, val int) (*Worker, error) { + return coredb.FetchByPKFromMasterCtx[Worker](ctx, DBName, TableName, []string{"id"}, val) +} + +// FetchFieldsByPK returns a row with selected fields from worker table with given primary key value +func FetchFieldsByPKFromMasterCtx[T any](ctx context.Context, val int) (*T, error) { + return coredb.FetchByPKFromMasterCtx[T](ctx, DBName, TableName, []string{"id"}, val) +} + +// FetchByPKs returns rows with from `worker` table with given primary key values +func FetchByPKsFromMasterCtx(ctx context.Context, vals ...int) ([]*Worker, error) { + pks := coredb.GetAnySlice(vals) + return coredb.FetchByPKsFromMasterCtx[Worker](ctx, DBName, TableName, "id", pks) +} + +// FetchFieldsByPKs returns rows with selected fields from `worker` table with given primary key values +func FetchFieldsByPKsFromMasterCtx[T any](ctx context.Context, vals ...int) ([]*T, error) { + pks := coredb.GetAnySlice(vals) + return coredb.FetchByPKsFromMasterCtx[T](ctx, DBName, TableName, "id", pks) +} + +// FindOne returns a row from `worker` table with arbitary where query +// whereSQL must start with "where ..." +func FindOneFromMasterCtx(ctx context.Context, whereSQL string, params ...any) (*Worker, error) { + w := coredb.NewWhere(whereSQL, params...) + return coredb.FindOneFromMasterCtx[Worker](ctx, DBName, TableName, w) +} + +// FindOneFields returns a row with selected fields from `worker` table with arbitary where query +// whereSQL must start with "where ..." +func FindOneFieldsFromMasterCtx[T any](ctx context.Context, whereSQL string, params ...any) (*T, error) { + w := coredb.NewWhere(whereSQL, params...) + return coredb.FindOneFromMasterCtx[T](ctx, DBName, TableName, w) +} + +// Find returns rows from `worker` table with arbitary where query +// whereSQL must start with "where ..." +func FindFromMasterCtx(ctx context.Context, whereSQL string, params ...any) ([]*Worker, error) { + w := coredb.NewWhere(whereSQL, params...) + return coredb.FindFromMasterCtx[Worker](ctx, DBName, TableName, w) +} + +// FindFields returns rows with selected fields from `worker` table with arbitary where query +// whereSQL must start with "where ..." +func FindFieldsFromMasterCtx[T any](ctx context.Context, whereSQL string, params ...any) ([]*T, error) { + w := coredb.NewWhere(whereSQL, params...) + return coredb.FindFromMasterCtx[T](ctx, DBName, TableName, w) +} + +// Count returns select count(*) with arbitary where query +// whereSQL must start with "where ..." +func CountFromMasterCtx(ctx context.Context, whereSQL string, params ...any) (int, error) { + return coredb.QueryIntFromMasterCtx(ctx, DBName, "SELECT COUNT(*) FROM `worker` "+whereSQL, params...) +} + +// Insert Worker struct to `worker` table +func (c *Worker) InsertCtx(ctx context.Context) error { + var result sql.Result + var err error + if c.Id.isAssigned { + result, err = coredb.ExecCtx(ctx, DBName, insertWithPK, c.getIdForDB(), c.getNameForDB(), c.getAgeForDB()) + if err != nil { + return err + } + } else { + result, err = coredb.ExecCtx(ctx, DBName, insertWithoutPK, c.getNameForDB(), c.getAgeForDB()) + if err != nil { + return err + } + + id, err := result.LastInsertId() + if err != nil { + return err + } + c.Id.val = int(id) + } + + affectedRows, err := result.RowsAffected() + if err != nil { + return err + } + if affectedRows == 0 { + return coredb.ErrAvoidInsert + } + + c.resetUpdated() + return nil +} + +// Update Worker struct in `worker` table +func (obj *Worker) UpdateCtx(ctx context.Context) (bool, error) { + var updatedFields []string + var params []any + if obj.Name.IsUpdated() { + updatedFields = append(updatedFields, "`name` = ?") + params = append(params, obj.getNameForDB()) + } + if obj.Age.IsUpdated() { + updatedFields = append(updatedFields, "`age` = ?") + params = append(params, obj.getAgeForDB()) + } + + if len(updatedFields) == 0 { + return false, nil + } + + sql := "UPDATE `worker` SET " + sql = sql + strings.Join(updatedFields, ",") + " WHERE `id` = ?" + params = append(params, obj.GetId()) + + result, err := coredb.ExecCtx(ctx, DBName, sql, params...) + if err != nil { + return false, err + } + + affectedRows, err := result.RowsAffected() + if err != nil { + return false, err + } + if affectedRows == 0 { + return false, coredb.ErrAvoidUpdate + } + + obj.resetUpdated() + return true, nil +} + +// Update Worker struct with given fields in `worker` table +func UpdateCtx(ctx context.Context, obj withPK) (bool, error) { + var updatedFields []string + var params []any + var resetFuncs []func() + + val := reflect.ValueOf(obj).Elem() + updatedFields = make([]string, 0, val.NumField()) + params = make([]any, 0, val.NumField()) + + for i := 0; i < val.NumField(); i++ { + col := val.Field(i).Addr().Interface() + + switch c := col.(type) { + case *Name: + if c.IsUpdated() { + updatedFields = append(updatedFields, "`name` = ?") + params = append(params, c.getNameForDB()) + resetFuncs = append(resetFuncs, c.resetUpdated) + } + case *Age: + if c.IsUpdated() { + updatedFields = append(updatedFields, "`age` = ?") + params = append(params, c.getAgeForDB()) + resetFuncs = append(resetFuncs, c.resetUpdated) + } + } + } + + if len(updatedFields) == 0 { + return false, nil + } + + sql := "UPDATE `worker` SET " + sql = sql + strings.Join(updatedFields, ",") + " WHERE `id` = ?" + params = append(params, obj.GetId()) + + result, err := coredb.ExecCtx(ctx, DBName, sql, params...) + if err != nil { + return false, err + } + + affectedRows, err := result.RowsAffected() + if err != nil { + return false, err + } + if affectedRows == 0 { + return false, coredb.ErrAvoidUpdate + } + + for _, f := range resetFuncs { + f() + } + return true, nil +} + +// DeleteByPK delete a row from worker table with given primary key value +func DeleteByPKCtx(ctx context.Context, val int) error { + _, err := coredb.ExecCtx(ctx, DBName, deleteSql, val) + return err +} diff --git a/golalib/testdata/worker/worker_idx.go b/golalib/testdata/worker/worker_idx.go new file mode 100644 index 0000000..c1d6627 --- /dev/null +++ b/golalib/testdata/worker/worker_idx.go @@ -0,0 +1,105 @@ +// Code generated by gola 0.1.1; DO NOT EDIT. + +package worker + +import ( + "fmt" + "strings" + + "github.com/olachat/gola/v2/coredb" +) + +type orderBy int + +type idxQuery[T any] struct { + whereSql string + limitSql string + orders []string + whereParams []any +} + +// order by enum & interface +const ( + IdAsc orderBy = iota + IdDesc + NameAsc + NameDesc + AgeAsc + AgeDesc +) + +func (q *idxQuery[T]) OrderBy(args ...orderBy) coredb.ReadQuery[T] { + q.orders = make([]string, len(args)) + for i, arg := range args { + switch arg { + case IdAsc: + q.orders[i] = "`id` asc" + case IdDesc: + q.orders[i] = "`id` desc" + case NameAsc: + q.orders[i] = "`name` asc" + case NameDesc: + q.orders[i] = "`name` desc" + case AgeAsc: + q.orders[i] = "`age` asc" + case AgeDesc: + q.orders[i] = "`age` desc" + } + } + return q +} + +func (q *idxQuery[T]) All() []*T { + result, _ := coredb.Find[T](DBName, TableName, q) + return result +} + +func (q *idxQuery[T]) Limit(offset, limit int) []*T { + q.limitSql = fmt.Sprintf(" limit %d, %d", offset, limit) + result, _ := coredb.Find[T](DBName, TableName, q) + return result +} + +func (q *idxQuery[T]) AllFromMaster() []*T { + result, _ := coredb.FindFromMaster[T](DBName, TableName, q) + return result +} + +func (q *idxQuery[T]) LimitFromMaster(offset, limit int) []*T { + q.limitSql = fmt.Sprintf(" limit %d, %d", offset, limit) + result, _ := coredb.FindFromMaster[T](DBName, TableName, q) + return result +} + +type order[T any] interface { + OrderBy(args ...orderBy) coredb.ReadQuery[T] +} + +type orderReadQuery[T any] interface { + order[T] + coredb.ReadQuery[T] +} + +type iQuery[T any] interface { + orderReadQuery[T] +} + +// Find methods + +// Select returns rows from `worker` table with index awared query +func Select() iQuery[Worker] { + return new(idxQuery[Worker]) +} + +// SelectFields returns rows with selected fields from `worker` table with index awared query +func SelectFields[T any]() iQuery[T] { + return new(idxQuery[T]) +} + +func (q *idxQuery[T]) GetWhere() (whereSql string, params []any) { + var orderSql string + if len(q.orders) > 0 { + orderSql = " order by " + strings.Join(q.orders, ",") + } + return q.whereSql + orderSql + q.limitSql, q.whereParams +} diff --git a/golalib/testdata/worker/worker_idx_ctx.go b/golalib/testdata/worker/worker_idx_ctx.go new file mode 100644 index 0000000..8aa533e --- /dev/null +++ b/golalib/testdata/worker/worker_idx_ctx.go @@ -0,0 +1,28 @@ +// Code generated by gola 0.1.1; DO NOT EDIT. + +package worker + +import ( + "context" + "fmt" + + "github.com/olachat/gola/v2/coredb" +) + +func (q *idxQuery[T]) AllCtx(ctx context.Context) ([]*T, error) { + return coredb.FindCtx[T](ctx, DBName, TableName, q) +} + +func (q *idxQuery[T]) LimitCtx(ctx context.Context, offset, limit int) ([]*T, error) { + q.limitSql = fmt.Sprintf(" limit %d, %d", offset, limit) + return coredb.FindCtx[T](ctx, DBName, TableName, q) +} + +func (q *idxQuery[T]) AllFromMasterCtx(ctx context.Context) ([]*T, error) { + return coredb.FindFromMasterCtx[T](ctx, DBName, TableName, q) +} + +func (q *idxQuery[T]) LimitFromMasterCtx(ctx context.Context, offset, limit int) ([]*T, error) { + q.limitSql = fmt.Sprintf(" limit %d, %d", offset, limit) + return coredb.FindFromMasterCtx[T](ctx, DBName, TableName, q) +} diff --git a/tests/tx_test.go b/tests/tx_test.go new file mode 100644 index 0000000..d2052c1 --- /dev/null +++ b/tests/tx_test.go @@ -0,0 +1,114 @@ +package tests + +import ( + "context" + "database/sql" + "errors" + "strings" + "testing" + "time" + + _ "github.com/go-sql-driver/mysql" + "github.com/stretchr/testify/assert" + + "github.com/olachat/gola/v2/coredb" + "github.com/olachat/gola/v2/golalib/testdata/worker" +) + +func TestBeginTx(t *testing.T) { + as := assert.New(t) + + prov := coredb.NewTxProvider("newdb") + err := prov.Tx(context.Background(), func(tx coredb.TxContext) error { + _, err := tx.Exec("truncate table worker") + as.Nil(err) + + var workers []*worker.Worker + err = tx.Find(&workers, "worker", "where id > ?", 0) + as.Nil(err) + as.Equal(0, len(workers)) + + _, err = tx.Exec("insert into worker (name,age) values (?, ?)", "peter", 18) + as.Nil(err) + + _, err = tx.Exec("insert into worker (name,age) values (?, ?)", "john", 28) + as.Nil(err) + return err + }) + as.Nil(err) + + err = prov.Tx(context.Background(), func(tx coredb.TxContext) error { + var workers []*worker.Worker + err := tx.Find(&workers, "worker", "where id > ?", 0) + as.Nil(err) + as.Equal(2, len(workers)) + as.Equal("peter", workers[0].GetName()) + as.Equal(18, workers[0].GetAge()) + as.Equal("john", workers[1].GetName()) + as.Equal(28, workers[1].GetAge()) + + var w worker.Worker + err = tx.FindOne(&w, "worker", "where id = ?", 1) + as.Nil(err) + as.Equal("peter", w.GetName()) + as.Equal(18, w.GetAge()) + + r, err := tx.QueryInt("select count(1) from worker") + as.Nil(err) + as.Equal(2, r) + + var workers2 []*worker.Worker + err = tx.Query(&workers2, "select * from worker where id > ?", 0) + as.Nil(err) + as.Equal(2, len(workers2)) + as.Equal("peter", workers2[0].GetName()) + as.Equal(18, workers2[0].GetAge()) + as.Equal("john", workers2[1].GetName()) + as.Equal(28, workers2[1].GetAge()) + return nil + }) + + prov.Tx(context.Background(), func(tx coredb.TxContext) error { + _, err := tx.Exec("insert into worker (name,age) values (?, ?)", "winson", 19) + as.Nil(err) + + return errors.New("abort") + }) + + prov.Tx(context.Background(), func(tx coredb.TxContext) error { + var w []*worker.Worker + err := tx.Find(&w, "worker", "where id > ?", 0) + as.Nil(err) + as.Equal(2, len(w)) + as.Equal("peter", w[0].GetName()) + as.Equal(18, w[0].GetAge()) + as.Equal("john", w[1].GetName()) + as.Equal(28, w[1].GetAge()) + return nil + }) + as.Nil(err) + +} + +func open() (db *sql.DB, err error) { + dsn := "root:123456@tcp(127.0.0.1:3307)/newdb" + if !strings.Contains(dsn, "?parseTime=true") { + dsn += "?parseTime=true" + } + + maxIdle := 3.0 + + maxOpen := 50.0 + + maxLifetime := 30.0 + + db, err = sql.Open("mysql", dsn) + if err != nil { + return nil, err + } + + db.SetConnMaxIdleTime(time.Duration(maxIdle) * time.Second) + db.SetConnMaxLifetime(time.Duration(maxLifetime) * time.Second) + db.SetMaxOpenConns(int(maxOpen)) + return +} diff --git a/tests/user_test.go b/tests/user_test.go index 0fd7239..a820b5a 100644 --- a/tests/user_test.go +++ b/tests/user_test.go @@ -52,7 +52,16 @@ func init() { panic(err) } + realdb, err := open() + + if err != nil { + panic(err) + } + coredb.Setup(func(dbname string, mode coredb.DBMode) *sql.DB { + if dbname == "newdb" { + return realdb + } return db }) From 398d11c968ad2d38f2380ac5b67eb0549b92696b Mon Sep 17 00:00:00 2001 From: tanyinloo Date: Thu, 14 Dec 2023 10:52:11 +0800 Subject: [PATCH 3/8] add deprecated msg to generated idx file --- golalib/testdata/account/account_idx.go | 4 ++++ golalib/testdata/blogs/blogs_idx.go | 4 ++++ golalib/testdata/gifts/gifts_idx.go | 4 ++++ golalib/testdata/gifts_nn/gifts_nn_idx.go | 4 ++++ .../gifts_nn_with_default/gifts_nn_with_default_idx.go | 4 ++++ golalib/testdata/gifts_with_default/gifts_with_default_idx.go | 4 ++++ golalib/testdata/profile/profile_idx.go | 4 ++++ golalib/testdata/room/room_idx.go | 4 ++++ .../testdata/song_user_favourites/song_user_favourites_idx.go | 4 ++++ golalib/testdata/songs/songs_idx.go | 4 ++++ golalib/testdata/users/users_idx.go | 4 ++++ golalib/testdata/wallet/wallet_idx.go | 4 ++++ golalib/testdata/worker/worker_idx.go | 4 ++++ ormtpl/01_struct_idx.gogo | 4 ++++ 14 files changed, 56 insertions(+) diff --git a/golalib/testdata/account/account_idx.go b/golalib/testdata/account/account_idx.go index 9225cbc..805f0a2 100644 --- a/golalib/testdata/account/account_idx.go +++ b/golalib/testdata/account/account_idx.go @@ -55,22 +55,26 @@ func (q *idxQuery[T]) OrderBy(args ...orderBy) coredb.ReadQuery[T] { return q } +// deprecated: use the function with context func (q *idxQuery[T]) All() []*T { result, _ := coredb.Find[T](DBName, TableName, q) return result } +// deprecated: use the function with context func (q *idxQuery[T]) Limit(offset, limit int) []*T { q.limitSql = fmt.Sprintf(" limit %d, %d", offset, limit) result, _ := coredb.Find[T](DBName, TableName, q) return result } +// deprecated: use the function with context func (q *idxQuery[T]) AllFromMaster() []*T { result, _ := coredb.FindFromMaster[T](DBName, TableName, q) return result } +// deprecated: use the function with context func (q *idxQuery[T]) LimitFromMaster(offset, limit int) []*T { q.limitSql = fmt.Sprintf(" limit %d, %d", offset, limit) result, _ := coredb.FindFromMaster[T](DBName, TableName, q) diff --git a/golalib/testdata/blogs/blogs_idx.go b/golalib/testdata/blogs/blogs_idx.go index 8243ad4..5323ac7 100644 --- a/golalib/testdata/blogs/blogs_idx.go +++ b/golalib/testdata/blogs/blogs_idx.go @@ -91,22 +91,26 @@ func (q *idxQuery[T]) OrderBy(args ...orderBy) coredb.ReadQuery[T] { return q } +// deprecated: use the function with context func (q *idxQuery[T]) All() []*T { result, _ := coredb.Find[T](DBName, TableName, q) return result } +// deprecated: use the function with context func (q *idxQuery[T]) Limit(offset, limit int) []*T { q.limitSql = fmt.Sprintf(" limit %d, %d", offset, limit) result, _ := coredb.Find[T](DBName, TableName, q) return result } +// deprecated: use the function with context func (q *idxQuery[T]) AllFromMaster() []*T { result, _ := coredb.FindFromMaster[T](DBName, TableName, q) return result } +// deprecated: use the function with context func (q *idxQuery[T]) LimitFromMaster(offset, limit int) []*T { q.limitSql = fmt.Sprintf(" limit %d, %d", offset, limit) result, _ := coredb.FindFromMaster[T](DBName, TableName, q) diff --git a/golalib/testdata/gifts/gifts_idx.go b/golalib/testdata/gifts/gifts_idx.go index 6c0d3fa..873958c 100644 --- a/golalib/testdata/gifts/gifts_idx.go +++ b/golalib/testdata/gifts/gifts_idx.go @@ -109,22 +109,26 @@ func (q *idxQuery[T]) OrderBy(args ...orderBy) coredb.ReadQuery[T] { return q } +// deprecated: use the function with context func (q *idxQuery[T]) All() []*T { result, _ := coredb.Find[T](DBName, TableName, q) return result } +// deprecated: use the function with context func (q *idxQuery[T]) Limit(offset, limit int) []*T { q.limitSql = fmt.Sprintf(" limit %d, %d", offset, limit) result, _ := coredb.Find[T](DBName, TableName, q) return result } +// deprecated: use the function with context func (q *idxQuery[T]) AllFromMaster() []*T { result, _ := coredb.FindFromMaster[T](DBName, TableName, q) return result } +// deprecated: use the function with context func (q *idxQuery[T]) LimitFromMaster(offset, limit int) []*T { q.limitSql = fmt.Sprintf(" limit %d, %d", offset, limit) result, _ := coredb.FindFromMaster[T](DBName, TableName, q) diff --git a/golalib/testdata/gifts_nn/gifts_nn_idx.go b/golalib/testdata/gifts_nn/gifts_nn_idx.go index ef9ddd0..a5b5037 100644 --- a/golalib/testdata/gifts_nn/gifts_nn_idx.go +++ b/golalib/testdata/gifts_nn/gifts_nn_idx.go @@ -109,22 +109,26 @@ func (q *idxQuery[T]) OrderBy(args ...orderBy) coredb.ReadQuery[T] { return q } +// deprecated: use the function with context func (q *idxQuery[T]) All() []*T { result, _ := coredb.Find[T](DBName, TableName, q) return result } +// deprecated: use the function with context func (q *idxQuery[T]) Limit(offset, limit int) []*T { q.limitSql = fmt.Sprintf(" limit %d, %d", offset, limit) result, _ := coredb.Find[T](DBName, TableName, q) return result } +// deprecated: use the function with context func (q *idxQuery[T]) AllFromMaster() []*T { result, _ := coredb.FindFromMaster[T](DBName, TableName, q) return result } +// deprecated: use the function with context func (q *idxQuery[T]) LimitFromMaster(offset, limit int) []*T { q.limitSql = fmt.Sprintf(" limit %d, %d", offset, limit) result, _ := coredb.FindFromMaster[T](DBName, TableName, q) diff --git a/golalib/testdata/gifts_nn_with_default/gifts_nn_with_default_idx.go b/golalib/testdata/gifts_nn_with_default/gifts_nn_with_default_idx.go index b55460c..ca9c6f5 100644 --- a/golalib/testdata/gifts_nn_with_default/gifts_nn_with_default_idx.go +++ b/golalib/testdata/gifts_nn_with_default/gifts_nn_with_default_idx.go @@ -115,22 +115,26 @@ func (q *idxQuery[T]) OrderBy(args ...orderBy) coredb.ReadQuery[T] { return q } +// deprecated: use the function with context func (q *idxQuery[T]) All() []*T { result, _ := coredb.Find[T](DBName, TableName, q) return result } +// deprecated: use the function with context func (q *idxQuery[T]) Limit(offset, limit int) []*T { q.limitSql = fmt.Sprintf(" limit %d, %d", offset, limit) result, _ := coredb.Find[T](DBName, TableName, q) return result } +// deprecated: use the function with context func (q *idxQuery[T]) AllFromMaster() []*T { result, _ := coredb.FindFromMaster[T](DBName, TableName, q) return result } +// deprecated: use the function with context func (q *idxQuery[T]) LimitFromMaster(offset, limit int) []*T { q.limitSql = fmt.Sprintf(" limit %d, %d", offset, limit) result, _ := coredb.FindFromMaster[T](DBName, TableName, q) diff --git a/golalib/testdata/gifts_with_default/gifts_with_default_idx.go b/golalib/testdata/gifts_with_default/gifts_with_default_idx.go index c8921ff..c041912 100644 --- a/golalib/testdata/gifts_with_default/gifts_with_default_idx.go +++ b/golalib/testdata/gifts_with_default/gifts_with_default_idx.go @@ -115,22 +115,26 @@ func (q *idxQuery[T]) OrderBy(args ...orderBy) coredb.ReadQuery[T] { return q } +// deprecated: use the function with context func (q *idxQuery[T]) All() []*T { result, _ := coredb.Find[T](DBName, TableName, q) return result } +// deprecated: use the function with context func (q *idxQuery[T]) Limit(offset, limit int) []*T { q.limitSql = fmt.Sprintf(" limit %d, %d", offset, limit) result, _ := coredb.Find[T](DBName, TableName, q) return result } +// deprecated: use the function with context func (q *idxQuery[T]) AllFromMaster() []*T { result, _ := coredb.FindFromMaster[T](DBName, TableName, q) return result } +// deprecated: use the function with context func (q *idxQuery[T]) LimitFromMaster(offset, limit int) []*T { q.limitSql = fmt.Sprintf(" limit %d, %d", offset, limit) result, _ := coredb.FindFromMaster[T](DBName, TableName, q) diff --git a/golalib/testdata/profile/profile_idx.go b/golalib/testdata/profile/profile_idx.go index d180565..c07f9c8 100644 --- a/golalib/testdata/profile/profile_idx.go +++ b/golalib/testdata/profile/profile_idx.go @@ -49,22 +49,26 @@ func (q *idxQuery[T]) OrderBy(args ...orderBy) coredb.ReadQuery[T] { return q } +// deprecated: use the function with context func (q *idxQuery[T]) All() []*T { result, _ := coredb.Find[T](DBName, TableName, q) return result } +// deprecated: use the function with context func (q *idxQuery[T]) Limit(offset, limit int) []*T { q.limitSql = fmt.Sprintf(" limit %d, %d", offset, limit) result, _ := coredb.Find[T](DBName, TableName, q) return result } +// deprecated: use the function with context func (q *idxQuery[T]) AllFromMaster() []*T { result, _ := coredb.FindFromMaster[T](DBName, TableName, q) return result } +// deprecated: use the function with context func (q *idxQuery[T]) LimitFromMaster(offset, limit int) []*T { q.limitSql = fmt.Sprintf(" limit %d, %d", offset, limit) result, _ := coredb.FindFromMaster[T](DBName, TableName, q) diff --git a/golalib/testdata/room/room_idx.go b/golalib/testdata/room/room_idx.go index 8e6016c..7a5763e 100644 --- a/golalib/testdata/room/room_idx.go +++ b/golalib/testdata/room/room_idx.go @@ -61,22 +61,26 @@ func (q *idxQuery[T]) OrderBy(args ...orderBy) coredb.ReadQuery[T] { return q } +// deprecated: use the function with context func (q *idxQuery[T]) All() []*T { result, _ := coredb.Find[T](DBName, TableName, q) return result } +// deprecated: use the function with context func (q *idxQuery[T]) Limit(offset, limit int) []*T { q.limitSql = fmt.Sprintf(" limit %d, %d", offset, limit) result, _ := coredb.Find[T](DBName, TableName, q) return result } +// deprecated: use the function with context func (q *idxQuery[T]) AllFromMaster() []*T { result, _ := coredb.FindFromMaster[T](DBName, TableName, q) return result } +// deprecated: use the function with context func (q *idxQuery[T]) LimitFromMaster(offset, limit int) []*T { q.limitSql = fmt.Sprintf(" limit %d, %d", offset, limit) result, _ := coredb.FindFromMaster[T](DBName, TableName, q) diff --git a/golalib/testdata/song_user_favourites/song_user_favourites_idx.go b/golalib/testdata/song_user_favourites/song_user_favourites_idx.go index 0c4f414..947866e 100644 --- a/golalib/testdata/song_user_favourites/song_user_favourites_idx.go +++ b/golalib/testdata/song_user_favourites/song_user_favourites_idx.go @@ -67,22 +67,26 @@ func (q *idxQuery[T]) OrderBy(args ...orderBy) coredb.ReadQuery[T] { return q } +// deprecated: use the function with context func (q *idxQuery[T]) All() []*T { result, _ := coredb.Find[T](DBName, TableName, q) return result } +// deprecated: use the function with context func (q *idxQuery[T]) Limit(offset, limit int) []*T { q.limitSql = fmt.Sprintf(" limit %d, %d", offset, limit) result, _ := coredb.Find[T](DBName, TableName, q) return result } +// deprecated: use the function with context func (q *idxQuery[T]) AllFromMaster() []*T { result, _ := coredb.FindFromMaster[T](DBName, TableName, q) return result } +// deprecated: use the function with context func (q *idxQuery[T]) LimitFromMaster(offset, limit int) []*T { q.limitSql = fmt.Sprintf(" limit %d, %d", offset, limit) result, _ := coredb.FindFromMaster[T](DBName, TableName, q) diff --git a/golalib/testdata/songs/songs_idx.go b/golalib/testdata/songs/songs_idx.go index 1d27dc3..f9dd3be 100644 --- a/golalib/testdata/songs/songs_idx.go +++ b/golalib/testdata/songs/songs_idx.go @@ -73,22 +73,26 @@ func (q *idxQuery[T]) OrderBy(args ...orderBy) coredb.ReadQuery[T] { return q } +// deprecated: use the function with context func (q *idxQuery[T]) All() []*T { result, _ := coredb.Find[T](DBName, TableName, q) return result } +// deprecated: use the function with context func (q *idxQuery[T]) Limit(offset, limit int) []*T { q.limitSql = fmt.Sprintf(" limit %d, %d", offset, limit) result, _ := coredb.Find[T](DBName, TableName, q) return result } +// deprecated: use the function with context func (q *idxQuery[T]) AllFromMaster() []*T { result, _ := coredb.FindFromMaster[T](DBName, TableName, q) return result } +// deprecated: use the function with context func (q *idxQuery[T]) LimitFromMaster(offset, limit int) []*T { q.limitSql = fmt.Sprintf(" limit %d, %d", offset, limit) result, _ := coredb.FindFromMaster[T](DBName, TableName, q) diff --git a/golalib/testdata/users/users_idx.go b/golalib/testdata/users/users_idx.go index 0a7ddb2..9aef931 100644 --- a/golalib/testdata/users/users_idx.go +++ b/golalib/testdata/users/users_idx.go @@ -103,22 +103,26 @@ func (q *idxQuery[T]) OrderBy(args ...orderBy) coredb.ReadQuery[T] { return q } +// deprecated: use the function with context func (q *idxQuery[T]) All() []*T { result, _ := coredb.Find[T](DBName, TableName, q) return result } +// deprecated: use the function with context func (q *idxQuery[T]) Limit(offset, limit int) []*T { q.limitSql = fmt.Sprintf(" limit %d, %d", offset, limit) result, _ := coredb.Find[T](DBName, TableName, q) return result } +// deprecated: use the function with context func (q *idxQuery[T]) AllFromMaster() []*T { result, _ := coredb.FindFromMaster[T](DBName, TableName, q) return result } +// deprecated: use the function with context func (q *idxQuery[T]) LimitFromMaster(offset, limit int) []*T { q.limitSql = fmt.Sprintf(" limit %d, %d", offset, limit) result, _ := coredb.FindFromMaster[T](DBName, TableName, q) diff --git a/golalib/testdata/wallet/wallet_idx.go b/golalib/testdata/wallet/wallet_idx.go index 9a4095b..0e9a68b 100644 --- a/golalib/testdata/wallet/wallet_idx.go +++ b/golalib/testdata/wallet/wallet_idx.go @@ -55,22 +55,26 @@ func (q *idxQuery[T]) OrderBy(args ...orderBy) coredb.ReadQuery[T] { return q } +// deprecated: use the function with context func (q *idxQuery[T]) All() []*T { result, _ := coredb.Find[T](DBName, TableName, q) return result } +// deprecated: use the function with context func (q *idxQuery[T]) Limit(offset, limit int) []*T { q.limitSql = fmt.Sprintf(" limit %d, %d", offset, limit) result, _ := coredb.Find[T](DBName, TableName, q) return result } +// deprecated: use the function with context func (q *idxQuery[T]) AllFromMaster() []*T { result, _ := coredb.FindFromMaster[T](DBName, TableName, q) return result } +// deprecated: use the function with context func (q *idxQuery[T]) LimitFromMaster(offset, limit int) []*T { q.limitSql = fmt.Sprintf(" limit %d, %d", offset, limit) result, _ := coredb.FindFromMaster[T](DBName, TableName, q) diff --git a/golalib/testdata/worker/worker_idx.go b/golalib/testdata/worker/worker_idx.go index c1d6627..b21fe5d 100644 --- a/golalib/testdata/worker/worker_idx.go +++ b/golalib/testdata/worker/worker_idx.go @@ -49,22 +49,26 @@ func (q *idxQuery[T]) OrderBy(args ...orderBy) coredb.ReadQuery[T] { return q } +// deprecated: use the function with context func (q *idxQuery[T]) All() []*T { result, _ := coredb.Find[T](DBName, TableName, q) return result } +// deprecated: use the function with context func (q *idxQuery[T]) Limit(offset, limit int) []*T { q.limitSql = fmt.Sprintf(" limit %d, %d", offset, limit) result, _ := coredb.Find[T](DBName, TableName, q) return result } +// deprecated: use the function with context func (q *idxQuery[T]) AllFromMaster() []*T { result, _ := coredb.FindFromMaster[T](DBName, TableName, q) return result } +// deprecated: use the function with context func (q *idxQuery[T]) LimitFromMaster(offset, limit int) []*T { q.limitSql = fmt.Sprintf(" limit %d, %d", offset, limit) result, _ := coredb.FindFromMaster[T](DBName, TableName, q) diff --git a/ormtpl/01_struct_idx.gogo b/ormtpl/01_struct_idx.gogo index 91647c6..701ecdc 100644 --- a/ormtpl/01_struct_idx.gogo +++ b/ormtpl/01_struct_idx.gogo @@ -46,22 +46,26 @@ func (q *idxQuery[T]) OrderBy(args ...orderBy) coredb.ReadQuery[T] { return q } +// deprecated: use the function with context func (q *idxQuery[T]) All() []*T { result, _ := coredb.Find[T](DBName, TableName, q) return result } +// deprecated: use the function with context func (q *idxQuery[T]) Limit(offset, limit int) []*T { q.limitSql = fmt.Sprintf(" limit %d, %d", offset, limit) result, _ := coredb.Find[T](DBName, TableName, q) return result } +// deprecated: use the function with context func (q *idxQuery[T]) AllFromMaster() []*T { result, _ := coredb.FindFromMaster[T](DBName, TableName, q) return result } +// deprecated: use the function with context func (q *idxQuery[T]) LimitFromMaster(offset, limit int) []*T { q.limitSql = fmt.Sprintf(" limit %d, %d", offset, limit) result, _ := coredb.FindFromMaster[T](DBName, TableName, q) From 39af31a224ce9ac527c7669155e618fb7338212f Mon Sep 17 00:00:00 2001 From: tanyinloo Date: Thu, 14 Dec 2023 10:59:34 +0800 Subject: [PATCH 4/8] add comments --- coredb/tx.go | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/coredb/tx.go b/coredb/tx.go index 6d2f1db..f34ff95 100644 --- a/coredb/tx.go +++ b/coredb/tx.go @@ -24,10 +24,30 @@ var DefaultTxOpts = sql.TxOptions{ // TxContext interface for DAO operations with context. type TxContext interface { + // Exec executes a query without returning any rows. + // The args are for any placeholder parameters in the query. Exec(query string, args ...any) (sql.Result, error) + + // Query executes a SELECT query and scans the resulting rows into the provided 'results' destination. + // It accepts a SQL query and an optional list of parameters for placeholder substitution. + // NOTE: results must be a pointer to a slice of struct pointers. Query(results any, query string, params ...any) (err error) + + // QueryInt executes a SELECT query expected to return a single integer value. + // Commonly used for COUNT(*) operations or where the result is inherently an integer. + // Multiple params for query placeholders are supported. QueryInt(query string, params ...any) (result int, err error) + + // FindOne fetches a single record from the database and populates 'result'. + // It requires the name of the table, an optional WHERE clause ('whereSQL'), and + // parameters to substitute into the WHERE clause's placeholders. + // NOTE: result must be a non-nil pointer to a struct. FindOne(result any, tableName string, whereSQL string, params ...any) error + + // Find executes a SELECT query based on the given 'tableName' and 'whereSQL', + // placing all matching records into the 'results' slice. + // Parameters for the WHERE clause's placeholders can be passed with 'params'. + // NOTE: results must be a pointer to a slice of struct pointers. Find(results any, tableName string, whereSQL string, params ...any) error } From 273bc8a105553a35cd6cadf3ba41142a42dfd92d Mon Sep 17 00:00:00 2001 From: tanyinloo Date: Thu, 14 Dec 2023 11:21:42 +0800 Subject: [PATCH 5/8] remove heavy checking --- coredb/scan.go | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/coredb/scan.go b/coredb/scan.go index d6ba429..1ffda63 100644 --- a/coredb/scan.go +++ b/coredb/scan.go @@ -28,6 +28,8 @@ func (e *InvalidScanError) Error() string { // // The function returns a slice of pointers to T structs and an error. func RowsToStructSlice[T any](rows *sql.Rows) (result []*T, err error) { + defer rows.Close() + var u *T for rows.Next() { u = new(T) @@ -38,6 +40,7 @@ func RowsToStructSlice[T any](rows *sql.Rows) (result []*T, err error) { } result = append(result, u) } + err = rows.Err() return } @@ -59,14 +62,7 @@ func RowToStruct[T any](row *sql.Row) (result *T, err error) { // StrutForScan returns value pointers of given obj func StrutForScan(u any) (pointers []any) { - val := reflect.ValueOf(u) - if val.Kind() != reflect.Pointer || val.IsNil() { - err := &InvalidScanError{reflect.TypeOf(u)} - panic(err) - } - - val = val.Elem() - + val := reflect.ValueOf(u).Elem() pointers = make([]any, 0, val.NumField()) for i := 0; i < val.NumField(); i++ { valueField := val.Field(i) From a18c33f19b3efdba9d41ebf9583b9de9db9d0204 Mon Sep 17 00:00:00 2001 From: tanyinloo Date: Thu, 14 Dec 2023 13:20:57 +0800 Subject: [PATCH 6/8] handle rollback error --- coredb/scan.go | 35 +++++++++------- coredb/tx.go | 31 ++++++++------ tests/tx_test.go | 102 +++++++++++++++++++++++++++++++---------------- 3 files changed, 105 insertions(+), 63 deletions(-) diff --git a/coredb/scan.go b/coredb/scan.go index 1ffda63..28ca5cd 100644 --- a/coredb/scan.go +++ b/coredb/scan.go @@ -51,11 +51,8 @@ func RowToStruct[T any](row *sql.Row) (result *T, err error) { result = new(T) data := StrutForScan(result) err = row.Scan(data...) - if err != nil { - if err == sql.ErrNoRows { - return nil, nil - } - return nil, err + if err == sql.ErrNoRows { + err = nil } return } @@ -74,20 +71,24 @@ func StrutForScan(u any) (pointers []any) { } func RowsToStructSliceReflect(rows *sql.Rows, out any) (err error) { + if rows == nil { + return + } sliceValue := reflect.ValueOf(out) if sliceValue.Kind() != reflect.Ptr || sliceValue.IsNil() { - return &InvalidScanError{Type: sliceValue.Type()} + panic(&InvalidScanError{Type: sliceValue.Type()}) } sliceValue = sliceValue.Elem() if sliceValue.Kind() != reflect.Slice { - return &InvalidScanError{Type: reflect.TypeOf(out)} + panic(&InvalidScanError{Type: reflect.TypeOf(out)}) } elementType := sliceValue.Type().Elem() if elementType.Kind() != reflect.Ptr { - return &InvalidScanError{Type: reflect.TypeOf(out)} + panic(&InvalidScanError{Type: reflect.TypeOf(out)}) } elementType = elementType.Elem() + var elements []reflect.Value for rows.Next() { v := reflect.New(elementType) data := StrutForScan(v.Interface()) @@ -95,24 +96,28 @@ func RowsToStructSliceReflect(rows *sql.Rows, out any) (err error) { if err != nil { return } - sliceValue.Set(reflect.Append(sliceValue, v)) + elements = append(elements, v.Elem()) + } + + sliceValue.Set(reflect.MakeSlice(sliceValue.Type(), len(elements), len(elements))) + for i, v := range elements { + sliceValue.Index(i).Set(v.Addr()) } + + err = rows.Err() return } func RowToStructReflect(row *sql.Row, v any) (err error) { rv := reflect.ValueOf(v) if rv.Kind() != reflect.Pointer || rv.IsNil() { - return &InvalidScanError{reflect.TypeOf(v)} + panic(&InvalidScanError{reflect.TypeOf(v)}) } data := StrutForScan(v) err = row.Scan(data...) - if err != nil { - if err == sql.ErrNoRows { - return nil - } - return err + if err == sql.ErrNoRows { + return nil } return } diff --git a/coredb/tx.go b/coredb/tx.go index f34ff95..59a344d 100644 --- a/coredb/tx.go +++ b/coredb/tx.go @@ -5,7 +5,6 @@ import ( "database/sql" "errors" "fmt" - "log" ) // BeginTx returns a custom db.Tx based on opts. This method exists for flexibility. @@ -144,8 +143,9 @@ func (t *TxProvider) acquireWithOpts(ctx context.Context, opts *sql.TxOptions) ( } // TxWithOpts ... -func (t *TxProvider) TxWithOpts(ctx context.Context, fn func(TxContext) error, opts *sql.TxOptions) error { - tx, err := t.acquireWithOpts(ctx, opts) +func (t *TxProvider) TxWithOpts(ctx context.Context, fn func(TxContext) error, opts *sql.TxOptions) (err error) { + var trx *tx + trx, err = t.acquireWithOpts(ctx, opts) if err != nil { return err } @@ -153,21 +153,26 @@ func (t *TxProvider) TxWithOpts(ctx context.Context, fn func(TxContext) error, o defer func() { //nolint:gocritic if r := recover(); r != nil { - log.Printf("Recovering from panic in TxWithOpts error is: %v \n", r) - _ = tx.Rollback() - err, _ = r.(error) + _ = trx.Rollback() + var ok bool + err, ok = r.(error) + if !ok { + err = fmt.Errorf("%v", r) + } } else if err != nil { - err = tx.Rollback() + errRollback := trx.Rollback() + if errors.Is(errRollback, sql.ErrTxDone) && ctx.Err() != nil { + errRollback = nil + } + if errRollback != nil { + err = fmt.Errorf("%v encountered. but rollback failed: %w", err, errRollback) + } } else { - err = tx.Commit() - } - - if ctx.Err() != nil && errors.Is(err, context.DeadlineExceeded) { - log.Printf("query response time exceeded the configured timeout") + err = trx.Commit() } }() - err = fn(tx) + err = fn(trx) return err } diff --git a/tests/tx_test.go b/tests/tx_test.go index d2052c1..68db8c4 100644 --- a/tests/tx_test.go +++ b/tests/tx_test.go @@ -4,73 +4,74 @@ import ( "context" "database/sql" "errors" + "fmt" + "reflect" "strings" - "testing" "time" _ "github.com/go-sql-driver/mysql" - "github.com/stretchr/testify/assert" "github.com/olachat/gola/v2/coredb" "github.com/olachat/gola/v2/golalib/testdata/worker" ) -func TestBeginTx(t *testing.T) { - as := assert.New(t) +func ExampleNewTxProvider() { prov := coredb.NewTxProvider("newdb") err := prov.Tx(context.Background(), func(tx coredb.TxContext) error { _, err := tx.Exec("truncate table worker") - as.Nil(err) + panicOnErr(err) var workers []*worker.Worker err = tx.Find(&workers, "worker", "where id > ?", 0) - as.Nil(err) - as.Equal(0, len(workers)) + panicOnErr(err) + mustEqual(0, len(workers)) + fmt.Println("no of workers:", len(workers)) // uncomment to run test + // Output: no of workers: 0 _, err = tx.Exec("insert into worker (name,age) values (?, ?)", "peter", 18) - as.Nil(err) + panicOnErr(err) _, err = tx.Exec("insert into worker (name,age) values (?, ?)", "john", 28) - as.Nil(err) + panicOnErr(err) return err }) - as.Nil(err) + panicOnErr(err) err = prov.Tx(context.Background(), func(tx coredb.TxContext) error { var workers []*worker.Worker err := tx.Find(&workers, "worker", "where id > ?", 0) - as.Nil(err) - as.Equal(2, len(workers)) - as.Equal("peter", workers[0].GetName()) - as.Equal(18, workers[0].GetAge()) - as.Equal("john", workers[1].GetName()) - as.Equal(28, workers[1].GetAge()) + panicOnErr(err) + mustEqual(2, len(workers)) + mustEqual("peter", workers[0].GetName()) + mustEqual(18, workers[0].GetAge()) + mustEqual("john", workers[1].GetName()) + mustEqual(28, workers[1].GetAge()) var w worker.Worker err = tx.FindOne(&w, "worker", "where id = ?", 1) - as.Nil(err) - as.Equal("peter", w.GetName()) - as.Equal(18, w.GetAge()) + panicOnErr(err) + mustEqual("peter", w.GetName()) + mustEqual(18, w.GetAge()) r, err := tx.QueryInt("select count(1) from worker") - as.Nil(err) - as.Equal(2, r) + panicOnErr(err) + mustEqual(2, r) var workers2 []*worker.Worker err = tx.Query(&workers2, "select * from worker where id > ?", 0) - as.Nil(err) - as.Equal(2, len(workers2)) - as.Equal("peter", workers2[0].GetName()) - as.Equal(18, workers2[0].GetAge()) - as.Equal("john", workers2[1].GetName()) - as.Equal(28, workers2[1].GetAge()) + panicOnErr(err) + mustEqual(2, len(workers2)) + mustEqual("peter", workers2[0].GetName()) + mustEqual(18, workers2[0].GetAge()) + mustEqual("john", workers2[1].GetName()) + mustEqual(28, workers2[1].GetAge()) return nil }) prov.Tx(context.Background(), func(tx coredb.TxContext) error { _, err := tx.Exec("insert into worker (name,age) values (?, ?)", "winson", 19) - as.Nil(err) + panicOnErr(err) return errors.New("abort") }) @@ -78,16 +79,47 @@ func TestBeginTx(t *testing.T) { prov.Tx(context.Background(), func(tx coredb.TxContext) error { var w []*worker.Worker err := tx.Find(&w, "worker", "where id > ?", 0) - as.Nil(err) - as.Equal(2, len(w)) - as.Equal("peter", w[0].GetName()) - as.Equal(18, w[0].GetAge()) - as.Equal("john", w[1].GetName()) - as.Equal(28, w[1].GetAge()) + panicOnErr(err) + mustEqual(2, len(w)) + mustEqual("peter", w[0].GetName()) + mustEqual(18, w[0].GetAge()) + mustEqual("john", w[1].GetName()) + mustEqual(28, w[1].GetAge()) return nil }) - as.Nil(err) + panicOnErr(err) + prov2 := coredb.NewTxProvider("newdb") + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + err = prov2.Tx(ctx, func(tx coredb.TxContext) error { + _, err := tx.Exec("insert into worker (name,age) values (?, ?)", "winson", 19) + if err != nil { + return err + } + var w []*worker.Worker + time.Sleep(10 * time.Millisecond) + err = tx.Find(&w, "worker", "where age = ?", 28) + if err != nil { + return err + } + return nil + }) + if !errors.Is(err, context.DeadlineExceeded) { + panic(err) + } + +} + +func panicOnErr(err error) { + if err != nil { + panic(err) + } +} +func mustEqual(a, b interface{}) { + if !reflect.DeepEqual(a, b) { + panic(fmt.Sprintf("%v != %v", a, b)) + } } func open() (db *sql.DB, err error) { From 5add6885498517b897f63b99affb024179793f91 Mon Sep 17 00:00:00 2001 From: tanyinloo Date: Thu, 14 Dec 2023 15:25:52 +0800 Subject: [PATCH 7/8] add missing rows.close --- coredb/scan.go | 1 + 1 file changed, 1 insertion(+) diff --git a/coredb/scan.go b/coredb/scan.go index 28ca5cd..00fc38a 100644 --- a/coredb/scan.go +++ b/coredb/scan.go @@ -74,6 +74,7 @@ func RowsToStructSliceReflect(rows *sql.Rows, out any) (err error) { if rows == nil { return } + defer rows.Close() sliceValue := reflect.ValueOf(out) if sliceValue.Kind() != reflect.Ptr || sliceValue.IsNil() { panic(&InvalidScanError{Type: sliceValue.Type()}) From 742a7e45ae9c21122d91bdfe362c243bb1f9a0d8 Mon Sep 17 00:00:00 2001 From: tanyinloo Date: Thu, 14 Dec 2023 15:36:12 +0800 Subject: [PATCH 8/8] shift rows.close() to caller --- coredb/scan.go | 3 --- coredb/tx.go | 1 + 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/coredb/scan.go b/coredb/scan.go index 00fc38a..200991c 100644 --- a/coredb/scan.go +++ b/coredb/scan.go @@ -28,8 +28,6 @@ func (e *InvalidScanError) Error() string { // // The function returns a slice of pointers to T structs and an error. func RowsToStructSlice[T any](rows *sql.Rows) (result []*T, err error) { - defer rows.Close() - var u *T for rows.Next() { u = new(T) @@ -74,7 +72,6 @@ func RowsToStructSliceReflect(rows *sql.Rows, out any) (err error) { if rows == nil { return } - defer rows.Close() sliceValue := reflect.ValueOf(out) if sliceValue.Kind() != reflect.Ptr || sliceValue.IsNil() { panic(&InvalidScanError{Type: sliceValue.Type()}) diff --git a/coredb/tx.go b/coredb/tx.go index 59a344d..bcbbbc4 100644 --- a/coredb/tx.go +++ b/coredb/tx.go @@ -67,6 +67,7 @@ func (t *tx) Query(results any, query string, params ...any) error { if err != nil { return err } + defer rows.Close() return RowsToStructSliceReflect(rows, results) }