Skip to content

Commit

Permalink
#108: Convert argument values to expected type
Browse files Browse the repository at this point in the history
  • Loading branch information
kaklakariada committed Jun 18, 2024
1 parent 872e051 commit 0bd89d5
Show file tree
Hide file tree
Showing 7 changed files with 254 additions and 5 deletions.
123 changes: 123 additions & 0 deletions itest/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,129 @@ func (suite *IntegrationTestSuite) TestPreparedStatement() {
suite.assertSingleValueResult(rows, "15")
}

var dereferenceString = func(v any) any { return *(v.(*string)) }
var dereferenceFloat64 = func(v any) any { return *(v.(*float64)) }
var dereferenceInt64 = func(v any) any { return *(v.(*int64)) }
var dereferenceInt = func(v any) any { return *(v.(*int)) }
var dereferenceBool = func(v any) any { return *(v.(*bool)) }
var dereferenceTime = func(v any) any { return *(v.(*time.Time)) }

Check failure on line 232 in itest/integration_test.go

View workflow job for this annotation

GitHub Actions / Build with go version 1.21 and db 7.1.26

var `dereferenceTime` is unused (unused)

Check failure on line 232 in itest/integration_test.go

View workflow job for this annotation

GitHub Actions / Build with go version 1.21 and db 8.27.0

var `dereferenceTime` is unused (unused)

Check failure on line 232 in itest/integration_test.go

View workflow job for this annotation

GitHub Actions / Build with go version 1.22 and db 7.1.26

var `dereferenceTime` is unused (unused)

func (suite *IntegrationTestSuite) TestQueryDataTypesCast() {
for i, testCase := range []struct {
testDescription string
sqlValue string
sqlType string
scanDest any
expectedValue any
dereference func(any) any
}{
{"decimal to int64", "1", "DECIMAL(18,0)", new(int64), int64(1), dereferenceInt64},
{"decimal to int", "1", "DECIMAL(18,0)", new(int), 1, dereferenceInt},
{"decimal to float", "1", "DECIMAL(18,0)", new(float64), 1.0, dereferenceFloat64},
{"decimal to string", "1", "DECIMAL(18,0)", new(string), "1", dereferenceString},
{"decimal to float64", "2.2", "DECIMAL(18,2)", new(float64), 2.2, dereferenceFloat64},
{"decimal to string", "2.2", "DECIMAL(18,2)", new(string), "2.2", dereferenceString},
{"double to float64", "3.3", "DOUBLE PRECISION", new(float64), 3.3, dereferenceFloat64},
{"double to string", "3.3", "DOUBLE PRECISION", new(string), "3.3", dereferenceString},
{"varchar to string", "'text'", "VARCHAR(10)", new(string), "text", dereferenceString},
{"char to string", "'text'", "CHAR(10)", new(string), "text ", dereferenceString},
{"date to string", "'2024-06-18'", "DATE", new(string), "2024-06-18", dereferenceString},
{"timestamp to string", "'2024-06-18 17:22:13.123456'", "TIMESTAMP", new(string), "2024-06-18 17:22:13.123000", dereferenceString},
{"timestamp with local time zone to string", "'2024-06-18 17:22:13.123456'", "TIMESTAMP WITH LOCAL TIME ZONE", new(string), "2024-06-18 17:22:13.123000", dereferenceString},
{"geometry to string", "'point(1 2)'", "GEOMETRY", new(string), "POINT (1 2)", dereferenceString},
{"interval ytm to string", "'5-3'", "INTERVAL YEAR TO MONTH", new(string), "+05-03", dereferenceString},
{"interval dts to string", "'2 12:50:10.123'", "INTERVAL DAY TO SECOND", new(string), "+02 12:50:10.123", dereferenceString},
{"hashtype to string", "'550e8400-e29b-11d4-a716-446655440000'", "HASHTYPE", new(string), "550e8400e29b11d4a716446655440000", dereferenceString},
{"bool to bool", "true", "BOOLEAN", new(bool), true, dereferenceBool},
{"bool to string", "false", "BOOLEAN", new(string), "false", dereferenceString},
} {
database := suite.openConnection(suite.createDefaultConfig())
defer database.Close()
suite.Run(fmt.Sprintf("Cast Test %02d %s: %s", i, testCase.testDescription, testCase.sqlType), func() {
rows, err := database.Query(fmt.Sprintf("SELECT CAST(%s AS %s)", testCase.sqlValue, testCase.sqlType))
onError(err)
defer rows.Close()
suite.True(rows.Next(), "should have one row")
onError(rows.Scan(testCase.scanDest))
val := testCase.scanDest
suite.Equal(testCase.expectedValue, testCase.dereference(val))
})
}
}

