Skip to content

Commit

Permalink
Merge pull request #56 from olachat/yinloo/refactor-column-interface
Browse files Browse the repository at this point in the history
split interface. add params helper and ColumnVal
  • Loading branch information
yinloo-ola authored Feb 2, 2024
2 parents 976b426 + 26a83f3 commit 9735f8c
Show file tree
Hide file tree
Showing 7 changed files with 106 additions and 7 deletions.
13 changes: 13 additions & 0 deletions coredb/column.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package coredb

type ColumnVal[T any] struct {
val T
}

func (c *ColumnVal[T]) GetValPointer() any {
return &c.val
}

func (c *ColumnVal[T]) GetVal() T {
return c.val
}
6 changes: 4 additions & 2 deletions coredb/column_type.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package coredb

// ColumnType defines the generated type of a table column
type ColumnType interface {
GetColumnName() string
type ColumnValPointer interface {
GetValPointer() any
}
type ColumnNamer interface {
GetColumnName() string
}
10 changes: 6 additions & 4 deletions coredb/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@ func Setup(dbp DBProvider) {
_dbp = dbp
}

var typeColumnNames = make(map[reflect.Type]string)
var typeColumnNamesLock sync.RWMutex
var (
typeColumnNames = make(map[reflect.Type]string)
typeColumnNamesLock sync.RWMutex
)

func getDB(dbname string, mode DBMode) *sql.DB {
if _dbp == nil {
Expand Down Expand Up @@ -269,7 +271,7 @@ func GetColumnsNames[T any]() (joinedColumnNames string) {
val := reflect.ValueOf(o).Elem()
for i := 0; i < val.NumField(); i++ {
valueField := val.Field(i)
if f, ok := valueField.Addr().Interface().(ColumnType); ok {
if f, ok := valueField.Addr().Interface().(ColumnNamer); ok {
columnNames = append(columnNames, "`"+f.GetColumnName()+"`")
}
}
Expand Down Expand Up @@ -315,7 +317,7 @@ func GetColumnsNamesReflect(o any) (joinedColumnNames string) {
val := reflect.New(elemType).Elem()
for i := 0; i < val.NumField(); i++ {
valueField := val.Field(i)
if f, ok := valueField.Addr().Interface().(ColumnType); ok {
if f, ok := valueField.Addr().Interface().(ColumnNamer); ok {
columnNames = append(columnNames, "`"+f.GetColumnName()+"`")
}
}
Expand Down
40 changes: 40 additions & 0 deletions coredb/params.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package coredb

import "reflect"

type Params struct {
params []any
}

func NewParams(params ...any) *Params {
return &Params{params}
}

func (p *Params) Add(params ...any) {
totalLen := len(p.params)
for i := 0; i < len(params); i++ {
if v := reflect.ValueOf(params[i]); v.Kind() == reflect.Slice {
totalLen += v.Len()
} else {
totalLen++
}
}

newSlice := make([]any, 0, totalLen)
newSlice = append(newSlice, p.params...)
for i := 0; i < len(params); i++ {
if v := reflect.ValueOf(params[i]); v.Kind() == reflect.Slice {
// append all elements in params[0] to p.params
for j := 0; j < v.Len(); j++ {
newSlice = append(newSlice, v.Index(j).Interface())
}
} else {
newSlice = append(newSlice, params[i])
}
}
p.params = newSlice
}

func (p *Params) Get() []any {
return p.params
}
41 changes: 41 additions & 0 deletions coredb/params_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package coredb

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestNewParams(t *testing.T) {
as := assert.New(t)
p := NewParams(1, "1", 1.0)
out := p.Get()
as.Equal(1, out[0])
as.Equal("1", out[1])
as.Equal(1.0, out[2])

p.Add([]string{"2", "two"}, "king", -66)
out = p.Get()
as.Equal(1, out[0])
as.Equal("1", out[1])
as.Equal(1.0, out[2])
as.Equal("2", out[3])
as.Equal("two", out[4])
as.Equal("king", out[5])
as.Equal(-66, out[6])

p.Add(0.6, 77, []int{88, 99, 100})
out = p.Get()
as.Equal(1, out[0])
as.Equal("1", out[1])
as.Equal(1.0, out[2])
as.Equal("2", out[3])
as.Equal("two", out[4])
as.Equal("king", out[5])
as.Equal(-66, out[6])
as.Equal(0.6, out[7])
as.Equal(77, out[8])
as.Equal(88, out[9])
as.Equal(99, out[10])
as.Equal(100, out[11])
}
2 changes: 1 addition & 1 deletion coredb/scan.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ func StrutForScan(u any) (pointers []any) {
pointers = make([]any, 0, val.NumField())
for i := 0; i < val.NumField(); i++ {
valueField := val.Field(i)
if f, ok := valueField.Addr().Interface().(ColumnType); ok {
if f, ok := valueField.Addr().Interface().(ColumnValPointer); ok {
pointers = append(pointers, f.GetValPointer())
}
}
Expand Down
1 change: 1 addition & 0 deletions test.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 9735f8c

Please sign in to comment.