Skip to content

Commit

Permalink
Merge pull request #54 from olachat/yinloo/transaction-support
Browse files Browse the repository at this point in the history
add transaction support
  • Loading branch information
yinloo-ola authored Dec 14, 2023
2 parents 44053ba + 742a7e4 commit 87bfcfb
Show file tree
Hide file tree
Showing 25 changed files with 1,455 additions and 10 deletions.
63 changes: 55 additions & 8 deletions coredb/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ func getDB(dbname string, mode DBMode) *sql.DB {
}

// FetchByPK returns a row of T type with given primary key value
//
// Deprecated: use the function with context
func FetchByPK[T any](dbname string, tableName string, pkName []string, val ...any) *T {
sql := "WHERE `" + pkName[0] + "` = ?"
Expand All @@ -54,6 +55,7 @@ func FetchByPK[T any](dbname string, tableName string, pkName []string, val ...a
}

// FetchByPKs returns rows of T type with given primary key values
//
// Deprecated: use the function with context
func FetchByPKs[T any](dbname string, tableName string, pkName string, vals []any) []*T {
if len(vals) == 0 {
Expand All @@ -71,6 +73,7 @@ func FetchByPKs[T any](dbname string, tableName string, pkName string, vals []an
}

// FetchByPKFromMaster returns a row of T type with given primary key value
//
// Deprecated: use the function with context
func FetchByPKFromMaster[T any](dbname string, tableName string, pkName []string, val ...any) *T {
sql := "WHERE `" + pkName[0] + "` = ?"
Expand All @@ -82,6 +85,7 @@ func FetchByPKFromMaster[T any](dbname string, tableName string, pkName []string
}

// FetchByPKsFromMaster returns rows of T type with given primary key values
//
// Deprecated: use the function with context
func FetchByPKsFromMaster[T any](dbname string, tableName string, pkName string, vals []any) []*T {
if len(vals) == 0 {
Expand All @@ -99,13 +103,15 @@ func FetchByPKsFromMaster[T any](dbname string, tableName string, pkName string,
}

// Exec given query with given db info & params
//
// Deprecated: use the function with context
func Exec(dbname string, query string, params ...any) (sql.Result, error) {
mydb := getDB(dbname, DBModeWrite)
return mydb.Exec(query, params...)
}

// FindOne returns a row from given table type with where query
//
// Deprecated: use the function with context
func FindOne[T any](dbname string, tableName string, where WhereQuery) *T {
u := new(T)
Expand All @@ -131,6 +137,7 @@ func FindOne[T any](dbname string, tableName string, where WhereQuery) *T {
}

// Find returns rows from given table type with where query
//
// Deprecated: use the function with context
func Find[T any](dbname string, tableName string, where WhereQuery) ([]*T, error) {
columnsNames := GetColumnsNames[T]()
Expand All @@ -142,6 +149,7 @@ func Find[T any](dbname string, tableName string, where WhereQuery) ([]*T, error
}

// FindOneFromMaster using master DB returns a row from given table type with where query
//
// Deprecated: use the function with context
func FindOneFromMaster[T any](dbname string, tableName string, where WhereQuery) *T {
u := new(T)
Expand All @@ -167,6 +175,7 @@ func FindOneFromMaster[T any](dbname string, tableName string, where WhereQuery)
}

// FindFromMaster using master DB returns rows from given table type with where query
//
// Deprecated: use the function with context
func FindFromMaster[T any](dbname string, tableName string, where WhereQuery) ([]*T, error) {
columnsNames := GetColumnsNames[T]()
Expand All @@ -178,22 +187,25 @@ func FindFromMaster[T any](dbname string, tableName string, where WhereQuery) ([
}

// QueryInt single int result by query, handy for count(*) querys
//
// Deprecated: use the function with context
func QueryInt(dbname string, query string, params ...any) (result int, err error) {
mydb := getDB(dbname, DBModeRead)
mydb.QueryRow(query, params...).Scan(&result)
err = mydb.QueryRow(query, params...).Scan(&result)
return
}

// QueryIntFromMaster single int result by query, handy for count(*) querys
//
// Deprecated: use the function with context
func QueryIntFromMaster(dbname string, query string, params ...any) (result int, err error) {
mydb := getDB(dbname, DBModeReadFromWrite)
mydb.QueryRow(query, params...).Scan(&result)
err = mydb.QueryRow(query, params...).Scan(&result)
return
}

// Query rows from given table type with where query & params
//
// Deprecated: use the function with context
func Query[T any](dbname string, query string, params ...any) (result []*T, err error) {
mydb := getDB(dbname, DBModeRead)
Expand All @@ -217,6 +229,7 @@ func Query[T any](dbname string, query string, params ...any) (result []*T, err
}

// Query rows from master DB from given table type with where query & params
//
// Deprecated: use the function with context
func QueryFromMaster[T any](dbname string, query string, params ...any) (result []*T, err error) {
mydb := getDB(dbname, DBModeReadFromWrite)
Expand Down Expand Up @@ -244,7 +257,8 @@ func GetColumnsNames[T any]() (joinedColumnNames string) {
var o *T
t := reflect.TypeOf(o)
typeColumnNamesLock.RLock()
joinedColumnNames, ok := typeColumnNames[t]
var ok bool
joinedColumnNames, ok = typeColumnNames[t]
typeColumnNamesLock.RUnlock()
if ok {
return
Expand All @@ -269,15 +283,48 @@ func GetColumnsNames[T any]() (joinedColumnNames string) {
return
}

// StrutForScan returns value pointers of given obj
func StrutForScan[T any](u *T) (pointers []any) {
val := reflect.ValueOf(u).Elem()
pointers = make([]any, 0, val.NumField())
// GetColumnsNamesReflect returns column names joined by `,` of given type
func GetColumnsNamesReflect(o any) (joinedColumnNames string) {
t := reflect.TypeOf(o)
elemType := t.Elem()
switch t.Kind() {
case reflect.Ptr:
// 是指针,尝试获取其指向的元素类型
if elemType.Kind() == reflect.Slice {
elemType = elemType.Elem().Elem() // 如果指针指向切片,获取切片元素的类型
}
case reflect.Slice:
// 是切片,获取切片元素的类型
if elemType.Kind() == reflect.Ptr {
elemType = elemType.Elem() // 切片元素是指针,获取其指向的类型
}
default:
// 既不是指针也不是切片,返回错误
panic("coredb: o is neither a pointer nor a slice")
}

typeColumnNamesLock.RLock()
var ok bool
joinedColumnNames, ok = typeColumnNames[elemType]
typeColumnNamesLock.RUnlock()
if ok {
return
}

var columnNames []string
val := reflect.New(elemType).Elem()
for i := 0; i < val.NumField(); i++ {
valueField := val.Field(i)
if f, ok := valueField.Addr().Interface().(ColumnType); ok {
pointers = append(pointers, f.GetValPointer())
columnNames = append(columnNames, "`"+f.GetColumnName()+"`")
}
}

joinedColumnNames = strings.Join(columnNames, ",")

typeColumnNamesLock.Lock()
typeColumnNames[t] = joinedColumnNames
typeColumnNamesLock.Unlock()

return
}
4 changes: 2 additions & 2 deletions coredb/engine_ctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,14 +129,14 @@ func FindFromMasterCtx[T any](ctx context.Context, dbname string, tableName stri
// QueryIntCtx single int result by query, handy for count(*) querys
func QueryIntCtx(ctx context.Context, dbname string, query string, params ...any) (result int, err error) {
mydb := getDB(dbname, DBModeRead)
mydb.QueryRowContext(ctx, query, params...).Scan(&result)
err = mydb.QueryRowContext(ctx, query, params...).Scan(&result)
return
}

// QueryIntFromMasterCtx single int result by query, handy for count(*) querys
func QueryIntFromMasterCtx(ctx context.Context, dbname string, query string, params ...any) (result int, err error) {
mydb := getDB(dbname, DBModeReadFromWrite)
mydb.QueryRowContext(ctx, query, params...).Scan(&result)
err = mydb.QueryRowContext(ctx, query, params...).Scan(&result)
return
}

Expand Down
121 changes: 121 additions & 0 deletions coredb/scan.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
package coredb

import (
"database/sql"
"reflect"
)

// An InvalidScanError describes an invalid argument passed to Scan.
type InvalidScanError struct {
Type reflect.Type
}

func (e *InvalidScanError) Error() string {
if e.Type == nil {
return "coredb: target is nil"
}

if e.Type.Kind() != reflect.Pointer {
return "coredb: target must be a non-nil pointer, got " + e.Type.String()
}
return "coredb: nil " + e.Type.String() + ")"
}

// RowsToStructSlice converts the rows of a SQL query result into a slice of structs.
//
// It takes a pointer to a sql.Rows object as input.
// The function also uses a generic type T, which represents the type of the struct.
//
// The function returns a slice of pointers to T structs and an error.
func RowsToStructSlice[T any](rows *sql.Rows) (result []*T, err error) {
var u *T
for rows.Next() {
u = new(T)
data := StrutForScan(u)
err = rows.Scan(data...)
if err != nil {
return
}
result = append(result, u)
}
err = rows.Err()
return
}

// RowToStruct converts a database row into a struct.
//
// It takes a pointer to a sql.Row and returns a pointer to the converted struct and an error.
func RowToStruct[T any](row *sql.Row) (result *T, err error) {
result = new(T)
data := StrutForScan(result)
err = row.Scan(data...)
if err == sql.ErrNoRows {
err = nil
}
return
}

// StrutForScan returns value pointers of given obj
func StrutForScan(u any) (pointers []any) {
val := reflect.ValueOf(u).Elem()
pointers = make([]any, 0, val.NumField())
for i := 0; i < val.NumField(); i++ {
valueField := val.Field(i)
if f, ok := valueField.Addr().Interface().(ColumnType); ok {
pointers = append(pointers, f.GetValPointer())
}
}
return
}

func RowsToStructSliceReflect(rows *sql.Rows, out any) (err error) {
if rows == nil {
return
}
sliceValue := reflect.ValueOf(out)
if sliceValue.Kind() != reflect.Ptr || sliceValue.IsNil() {
panic(&InvalidScanError{Type: sliceValue.Type()})
}
sliceValue = sliceValue.Elem()
if sliceValue.Kind() != reflect.Slice {
panic(&InvalidScanError{Type: reflect.TypeOf(out)})
}
elementType := sliceValue.Type().Elem()
if elementType.Kind() != reflect.Ptr {
panic(&InvalidScanError{Type: reflect.TypeOf(out)})
}
elementType = elementType.Elem()

var elements []reflect.Value
for rows.Next() {
v := reflect.New(elementType)
data := StrutForScan(v.Interface())
err = rows.Scan(data...)
if err != nil {
return
}
elements = append(elements, v.Elem())
}

sliceValue.Set(reflect.MakeSlice(sliceValue.Type(), len(elements), len(elements)))
for i, v := range elements {
sliceValue.Index(i).Set(v.Addr())
}

err = rows.Err()
return
}

func RowToStructReflect(row *sql.Row, v any) (err error) {
rv := reflect.ValueOf(v)
if rv.Kind() != reflect.Pointer || rv.IsNil() {
panic(&InvalidScanError{reflect.TypeOf(v)})
}

data := StrutForScan(v)
err = row.Scan(data...)
if err == sql.ErrNoRows {
return nil
}
return
}
Loading

0 comments on commit 87bfcfb

Please sign in to comment.