func (suite *IntegrationTestSuite) TestQueryDataTypesPreparedStatement() {
for i, testCase := range []struct {
testDescription string
sqlValue any
sqlType string
scanDest any
expectedValue any
dereference func(any) any
}{
{"decimal to int64", 1, "DECIMAL(18,0)", new(int64), int64(1), dereferenceInt64},
{"decimal to int", 1, "DECIMAL(18,0)", new(int), 1, dereferenceInt},
{"decimal to float", 1, "DECIMAL(18,0)", new(float64), 1.0, dereferenceFloat64},
{"decimal to float64", 2.2, "DECIMAL(18,2)", new(float64), 2.2, dereferenceFloat64},
{"double to float64", 3.3, "DOUBLE PRECISION", new(float64), 3.3, dereferenceFloat64},
{"varchar to string", "text", "VARCHAR(10)", new(string), "text", dereferenceString},
{"char to string", "text", "CHAR(10)", new(string), "text ", dereferenceString},
{"date to string", "2024-06-18", "DATE", new(string), "2024-06-18", dereferenceString},
{"date to string", time.Date(2024, time.June, 18, 0, 0, 0, 0, time.UTC), "DATE", new(string), "2024-06-18", dereferenceString},
{"timestamp to string", "2024-06-18 17:22:13.123456", "TIMESTAMP", new(string), "2024-06-18 17:22:13.123000", dereferenceString},
{"timestamp to string", time.Date(2024, time.June, 18, 17, 22, 13, 123456789, time.UTC), "TIMESTAMP", new(string), "2024-06-18 17:22:13.123000", dereferenceString},
//{"timestamp to timestamp", time.Date(2024, time.June, 18, 17, 22, 13, 123456789, time.UTC), "TIMESTAMP", new(time.Time), time.Date(2024, time.June, 18, 17, 22, 13, 123456789, time.UTC), dereferenceTime},
{"timestamp with local time zone to string", "2024-06-18 17:22:13.123456", "TIMESTAMP WITH LOCAL TIME ZONE", new(string), "2024-06-18 17:22:13.123000", dereferenceString},
{"timestamp with local time zone to string", time.Date(2024, time.June, 18, 17, 22, 13, 123456789, time.UTC), "TIMESTAMP WITH LOCAL TIME ZONE", new(string), "2024-06-18 17:22:13.123000", dereferenceString},
//{"timestamp with local time zone to string", time.Date(2024, time.June, 18, 17, 22, 13, 123456789, time.UTC), "TIMESTAMP WITH LOCAL TIME ZONE", new(time.Time), time.Date(2024, time.June, 18, 17, 22, 13, 123456789, time.UTC), dereferenceTime},
{"geometry to string", "point(1 2)", "GEOMETRY", new(string), "POINT (1 2)", dereferenceString},
{"interval ytm to string", "5-3", "INTERVAL YEAR TO MONTH", new(string), "+05-03", dereferenceString},
{"interval dts to string", "2 12:50:10.123", "INTERVAL DAY TO SECOND", new(string), "+02 12:50:10.123", dereferenceString},
{"hashtype to string", "550e8400-e29b-11d4-a716-446655440000", "HASHTYPE", new(string), "550e8400e29b11d4a716446655440000", dereferenceString},
{"bool to bool", true, "BOOLEAN", new(bool), true, dereferenceBool},
} {
database := suite.openConnection(suite.createDefaultConfig().Autocommit(false))
schemaName := "DATATYPE_TEST"
_, err := database.Exec("CREATE SCHEMA " + schemaName)
onError(err)
defer suite.cleanup(database, schemaName)

suite.Run(fmt.Sprintf("Prepared statement Test %02d %s: %s", i, testCase.testDescription, testCase.sqlType), func() {
tableName := fmt.Sprintf("%s.TAB_%d", schemaName, i)
_, err = database.Exec(fmt.Sprintf("CREATE TABLE %s (col %s)", tableName, testCase.sqlType))
onError(err)
stmt, err := database.Prepare(fmt.Sprintf("insert into %s values (?)", tableName))
onError(err)
_, err = stmt.Exec(testCase.sqlValue)
onError(err)
rows, err := database.Query(fmt.Sprintf("select * from %s", tableName))
onError(err)
defer rows.Close()
suite.True(rows.Next(), "should have one row")
onError(rows.Scan(testCase.scanDest))
val := testCase.scanDest
suite.Equal(testCase.expectedValue, testCase.dereference(val))
})
}
}

