Skip to content

Commit

Permalink
Pass Dialog to default function
Browse files Browse the repository at this point in the history
  • Loading branch information
nathan-artie committed May 2, 2024
1 parent 72d8011 commit ea31911
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 65 deletions.
2 changes: 1 addition & 1 deletion clients/shared/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ func BackfillColumn(cfg config.Config, dwh destination.DataWarehouse, column col
}

additionalDateFmts := cfg.SharedTransferConfig.TypingSettings.AdditionalDateFormats
defaultVal, err := column.DefaultValue(&columns.DefaultValueArgs{Escape: true, DestKind: dwh.Label()}, additionalDateFmts)
defaultVal, err := column.DefaultValue(dwh.Dialect(), additionalDateFmts)
if err != nil {
return fmt.Errorf("failed to escape default value: %w", err)
}
Expand Down
13 changes: 7 additions & 6 deletions lib/typing/columns/default.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"fmt"

"github.com/artie-labs/transfer/lib/config/constants"
"github.com/artie-labs/transfer/lib/sql"

"github.com/artie-labs/transfer/lib/typing/ext"

Expand All @@ -21,19 +22,19 @@ func (c *Column) RawDefaultValue() any {
return c.defaultValue
}

func (c *Column) DefaultValue(args *DefaultValueArgs, additionalDateFmts []string) (any, error) {
if args == nil || !args.Escape || c.defaultValue == nil {
func (c *Column) DefaultValue(dialect sql.Dialect, additionalDateFmts []string) (any, error) {
if c.defaultValue == nil {
return c.defaultValue, nil
}

switch c.KindDetails.Kind {
case typing.Struct.Kind, typing.Array.Kind:
switch args.DestKind {
case constants.BigQuery:
switch dialect.(type) {
case sql.BigQueryDialect:
return "JSON" + stringutil.Wrap(c.defaultValue, false), nil
case constants.Redshift:
case sql.RedshiftDialect:
return fmt.Sprintf("JSON_PARSE(%s)", stringutil.Wrap(c.defaultValue, false)), nil
case constants.Snowflake:
case sql.SnowflakeDialect:
return stringutil.Wrap(c.defaultValue, false), nil
}
case typing.ETime.Kind:
Expand Down
81 changes: 23 additions & 58 deletions lib/typing/columns/default_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import (
"testing"
"time"

"github.com/artie-labs/transfer/lib/config/constants"
"github.com/artie-labs/transfer/lib/sql"

"github.com/artie-labs/transfer/lib/typing/ext"

Expand All @@ -14,6 +14,13 @@ import (
"github.com/stretchr/testify/assert"
)

var dialects = []sql.Dialect{
sql.BigQueryDialect{},
sql.MSSQLDialect{},
sql.RedshiftDialect{},
sql.SnowflakeDialect{UppercaseEscNames: true},
}

func TestColumn_DefaultValue(t *testing.T) {
birthday := time.Date(2022, time.September, 6, 3, 19, 24, 942000000, time.UTC)
birthdayExtDateTime, err := ext.ParseExtendedDateTime(birthday.Format(ext.ISO8601), nil)
Expand All @@ -32,47 +39,24 @@ func TestColumn_DefaultValue(t *testing.T) {
testCases := []struct {
name string
col *Column
args *DefaultValueArgs
dialect sql.Dialect
expectedValue any
destKindToExpectedValueMap map[constants.DestinationKind]any
destKindToExpectedValueMap map[sql.Dialect]any
}{
{
name: "default value = nil",
col: &Column{
KindDetails: typing.String,
defaultValue: nil,
},
args: &DefaultValueArgs{
Escape: true,
},
expectedValue: nil,
},
{
name: "escaped args (nil)",
col: &Column{
KindDetails: typing.String,
defaultValue: "abcdef",
},
expectedValue: "abcdef",
},
{
name: "escaped args (escaped = false)",
col: &Column{
KindDetails: typing.String,
defaultValue: "abcdef",
},
args: &DefaultValueArgs{},
expectedValue: "abcdef",
},
{
name: "string",
col: &Column{
KindDetails: typing.String,
defaultValue: "abcdef",
},
args: &DefaultValueArgs{
Escape: true,
},
expectedValue: "'abcdef'",
},
{
Expand All @@ -81,14 +65,11 @@ func TestColumn_DefaultValue(t *testing.T) {
KindDetails: typing.Struct,
defaultValue: "{}",
},
args: &DefaultValueArgs{
Escape: true,
},
expectedValue: `{}`,
destKindToExpectedValueMap: map[constants.DestinationKind]any{
constants.BigQuery: "JSON'{}'",
constants.Redshift: `JSON_PARSE('{}')`,
constants.Snowflake: `'{}'`,
destKindToExpectedValueMap: map[sql.Dialect]any{
dialects[0]: "JSON'{}'",
dialects[2]: `JSON_PARSE('{}')`,
dialects[3]: `'{}'`,
},
},
{
Expand All @@ -97,14 +78,11 @@ func TestColumn_DefaultValue(t *testing.T) {
KindDetails: typing.Struct,
defaultValue: "{\"age\": 0, \"membership_level\": \"standard\"}",
},
args: &DefaultValueArgs{
Escape: true,
},
expectedValue: "{\"age\": 0, \"membership_level\": \"standard\"}",
destKindToExpectedValueMap: map[constants.DestinationKind]any{
constants.BigQuery: "JSON'{\"age\": 0, \"membership_level\": \"standard\"}'",
constants.Redshift: "JSON_PARSE('{\"age\": 0, \"membership_level\": \"standard\"}')",
constants.Snowflake: "'{\"age\": 0, \"membership_level\": \"standard\"}'",
destKindToExpectedValueMap: map[sql.Dialect]any{
dialects[0]: "JSON'{\"age\": 0, \"membership_level\": \"standard\"}'",
dialects[2]: "JSON_PARSE('{\"age\": 0, \"membership_level\": \"standard\"}')",
dialects[3]: "'{\"age\": 0, \"membership_level\": \"standard\"}'",
},
},
{
Expand All @@ -113,9 +91,6 @@ func TestColumn_DefaultValue(t *testing.T) {
KindDetails: dateKind,
defaultValue: birthdayExtDateTime,
},
args: &DefaultValueArgs{
Escape: true,
},
expectedValue: "'2022-09-06'",
},
{
Expand All @@ -124,9 +99,6 @@ func TestColumn_DefaultValue(t *testing.T) {
KindDetails: timeKind,
defaultValue: birthdayExtDateTime,
},
args: &DefaultValueArgs{
Escape: true,
},
expectedValue: "'03:19:24.942'",
},
{
Expand All @@ -135,29 +107,22 @@ func TestColumn_DefaultValue(t *testing.T) {
KindDetails: dateTimeKind,
defaultValue: birthdayExtDateTime,
},
args: &DefaultValueArgs{
Escape: true,
},
expectedValue: "'2022-09-06T03:19:24.942Z'",
},
}

for _, testCase := range testCases {
for _, validDest := range constants.ValidDestinations {
if testCase.args != nil {
testCase.args.DestKind = validDest
}

actualValue, actualErr := testCase.col.DefaultValue(testCase.args, nil)
assert.NoError(t, actualErr, fmt.Sprintf("%s %s", testCase.name, validDest))
for _, dialect := range dialects {
actualValue, actualErr := testCase.col.DefaultValue(dialect, nil)
assert.NoError(t, actualErr, fmt.Sprintf("%s %s", testCase.name, dialect))

expectedValue := testCase.expectedValue
if potentialValue, isOk := testCase.destKindToExpectedValueMap[validDest]; isOk {
if potentialValue, isOk := testCase.destKindToExpectedValueMap[dialect]; isOk {
// Not everything requires a destination specific value, so only use this if necessary.
expectedValue = potentialValue
}

assert.Equal(t, expectedValue, actualValue, fmt.Sprintf("%s %s", testCase.name, validDest))
assert.Equal(t, expectedValue, actualValue, fmt.Sprintf("%s %s", testCase.name, dialect))
}
}
}

0 comments on commit ea31911

Please sign in to comment.