From ea31911c4c5f8a54122c5772eb812b6ed5de3442 Mon Sep 17 00:00:00 2001 From: Nathan Villaescusa Date: Wed, 1 May 2024 17:13:00 -0700 Subject: [PATCH] Pass `Dialog` to default function --- clients/shared/utils.go | 2 +- lib/typing/columns/default.go | 13 ++--- lib/typing/columns/default_test.go | 81 +++++++++--------------------- 3 files changed, 31 insertions(+), 65 deletions(-) diff --git a/clients/shared/utils.go b/clients/shared/utils.go index 1382abf0e..664fcf7ff 100644 --- a/clients/shared/utils.go +++ b/clients/shared/utils.go @@ -25,7 +25,7 @@ func BackfillColumn(cfg config.Config, dwh destination.DataWarehouse, column col } additionalDateFmts := cfg.SharedTransferConfig.TypingSettings.AdditionalDateFormats - defaultVal, err := column.DefaultValue(&columns.DefaultValueArgs{Escape: true, DestKind: dwh.Label()}, additionalDateFmts) + defaultVal, err := column.DefaultValue(dwh.Dialect(), additionalDateFmts) if err != nil { return fmt.Errorf("failed to escape default value: %w", err) } diff --git a/lib/typing/columns/default.go b/lib/typing/columns/default.go index 45d2fa8f6..0191e58f0 100644 --- a/lib/typing/columns/default.go +++ b/lib/typing/columns/default.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/artie-labs/transfer/lib/config/constants" + "github.com/artie-labs/transfer/lib/sql" "github.com/artie-labs/transfer/lib/typing/ext" @@ -21,19 +22,19 @@ func (c *Column) RawDefaultValue() any { return c.defaultValue } -func (c *Column) DefaultValue(args *DefaultValueArgs, additionalDateFmts []string) (any, error) { - if args == nil || !args.Escape || c.defaultValue == nil { +func (c *Column) DefaultValue(dialect sql.Dialect, additionalDateFmts []string) (any, error) { + if c.defaultValue == nil { return c.defaultValue, nil } switch c.KindDetails.Kind { case typing.Struct.Kind, typing.Array.Kind: - switch args.DestKind { - case constants.BigQuery: + switch dialect.(type) { + case sql.BigQueryDialect: return "JSON" + stringutil.Wrap(c.defaultValue, false), nil - case constants.Redshift: + case sql.RedshiftDialect: return fmt.Sprintf("JSON_PARSE(%s)", stringutil.Wrap(c.defaultValue, false)), nil - case constants.Snowflake: + case sql.SnowflakeDialect: return stringutil.Wrap(c.defaultValue, false), nil } case typing.ETime.Kind: diff --git a/lib/typing/columns/default_test.go b/lib/typing/columns/default_test.go index 8c277b4df..999ee9abb 100644 --- a/lib/typing/columns/default_test.go +++ b/lib/typing/columns/default_test.go @@ -5,7 +5,7 @@ import ( "testing" "time" - "github.com/artie-labs/transfer/lib/config/constants" + "github.com/artie-labs/transfer/lib/sql" "github.com/artie-labs/transfer/lib/typing/ext" @@ -14,6 +14,13 @@ import ( "github.com/stretchr/testify/assert" ) +var dialects = []sql.Dialect{ + sql.BigQueryDialect{}, + sql.MSSQLDialect{}, + sql.RedshiftDialect{}, + sql.SnowflakeDialect{UppercaseEscNames: true}, +} + func TestColumn_DefaultValue(t *testing.T) { birthday := time.Date(2022, time.September, 6, 3, 19, 24, 942000000, time.UTC) birthdayExtDateTime, err := ext.ParseExtendedDateTime(birthday.Format(ext.ISO8601), nil) @@ -32,9 +39,9 @@ func TestColumn_DefaultValue(t *testing.T) { testCases := []struct { name string col *Column - args *DefaultValueArgs + dialect sql.Dialect expectedValue any - destKindToExpectedValueMap map[constants.DestinationKind]any + destKindToExpectedValueMap map[sql.Dialect]any }{ { name: "default value = nil", @@ -42,37 +49,14 @@ func TestColumn_DefaultValue(t *testing.T) { KindDetails: typing.String, defaultValue: nil, }, - args: &DefaultValueArgs{ - Escape: true, - }, expectedValue: nil, }, - { - name: "escaped args (nil)", - col: &Column{ - KindDetails: typing.String, - defaultValue: "abcdef", - }, - expectedValue: "abcdef", - }, - { - name: "escaped args (escaped = false)", - col: &Column{ - KindDetails: typing.String, - defaultValue: "abcdef", - }, - args: &DefaultValueArgs{}, - expectedValue: "abcdef", - }, { name: "string", col: &Column{ KindDetails: typing.String, defaultValue: "abcdef", }, - args: &DefaultValueArgs{ - Escape: true, - }, expectedValue: "'abcdef'", }, { @@ -81,14 +65,11 @@ func TestColumn_DefaultValue(t *testing.T) { KindDetails: typing.Struct, defaultValue: "{}", }, - args: &DefaultValueArgs{ - Escape: true, - }, expectedValue: `{}`, - destKindToExpectedValueMap: map[constants.DestinationKind]any{ - constants.BigQuery: "JSON'{}'", - constants.Redshift: `JSON_PARSE('{}')`, - constants.Snowflake: `'{}'`, + destKindToExpectedValueMap: map[sql.Dialect]any{ + dialects[0]: "JSON'{}'", + dialects[2]: `JSON_PARSE('{}')`, + dialects[3]: `'{}'`, }, }, { @@ -97,14 +78,11 @@ func TestColumn_DefaultValue(t *testing.T) { KindDetails: typing.Struct, defaultValue: "{\"age\": 0, \"membership_level\": \"standard\"}", }, - args: &DefaultValueArgs{ - Escape: true, - }, expectedValue: "{\"age\": 0, \"membership_level\": \"standard\"}", - destKindToExpectedValueMap: map[constants.DestinationKind]any{ - constants.BigQuery: "JSON'{\"age\": 0, \"membership_level\": \"standard\"}'", - constants.Redshift: "JSON_PARSE('{\"age\": 0, \"membership_level\": \"standard\"}')", - constants.Snowflake: "'{\"age\": 0, \"membership_level\": \"standard\"}'", + destKindToExpectedValueMap: map[sql.Dialect]any{ + dialects[0]: "JSON'{\"age\": 0, \"membership_level\": \"standard\"}'", + dialects[2]: "JSON_PARSE('{\"age\": 0, \"membership_level\": \"standard\"}')", + dialects[3]: "'{\"age\": 0, \"membership_level\": \"standard\"}'", }, }, { @@ -113,9 +91,6 @@ func TestColumn_DefaultValue(t *testing.T) { KindDetails: dateKind, defaultValue: birthdayExtDateTime, }, - args: &DefaultValueArgs{ - Escape: true, - }, expectedValue: "'2022-09-06'", }, { @@ -124,9 +99,6 @@ func TestColumn_DefaultValue(t *testing.T) { KindDetails: timeKind, defaultValue: birthdayExtDateTime, }, - args: &DefaultValueArgs{ - Escape: true, - }, expectedValue: "'03:19:24.942'", }, { @@ -135,29 +107,22 @@ func TestColumn_DefaultValue(t *testing.T) { KindDetails: dateTimeKind, defaultValue: birthdayExtDateTime, }, - args: &DefaultValueArgs{ - Escape: true, - }, expectedValue: "'2022-09-06T03:19:24.942Z'", }, } for _, testCase := range testCases { - for _, validDest := range constants.ValidDestinations { - if testCase.args != nil { - testCase.args.DestKind = validDest - } - - actualValue, actualErr := testCase.col.DefaultValue(testCase.args, nil) - assert.NoError(t, actualErr, fmt.Sprintf("%s %s", testCase.name, validDest)) + for _, dialect := range dialects { + actualValue, actualErr := testCase.col.DefaultValue(dialect, nil) + assert.NoError(t, actualErr, fmt.Sprintf("%s %s", testCase.name, dialect)) expectedValue := testCase.expectedValue - if potentialValue, isOk := testCase.destKindToExpectedValueMap[validDest]; isOk { + if potentialValue, isOk := testCase.destKindToExpectedValueMap[dialect]; isOk { // Not everything requires a destination specific value, so only use this if necessary. expectedValue = potentialValue } - assert.Equal(t, expectedValue, actualValue, fmt.Sprintf("%s %s", testCase.name, validDest)) + assert.Equal(t, expectedValue, actualValue, fmt.Sprintf("%s %s", testCase.name, dialect)) } } }