diff --git a/itest/integration_test.go b/itest/integration_test.go index 90c4d7f..04defaa 100644 --- a/itest/integration_test.go +++ b/itest/integration_test.go @@ -224,6 +224,129 @@ func (suite *IntegrationTestSuite) TestPreparedStatement() { suite.assertSingleValueResult(rows, "15") } +var dereferenceString = func(v any) any { return *(v.(*string)) } +var dereferenceFloat64 = func(v any) any { return *(v.(*float64)) } +var dereferenceInt64 = func(v any) any { return *(v.(*int64)) } +var dereferenceInt = func(v any) any { return *(v.(*int)) } +var dereferenceBool = func(v any) any { return *(v.(*bool)) } +var dereferenceTime = func(v any) any { return *(v.(*time.Time)) } + +func (suite *IntegrationTestSuite) TestQueryDataTypesCast() { + for i, testCase := range []struct { + testDescription string + sqlValue string + sqlType string + scanDest any + expectedValue any + dereference func(any) any + }{ + {"decimal to int64", "1", "DECIMAL(18,0)", new(int64), int64(1), dereferenceInt64}, + {"decimal to int", "1", "DECIMAL(18,0)", new(int), 1, dereferenceInt}, + {"decimal to float", "1", "DECIMAL(18,0)", new(float64), 1.0, dereferenceFloat64}, + {"decimal to string", "1", "DECIMAL(18,0)", new(string), "1", dereferenceString}, + {"decimal to float64", "2.2", "DECIMAL(18,2)", new(float64), 2.2, dereferenceFloat64}, + {"decimal to string", "2.2", "DECIMAL(18,2)", new(string), "2.2", dereferenceString}, + {"double to float64", "3.3", "DOUBLE PRECISION", new(float64), 3.3, dereferenceFloat64}, + {"double to string", "3.3", "DOUBLE PRECISION", new(string), "3.3", dereferenceString}, + {"varchar to string", "'text'", "VARCHAR(10)", new(string), "text", dereferenceString}, + {"char to string", "'text'", "CHAR(10)", new(string), "text ", dereferenceString}, + {"date to string", "'2024-06-18'", "DATE", new(string), "2024-06-18", dereferenceString}, + {"timestamp to string", "'2024-06-18 17:22:13.123456'", "TIMESTAMP", new(string), "2024-06-18 17:22:13.123000", dereferenceString}, + {"timestamp with local time zone to string", "'2024-06-18 17:22:13.123456'", "TIMESTAMP WITH LOCAL TIME ZONE", new(string), "2024-06-18 17:22:13.123000", dereferenceString}, + {"geometry to string", "'point(1 2)'", "GEOMETRY", new(string), "POINT (1 2)", dereferenceString}, + {"interval ytm to string", "'5-3'", "INTERVAL YEAR TO MONTH", new(string), "+05-03", dereferenceString}, + {"interval dts to string", "'2 12:50:10.123'", "INTERVAL DAY TO SECOND", new(string), "+02 12:50:10.123", dereferenceString}, + {"hashtype to string", "'550e8400-e29b-11d4-a716-446655440000'", "HASHTYPE", new(string), "550e8400e29b11d4a716446655440000", dereferenceString}, + {"bool to bool", "true", "BOOLEAN", new(bool), true, dereferenceBool}, + {"bool to string", "false", "BOOLEAN", new(string), "false", dereferenceString}, + } { + database := suite.openConnection(suite.createDefaultConfig()) + defer database.Close() + suite.Run(fmt.Sprintf("Cast Test %02d %s: %s", i, testCase.testDescription, testCase.sqlType), func() { + rows, err := database.Query(fmt.Sprintf("SELECT CAST(%s AS %s)", testCase.sqlValue, testCase.sqlType)) + onError(err) + defer rows.Close() + suite.True(rows.Next(), "should have one row") + onError(rows.Scan(testCase.scanDest)) + val := testCase.scanDest + suite.Equal(testCase.expectedValue, testCase.dereference(val)) + }) + } +} + +func (suite *IntegrationTestSuite) TestQueryDataTypesPreparedStatement() { + for i, testCase := range []struct { + testDescription string + sqlValue any + sqlType string + scanDest any + expectedValue any + dereference func(any) any + }{ + {"decimal to int64", 1, "DECIMAL(18,0)", new(int64), int64(1), dereferenceInt64}, + {"decimal to int", 1, "DECIMAL(18,0)", new(int), 1, dereferenceInt}, + {"decimal to float", 1, "DECIMAL(18,0)", new(float64), 1.0, dereferenceFloat64}, + {"decimal to float64", 2.2, "DECIMAL(18,2)", new(float64), 2.2, dereferenceFloat64}, + {"double to float64", 3.3, "DOUBLE PRECISION", new(float64), 3.3, dereferenceFloat64}, + {"varchar to string", "text", "VARCHAR(10)", new(string), "text", dereferenceString}, + {"char to string", "text", "CHAR(10)", new(string), "text ", dereferenceString}, + {"date to string", "2024-06-18", "DATE", new(string), "2024-06-18", dereferenceString}, + {"date to string", time.Date(2024, time.June, 18, 0, 0, 0, 0, time.UTC), "DATE", new(string), "2024-06-18", dereferenceString}, + {"timestamp to string", "2024-06-18 17:22:13.123456", "TIMESTAMP", new(string), "2024-06-18 17:22:13.123000", dereferenceString}, + {"timestamp to string", time.Date(2024, time.June, 18, 17, 22, 13, 123456789, time.UTC), "TIMESTAMP", new(string), "2024-06-18 17:22:13.123000", dereferenceString}, + //{"timestamp to timestamp", time.Date(2024, time.June, 18, 17, 22, 13, 123456789, time.UTC), "TIMESTAMP", new(time.Time), time.Date(2024, time.June, 18, 17, 22, 13, 123456789, time.UTC), dereferenceTime}, + {"timestamp with local time zone to string", "2024-06-18 17:22:13.123456", "TIMESTAMP WITH LOCAL TIME ZONE", new(string), "2024-06-18 17:22:13.123000", dereferenceString}, + {"timestamp with local time zone to string", time.Date(2024, time.June, 18, 17, 22, 13, 123456789, time.UTC), "TIMESTAMP WITH LOCAL TIME ZONE", new(string), "2024-06-18 17:22:13.123000", dereferenceString}, + //{"timestamp with local time zone to string", time.Date(2024, time.June, 18, 17, 22, 13, 123456789, time.UTC), "TIMESTAMP WITH LOCAL TIME ZONE", new(time.Time), time.Date(2024, time.June, 18, 17, 22, 13, 123456789, time.UTC), dereferenceTime}, + {"geometry to string", "point(1 2)", "GEOMETRY", new(string), "POINT (1 2)", dereferenceString}, + {"interval ytm to string", "5-3", "INTERVAL YEAR TO MONTH", new(string), "+05-03", dereferenceString}, + {"interval dts to string", "2 12:50:10.123", "INTERVAL DAY TO SECOND", new(string), "+02 12:50:10.123", dereferenceString}, + {"hashtype to string", "550e8400-e29b-11d4-a716-446655440000", "HASHTYPE", new(string), "550e8400e29b11d4a716446655440000", dereferenceString}, + {"bool to bool", true, "BOOLEAN", new(bool), true, dereferenceBool}, + } { + database := suite.openConnection(suite.createDefaultConfig().Autocommit(false)) + schemaName := "DATATYPE_TEST" + _, err := database.Exec("CREATE SCHEMA " + schemaName) + onError(err) + defer suite.cleanup(database, schemaName) + + suite.Run(fmt.Sprintf("Prepared statement Test %02d %s: %s", i, testCase.testDescription, testCase.sqlType), func() { + tableName := fmt.Sprintf("%s.TAB_%d", schemaName, i) + _, err = database.Exec(fmt.Sprintf("CREATE TABLE %s (col %s)", tableName, testCase.sqlType)) + onError(err) + stmt, err := database.Prepare(fmt.Sprintf("insert into %s values (?)", tableName)) + onError(err) + _, err = stmt.Exec(testCase.sqlValue) + onError(err) + rows, err := database.Query(fmt.Sprintf("select * from %s", tableName)) + onError(err) + defer rows.Close() + suite.True(rows.Next(), "should have one row") + onError(rows.Scan(testCase.scanDest)) + val := testCase.scanDest + suite.Equal(testCase.expectedValue, testCase.dereference(val)) + }) + } +} + +// https://github.com/exasol/exasol-driver-go/issues/108 +func (suite *IntegrationTestSuite) TestPreparedStatementIntConvertedToFloat() { + database := suite.openConnection(suite.createDefaultConfig()) + schemaName := "TEST_SCHEMA_3" + _, err := database.Exec("CREATE SCHEMA " + schemaName) + onError(err) + _, err = database.Exec(fmt.Sprintf("create or replace table %s.dummy (a integer, b float)", schemaName)) + onError(err) + defer suite.cleanup(database, schemaName) + stmt, err := database.Prepare(fmt.Sprintf("insert into %s.dummy values(?,?)", schemaName)) + onError(err) + _, err = stmt.Exec(1, 2) + onError(err) + rows, err := database.Query(fmt.Sprintf("select a || ':' || b from %s.dummy", schemaName)) + onError(err) + suite.assertSingleValueResult(rows, "1:2") +} + func (suite *IntegrationTestSuite) TestQueryWithValuesAndContext() { database := suite.openConnection(suite.createDefaultConfig()) schemaName := "TEST_SCHEMA_3_2" diff --git a/pkg/connection/converter.go b/pkg/connection/converter.go new file mode 100644 index 0000000..ae5d092 --- /dev/null +++ b/pkg/connection/converter.go @@ -0,0 +1,87 @@ +package connection + +import ( + "database/sql/driver" + "encoding/json" + "fmt" + "time" + + "github.com/exasol/exasol-driver-go/pkg/errors" + "github.com/exasol/exasol-driver-go/pkg/types" +) + +func convertArg(arg driver.Value, colType types.SqlQueryColumnType) (interface{}, error) { + dataType := colType.Type + if dataType == "DOUBLE" { + if intArg, ok := arg.(int64); ok { + return jsonDoubleValue(float64(intArg)), nil + } + if floatArg, ok := arg.(float64); ok { + return jsonDoubleValue(floatArg), nil + } + return nil, errors.NewInvalidArgType(arg, dataType) + } + if dataType == "TIMESTAMP" || dataType == "TIMESTAMP WITH LOCAL TIME ZONE" { + if timeArg, ok := arg.(time.Time); ok { + return jsonTimestampValue(timeArg), nil + } + if stringArg, ok := arg.(string); ok { + return stringArg, nil + } + return nil, errors.NewInvalidArgType(arg, dataType) + } + if dataType == "DATE" { + if timeArg, ok := arg.(time.Time); ok { + return jsonDateValue(timeArg), nil + } + if stringArg, ok := arg.(string); ok { + return stringArg, nil + } + return nil, errors.NewInvalidArgType(arg, dataType) + } + if dataType == "BOOLEAN" { + if boolArg, ok := arg.(bool); ok { + return boolArg, nil + } + return nil, errors.NewInvalidArgType(arg, dataType) + } + return arg, nil +} + +func jsonDoubleValue(value float64) json.Marshaler { + return &jsonDoubleValueStruct{value: value} +} + +type jsonDoubleValueStruct struct { + value float64 +} + +func (j *jsonDoubleValueStruct) MarshalJSON() ([]byte, error) { + return []byte(fmt.Sprintf("%f", j.value)), nil +} + +func jsonTimestampValue(value time.Time) json.Marshaler { + return &jsonTimestampValueStruct{value: value} +} + +type jsonTimestampValueStruct struct { + value time.Time +} + +func (j *jsonTimestampValueStruct) MarshalJSON() ([]byte, error) { + // Exasol expects format YYYY-MM-DD HH24:MI:SS.FF6 + return []byte(fmt.Sprintf(`"%s"`, j.value.Format("2006-01-02 15:04:05.000000"))), nil +} + +func jsonDateValue(value time.Time) json.Marshaler { + return &jsonDateValueStruct{value: value} +} + +type jsonDateValueStruct struct { + value time.Time +} + +func (j *jsonDateValueStruct) MarshalJSON() ([]byte, error) { + // Exasol expects format YYYY-MM-DD + return []byte(fmt.Sprintf(`"%s"`, j.value.Format("2006-01-02"))), nil +} diff --git a/pkg/connection/result_set.go b/pkg/connection/result_set.go index 8186bc8..f920a97 100644 --- a/pkg/connection/result_set.go +++ b/pkg/connection/result_set.go @@ -101,7 +101,9 @@ func (results *QueryResults) Next(dest []driver.Value) error { } for i := range dest { - dest[i] = results.data.Data[i][results.rowPointer] + dataType := results.data.Columns[i].DataType + value := results.data.Data[i][results.rowPointer] + dest[i] = convertResultSetValue(dataType, value) } results.rowPointer = results.rowPointer + 1 @@ -109,3 +111,7 @@ func (results *QueryResults) Next(dest []driver.Value) error { return nil } + +func convertResultSetValue(dataType types.SqlQueryColumnType, value any) driver.Value { + return value +} diff --git a/pkg/connection/statement.go b/pkg/connection/statement.go index 0b0d9e2..8549d8e 100644 --- a/pkg/connection/statement.go +++ b/pkg/connection/statement.go @@ -83,10 +83,16 @@ func (s *Statement) executePreparedStatement(ctx context.Context, args []driver. data := make([][]interface{}, len(columns)) for i, arg := range args { - if data[i%len(columns)] == nil { - data[i%len(columns)] = make([]interface{}, 0) + col := i % len(columns) + colType := columns[col] + if data[col] == nil { + data[col] = make([]interface{}, 0) } - data[i%len(columns)] = append(data[i%len(columns)], arg) + convertedArg, err := convertArg(arg, colType.DataType) + if err != nil { + return nil, err + } + data[col] = append(data[col], convertedArg) } command := &types.ExecutePreparedStatementCommand{ diff --git a/pkg/connection/websocket.go b/pkg/connection/websocket.go index 3bcfa69..8547061 100644 --- a/pkg/connection/websocket.go +++ b/pkg/connection/websocket.go @@ -91,6 +91,7 @@ func (c *Connection) asyncSend(request interface{}) (func(interface{}) error, er logger.ErrorLogger.Print(errors.NewMarshallingError(request, err)) return nil, driver.ErrBadConn } + logger.TraceLogger.Printf("Sending message: %s", message) messageType := websocket.TextMessage if c.Config.Compression { diff --git a/pkg/errors/errors.go b/pkg/errors/errors.go index a2fe8a5..fc37fa2 100644 --- a/pkg/errors/errors.go +++ b/pkg/errors/errors.go @@ -1,6 +1,7 @@ package errors import ( + "fmt" "net/url" exaerror "github.com/exasol/error-reporting-go" @@ -147,6 +148,14 @@ func NewWebsocketNotConnected(request interface{}) DriverErr { Parameter("request", request)) } +func NewInvalidArgType(value interface{}, expectedType string) DriverErr { + return NewDriverErr(exaerror.New("E-EGOD-30"). + Message("cannot convert argument {{value}} of type {{type}} to {{expected type}} type"). + Parameter("value", value). + Parameter("type", fmt.Sprintf("%T", value)). + Parameter(("expected type"), expectedType)) +} + // DriverErr This type represents an error that can occur when working with a database connection. type DriverErr string diff --git a/pkg/logger/logger.go b/pkg/logger/logger.go index 6672c53..12a79b6 100644 --- a/pkg/logger/logger.go +++ b/pkg/logger/logger.go @@ -7,7 +7,8 @@ import ( "github.com/exasol/exasol-driver-go/pkg/errors" ) -var ErrorLogger = Logger(log.New(os.Stderr, "[exasol] ", log.LstdFlags|log.Lshortfile)) +var ErrorLogger Logger = log.New(os.Stderr, "[exasol] ", log.LstdFlags|log.Lshortfile) +var TraceLogger Logger = noOpLogger{} // Logger is used to log critical error messages. type Logger interface { @@ -24,3 +25,19 @@ func SetLogger(logger Logger) error { ErrorLogger = logger return nil } + +// SetTraceLogger is used to set the logger for tracing. +// The initial logger is a no-op logger. Please note that this will generate a lot of output. +// Set the logger to nil to disable tracing. +func SetTraceLogger(logger Logger) { + if logger == nil { + TraceLogger = noOpLogger{} + } else { + TraceLogger = logger + } +} + +type noOpLogger struct{} + +func (noOpLogger) Print(v ...interface{}) { /* no-op */ } +func (noOpLogger) Printf(format string, v ...interface{}) { /* no-op */ }