Skip to content

Commit

Permalink
Fixed bugs and added example
Browse files Browse the repository at this point in the history
Signed-off-by: nithinkdb <[email protected]>
  • Loading branch information
nithinkdb committed Sep 29, 2023
1 parent 0c5844e commit e338383
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 90 deletions.
15 changes: 10 additions & 5 deletions examples/parameters/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,22 @@ func main() {
// ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
// defer cancel()
ctx := context.Background()
var res string
var p_bool bool
var p_int int
var p_double float64
var p_float float32
var p_date string
err1 := db.QueryRowContext(ctx, `SELECT
:p_bool AS col_bool,
:p_int AS col_int,
:p_double AS col_double,
:p_floatfloat AS col_float`,
:p_float AS col_float,
:p_date AS col_date`,
dbsql.DBSqlParam{Name: "p_bool", Value: true},
dbsql.DBSqlParam{Name: "p_int", Value: int(1234)},
dbsql.DBSqlParam{Name: "p_double", Type: dbsql.Double, Value: 3.14},
dbsql.DBSqlParam{Name: "p_float", Type: dbsql.Float, Value: 3.14}).Scan(&res)
dbsql.DBSqlParam{Name: "p_double", Type: dbsql.SqlDouble, Value: "3.14"},
dbsql.DBSqlParam{Name: "p_float", Type: dbsql.SqlFloat, Value: "3.14"},
dbsql.DBSqlParam{Name: "p_date", Type: dbsql.SqlDate, Value: "2017-07-23 00:00:00"}).Scan(&p_bool, &p_int, &p_double, &p_float, &p_date)

if err1 != nil {
if err1 == sql.ErrNoRows {
Expand All @@ -58,6 +64,5 @@ func main() {
fmt.Printf("err: %v\n", err1)
}
}
fmt.Println(res)

}
4 changes: 2 additions & 2 deletions parameter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (

func TestParameter_Inference(t *testing.T) {
t.Run("Should infer types correctly", func(t *testing.T) {
values := [5]driver.NamedValue{{Name: "", Value: float32(5.1)}, {Name: "", Value: time.Now()}, {Name: "", Value: int64(5)}, {Name: "", Value: true}, {Name: "", Value: DBSqlParam{Value: "6.2", Type: Decimal}}}
values := [5]driver.NamedValue{{Name: "", Value: float32(5.1)}, {Name: "", Value: time.Now()}, {Name: "", Value: int64(5)}, {Name: "", Value: true}, {Name: "", Value: DBSqlParam{Value: "6.2", Type: SqlDecimal}}}
parameters := convertNamedValuesToSparkParams(values[:])
assert.Equal(t, strconv.FormatFloat(float64(5.1), 'f', -1, 64), *parameters[0].Value.StringValue)
assert.NotNil(t, parameters[1].Value.StringValue)
Expand All @@ -25,7 +25,7 @@ func TestParameter_Inference(t *testing.T) {
}
func TestParameters_Names(t *testing.T) {
t.Run("Should infer types correctly", func(t *testing.T) {
values := [2]driver.NamedValue{{Name: "1", Value: int(26)}, {Name: "", Value: DBSqlParam{Name: "2", Type: Decimal, Value: "6.2"}}}
values := [2]driver.NamedValue{{Name: "1", Value: int(26)}, {Name: "", Value: DBSqlParam{Name: "2", Type: SqlDecimal, Value: "6.2"}}}
parameters := convertNamedValuesToSparkParams(values[:])
assert.Equal(t, string("1"), *parameters[0].Name)
assert.Equal(t, cli_service.TSparkParameterValue{StringValue: strPtr("26")}, *parameters[0].Value)
Expand Down
172 changes: 89 additions & 83 deletions parameters.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,51 +16,52 @@ type DBSqlParam struct {
Value any
}

type SqlType int64
type SqlType int

const (
String SqlType = iota
Date
Timestamp
Float
Decimal
Double
Integer
BigInt
SmallInt
TinyInt
Boolean
IntervalMonth
IntervalDay
SqlUnkown SqlType = iota
SqlString
SqlDate
SqlTimestamp
SqlFloat
SqlDecimal
SqlDouble
SqlInteger
SqlBigInt
SqlSmallInt
SqlTinyInt
SqlBoolean
SqlIntervalMonth
SqlIntervalDay
)

func (s SqlType) String() string {
switch s {
case String:
case SqlString:
return "STRING"
case Date:
case SqlDate:
return "DATE"
case Timestamp:
case SqlTimestamp:
return "TIMESTAMP"
case Float:
case SqlFloat:
return "FLOAT"
case Decimal:
case SqlDecimal:
return "DECIMAL"
case Double:
case SqlDouble:
return "DOUBLE"
case Integer:
case SqlInteger:
return "INTEGER"
case BigInt:
case SqlBigInt:
return "BIGINT"
case SmallInt:
case SqlSmallInt:
return "SMALLINT"
case TinyInt:
case SqlTinyInt:
return "TINYINT"
case Boolean:
case SqlBoolean:
return "BOOLEAN"
case IntervalMonth:
case SqlIntervalMonth:
return "INTERVAL MONTH"
case IntervalDay:
case SqlIntervalDay:
return "INTERVAL DAY"
}
return "unknown"
Expand All @@ -69,69 +70,74 @@ func (s SqlType) String() string {
func valuesToDBSQLParams(namedValues []driver.NamedValue) []DBSqlParam {
var params []DBSqlParam
for i := range namedValues {
newParam := *new(DBSqlParam)
namedValue := namedValues[i]
param := *new(DBSqlParam)
param.Name = namedValue.Name
param.Value = namedValue.Value
params = append(params, param)
param, ok := namedValue.Value.(DBSqlParam)
if ok {
newParam.Name = param.Name
newParam.Value = param.Value
newParam.Type = param.Type
} else {
newParam.Name = namedValue.Name
newParam.Value = namedValue.Value
}
params = append(params, newParam)
}
return params
}

func inferTypes(params []DBSqlParam) {
for i := range params {
param := &params[i]
switch value := param.Value.(type) {
case bool:
param.Value = strconv.FormatBool(value)
param.Type = Boolean
case string:
param.Value = value
param.Type = String
case int:
param.Value = strconv.Itoa(value)
param.Type = Integer
case uint:
param.Value = strconv.FormatUint(uint64(value), 10)
param.Type = Integer
case int8:
param.Value = strconv.Itoa(int(value))
param.Type = Integer
case uint8:
param.Value = strconv.FormatUint(uint64(value), 10)
param.Type = Integer
case int16:
param.Value = strconv.Itoa(int(value))
param.Type = Integer
case uint16:
param.Value = strconv.FormatUint(uint64(value), 10)
param.Type = Integer
case int32:
param.Value = strconv.Itoa(int(value))
param.Type = Integer
case uint32:
param.Value = strconv.FormatUint(uint64(value), 10)
param.Type = Integer
case int64:
param.Value = strconv.Itoa(int(value))
param.Type = Integer
case uint64:
param.Value = strconv.FormatUint(uint64(value), 10)
param.Type = Integer
case float32:
param.Value = strconv.FormatFloat(float64(value), 'f', -1, 32)
param.Type = Float
case time.Time:
param.Value = value.String()
param.Type = Timestamp
case DBSqlParam:
param.Name = value.Name
param.Value = value.Value
param.Type = value.Type
default:
s := fmt.Sprintf("%s", value)
param.Value = s
param.Type = String
if param.Type == SqlUnkown {
switch value := param.Value.(type) {
case bool:
param.Value = strconv.FormatBool(value)
param.Type = SqlBoolean
case string:
param.Value = value
param.Type = SqlString
case int:
param.Value = strconv.Itoa(value)
param.Type = SqlInteger
case uint:
param.Value = strconv.FormatUint(uint64(value), 10)
param.Type = SqlInteger
case int8:
param.Value = strconv.Itoa(int(value))
param.Type = SqlInteger
case uint8:
param.Value = strconv.FormatUint(uint64(value), 10)
param.Type = SqlInteger
case int16:
param.Value = strconv.Itoa(int(value))
param.Type = SqlInteger
case uint16:
param.Value = strconv.FormatUint(uint64(value), 10)
param.Type = SqlInteger
case int32:
param.Value = strconv.Itoa(int(value))
param.Type = SqlInteger
case uint32:
param.Value = strconv.FormatUint(uint64(value), 10)
param.Type = SqlInteger
case int64:
param.Value = strconv.Itoa(int(value))
param.Type = SqlInteger
case uint64:
param.Value = strconv.FormatUint(uint64(value), 10)
param.Type = SqlInteger
case float32:
param.Value = strconv.FormatFloat(float64(value), 'f', -1, 32)
param.Type = SqlFloat
case time.Time:
param.Value = value.String()
param.Type = SqlTimestamp
default:
s := fmt.Sprintf("%s", param.Value)
param.Value = s
param.Type = SqlString
}
}
}
}
Expand All @@ -144,7 +150,7 @@ func convertNamedValuesToSparkParams(values []driver.NamedValue) []*cli_service.
sqlParam := sqlParams[i]
sparkParamValue := sqlParam.Value.(string)
var sparkParamType string
if sqlParam.Type == Decimal {
if sqlParam.Type == SqlDecimal {
sparkParamType = inferDecimalType(sparkParamValue)
} else {
sparkParamType = sqlParam.Type.String()
Expand Down

0 comments on commit e338383

Please sign in to comment.