From 9026cbe066697f67253fb223cef928401f792e0f Mon Sep 17 00:00:00 2001 From: Nathan Villaescusa Date: Wed, 1 May 2024 13:59:47 -0700 Subject: [PATCH] Add `Dialect` to `MergeArgument` --- clients/shared/merge.go | 1 + clients/snowflake/snowflake.go | 4 ++-- lib/destination/dml/merge.go | 11 ++++++++--- lib/destination/dml/merge_bigquery_test.go | 3 +++ lib/destination/dml/merge_mssql_test.go | 2 ++ lib/destination/dml/merge_parts_test.go | 8 ++++++++ lib/destination/dml/merge_test.go | 6 ++++++ lib/destination/dml/merge_valid_test.go | 14 ++++++++++++++ lib/sql/escape.go | 8 ++++++++ 9 files changed, 52 insertions(+), 5 deletions(-) diff --git a/clients/shared/merge.go b/clients/shared/merge.go index 47ce3c6d3..970fc3652 100644 --- a/clients/shared/merge.go +++ b/clients/shared/merge.go @@ -129,6 +129,7 @@ func Merge(dwh destination.DataWarehouse, tableData *optimization.TableData, cfg Columns: tableData.ReadOnlyInMemoryCols(), SoftDelete: tableData.TopicConfig().SoftDelete, DestKind: dwh.Label(), + Dialect: dwh.Dialect(), UppercaseEscNames: ptr.ToBool(dwh.ShouldUppercaseEscapedNames()), ContainsHardDeletes: ptr.ToBool(tableData.ContainsHardDeletes()), } diff --git a/clients/snowflake/snowflake.go b/clients/snowflake/snowflake.go index 14d859b06..eec3b4dda 100644 --- a/clients/snowflake/snowflake.go +++ b/clients/snowflake/snowflake.go @@ -134,12 +134,12 @@ func (s *Store) reestablishConnection() error { func (s *Store) generateDedupeQueries(tableID, stagingTableID types.TableIdentifier, primaryKeys []string, topicConfig kafkalib.TopicConfig) []string { var primaryKeysEscaped []string for _, pk := range primaryKeys { - primaryKeysEscaped = append(primaryKeysEscaped, sql.EscapeNameIfNecessary(pk, s.ShouldUppercaseEscapedNames(), s.Label())) + primaryKeysEscaped = append(primaryKeysEscaped, sql.EscapeNameIfNecessaryUsingDialect(pk, s.Dialect())) } orderColsToIterate := primaryKeysEscaped if topicConfig.IncludeArtieUpdatedAt { - orderColsToIterate = append(orderColsToIterate, sql.EscapeNameIfNecessary(constants.UpdateColumnMarker, s.ShouldUppercaseEscapedNames(), s.Label())) + orderColsToIterate = append(orderColsToIterate, sql.EscapeNameIfNecessaryUsingDialect(constants.UpdateColumnMarker, s.Dialect())) } var orderByCols []string diff --git a/lib/destination/dml/merge.go b/lib/destination/dml/merge.go index 03166d205..715f7ab5a 100644 --- a/lib/destination/dml/merge.go +++ b/lib/destination/dml/merge.go @@ -31,6 +31,7 @@ type MergeArgument struct { // where we do not issue a DELETE statement if there are no hard deletes in the batch ContainsHardDeletes *bool UppercaseEscNames *bool + Dialect sql.Dialect } func (m *MergeArgument) Valid() error { @@ -62,6 +63,10 @@ func (m *MergeArgument) Valid() error { return fmt.Errorf("invalid destination: %s", m.DestKind) } + if m.Dialect == nil { + return fmt.Errorf("dialect cannot be nil") + } + return nil } @@ -129,7 +134,7 @@ func (m *MergeArgument) GetParts() ([]string, error) { // We also need to remove __artie flags since it does not exist in the destination table var removed bool for idx, col := range cols { - if col == sql.EscapeNameIfNecessary(constants.DeleteColumnMarker, *m.UppercaseEscNames, m.DestKind) { + if col == sql.EscapeNameIfNecessaryUsingDialect(constants.DeleteColumnMarker, m.Dialect) { cols = append(cols[:idx], cols[idx+1:]...) removed = true break @@ -252,7 +257,7 @@ WHEN NOT MATCHED AND IFNULL(cc.%s, false) = false THEN INSERT (%s) VALUES (%s);` // We also need to remove __artie flags since it does not exist in the destination table var removed bool for idx, col := range cols { - if col == sql.EscapeNameIfNecessary(constants.DeleteColumnMarker, *m.UppercaseEscNames, m.DestKind) { + if col == sql.EscapeNameIfNecessaryUsingDialect(constants.DeleteColumnMarker, m.Dialect) { cols = append(cols[:idx], cols[idx+1:]...) removed = true break @@ -322,7 +327,7 @@ WHEN NOT MATCHED AND COALESCE(cc.%s, 0) = 0 THEN INSERT (%s) VALUES (%s);`, // We also need to remove __artie flags since it does not exist in the destination table var removed bool for idx, col := range cols { - if col == sql.EscapeNameIfNecessary(constants.DeleteColumnMarker, *m.UppercaseEscNames, m.DestKind) { + if col == sql.EscapeNameIfNecessaryUsingDialect(constants.DeleteColumnMarker, m.Dialect) { cols = append(cols[:idx], cols[idx+1:]...) removed = true break diff --git a/lib/destination/dml/merge_bigquery_test.go b/lib/destination/dml/merge_bigquery_test.go index 24cd40398..38524abb9 100644 --- a/lib/destination/dml/merge_bigquery_test.go +++ b/lib/destination/dml/merge_bigquery_test.go @@ -5,6 +5,7 @@ import ( "github.com/artie-labs/transfer/lib/config/constants" "github.com/artie-labs/transfer/lib/ptr" + "github.com/artie-labs/transfer/lib/sql" "github.com/artie-labs/transfer/lib/typing" "github.com/artie-labs/transfer/lib/typing/columns" "github.com/stretchr/testify/assert" @@ -22,6 +23,7 @@ func TestMergeStatement_TempTable(t *testing.T) { PrimaryKeys: []columns.Wrapper{columns.NewWrapper(columns.NewColumn("order_id", typing.Invalid), false, constants.BigQuery)}, Columns: &cols, DestKind: constants.BigQuery, + Dialect: sql.BigQueryDialect{}, SoftDelete: false, UppercaseEscNames: ptr.ToBool(false), } @@ -44,6 +46,7 @@ func TestMergeStatement_JSONKey(t *testing.T) { PrimaryKeys: []columns.Wrapper{columns.NewWrapper(columns.NewColumn("order_oid", typing.Invalid), false, constants.BigQuery)}, Columns: &cols, DestKind: constants.BigQuery, + Dialect: sql.BigQueryDialect{}, SoftDelete: false, UppercaseEscNames: ptr.ToBool(false), } diff --git a/lib/destination/dml/merge_mssql_test.go b/lib/destination/dml/merge_mssql_test.go index 385d4ace5..b64455140 100644 --- a/lib/destination/dml/merge_mssql_test.go +++ b/lib/destination/dml/merge_mssql_test.go @@ -8,6 +8,7 @@ import ( "github.com/artie-labs/transfer/lib/config/constants" "github.com/artie-labs/transfer/lib/ptr" + "github.com/artie-labs/transfer/lib/sql" "github.com/artie-labs/transfer/lib/typing" "github.com/artie-labs/transfer/lib/typing/columns" "github.com/stretchr/testify/assert" @@ -47,6 +48,7 @@ func Test_GetMSSQLStatement(t *testing.T) { PrimaryKeys: []columns.Wrapper{columns.NewWrapper(columns.NewColumn("id", typing.Invalid), false, constants.MSSQL)}, Columns: &_cols, DestKind: constants.MSSQL, + Dialect: sql.DefaultDialect{}, SoftDelete: false, UppercaseEscNames: ptr.ToBool(false), } diff --git a/lib/destination/dml/merge_parts_test.go b/lib/destination/dml/merge_parts_test.go index ed586272c..d6be3ce5c 100644 --- a/lib/destination/dml/merge_parts_test.go +++ b/lib/destination/dml/merge_parts_test.go @@ -5,6 +5,7 @@ import ( "github.com/artie-labs/transfer/lib/config/constants" "github.com/artie-labs/transfer/lib/ptr" + "github.com/artie-labs/transfer/lib/sql" "github.com/artie-labs/transfer/lib/typing" @@ -72,6 +73,7 @@ func TestMergeStatementParts_SkipDelete(t *testing.T) { PrimaryKeys: res.PrimaryKeys, Columns: &res.ColumnsToTypes, DestKind: constants.Redshift, + Dialect: sql.RedshiftDialect{}, ContainsHardDeletes: ptr.ToBool(false), UppercaseEscNames: ptr.ToBool(false), } @@ -99,6 +101,7 @@ func TestMergeStatementPartsSoftDelete(t *testing.T) { PrimaryKeys: res.PrimaryKeys, Columns: &res.ColumnsToTypes, DestKind: constants.Redshift, + Dialect: sql.RedshiftDialect{}, SoftDelete: true, UppercaseEscNames: ptr.ToBool(false), ContainsHardDeletes: ptr.ToBool(false), @@ -139,6 +142,7 @@ func TestMergeStatementPartsSoftDeleteComposite(t *testing.T) { PrimaryKeys: res.PrimaryKeys, Columns: &res.ColumnsToTypes, DestKind: constants.Redshift, + Dialect: sql.RedshiftDialect{}, SoftDelete: true, UppercaseEscNames: ptr.ToBool(false), ContainsHardDeletes: ptr.ToBool(false), @@ -182,6 +186,7 @@ func TestMergeStatementParts(t *testing.T) { PrimaryKeys: res.PrimaryKeys, Columns: &res.ColumnsToTypes, DestKind: constants.Redshift, + Dialect: sql.RedshiftDialect{}, ContainsHardDeletes: ptr.ToBool(true), UppercaseEscNames: ptr.ToBool(false), } @@ -208,6 +213,7 @@ func TestMergeStatementParts(t *testing.T) { PrimaryKeys: res.PrimaryKeys, Columns: &res.ColumnsToTypes, DestKind: constants.Redshift, + Dialect: sql.RedshiftDialect{}, IdempotentKey: "created_at", ContainsHardDeletes: ptr.ToBool(true), UppercaseEscNames: ptr.ToBool(false), @@ -240,6 +246,7 @@ func TestMergeStatementPartsCompositeKey(t *testing.T) { PrimaryKeys: res.PrimaryKeys, Columns: &res.ColumnsToTypes, DestKind: constants.Redshift, + Dialect: sql.RedshiftDialect{}, ContainsHardDeletes: ptr.ToBool(true), UppercaseEscNames: ptr.ToBool(false), } @@ -266,6 +273,7 @@ func TestMergeStatementPartsCompositeKey(t *testing.T) { PrimaryKeys: res.PrimaryKeys, Columns: &res.ColumnsToTypes, DestKind: constants.Redshift, + Dialect: sql.RedshiftDialect{}, ContainsHardDeletes: ptr.ToBool(true), IdempotentKey: "created_at", UppercaseEscNames: ptr.ToBool(false), diff --git a/lib/destination/dml/merge_test.go b/lib/destination/dml/merge_test.go index 0b17ab3b7..5057ffad7 100644 --- a/lib/destination/dml/merge_test.go +++ b/lib/destination/dml/merge_test.go @@ -11,6 +11,7 @@ import ( "github.com/artie-labs/transfer/lib/config/constants" "github.com/artie-labs/transfer/lib/destination/types" "github.com/artie-labs/transfer/lib/ptr" + "github.com/artie-labs/transfer/lib/sql" "github.com/artie-labs/transfer/lib/typing" "github.com/artie-labs/transfer/lib/typing/columns" ) @@ -64,6 +65,7 @@ func TestMergeStatementSoftDelete(t *testing.T) { PrimaryKeys: []columns.Wrapper{columns.NewWrapper(columns.NewColumn("id", typing.Invalid), false, constants.Snowflake)}, Columns: &_cols, DestKind: constants.Snowflake, + Dialect: sql.SnowflakeDialect{UppercaseEscNames: false}, SoftDelete: true, UppercaseEscNames: ptr.ToBool(false), } @@ -113,6 +115,7 @@ func TestMergeStatement(t *testing.T) { PrimaryKeys: []columns.Wrapper{columns.NewWrapper(columns.NewColumn("id", typing.Invalid), true, constants.Snowflake)}, Columns: &_cols, DestKind: constants.Snowflake, + Dialect: sql.SnowflakeDialect{UppercaseEscNames: true}, SoftDelete: false, UppercaseEscNames: ptr.ToBool(true), } @@ -161,6 +164,7 @@ func TestMergeStatementIdempotentKey(t *testing.T) { PrimaryKeys: []columns.Wrapper{columns.NewWrapper(columns.NewColumn("id", typing.Invalid), false, constants.Snowflake)}, Columns: &_cols, DestKind: constants.Snowflake, + Dialect: sql.SnowflakeDialect{UppercaseEscNames: false}, SoftDelete: false, UppercaseEscNames: ptr.ToBool(false), } @@ -206,6 +210,7 @@ func TestMergeStatementCompositeKey(t *testing.T) { }, Columns: &_cols, DestKind: constants.Snowflake, + Dialect: sql.SnowflakeDialect{UppercaseEscNames: false}, SoftDelete: false, UppercaseEscNames: ptr.ToBool(false), } @@ -255,6 +260,7 @@ func TestMergeStatementEscapePrimaryKeys(t *testing.T) { }, Columns: &_cols, DestKind: constants.Snowflake, + Dialect: sql.SnowflakeDialect{UppercaseEscNames: true}, SoftDelete: false, UppercaseEscNames: ptr.ToBool(true), } diff --git a/lib/destination/dml/merge_valid_test.go b/lib/destination/dml/merge_valid_test.go index e8223c97e..ae2a335e3 100644 --- a/lib/destination/dml/merge_valid_test.go +++ b/lib/destination/dml/merge_valid_test.go @@ -4,6 +4,7 @@ import ( "testing" "github.com/artie-labs/transfer/lib/config/constants" + "github.com/artie-labs/transfer/lib/sql" "github.com/artie-labs/transfer/lib/ptr" "github.com/artie-labs/transfer/lib/typing" @@ -90,6 +91,18 @@ func TestMergeArgument_Valid(t *testing.T) { }, expectedErr: "invalid destination", }, + { + name: "missing dialect kind", + mergeArg: &MergeArgument{ + PrimaryKeys: primaryKeys, + Columns: &cols, + SubQuery: "schema.tableName", + TableID: MockTableIdentifier{"schema.tableName"}, + UppercaseEscNames: ptr.ToBool(false), + DestKind: constants.BigQuery, + }, + expectedErr: "dialect cannot be nil", + }, { name: "everything exists", mergeArg: &MergeArgument{ @@ -99,6 +112,7 @@ func TestMergeArgument_Valid(t *testing.T) { TableID: MockTableIdentifier{"schema.tableName"}, UppercaseEscNames: ptr.ToBool(false), DestKind: constants.BigQuery, + Dialect: sql.BigQueryDialect{}, }, }, } diff --git a/lib/sql/escape.go b/lib/sql/escape.go index 1794da513..8ba1bf847 100644 --- a/lib/sql/escape.go +++ b/lib/sql/escape.go @@ -5,6 +5,7 @@ import ( ) func EscapeNameIfNecessary(name string, uppercaseEscNames bool, destKind constants.DestinationKind) string { + // TODO: Switch all calls of [EscapeNameIfNecessary] to [EscapeNameIfNecessaryUsingDialect] and kill this. var dialect = dialectFor(destKind, uppercaseEscNames) if destKind != constants.S3 && dialect.NeedsEscaping(name) { @@ -13,6 +14,13 @@ func EscapeNameIfNecessary(name string, uppercaseEscNames bool, destKind constan return name } +func EscapeNameIfNecessaryUsingDialect(name string, dialect Dialect) string { + if dialect.NeedsEscaping(name) { + return dialect.QuoteIdentifier(name) + } + return name +} + func dialectFor(destKind constants.DestinationKind, uppercaseEscNames bool) Dialect { switch destKind { case constants.BigQuery: