From 6f229f3b85bb71efc9d9807147835e3e203daa9e Mon Sep 17 00:00:00 2001 From: Nathan Villaescusa Date: Thu, 9 May 2024 18:12:14 -0700 Subject: [PATCH] [sql] Add tests for `Dialect.KindForDataType` --- lib/sql/bigquery_test.go | 28 +++++++++++++++++--- lib/sql/mssql_test.go | 56 +++++++++++++++++++++++++++++++++++++++ lib/sql/redshift_test.go | 12 ++++++++- lib/sql/snowflake_test.go | 23 +++++++++++----- 4 files changed, 108 insertions(+), 11 deletions(-) create mode 100644 lib/sql/mssql_test.go diff --git a/lib/sql/bigquery_test.go b/lib/sql/bigquery_test.go index 33872f7a7..5f1864e95 100644 --- a/lib/sql/bigquery_test.go +++ b/lib/sql/bigquery_test.go @@ -10,6 +10,8 @@ import ( ) func TestBigQueryDialect_KindForDataType(t *testing.T) { + dialect := BigQueryDialect{} + bqColToExpectedKind := map[string]typing.KindDetails{ // Number "numeric": typing.EDecimal, @@ -44,16 +46,34 @@ func TestBigQueryDialect_KindForDataType(t *testing.T) { "time": typing.NewKindDetailsFromTemplate(typing.ETime, ext.TimeKindType), "date": typing.NewKindDetailsFromTemplate(typing.ETime, ext.DateKindType), //Invalid - "foo": typing.Invalid, - "foofoo": typing.Invalid, - "": typing.Invalid, + "foo": typing.Invalid, + "foofoo": typing.Invalid, + "": typing.Invalid, + "numeric(1,2,3)": typing.Invalid, } for bqCol, expectedKind := range bqColToExpectedKind { - kd, err := BigQueryDialect{}.KindForDataType(bqCol, "") + kd, err := dialect.KindForDataType(bqCol, "") assert.NoError(t, err) assert.Equal(t, expectedKind.Kind, kd.Kind, bqCol) } + + { + _, err := dialect.KindForDataType("numeric(5", "") + assert.ErrorContains(t, err, "missing closing parenthesis") + } + { + kd, err := dialect.KindForDataType("numeric(5, 2)", "") + assert.NoError(t, err) + assert.Equal(t, 5, *kd.ExtendedDecimalDetails.Precision()) + assert.Equal(t, 2, kd.ExtendedDecimalDetails.Scale()) + } + { + kd, err := dialect.KindForDataType("bignumeric(5, 2)", "") + assert.NoError(t, err) + assert.Equal(t, 5, *kd.ExtendedDecimalDetails.Precision()) + assert.Equal(t, 2, kd.ExtendedDecimalDetails.Scale()) + } } func TestBigQueryDialect_KindForDataType_NoDataLoss(t *testing.T) { diff --git a/lib/sql/mssql_test.go b/lib/sql/mssql_test.go new file mode 100644 index 000000000..1bdca8bab --- /dev/null +++ b/lib/sql/mssql_test.go @@ -0,0 +1,56 @@ +package sql + +import ( + "testing" + + "github.com/artie-labs/transfer/lib/typing" + "github.com/artie-labs/transfer/lib/typing/ext" + "github.com/stretchr/testify/assert" +) + +func TestMSSQLDialect_KindForDataType(t *testing.T) { + dialect := MSSQLDialect{} + + colToExpectedKind := map[string]typing.KindDetails{ + "char": typing.String, + "varchar": typing.String, + "nchar": typing.String, + "nvarchar": typing.String, + "ntext": typing.String, + "text": typing.String, + "smallint": typing.Integer, + "tinyint": typing.Integer, + "int": typing.Integer, + "float": typing.Float, + "real": typing.Float, + "bit": typing.Boolean, + "date": typing.NewKindDetailsFromTemplate(typing.ETime, ext.DateKindType), + "time": typing.NewKindDetailsFromTemplate(typing.ETime, ext.TimeKindType), + "datetime": typing.NewKindDetailsFromTemplate(typing.ETime, ext.DateTimeKindType), + "datetime2": typing.NewKindDetailsFromTemplate(typing.ETime, ext.DateTimeKindType), + } + + for col, expectedKind := range colToExpectedKind { + kd, err := dialect.KindForDataType(col, "") + assert.NoError(t, err) + assert.Equal(t, expectedKind.Kind, kd.Kind, col) + } + + { + _, err := dialect.KindForDataType("numeric(5", "") + assert.ErrorContains(t, err, "missing closing parenthesis") + } + { + kd, err := dialect.KindForDataType("numeric(5, 2)", "") + assert.NoError(t, err) + assert.Equal(t, typing.EDecimal.Kind, kd.Kind) + assert.Equal(t, 5, *kd.ExtendedDecimalDetails.Precision()) + assert.Equal(t, 2, kd.ExtendedDecimalDetails.Scale()) + } + { + kd, err := dialect.KindForDataType("char", "5") + assert.NoError(t, err) + assert.Equal(t, typing.String.Kind, kd.Kind) + assert.Equal(t, 5, *kd.OptionalStringPrecision) + } +} diff --git a/lib/sql/redshift_test.go b/lib/sql/redshift_test.go index eacbf1c8d..c705cb61e 100644 --- a/lib/sql/redshift_test.go +++ b/lib/sql/redshift_test.go @@ -10,6 +10,8 @@ import ( ) func TestRedshiftDialect_KindForDataType(t *testing.T) { + dialect := RedshiftDialect{} + type rawTypeAndPrecision struct { rawType string precision string @@ -93,7 +95,7 @@ func TestRedshiftDialect_KindForDataType(t *testing.T) { for _, testCase := range testCases { for _, rawTypeAndPrec := range testCase.rawTypes { - kd, err := RedshiftDialect{}.KindForDataType(rawTypeAndPrec.rawType, rawTypeAndPrec.precision) + kd, err := dialect.KindForDataType(rawTypeAndPrec.rawType, rawTypeAndPrec.precision) assert.NoError(t, err) assert.Equal(t, testCase.expectedKd.Kind, kd.Kind, testCase.name) @@ -104,4 +106,12 @@ func TestRedshiftDialect_KindForDataType(t *testing.T) { } } } + + { + kd, err := dialect.KindForDataType("numeric(5,2)", "") + assert.NoError(t, err) + assert.Equal(t, typing.EDecimal.Kind, kd.Kind) + assert.Equal(t, 5, *kd.ExtendedDecimalDetails.Precision()) + assert.Equal(t, 2, kd.ExtendedDecimalDetails.Scale()) + } } diff --git a/lib/sql/snowflake_test.go b/lib/sql/snowflake_test.go index fa484793f..ffee024ad 100644 --- a/lib/sql/snowflake_test.go +++ b/lib/sql/snowflake_test.go @@ -44,12 +44,23 @@ func TestSnowflakeDialect_KindForDataType_Floats(t *testing.T) { assert.Equal(t, typing.Invalid, kd) } { - expectedNumerics := []string{"NUMERIC(38, 2)", "NUMBER(38, 2)", "DECIMAL"} - for _, expectedNumeric := range expectedNumerics { - kd, err := SnowflakeDialect{}.KindForDataType(expectedNumeric, "") - assert.NoError(t, err) - assert.Equal(t, typing.EDecimal.Kind, kd.Kind, expectedNumeric) - } + kd, err := SnowflakeDialect{}.KindForDataType("NUMERIC(38, 2)", "") + assert.NoError(t, err) + assert.Equal(t, typing.EDecimal.Kind, kd.Kind) + assert.Equal(t, 38, *kd.ExtendedDecimalDetails.Precision()) + assert.Equal(t, 2, kd.ExtendedDecimalDetails.Scale()) + } + { + kd, err := SnowflakeDialect{}.KindForDataType("NUMBER(38, 2)", "") + assert.NoError(t, err) + assert.Equal(t, typing.EDecimal.Kind, kd.Kind) + assert.Equal(t, 38, *kd.ExtendedDecimalDetails.Precision()) + assert.Equal(t, 2, kd.ExtendedDecimalDetails.Scale()) + } + { + kd, err := SnowflakeDialect{}.KindForDataType("DECIMAL", "") + assert.NoError(t, err) + assert.Equal(t, typing.EDecimal.Kind, kd.Kind) } }