// https://github.com/exasol/exasol-driver-go/issues/108
func (suite *IntegrationTestSuite) TestPreparedStatementIntConvertedToFloat() {
database := suite.openConnection(suite.createDefaultConfig())
schemaName := "TEST_SCHEMA_3"
_, err := database.Exec("CREATE SCHEMA " + schemaName)
onError(err)
_, err = database.Exec(fmt.Sprintf("create or replace table %s.dummy (a integer, b float)", schemaName))
onError(err)
defer suite.cleanup(database, schemaName)
stmt, err := database.Prepare(fmt.Sprintf("insert into %s.dummy values(?,?)", schemaName))
onError(err)
_, err = stmt.Exec(1, 2)
onError(err)
rows, err := database.Query(fmt.Sprintf("select a || ':' || b from %s.dummy", schemaName))
onError(err)
suite.assertSingleValueResult(rows, "1:2")
}

func (suite *IntegrationTestSuite) TestQueryWithValuesAndContext() {
database := suite.openConnection(suite.createDefaultConfig())
schemaName := "TEST_SCHEMA_3_2"
Expand Down
87 changes: 87 additions & 0 deletions pkg/connection/converter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
package connection

import (
"database/sql/driver"
"encoding/json"
"fmt"
"time"

"github.com/exasol/exasol-driver-go/pkg/errors"
"github.com/exasol/exasol-driver-go/pkg/types"
)

func convertArg(arg driver.Value, colType types.SqlQueryColumnType) (interface{}, error) {
dataType := colType.Type
if dataType == "DOUBLE" {
if intArg, ok := arg.(int64); ok {
return jsonDoubleValue(float64(intArg)), nil
}
if floatArg, ok := arg.(float64); ok {
return jsonDoubleValue(floatArg), nil
}
return nil, errors.NewInvalidArgType(arg, dataType)
}
if dataType == "TIMESTAMP" || dataType == "TIMESTAMP WITH LOCAL TIME ZONE" {
if timeArg, ok := arg.(time.Time); ok {
return jsonTimestampValue(timeArg), nil
}
if stringArg, ok := arg.(string); ok {
return stringArg, nil
}
return nil, errors.NewInvalidArgType(arg, dataType)
}
if dataType == "DATE" {
if timeArg, ok := arg.(time.Time); ok {
return jsonDateValue(timeArg), nil
}
if stringArg, ok := arg.(string); ok {
return stringArg, nil
}
return nil, errors.NewInvalidArgType(arg, dataType)
}
if dataType == "BOOLEAN" {
if boolArg, ok := arg.(bool); ok {
return boolArg, nil
}
return nil, errors.NewInvalidArgType(arg, dataType)
}
return arg, nil
}

func jsonDoubleValue(value float64) json.Marshaler {
return &jsonDoubleValueStruct{value: value}
}

type jsonDoubleValueStruct struct {
value float64
}

func (j *jsonDoubleValueStruct) MarshalJSON() ([]byte, error) {
return []byte(fmt.Sprintf("%f", j.value)), nil
}

func jsonTimestampValue(value time.Time) json.Marshaler {
return &jsonTimestampValueStruct{value: value}
}

type jsonTimestampValueStruct struct {
value time.Time
}

func (j *jsonTimestampValueStruct) MarshalJSON() ([]byte, error) {
// Exasol expects format YYYY-MM-DD HH24:MI:SS.FF6
return []byte(fmt.Sprintf(`"%s"`, j.value.Format("2006-01-02 15:04:05.000000"))), nil
}

func jsonDateValue(value time.Time) json.Marshaler {
return &jsonDateValueStruct{value: value}
}

type jsonDateValueStruct struct {
value time.Time
}

func (j *jsonDateValueStruct) MarshalJSON() ([]byte, error) {
// Exasol expects format YYYY-MM-DD
return []byte(fmt.Sprintf(`"%s"`, j.value.Format("2006-01-02"))), nil
}
8 changes: 7 additions & 1 deletion pkg/connection/result_set.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,17 @@ func (results *QueryResults) Next(dest []driver.Value) error {
}

for i := range dest {
dest[i] = results.data.Data[i][results.rowPointer]
dataType := results.data.Columns[i].DataType
value := results.data.Data[i][results.rowPointer]
dest[i] = convertResultSetValue(dataType, value)
}

results.rowPointer = results.rowPointer + 1
results.totalRowPointer = results.totalRowPointer + 1

return nil
}

