Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass Dialect to Column.DefaultValue #536

Merged
merged 4 commits into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion clients/shared/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ func BackfillColumn(cfg config.Config, dwh destination.DataWarehouse, column col
}

additionalDateFmts := cfg.SharedTransferConfig.TypingSettings.AdditionalDateFormats
defaultVal, err := column.DefaultValue(&columns.DefaultValueArgs{Escape: true, DestKind: dwh.Label()}, additionalDateFmts)
defaultVal, err := column.DefaultValue(dwh.Dialect(), additionalDateFmts)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the only place we call DefaultValue so I removed the Escape argument.

if err != nil {
return fmt.Errorf("failed to escape default value: %w", err)
}
Expand Down
22 changes: 10 additions & 12 deletions lib/typing/columns/default.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package columns
import (
"fmt"

"github.com/artie-labs/transfer/lib/config/constants"
"github.com/artie-labs/transfer/lib/sql"

"github.com/artie-labs/transfer/lib/typing/ext"

Expand All @@ -12,29 +12,27 @@ import (
"github.com/artie-labs/transfer/lib/typing/decimal"
)

type DefaultValueArgs struct {
Escape bool
DestKind constants.DestinationKind
}

func (c *Column) RawDefaultValue() any {
return c.defaultValue
}

func (c *Column) DefaultValue(args *DefaultValueArgs, additionalDateFmts []string) (any, error) {
if args == nil || !args.Escape || c.defaultValue == nil {
func (c *Column) DefaultValue(dialect sql.Dialect, additionalDateFmts []string) (any, error) {
if c.defaultValue == nil {
return c.defaultValue, nil
}

switch c.KindDetails.Kind {
case typing.Struct.Kind, typing.Array.Kind:
switch args.DestKind {
case constants.BigQuery:
switch dialect.(type) {
case sql.BigQueryDialect:
return "JSON" + stringutil.Wrap(c.defaultValue, false), nil
case constants.Redshift:
case sql.RedshiftDialect:
return fmt.Sprintf("JSON_PARSE(%s)", stringutil.Wrap(c.defaultValue, false)), nil
case constants.Snowflake:
case sql.SnowflakeDialect:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Tang8330 it doesn't seem like we're doing anything for MS SQL here, do we need a case for it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah we're not. We don't support backfill default values in MSSQL yet.

Do you want to leave a comment?

return stringutil.Wrap(c.defaultValue, false), nil
default:
// Note that we don't currently support backfills for MS SQL.
return nil, fmt.Errorf("not implemented for %v dialect", dialect)
}
case typing.ETime.Kind:
if c.KindDetails.ExtendedTimeDetails == nil {
Expand Down
80 changes: 22 additions & 58 deletions lib/typing/columns/default_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import (
"testing"
"time"

"github.com/artie-labs/transfer/lib/config/constants"
"github.com/artie-labs/transfer/lib/sql"

"github.com/artie-labs/transfer/lib/typing/ext"

Expand All @@ -14,6 +14,12 @@ import (
"github.com/stretchr/testify/assert"
)

var dialects = []sql.Dialect{
sql.BigQueryDialect{},
sql.RedshiftDialect{},
sql.SnowflakeDialect{UppercaseEscNames: true},
}

func TestColumn_DefaultValue(t *testing.T) {
birthday := time.Date(2022, time.September, 6, 3, 19, 24, 942000000, time.UTC)
birthdayExtDateTime, err := ext.ParseExtendedDateTime(birthday.Format(ext.ISO8601), nil)
Expand All @@ -32,47 +38,24 @@ func TestColumn_DefaultValue(t *testing.T) {
testCases := []struct {
name string
col *Column
args *DefaultValueArgs
dialect sql.Dialect
expectedValue any
destKindToExpectedValueMap map[constants.DestinationKind]any
destKindToExpectedValueMap map[sql.Dialect]any
}{
{
name: "default value = nil",
col: &Column{
KindDetails: typing.String,
defaultValue: nil,
},
args: &DefaultValueArgs{
Escape: true,
},
expectedValue: nil,
},
{
name: "escaped args (nil)",
col: &Column{
KindDetails: typing.String,
defaultValue: "abcdef",
},
expectedValue: "abcdef",
},
{
name: "escaped args (escaped = false)",
col: &Column{
KindDetails: typing.String,
defaultValue: "abcdef",
},
args: &DefaultValueArgs{},
expectedValue: "abcdef",
},
{
name: "string",
col: &Column{
KindDetails: typing.String,
defaultValue: "abcdef",
},
args: &DefaultValueArgs{
Escape: true,
},
expectedValue: "'abcdef'",
},
{
Expand All @@ -81,14 +64,11 @@ func TestColumn_DefaultValue(t *testing.T) {
KindDetails: typing.Struct,
defaultValue: "{}",
},
args: &DefaultValueArgs{
Escape: true,
},
expectedValue: `{}`,
destKindToExpectedValueMap: map[constants.DestinationKind]any{
constants.BigQuery: "JSON'{}'",
constants.Redshift: `JSON_PARSE('{}')`,
constants.Snowflake: `'{}'`,
destKindToExpectedValueMap: map[sql.Dialect]any{
dialects[0]: "JSON'{}'",
dialects[1]: `JSON_PARSE('{}')`,
dialects[2]: `'{}'`,
},
},
{
Expand All @@ -97,14 +77,11 @@ func TestColumn_DefaultValue(t *testing.T) {
KindDetails: typing.Struct,
defaultValue: "{\"age\": 0, \"membership_level\": \"standard\"}",
},
args: &DefaultValueArgs{
Escape: true,
},
expectedValue: "{\"age\": 0, \"membership_level\": \"standard\"}",
destKindToExpectedValueMap: map[constants.DestinationKind]any{
constants.BigQuery: "JSON'{\"age\": 0, \"membership_level\": \"standard\"}'",
constants.Redshift: "JSON_PARSE('{\"age\": 0, \"membership_level\": \"standard\"}')",
constants.Snowflake: "'{\"age\": 0, \"membership_level\": \"standard\"}'",
destKindToExpectedValueMap: map[sql.Dialect]any{
dialects[0]: "JSON'{\"age\": 0, \"membership_level\": \"standard\"}'",
dialects[1]: "JSON_PARSE('{\"age\": 0, \"membership_level\": \"standard\"}')",
dialects[2]: "'{\"age\": 0, \"membership_level\": \"standard\"}'",
},
},
{
Expand All @@ -113,9 +90,6 @@ func TestColumn_DefaultValue(t *testing.T) {
KindDetails: dateKind,
defaultValue: birthdayExtDateTime,
},
args: &DefaultValueArgs{
Escape: true,
},
expectedValue: "'2022-09-06'",
},
{
Expand All @@ -124,9 +98,6 @@ func TestColumn_DefaultValue(t *testing.T) {
KindDetails: timeKind,
defaultValue: birthdayExtDateTime,
},
args: &DefaultValueArgs{
Escape: true,
},
expectedValue: "'03:19:24.942'",
},
{
Expand All @@ -135,29 +106,22 @@ func TestColumn_DefaultValue(t *testing.T) {
KindDetails: dateTimeKind,
defaultValue: birthdayExtDateTime,
},
args: &DefaultValueArgs{
Escape: true,
},
expectedValue: "'2022-09-06T03:19:24.942Z'",
},
}

for _, testCase := range testCases {
for _, validDest := range constants.ValidDestinations {
if testCase.args != nil {
testCase.args.DestKind = validDest
}

actualValue, actualErr := testCase.col.DefaultValue(testCase.args, nil)
assert.NoError(t, actualErr, fmt.Sprintf("%s %s", testCase.name, validDest))
for _, dialect := range dialects {
actualValue, actualErr := testCase.col.DefaultValue(dialect, nil)
assert.NoError(t, actualErr, fmt.Sprintf("%s %s", testCase.name, dialect))

expectedValue := testCase.expectedValue
if potentialValue, isOk := testCase.destKindToExpectedValueMap[validDest]; isOk {
if potentialValue, isOk := testCase.destKindToExpectedValueMap[dialect]; isOk {
// Not everything requires a destination specific value, so only use this if necessary.
expectedValue = potentialValue
}

assert.Equal(t, expectedValue, actualValue, fmt.Sprintf("%s %s", testCase.name, validDest))
assert.Equal(t, expectedValue, actualValue, fmt.Sprintf("%s %s", testCase.name, dialect))
}
}
}