Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add transaction support #54

Merged
merged 8 commits into from
Dec 14, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 55 additions & 8 deletions coredb/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ func getDB(dbname string, mode DBMode) *sql.DB {
}

// FetchByPK returns a row of T type with given primary key value
//
// Deprecated: use the function with context
func FetchByPK[T any](dbname string, tableName string, pkName []string, val ...any) *T {
sql := "WHERE `" + pkName[0] + "` = ?"
Expand All @@ -54,6 +55,7 @@ func FetchByPK[T any](dbname string, tableName string, pkName []string, val ...a
}

// FetchByPKs returns rows of T type with given primary key values
//
// Deprecated: use the function with context
func FetchByPKs[T any](dbname string, tableName string, pkName string, vals []any) []*T {
if len(vals) == 0 {
Expand All @@ -71,6 +73,7 @@ func FetchByPKs[T any](dbname string, tableName string, pkName string, vals []an
}

// FetchByPKFromMaster returns a row of T type with given primary key value
//
// Deprecated: use the function with context
func FetchByPKFromMaster[T any](dbname string, tableName string, pkName []string, val ...any) *T {
sql := "WHERE `" + pkName[0] + "` = ?"
Expand All @@ -82,6 +85,7 @@ func FetchByPKFromMaster[T any](dbname string, tableName string, pkName []string
}

// FetchByPKsFromMaster returns rows of T type with given primary key values
//
// Deprecated: use the function with context
func FetchByPKsFromMaster[T any](dbname string, tableName string, pkName string, vals []any) []*T {
if len(vals) == 0 {
Expand All @@ -99,13 +103,15 @@ func FetchByPKsFromMaster[T any](dbname string, tableName string, pkName string,
}

// Exec given query with given db info & params
//
// Deprecated: use the function with context
func Exec(dbname string, query string, params ...any) (sql.Result, error) {
mydb := getDB(dbname, DBModeWrite)
return mydb.Exec(query, params...)
}

// FindOne returns a row from given table type with where query
//
// Deprecated: use the function with context
func FindOne[T any](dbname string, tableName string, where WhereQuery) *T {
u := new(T)
Expand All @@ -131,6 +137,7 @@ func FindOne[T any](dbname string, tableName string, where WhereQuery) *T {
}

// Find returns rows from given table type with where query
//
// Deprecated: use the function with context
func Find[T any](dbname string, tableName string, where WhereQuery) ([]*T, error) {
columnsNames := GetColumnsNames[T]()
Expand All @@ -142,6 +149,7 @@ func Find[T any](dbname string, tableName string, where WhereQuery) ([]*T, error
}

// FindOneFromMaster using master DB returns a row from given table type with where query
//
// Deprecated: use the function with context
func FindOneFromMaster[T any](dbname string, tableName string, where WhereQuery) *T {
u := new(T)
Expand All @@ -167,6 +175,7 @@ func FindOneFromMaster[T any](dbname string, tableName string, where WhereQuery)
}

// FindFromMaster using master DB returns rows from given table type with where query
//
// Deprecated: use the function with context
func FindFromMaster[T any](dbname string, tableName string, where WhereQuery) ([]*T, error) {
columnsNames := GetColumnsNames[T]()
Expand All @@ -178,22 +187,25 @@ func FindFromMaster[T any](dbname string, tableName string, where WhereQuery) ([
}

// QueryInt single int result by query, handy for count(*) querys
//
// Deprecated: use the function with context
func QueryInt(dbname string, query string, params ...any) (result int, err error) {
mydb := getDB(dbname, DBModeRead)
mydb.QueryRow(query, params...).Scan(&result)
err = mydb.QueryRow(query, params...).Scan(&result)
return
}

// QueryIntFromMaster single int result by query, handy for count(*) querys
//
// Deprecated: use the function with context
func QueryIntFromMaster(dbname string, query string, params ...any) (result int, err error) {
mydb := getDB(dbname, DBModeReadFromWrite)
mydb.QueryRow(query, params...).Scan(&result)
err = mydb.QueryRow(query, params...).Scan(&result)
return
}

// Query rows from given table type with where query & params
//
// Deprecated: use the function with context
func Query[T any](dbname string, query string, params ...any) (result []*T, err error) {
mydb := getDB(dbname, DBModeRead)
Expand All @@ -217,6 +229,7 @@ func Query[T any](dbname string, query string, params ...any) (result []*T, err
}

// Query rows from master DB from given table type with where query & params
//
// Deprecated: use the function with context
func QueryFromMaster[T any](dbname string, query string, params ...any) (result []*T, err error) {
mydb := getDB(dbname, DBModeReadFromWrite)
Expand Down Expand Up @@ -244,7 +257,8 @@ func GetColumnsNames[T any]() (joinedColumnNames string) {
var o *T
t := reflect.TypeOf(o)
typeColumnNamesLock.RLock()
joinedColumnNames, ok := typeColumnNames[t]
var ok bool
joinedColumnNames, ok = typeColumnNames[t]
typeColumnNamesLock.RUnlock()
if ok {
return
Expand All @@ -269,15 +283,48 @@ func GetColumnsNames[T any]() (joinedColumnNames string) {
return
}

// StrutForScan returns value pointers of given obj
func StrutForScan[T any](u *T) (pointers []any) {
val := reflect.ValueOf(u).Elem()
pointers = make([]any, 0, val.NumField())
// GetColumnsNamesReflect returns column names joined by `,` of given type
func GetColumnsNamesReflect(o any) (joinedColumnNames string) {
t := reflect.TypeOf(o)
elemType := t.Elem()
switch t.Kind() {
case reflect.Ptr:
// 是指针,尝试获取其指向的元素类型
if elemType.Kind() == reflect.Slice {
elemType = elemType.Elem().Elem() // 如果指针指向切片,获取切片元素的类型
}
case reflect.Slice:
// 是切片,获取切片元素的类型
if elemType.Kind() == reflect.Ptr {
elemType = elemType.Elem() // 切片元素是指针,获取其指向的类型
}
default:
// 既不是指针也不是切片,返回错误
panic("coredb: o is neither a pointer nor a slice")
}

typeColumnNamesLock.RLock()
var ok bool
joinedColumnNames, ok = typeColumnNames[elemType]
typeColumnNamesLock.RUnlock()
if ok {
return
}

var columnNames []string
val := reflect.New(elemType).Elem()
for i := 0; i < val.NumField(); i++ {
valueField := val.Field(i)
if f, ok := valueField.Addr().Interface().(ColumnType); ok {
pointers = append(pointers, f.GetValPointer())
columnNames = append(columnNames, "`"+f.GetColumnName()+"`")
}
}

joinedColumnNames = strings.Join(columnNames, ",")

typeColumnNamesLock.Lock()
typeColumnNames[t] = joinedColumnNames
typeColumnNamesLock.Unlock()

return
}
4 changes: 2 additions & 2 deletions coredb/engine_ctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,14 +129,14 @@ func FindFromMasterCtx[T any](ctx context.Context, dbname string, tableName stri
// QueryIntCtx single int result by query, handy for count(*) querys
func QueryIntCtx(ctx context.Context, dbname string, query string, params ...any) (result int, err error) {
mydb := getDB(dbname, DBModeRead)
mydb.QueryRowContext(ctx, query, params...).Scan(&result)
err = mydb.QueryRowContext(ctx, query, params...).Scan(&result)
return
}

// QueryIntFromMasterCtx single int result by query, handy for count(*) querys
func QueryIntFromMasterCtx(ctx context.Context, dbname string, query string, params ...any) (result int, err error) {
mydb := getDB(dbname, DBModeReadFromWrite)
mydb.QueryRowContext(ctx, query, params...).Scan(&result)
err = mydb.QueryRowContext(ctx, query, params...).Scan(&result)
return
}

Expand Down
123 changes: 123 additions & 0 deletions coredb/scan.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
package coredb

import (
"database/sql"
"reflect"
)

// An InvalidScanError describes an invalid argument passed to Scan.
type InvalidScanError struct {
Type reflect.Type
}

func (e *InvalidScanError) Error() string {
if e.Type == nil {
return "coredb: target is nil"
}

if e.Type.Kind() != reflect.Pointer {
return "coredb: target must be a non-nil pointer, got " + e.Type.String()
}
return "coredb: nil " + e.Type.String() + ")"
}

// RowsToStructSlice converts the rows of a SQL query result into a slice of structs.
//
// It takes a pointer to a sql.Rows object as input.
// The function also uses a generic type T, which represents the type of the struct.
//
// The function returns a slice of pointers to T structs and an error.
func RowsToStructSlice[T any](rows *sql.Rows) (result []*T, err error) {
defer rows.Close()

var u *T
for rows.Next() {
u = new(T)
data := StrutForScan(u)
err = rows.Scan(data...)
if err != nil {
return
}
result = append(result, u)
}
err = rows.Err()
return
}

// RowToStruct converts a database row into a struct.
//
// It takes a pointer to a sql.Row and returns a pointer to the converted struct and an error.
func RowToStruct[T any](row *sql.Row) (result *T, err error) {
result = new(T)
data := StrutForScan(result)
err = row.Scan(data...)
if err == sql.ErrNoRows {
err = nil
}
return
}

// StrutForScan returns value pointers of given obj
func StrutForScan(u any) (pointers []any) {
val := reflect.ValueOf(u).Elem()
pointers = make([]any, 0, val.NumField())
for i := 0; i < val.NumField(); i++ {
valueField := val.Field(i)
if f, ok := valueField.Addr().Interface().(ColumnType); ok {
pointers = append(pointers, f.GetValPointer())
}
}
return
}

func RowsToStructSliceReflect(rows *sql.Rows, out any) (err error) {
if rows == nil {
return
}
sliceValue := reflect.ValueOf(out)
if sliceValue.Kind() != reflect.Ptr || sliceValue.IsNil() {
panic(&InvalidScanError{Type: sliceValue.Type()})
}
sliceValue = sliceValue.Elem()
if sliceValue.Kind() != reflect.Slice {
panic(&InvalidScanError{Type: reflect.TypeOf(out)})
}
elementType := sliceValue.Type().Elem()
if elementType.Kind() != reflect.Ptr {
panic(&InvalidScanError{Type: reflect.TypeOf(out)})
}
elementType = elementType.Elem()

var elements []reflect.Value
for rows.Next() {
v := reflect.New(elementType)
data := StrutForScan(v.Interface())
err = rows.Scan(data...)
if err != nil {
return
}
elements = append(elements, v.Elem())
}

sliceValue.Set(reflect.MakeSlice(sliceValue.Type(), len(elements), len(elements)))
for i, v := range elements {
sliceValue.Index(i).Set(v.Addr())
}

err = rows.Err()
yinloo-ola marked this conversation as resolved.
Show resolved Hide resolved
return
}

func RowToStructReflect(row *sql.Row, v any) (err error) {
rv := reflect.ValueOf(v)
if rv.Kind() != reflect.Pointer || rv.IsNil() {
panic(&InvalidScanError{reflect.TypeOf(v)})
}

data := StrutForScan(v)
err = row.Scan(data...)
if err == sql.ErrNoRows {
return nil
}
return
}
Loading
Loading