From 9ed0c94025f3c39b7a965346e4eb28665579ff0c Mon Sep 17 00:00:00 2001 From: Nathan <148575555+nathan-artie@users.noreply.github.com> Date: Thu, 9 May 2024 17:36:38 -0700 Subject: [PATCH] [sql] Use `ParseDataTypeDefinition` in `KindForDataType` methods (#605) --- lib/sql/bigquery.go | 29 ++++++++--------------------- lib/sql/mssql.go | 6 +++++- lib/sql/redshift.go | 6 +++++- lib/sql/snowflake.go | 16 ++++++---------- lib/typing/numeric.go | 14 +------------- lib/typing/numeric_test.go | 30 +++++++++++++++--------------- 6 files changed, 40 insertions(+), 61 deletions(-) diff --git a/lib/sql/bigquery.go b/lib/sql/bigquery.go index 61196912e..10db2260b 100644 --- a/lib/sql/bigquery.go +++ b/lib/sql/bigquery.go @@ -64,43 +64,30 @@ func (BigQueryDialect) DataTypeForKind(kindDetails typing.KindDetails, _ bool) s } func (BigQueryDialect) KindForDataType(rawBqType string, _ string) (typing.KindDetails, error) { - rawBqType = strings.ToLower(rawBqType) - - bqType := rawBqType - if len(bqType) == 0 { + if len(rawBqType) == 0 { return typing.Invalid, nil } - idxStop := len(bqType) - // Trim STRING (10) to String - if idx := strings.Index(bqType, "("); idx > 0 { - idxStop = idx + bqType, parameters, err := ParseDataTypeDefinition(strings.ToLower(rawBqType)) + if err != nil { + return typing.Invalid, err } - bqType = bqType[:idxStop] - // Trim Struct to Struct - idxStop = len(bqType) + idxStop := len(bqType) if idx := strings.Index(bqType, "<"); idx > 0 { idxStop = idx } // Geography, geometry date, time, varbinary, binary are currently not supported. switch strings.TrimSpace(bqType[:idxStop]) { - case "numeric": - if rawBqType == "numeric" || rawBqType == "bignumeric" { + case "numeric", "bignumeric": + if len(parameters) == 0 { // This is a specific thing to BigQuery // A `NUMERIC` type without precision or scale specified is NUMERIC(38, 9) return typing.EDecimal, nil } - - return typing.ParseNumeric(typing.DefaultPrefix, rawBqType), nil - case "bignumeric": - if rawBqType == "bignumeric" { - return typing.EDecimal, nil - } - - return typing.ParseNumeric("bignumeric", rawBqType), nil + return typing.ParseNumeric(parameters), nil case "decimal", "float", "float64", "bigdecimal": return typing.Float, nil case "int", "integer", "int64": diff --git a/lib/sql/mssql.go b/lib/sql/mssql.go index 7669ce50d..64495f744 100644 --- a/lib/sql/mssql.go +++ b/lib/sql/mssql.go @@ -70,7 +70,11 @@ func (MSSQLDialect) KindForDataType(rawType string, stringPrecision string) (typ rawType = strings.ToLower(rawType) if strings.HasPrefix(rawType, "numeric") { - return typing.ParseNumeric(typing.DefaultPrefix, rawType), nil + _, parameters, err := ParseDataTypeDefinition(rawType) + if err != nil { + return typing.Invalid, err + } + return typing.ParseNumeric(parameters), nil } switch rawType { diff --git a/lib/sql/redshift.go b/lib/sql/redshift.go index 04f3eb174..f2acbda27 100644 --- a/lib/sql/redshift.go +++ b/lib/sql/redshift.go @@ -64,7 +64,11 @@ func (RedshiftDialect) KindForDataType(rawType string, stringPrecision string) ( rawType = strings.ToLower(rawType) // TODO: Check if there are any missing Redshift data types. if strings.HasPrefix(rawType, "numeric") { - return typing.ParseNumeric(typing.DefaultPrefix, rawType), nil + _, parameters, err := ParseDataTypeDefinition(rawType) + if err != nil { + return typing.Invalid, err + } + return typing.ParseNumeric(parameters), nil } if strings.Contains(rawType, "character varying") { diff --git a/lib/sql/snowflake.go b/lib/sql/snowflake.go index b86aba710..18717f47c 100644 --- a/lib/sql/snowflake.go +++ b/lib/sql/snowflake.go @@ -52,26 +52,22 @@ func (SnowflakeDialect) DataTypeForKind(kindDetails typing.KindDetails, _ bool) // KindForDataType converts a Snowflake type to a KindDetails. // Following this spec: https://docs.snowflake.com/en/sql-reference/intro-summary-data-types.html func (SnowflakeDialect) KindForDataType(snowflakeType string, _ string) (typing.KindDetails, error) { - snowflakeType = strings.ToLower(snowflakeType) - - // We need to strip away the variable - // For example, a Column can look like: TEXT, or Number(38, 0) or VARCHAR(255). - // We need to strip out all the content from ( ... ) if len(snowflakeType) == 0 { return typing.Invalid, nil } - dataType, parameters, err := ParseDataTypeDefinition(snowflakeType) + // We need to strip away the variable + // For example, a Column can look like: TEXT, or Number(38, 0) or VARCHAR(255). + // We need to strip out all the content from ( ... ) + dataType, parameters, err := ParseDataTypeDefinition(strings.ToLower(snowflakeType)) if err != nil { return typing.Invalid, err } // Geography, geometry date, time, varbinary, binary are currently not supported. switch dataType { - case "number": - return typing.ParseNumeric("number", snowflakeType), nil - case "numeric": - return typing.ParseNumeric(typing.DefaultPrefix, snowflakeType), nil + case "number", "numeric": + return typing.ParseNumeric(parameters), nil case "decimal": return typing.EDecimal, nil case "float", "float4", diff --git a/lib/typing/numeric.go b/lib/typing/numeric.go index c8335a717..7a196b130 100644 --- a/lib/typing/numeric.go +++ b/lib/typing/numeric.go @@ -7,19 +7,7 @@ import ( "github.com/artie-labs/transfer/lib/typing/decimal" ) -const DefaultPrefix = "numeric" - -// ParseNumeric - will prefix (since it can be NUMBER or NUMERIC) + valString in the form of: -// * NUMERIC(p, s) -// * NUMERIC(p) -func ParseNumeric(prefix, valString string) KindDetails { - if !strings.HasPrefix(valString, prefix) { - return Invalid - } - - valString = strings.TrimPrefix(valString, prefix+"(") - valString = strings.TrimSuffix(valString, ")") - parts := strings.Split(valString, ",") +func ParseNumeric(parts []string) KindDetails { if len(parts) == 0 || len(parts) > 2 { return Invalid } diff --git a/lib/typing/numeric_test.go b/lib/typing/numeric_test.go index 039654617..4960fc73a 100644 --- a/lib/typing/numeric_test.go +++ b/lib/typing/numeric_test.go @@ -10,7 +10,7 @@ import ( func TestParseNumeric(t *testing.T) { type _testCase struct { - valString string + parameters []string expectedKindDetails KindDetails expectedPrecision *int // Using a pointer to int so we can differentiate between unset (nil) and set (0 included) expectedScale int @@ -18,51 +18,51 @@ func TestParseNumeric(t *testing.T) { testCases := []_testCase{ { - valString: "numeri232321c(5,2)", + parameters: []string{}, expectedKindDetails: Invalid, }, { - valString: "numeric", + parameters: []string{"5", "a"}, expectedKindDetails: Invalid, }, { - valString: "numeric(5, a)", + parameters: []string{"b", "5"}, expectedKindDetails: Invalid, }, { - valString: "numeric(b, 5)", + parameters: []string{"a", "b"}, expectedKindDetails: Invalid, }, { - valString: "numeric(b, a)", + parameters: []string{"1", "2", "3"}, expectedKindDetails: Invalid, }, { - valString: "numeric(5, 2)", + parameters: []string{"5", " 2"}, expectedKindDetails: EDecimal, expectedPrecision: ptr.ToInt(5), expectedScale: 2, }, { - valString: "numeric(5,2)", + parameters: []string{"5", "2"}, expectedKindDetails: EDecimal, expectedPrecision: ptr.ToInt(5), expectedScale: 2, }, { - valString: "numeric(39, 6)", + parameters: []string{"39", "6"}, expectedKindDetails: EDecimal, expectedPrecision: ptr.ToInt(39), expectedScale: 6, }, { - valString: "numeric(5)", + parameters: []string{"5"}, expectedKindDetails: Integer, expectedPrecision: ptr.ToInt(5), expectedScale: 0, }, { - valString: "numeric(5, 0)", + parameters: []string{"5", "0"}, expectedKindDetails: Integer, expectedPrecision: ptr.ToInt(5), expectedScale: 0, @@ -70,13 +70,13 @@ func TestParseNumeric(t *testing.T) { } for _, testCase := range testCases { - result := ParseNumeric(DefaultPrefix, testCase.valString) - assert.Equal(t, testCase.expectedKindDetails.Kind, result.Kind, testCase.valString) + result := ParseNumeric(testCase.parameters) + assert.Equal(t, testCase.expectedKindDetails.Kind, result.Kind, testCase.parameters) if result.ExtendedDecimalDetails != nil { - assert.Equal(t, testCase.expectedScale, result.ExtendedDecimalDetails.Scale(), testCase.valString) + assert.Equal(t, testCase.expectedScale, result.ExtendedDecimalDetails.Scale(), testCase.parameters) if result.ExtendedDecimalDetails.Precision() != nil { - assert.Equal(t, *testCase.expectedPrecision, *result.ExtendedDecimalDetails.Precision(), testCase.valString) + assert.Equal(t, *testCase.expectedPrecision, *result.ExtendedDecimalDetails.Precision(), testCase.parameters) } } }