Skip to content

Commit

Permalink
[PECO-1048] Add example for parameterized queries (databricks#168)
Browse files Browse the repository at this point in the history
We don't currently have an example for parameterized queries, we should
add one.
  • Loading branch information
nithinkdb authored Sep 29, 2023
2 parents acdb8ba + 1ba2c5e commit e95dd4a
Show file tree
Hide file tree
Showing 4 changed files with 167 additions and 93 deletions.
6 changes: 3 additions & 3 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -596,9 +596,9 @@ func (c *conn) pollOperation(ctx context.Context, opHandle *cli_service.TOperati

func (c *conn) CheckNamedValue(nv *driver.NamedValue) error {
var err error
if dbsqlParam, ok := nv.Value.(DBSqlParam); ok {
nv.Name = dbsqlParam.Name
dbsqlParam.Value, err = driver.DefaultParameterConverter.ConvertValue(dbsqlParam.Value)
if parameter, ok := nv.Value.(Parameter); ok {
nv.Name = parameter.Name
parameter.Value, err = driver.DefaultParameterConverter.ConvertValue(parameter.Value)
return err
}

Expand Down
68 changes: 68 additions & 0 deletions examples/parameters/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package main

import (
"context"
"database/sql"
"fmt"
"log"
"os"
"strconv"

dbsql "github.com/databricks/databricks-sql-go"
"github.com/joho/godotenv"
)

func main() {
// Opening a driver typically will not attempt to connect to the database.
err := godotenv.Load()

if err != nil {
log.Fatal(err.Error())
}
port, err := strconv.Atoi(os.Getenv("DATABRICKS_PORT"))
if err != nil {
log.Fatal(err.Error())
}
connector, err := dbsql.NewConnector(
dbsql.WithServerHostname(os.Getenv("DATABRICKS_HOST")),
dbsql.WithPort(port),
dbsql.WithHTTPPath(os.Getenv("DATABRICKS_HTTPPATH")),
dbsql.WithAccessToken(os.Getenv("DATABRICKS_ACCESSTOKEN")),
)
if err != nil {
// This will not be a connection error, but a DSN parse error or
// another initialization error.
log.Fatal(err)
}
db := sql.OpenDB(connector)
defer db.Close()
// ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
// defer cancel()
ctx := context.Background()
var p_bool bool
var p_int int
var p_double float64
var p_float float32
var p_date string
err1 := db.QueryRowContext(ctx, `SELECT
:p_bool AS col_bool,
:p_int AS col_int,
:p_double AS col_double,
:p_float AS col_float,
:p_date AS col_date`,
dbsql.Parameter{Name: "p_bool", Value: true},
dbsql.Parameter{Name: "p_int", Value: int(1234)},
dbsql.Parameter{Name: "p_double", Type: dbsql.SqlDouble, Value: "3.14"},
dbsql.Parameter{Name: "p_float", Type: dbsql.SqlFloat, Value: "3.14"},
dbsql.Parameter{Name: "p_date", Type: dbsql.SqlDate, Value: "2017-07-23 00:00:00"}).Scan(&p_bool, &p_int, &p_double, &p_float, &p_date)

if err1 != nil {
if err1 == sql.ErrNoRows {
fmt.Println("not found")
return
} else {
fmt.Printf("err: %v\n", err1)
}
}

}
4 changes: 2 additions & 2 deletions parameter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (

func TestParameter_Inference(t *testing.T) {
t.Run("Should infer types correctly", func(t *testing.T) {
values := [5]driver.NamedValue{{Name: "", Value: float32(5.1)}, {Name: "", Value: time.Now()}, {Name: "", Value: int64(5)}, {Name: "", Value: true}, {Name: "", Value: DBSqlParam{Value: "6.2", Type: Decimal}}}
values := [5]driver.NamedValue{{Name: "", Value: float32(5.1)}, {Name: "", Value: time.Now()}, {Name: "", Value: int64(5)}, {Name: "", Value: true}, {Name: "", Value: Parameter{Value: "6.2", Type: SqlDecimal}}}
parameters := convertNamedValuesToSparkParams(values[:])
assert.Equal(t, strconv.FormatFloat(float64(5.1), 'f', -1, 64), *parameters[0].Value.StringValue)
assert.NotNil(t, parameters[1].Value.StringValue)
Expand All @@ -25,7 +25,7 @@ func TestParameter_Inference(t *testing.T) {
}
func TestParameters_Names(t *testing.T) {
t.Run("Should infer types correctly", func(t *testing.T) {
values := [2]driver.NamedValue{{Name: "1", Value: int(26)}, {Name: "", Value: DBSqlParam{Name: "2", Type: Decimal, Value: "6.2"}}}
values := [2]driver.NamedValue{{Name: "1", Value: int(26)}, {Name: "", Value: Parameter{Name: "2", Type: SqlDecimal, Value: "6.2"}}}
parameters := convertNamedValuesToSparkParams(values[:])
assert.Equal(t, string("1"), *parameters[0].Name)
assert.Equal(t, cli_service.TSparkParameterValue{StringValue: strPtr("26")}, *parameters[0].Value)
Expand Down
182 changes: 94 additions & 88 deletions parameters.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,141 +10,147 @@ import (
"github.com/databricks/databricks-sql-go/internal/cli_service"
)

type DBSqlParam struct {
type Parameter struct {
Name string
Type SqlType
Value any
}

type SqlType int64
type SqlType int

const (
String SqlType = iota
Date
Timestamp
Float
Decimal
Double
Integer
BigInt
SmallInt
TinyInt
Boolean
IntervalMonth
IntervalDay
SqlUnkown SqlType = iota
SqlString
SqlDate
SqlTimestamp
SqlFloat
SqlDecimal
SqlDouble
SqlInteger
SqlBigInt
SqlSmallInt
SqlTinyInt
SqlBoolean
SqlIntervalMonth
SqlIntervalDay
)

func (s SqlType) String() string {
switch s {
case String:
case SqlString:
return "STRING"
case Date:
case SqlDate:
return "DATE"
case Timestamp:
case SqlTimestamp:
return "TIMESTAMP"
case Float:
case SqlFloat:
return "FLOAT"
case Decimal:
case SqlDecimal:
return "DECIMAL"
case Double:
case SqlDouble:
return "DOUBLE"
case Integer:
case SqlInteger:
return "INTEGER"
case BigInt:
case SqlBigInt:
return "BIGINT"
case SmallInt:
case SqlSmallInt:
return "SMALLINT"
case TinyInt:
case SqlTinyInt:
return "TINYINT"
case Boolean:
case SqlBoolean:
return "BOOLEAN"
case IntervalMonth:
case SqlIntervalMonth:
return "INTERVAL MONTH"
case IntervalDay:
case SqlIntervalDay:
return "INTERVAL DAY"
}
return "unknown"
}

func valuesToDBSQLParams(namedValues []driver.NamedValue) []DBSqlParam {
var params []DBSqlParam
func valuesToParameters(namedValues []driver.NamedValue) []Parameter {
var params []Parameter
for i := range namedValues {
newParam := *new(Parameter)
namedValue := namedValues[i]
param := *new(DBSqlParam)
param.Name = namedValue.Name
param.Value = namedValue.Value
params = append(params, param)
param, ok := namedValue.Value.(Parameter)
if ok {
newParam.Name = param.Name
newParam.Value = param.Value
newParam.Type = param.Type
} else {
newParam.Name = namedValue.Name
newParam.Value = namedValue.Value
}
params = append(params, newParam)
}
return params
}

func inferTypes(params []DBSqlParam) {
func inferTypes(params []Parameter) {
for i := range params {
param := &params[i]
switch value := param.Value.(type) {
case bool:
param.Value = strconv.FormatBool(value)
param.Type = Boolean
case string:
param.Value = value
param.Type = String
case int:
param.Value = strconv.Itoa(value)
param.Type = Integer
case uint:
param.Value = strconv.FormatUint(uint64(value), 10)
param.Type = Integer
case int8:
param.Value = strconv.Itoa(int(value))
param.Type = Integer
case uint8:
param.Value = strconv.FormatUint(uint64(value), 10)
param.Type = Integer
case int16:
param.Value = strconv.Itoa(int(value))
param.Type = Integer
case uint16:
param.Value = strconv.FormatUint(uint64(value), 10)
param.Type = Integer
case int32:
param.Value = strconv.Itoa(int(value))
param.Type = Integer
case uint32:
param.Value = strconv.FormatUint(uint64(value), 10)
param.Type = Integer
case int64:
param.Value = strconv.Itoa(int(value))
param.Type = Integer
case uint64:
param.Value = strconv.FormatUint(uint64(value), 10)
param.Type = Integer
case float32:
param.Value = strconv.FormatFloat(float64(value), 'f', -1, 32)
param.Type = Float
case time.Time:
param.Value = value.String()
param.Type = Timestamp
case DBSqlParam:
param.Name = value.Name
param.Value = value.Value
param.Type = value.Type
default:
s := fmt.Sprintf("%s", value)
param.Value = s
param.Type = String
if param.Type == SqlUnkown {
switch value := param.Value.(type) {
case bool:
param.Value = strconv.FormatBool(value)
param.Type = SqlBoolean
case string:
param.Value = value
param.Type = SqlString
case int:
param.Value = strconv.Itoa(value)
param.Type = SqlInteger
case uint:
param.Value = strconv.FormatUint(uint64(value), 10)
param.Type = SqlInteger
case int8:
param.Value = strconv.Itoa(int(value))
param.Type = SqlInteger
case uint8:
param.Value = strconv.FormatUint(uint64(value), 10)
param.Type = SqlInteger
case int16:
param.Value = strconv.Itoa(int(value))
param.Type = SqlInteger
case uint16:
param.Value = strconv.FormatUint(uint64(value), 10)
param.Type = SqlInteger
case int32:
param.Value = strconv.Itoa(int(value))
param.Type = SqlInteger
case uint32:
param.Value = strconv.FormatUint(uint64(value), 10)
param.Type = SqlInteger
case int64:
param.Value = strconv.Itoa(int(value))
param.Type = SqlInteger
case uint64:
param.Value = strconv.FormatUint(uint64(value), 10)
param.Type = SqlInteger
case float32:
param.Value = strconv.FormatFloat(float64(value), 'f', -1, 32)
param.Type = SqlFloat
case time.Time:
param.Value = value.String()
param.Type = SqlTimestamp
default:
s := fmt.Sprintf("%s", param.Value)
param.Value = s
param.Type = SqlString
}
}
}
}
func convertNamedValuesToSparkParams(values []driver.NamedValue) []*cli_service.TSparkParameter {
var sparkParams []*cli_service.TSparkParameter

sqlParams := valuesToDBSQLParams(values)
sqlParams := valuesToParameters(values)
inferTypes(sqlParams)
for i := range sqlParams {
sqlParam := sqlParams[i]
sparkParamValue := sqlParam.Value.(string)
var sparkParamType string
if sqlParam.Type == Decimal {
if sqlParam.Type == SqlDecimal {
sparkParamType = inferDecimalType(sparkParamValue)
} else {
sparkParamType = sqlParam.Type.String()
Expand Down

0 comments on commit e95dd4a

Please sign in to comment.