func convertResultSetValue(dataType types.SqlQueryColumnType, value any) driver.Value {
return value
}
12 changes: 9 additions & 3 deletions pkg/connection/statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,16 @@ func (s *Statement) executePreparedStatement(ctx context.Context, args []driver.

data := make([][]interface{}, len(columns))
for i, arg := range args {
if data[i%len(columns)] == nil {
data[i%len(columns)] = make([]interface{}, 0)
col := i % len(columns)
colType := columns[col]
if data[col] == nil {
data[col] = make([]interface{}, 0)
}
data[i%len(columns)] = append(data[i%len(columns)], arg)
convertedArg, err := convertArg(arg, colType.DataType)
if err != nil {
return nil, err
}
data[col] = append(data[col], convertedArg)
}

command := &types.ExecutePreparedStatementCommand{
Expand Down
1 change: 1 addition & 0 deletions pkg/connection/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ func (c *Connection) asyncSend(request interface{}) (func(interface{}) error, er
logger.ErrorLogger.Print(errors.NewMarshallingError(request, err))
return nil, driver.ErrBadConn
}
logger.TraceLogger.Printf("Sending message: %s", message)

messageType := websocket.TextMessage
if c.Config.Compression {
Expand Down
9 changes: 9 additions & 0 deletions pkg/errors/errors.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package errors

import (
"fmt"
"net/url"

exaerror "github.com/exasol/error-reporting-go"
Expand Down Expand Up @@ -147,6 +148,14 @@ func NewWebsocketNotConnected(request interface{}) DriverErr {
Parameter("request", request))
}

func NewInvalidArgType(value interface{}, expectedType string) DriverErr {
return NewDriverErr(exaerror.New("E-EGOD-30").
Message("cannot convert argument {{value}} of type {{type}} to {{expected type}} type").
Parameter("value", value).
Parameter("type", fmt.Sprintf("%T", value)).
Parameter(("expected type"), expectedType))
}

// DriverErr This type represents an error that can occur when working with a database connection.
type DriverErr string

Expand Down
19 changes: 18 additions & 1 deletion pkg/logger/logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ import (
"github.com/exasol/exasol-driver-go/pkg/errors"
)

var ErrorLogger = Logger(log.New(os.Stderr, "[exasol] ", log.LstdFlags|log.Lshortfile))
var ErrorLogger Logger = log.New(os.Stderr, "[exasol] ", log.LstdFlags|log.Lshortfile)
var TraceLogger Logger = noOpLogger{}

// Logger is used to log critical error messages.
type Logger interface {
Expand All @@ -24,3 +25,19 @@ func SetLogger(logger Logger) error {
ErrorLogger = logger
return nil
}

// SetTraceLogger is used to set the logger for tracing.
// The initial logger is a no-op logger. Please note that this will generate a lot of output.
// Set the logger to nil to disable tracing.
func SetTraceLogger(logger Logger) {
if logger == nil {
TraceLogger = noOpLogger{}
} else {
TraceLogger = logger
}
}

type noOpLogger struct{}

func (noOpLogger) Print(v ...interface{}) { /* no-op */ }
func (noOpLogger) Printf(format string, v ...interface{}) { /* no-op */ }

0 comments on commit 0bd89d5

Please sign in to comment.