From 916b120d89054be56e0248e700d4aa5486154119 Mon Sep 17 00:00:00 2001 From: Nathan Villaescusa Date: Wed, 24 Apr 2024 15:26:33 -0700 Subject: [PATCH 1/2] Pass dest kind --- clients/shared/merge.go | 2 +- clients/shared/utils.go | 2 +- clients/snowflake/snowflake.go | 2 +- lib/destination/ddl/ddl.go | 8 ++------ lib/destination/ddl/ddl_bq_test.go | 14 +++++--------- lib/destination/ddl/ddl_sflk_test.go | 4 ++-- lib/destination/dml/merge_bigquery_test.go | 4 ++-- lib/destination/dml/merge_mssql_test.go | 2 +- lib/destination/dml/merge_parts_test.go | 8 ++------ lib/destination/dml/merge_test.go | 20 +++++++++----------- lib/destination/dml/merge_valid_test.go | 2 +- lib/optimization/table_data.go | 4 ++-- lib/typing/columns/columns.go | 16 ++++++++-------- lib/typing/columns/columns_test.go | 8 ++------ lib/typing/columns/wrapper.go | 6 ++++-- lib/typing/columns/wrapper_test.go | 19 +++---------------- 16 files changed, 46 insertions(+), 75 deletions(-) diff --git a/clients/shared/merge.go b/clients/shared/merge.go index 16bcb740b..b9247faa6 100644 --- a/clients/shared/merge.go +++ b/clients/shared/merge.go @@ -122,7 +122,7 @@ func Merge(dwh destination.DataWarehouse, tableData *optimization.TableData, cfg TableID: tableID, SubQuery: subQuery, IdempotentKey: tableData.TopicConfig().IdempotentKey, - PrimaryKeys: tableData.PrimaryKeys(dwh.ShouldUppercaseEscapedNames(), &columns.NameArgs{DestKind: dwh.Label()}), + PrimaryKeys: tableData.PrimaryKeys(dwh.ShouldUppercaseEscapedNames(), dwh.Label()), Columns: tableData.ReadOnlyInMemoryCols(), SoftDelete: tableData.TopicConfig().SoftDelete, DestKind: dwh.Label(), diff --git a/clients/shared/utils.go b/clients/shared/utils.go index dc2d0de0c..b3b236604 100644 --- a/clients/shared/utils.go +++ b/clients/shared/utils.go @@ -30,7 +30,7 @@ func BackfillColumn(cfg config.Config, dwh destination.DataWarehouse, column col return fmt.Errorf("failed to escape default value: %w", err) } - escapedCol := column.Name(dwh.ShouldUppercaseEscapedNames(), &columns.NameArgs{DestKind: dwh.Label()}) + escapedCol := column.Name(dwh.ShouldUppercaseEscapedNames(), dwh.Label()) // TODO: This is added because `default` is not technically a column that requires escaping, but it is required when it's in the where clause. // Once we escape everything by default, we can remove this patch of code. diff --git a/clients/snowflake/snowflake.go b/clients/snowflake/snowflake.go index 6c4105f84..312510bca 100644 --- a/clients/snowflake/snowflake.go +++ b/clients/snowflake/snowflake.go @@ -129,7 +129,7 @@ func (s *Store) generateDedupeQueries(tableID, stagingTableID types.TableIdentif var primaryKeysEscaped []string for _, pk := range primaryKeys { pkCol := columns.NewColumn(pk, typing.Invalid) - primaryKeysEscaped = append(primaryKeysEscaped, pkCol.Name(s.ShouldUppercaseEscapedNames(), &columns.NameArgs{DestKind: s.Label()})) + primaryKeysEscaped = append(primaryKeysEscaped, pkCol.Name(s.ShouldUppercaseEscapedNames(), s.Label())) } orderColsToIterate := primaryKeysEscaped diff --git a/lib/destination/ddl/ddl.go b/lib/destination/ddl/ddl.go index 6b92fe2a2..723238ac6 100644 --- a/lib/destination/ddl/ddl.go +++ b/lib/destination/ddl/ddl.go @@ -104,9 +104,7 @@ func (a AlterTableArgs) AlterTable(cols ...columns.Column) error { mutateCol = append(mutateCol, col) switch a.ColumnOp { case constants.Add: - colName := col.Name(*a.UppercaseEscNames, &columns.NameArgs{ - DestKind: a.Dwh.Label(), - }) + colName := col.Name(*a.UppercaseEscNames, a.Dwh.Label()) if col.PrimaryKey() && a.Mode != config.History { // Don't create a PK for history mode because it's append-only, so the primary key should not be enforced. @@ -115,9 +113,7 @@ func (a AlterTableArgs) AlterTable(cols ...columns.Column) error { colSQLParts = append(colSQLParts, fmt.Sprintf(`%s %s`, colName, typing.KindToDWHType(col.KindDetails, a.Dwh.Label(), col.PrimaryKey()))) case constants.Delete: - colSQLParts = append(colSQLParts, col.Name(*a.UppercaseEscNames, &columns.NameArgs{ - DestKind: a.Dwh.Label(), - })) + colSQLParts = append(colSQLParts, col.Name(*a.UppercaseEscNames, a.Dwh.Label())) } } diff --git a/lib/destination/ddl/ddl_bq_test.go b/lib/destination/ddl/ddl_bq_test.go index bf84cb4e4..62c6a19bc 100644 --- a/lib/destination/ddl/ddl_bq_test.go +++ b/lib/destination/ddl/ddl_bq_test.go @@ -90,9 +90,7 @@ func (d *DDLTestSuite) TestAlterTableDropColumnsBigQuery() { assert.NoError(d.T(), alterTableArgs.AlterTable(column)) query, _ := d.fakeBigQueryStore.ExecArgsForCall(callIdx) - assert.Equal(d.T(), fmt.Sprintf("ALTER TABLE %s drop COLUMN %s", fqName, column.Name(false, &columns.NameArgs{ - DestKind: d.bigQueryStore.Label(), - })), query) + assert.Equal(d.T(), fmt.Sprintf("ALTER TABLE %s drop COLUMN %s", fqName, column.Name(false, d.bigQueryStore.Label())), query) callIdx += 1 } @@ -150,9 +148,8 @@ func (d *DDLTestSuite) TestAlterTableAddColumns() { assert.NoError(d.T(), alterTableArgs.AlterTable(col)) query, _ := d.fakeBigQueryStore.ExecArgsForCall(callIdx) - assert.Equal(d.T(), fmt.Sprintf("ALTER TABLE %s %s COLUMN %s %s", fqName, constants.Add, col.Name(false, &columns.NameArgs{ - DestKind: d.bigQueryStore.Label(), - }), typing.KindToDWHType(kind, d.bigQueryStore.Label(), false)), query) + assert.Equal(d.T(), fmt.Sprintf("ALTER TABLE %s %s COLUMN %s %s", fqName, constants.Add, col.Name(false, d.bigQueryStore.Label()), + typing.KindToDWHType(kind, d.bigQueryStore.Label(), false)), query) callIdx += 1 } @@ -211,9 +208,8 @@ func (d *DDLTestSuite) TestAlterTableAddColumnsSomeAlreadyExist() { assert.NoError(d.T(), alterTableArgs.AlterTable(column)) query, _ := d.fakeBigQueryStore.ExecArgsForCall(callIdx) - assert.Equal(d.T(), fmt.Sprintf("ALTER TABLE %s %s COLUMN %s %s", fqName, constants.Add, column.Name(false, &columns.NameArgs{ - DestKind: d.bigQueryStore.Label(), - }), typing.KindToDWHType(column.KindDetails, d.bigQueryStore.Label(), false)), query) + assert.Equal(d.T(), fmt.Sprintf("ALTER TABLE %s %s COLUMN %s %s", fqName, constants.Add, column.Name(false, d.bigQueryStore.Label()), + typing.KindToDWHType(column.KindDetails, d.bigQueryStore.Label(), false)), query) callIdx += 1 } diff --git a/lib/destination/ddl/ddl_sflk_test.go b/lib/destination/ddl/ddl_sflk_test.go index 4c4da59a3..2305929d2 100644 --- a/lib/destination/ddl/ddl_sflk_test.go +++ b/lib/destination/ddl/ddl_sflk_test.go @@ -47,7 +47,7 @@ func (d *DDLTestSuite) TestAlterComplexObjects() { for i := 0; i < len(cols); i++ { execQuery, _ := d.fakeSnowflakeStagesStore.ExecArgsForCall(i) assert.Equal(d.T(), fmt.Sprintf("ALTER TABLE %s add COLUMN %s %s", `shop.public."COMPLEX_COLUMNS"`, - cols[i].Name(false, &columns.NameArgs{DestKind: d.snowflakeStagesStore.Label()}), + cols[i].Name(false, d.snowflakeStagesStore.Label()), typing.KindToDWHType(cols[i].KindDetails, d.snowflakeStagesStore.Label(), false)), execQuery) } @@ -180,7 +180,7 @@ func (d *DDLTestSuite) TestAlterTableDeleteDryRun() { execArg, _ := d.fakeSnowflakeStagesStore.ExecArgsForCall(i) assert.Equal(d.T(), execArg, fmt.Sprintf("ALTER TABLE %s %s COLUMN %s", `shop.public."USERS"`, constants.Delete, - cols[i].Name(false, &columns.NameArgs{DestKind: d.snowflakeStagesStore.Label()}))) + cols[i].Name(false, d.snowflakeStagesStore.Label()))) } } diff --git a/lib/destination/dml/merge_bigquery_test.go b/lib/destination/dml/merge_bigquery_test.go index d5b2294ce..6fb44d488 100644 --- a/lib/destination/dml/merge_bigquery_test.go +++ b/lib/destination/dml/merge_bigquery_test.go @@ -19,7 +19,7 @@ func TestMergeStatement_TempTable(t *testing.T) { mergeArg := &MergeArgument{ TableID: MockTableIdentifier{"customers.orders"}, SubQuery: "customers.orders_tmp", - PrimaryKeys: []columns.Wrapper{columns.NewWrapper(columns.NewColumn("order_id", typing.Invalid), false, nil)}, + PrimaryKeys: []columns.Wrapper{columns.NewWrapper(columns.NewColumn("order_id", typing.Invalid), false, constants.BigQuery)}, Columns: &cols, DestKind: constants.BigQuery, SoftDelete: false, @@ -41,7 +41,7 @@ func TestMergeStatement_JSONKey(t *testing.T) { mergeArg := &MergeArgument{ TableID: MockTableIdentifier{"customers.orders"}, SubQuery: "customers.orders_tmp", - PrimaryKeys: []columns.Wrapper{columns.NewWrapper(columns.NewColumn("order_oid", typing.Invalid), false, nil)}, + PrimaryKeys: []columns.Wrapper{columns.NewWrapper(columns.NewColumn("order_oid", typing.Invalid), false, constants.BigQuery)}, Columns: &cols, DestKind: constants.BigQuery, SoftDelete: false, diff --git a/lib/destination/dml/merge_mssql_test.go b/lib/destination/dml/merge_mssql_test.go index edc58ef70..cd3f60a1c 100644 --- a/lib/destination/dml/merge_mssql_test.go +++ b/lib/destination/dml/merge_mssql_test.go @@ -44,7 +44,7 @@ func Test_GetMSSQLStatement(t *testing.T) { TableID: MockTableIdentifier{fqTable}, SubQuery: subQuery, IdempotentKey: "", - PrimaryKeys: []columns.Wrapper{columns.NewWrapper(columns.NewColumn("id", typing.Invalid), false, nil)}, + PrimaryKeys: []columns.Wrapper{columns.NewWrapper(columns.NewColumn("id", typing.Invalid), false, constants.MSSQL)}, Columns: &_cols, DestKind: constants.MSSQL, SoftDelete: false, diff --git a/lib/destination/dml/merge_parts_test.go b/lib/destination/dml/merge_parts_test.go index 1f17223ec..23f07bb94 100644 --- a/lib/destination/dml/merge_parts_test.go +++ b/lib/destination/dml/merge_parts_test.go @@ -47,14 +47,10 @@ func getBasicColumnsForTest(compositeKey bool, uppercaseEscNames bool) result { cols.AddColumn(columns.NewColumn(constants.DeleteColumnMarker, typing.Boolean)) var pks []columns.Wrapper - pks = append(pks, columns.NewWrapper(idCol, uppercaseEscNames, &columns.NameArgs{ - DestKind: constants.Redshift, - })) + pks = append(pks, columns.NewWrapper(idCol, uppercaseEscNames, constants.Redshift)) if compositeKey { - pks = append(pks, columns.NewWrapper(emailCol, uppercaseEscNames, &columns.NameArgs{ - DestKind: constants.Redshift, - })) + pks = append(pks, columns.NewWrapper(emailCol, uppercaseEscNames, constants.Redshift)) } return result{ diff --git a/lib/destination/dml/merge_test.go b/lib/destination/dml/merge_test.go index aa08a5c8c..962902109 100644 --- a/lib/destination/dml/merge_test.go +++ b/lib/destination/dml/merge_test.go @@ -61,7 +61,7 @@ func TestMergeStatementSoftDelete(t *testing.T) { TableID: MockTableIdentifier{fqTable}, SubQuery: subQuery, IdempotentKey: idempotentKey, - PrimaryKeys: []columns.Wrapper{columns.NewWrapper(columns.NewColumn("id", typing.Invalid), false, nil)}, + PrimaryKeys: []columns.Wrapper{columns.NewWrapper(columns.NewColumn("id", typing.Invalid), false, constants.Snowflake)}, Columns: &_cols, DestKind: constants.Snowflake, SoftDelete: true, @@ -110,7 +110,7 @@ func TestMergeStatement(t *testing.T) { TableID: MockTableIdentifier{fqTable}, SubQuery: subQuery, IdempotentKey: "", - PrimaryKeys: []columns.Wrapper{columns.NewWrapper(columns.NewColumn("id", typing.Invalid), false, nil)}, + PrimaryKeys: []columns.Wrapper{columns.NewWrapper(columns.NewColumn("id", typing.Invalid), false, constants.Snowflake)}, Columns: &_cols, DestKind: constants.Snowflake, SoftDelete: false, @@ -158,7 +158,7 @@ func TestMergeStatementIdempotentKey(t *testing.T) { TableID: MockTableIdentifier{fqTable}, SubQuery: subQuery, IdempotentKey: "updated_at", - PrimaryKeys: []columns.Wrapper{columns.NewWrapper(columns.NewColumn("id", typing.Invalid), false, nil)}, + PrimaryKeys: []columns.Wrapper{columns.NewWrapper(columns.NewColumn("id", typing.Invalid), false, constants.Snowflake)}, Columns: &_cols, DestKind: constants.Snowflake, SoftDelete: false, @@ -200,8 +200,10 @@ func TestMergeStatementCompositeKey(t *testing.T) { TableID: MockTableIdentifier{fqTable}, SubQuery: subQuery, IdempotentKey: "updated_at", - PrimaryKeys: []columns.Wrapper{columns.NewWrapper(columns.NewColumn("id", typing.Invalid), false, nil), - columns.NewWrapper(columns.NewColumn("another_id", typing.Invalid), false, nil)}, + PrimaryKeys: []columns.Wrapper{ + columns.NewWrapper(columns.NewColumn("id", typing.Invalid), false, constants.Snowflake), + columns.NewWrapper(columns.NewColumn("another_id", typing.Invalid), false, constants.Snowflake), + }, Columns: &_cols, DestKind: constants.Snowflake, SoftDelete: false, @@ -248,12 +250,8 @@ func TestMergeStatementEscapePrimaryKeys(t *testing.T) { SubQuery: subQuery, IdempotentKey: "", PrimaryKeys: []columns.Wrapper{ - columns.NewWrapper(columns.NewColumn("id", typing.Invalid), false, &columns.NameArgs{ - DestKind: constants.Snowflake, - }), - columns.NewWrapper(columns.NewColumn("group", typing.Invalid), false, &columns.NameArgs{ - DestKind: constants.Snowflake, - }), + columns.NewWrapper(columns.NewColumn("id", typing.Invalid), false, constants.Snowflake), + columns.NewWrapper(columns.NewColumn("group", typing.Invalid), false, constants.Snowflake), }, Columns: &_cols, DestKind: constants.Snowflake, diff --git a/lib/destination/dml/merge_valid_test.go b/lib/destination/dml/merge_valid_test.go index 00fe12a2c..e8223c97e 100644 --- a/lib/destination/dml/merge_valid_test.go +++ b/lib/destination/dml/merge_valid_test.go @@ -14,7 +14,7 @@ import ( func TestMergeArgument_Valid(t *testing.T) { primaryKeys := []columns.Wrapper{ - columns.NewWrapper(columns.NewColumn("id", typing.Integer), false, nil), + columns.NewWrapper(columns.NewColumn("id", typing.Integer), false, constants.Snowflake), } var cols columns.Columns diff --git a/lib/optimization/table_data.go b/lib/optimization/table_data.go index c33346a8e..89b3eb264 100644 --- a/lib/optimization/table_data.go +++ b/lib/optimization/table_data.go @@ -66,10 +66,10 @@ func (t *TableData) ContainOtherOperations() bool { return t.containOtherOperations } -func (t *TableData) PrimaryKeys(uppercaseEscNames bool, args *columns.NameArgs) []columns.Wrapper { +func (t *TableData) PrimaryKeys(uppercaseEscNames bool, destKind constants.DestinationKind) []columns.Wrapper { var pks []columns.Wrapper for _, pk := range t.primaryKeys { - pks = append(pks, columns.NewWrapper(columns.NewColumn(pk, typing.Invalid), uppercaseEscNames, args)) + pks = append(pks, columns.NewWrapper(columns.NewColumn(pk, typing.Invalid), uppercaseEscNames, destKind)) } return pks diff --git a/lib/typing/columns/columns.go b/lib/typing/columns/columns.go index 4f8a3d5dd..cf5fa6c20 100644 --- a/lib/typing/columns/columns.go +++ b/lib/typing/columns/columns.go @@ -90,12 +90,8 @@ type NameArgs struct { // Name will give you c.name // However, if you pass in escape, we will escape if the column name is part of the reserved words from destinations. // If so, it'll change from `start` => `"start"` as suggested by Snowflake. -func (c *Column) Name(uppercaseEscNames bool, args *NameArgs) string { - // TODO: Kill [NameArgs] and just pass a [DestinationKind]. - if args == nil { - return c.name - } - return sql.EscapeNameIfNecessary(c.name, uppercaseEscNames, args.DestKind) +func (c *Column) Name(uppercaseEscNames bool, destKind constants.DestinationKind) string { + return sql.EscapeNameIfNecessary(c.name, uppercaseEscNames, destKind) } type Columns struct { @@ -198,7 +194,11 @@ func (c *Columns) GetColumnsToUpdate(uppercaseEscNames bool, args *NameArgs) []s continue } - cols = append(cols, col.Name(uppercaseEscNames, args)) + if args == nil { + cols = append(cols, col.RawName()) + } else { + cols = append(cols, col.Name(uppercaseEscNames, args.DestKind)) + } } return cols @@ -255,7 +255,7 @@ func (c *Columns) UpdateQuery(destKind constants.DestinationKind, uppercaseEscNa continue } - colName := column.Name(uppercaseEscNames, &NameArgs{DestKind: destKind}) + colName := column.Name(uppercaseEscNames, destKind) if column.ToastColumn { if column.KindDetails == typing.Struct { cols = append(cols, processToastStructCol(colName, destKind)) diff --git a/lib/typing/columns/columns_test.go b/lib/typing/columns/columns_test.go index ab0059d8d..a67f09dab 100644 --- a/lib/typing/columns/columns_test.go +++ b/lib/typing/columns/columns_test.go @@ -169,12 +169,8 @@ func TestColumn_Name(t *testing.T) { assert.Equal(t, testCase.expectedName, col.RawName(), testCase.colName) - assert.Equal(t, testCase.expectedNameEsc, col.Name(false, &NameArgs{ - DestKind: constants.Snowflake, - }), testCase.colName) - assert.Equal(t, testCase.expectedNameEscBq, col.Name(false, &NameArgs{ - DestKind: constants.BigQuery, - }), testCase.colName) + assert.Equal(t, testCase.expectedNameEsc, col.Name(false, constants.Snowflake), testCase.colName) + assert.Equal(t, testCase.expectedNameEscBq, col.Name(false, constants.BigQuery), testCase.colName) } } diff --git a/lib/typing/columns/wrapper.go b/lib/typing/columns/wrapper.go index ecac1d4fd..cb5f37643 100644 --- a/lib/typing/columns/wrapper.go +++ b/lib/typing/columns/wrapper.go @@ -1,14 +1,16 @@ package columns +import "github.com/artie-labs/transfer/lib/config/constants" + type Wrapper struct { name string escapedName string } -func NewWrapper(col Column, uppercaseEscNames bool, args *NameArgs) Wrapper { +func NewWrapper(col Column, uppercaseEscNames bool, destKind constants.DestinationKind) Wrapper { return Wrapper{ name: col.name, - escapedName: col.Name(uppercaseEscNames, args), + escapedName: col.Name(uppercaseEscNames, destKind), } } diff --git a/lib/typing/columns/wrapper_test.go b/lib/typing/columns/wrapper_test.go index 8c65de7a2..7a624ab18 100644 --- a/lib/typing/columns/wrapper_test.go +++ b/lib/typing/columns/wrapper_test.go @@ -41,33 +41,20 @@ func TestWrapper_Complete(t *testing.T) { for _, testCase := range testCases { // Snowflake escape - w := NewWrapper(NewColumn(testCase.name, typing.Invalid), false, &NameArgs{ - DestKind: constants.Snowflake, - }) + w := NewWrapper(NewColumn(testCase.name, typing.Invalid), false, constants.Snowflake) assert.Equal(t, testCase.expectedEscapedName, w.EscapedName(), testCase.name) assert.Equal(t, testCase.expectedRawName, w.RawName(), testCase.name) // BigQuery escape - w = NewWrapper(NewColumn(testCase.name, typing.Invalid), false, &NameArgs{ - DestKind: constants.BigQuery, - }) + w = NewWrapper(NewColumn(testCase.name, typing.Invalid), false, constants.BigQuery) assert.Equal(t, testCase.expectedEscapedNameBQ, w.EscapedName(), testCase.name) assert.Equal(t, testCase.expectedRawName, w.RawName(), testCase.name) for _, destKind := range []constants.DestinationKind{constants.Snowflake, constants.BigQuery} { - w = NewWrapper(NewColumn(testCase.name, typing.Invalid), false, &NameArgs{ - DestKind: destKind, - }) + w = NewWrapper(NewColumn(testCase.name, typing.Invalid), false, destKind) assert.Equal(t, testCase.expectedRawName, w.RawName(), testCase.name) } - - // Same if nil - w = NewWrapper(NewColumn(testCase.name, typing.Invalid), false, nil) - - assert.Equal(t, testCase.expectedRawName, w.EscapedName(), testCase.name) - assert.Equal(t, testCase.expectedRawName, w.RawName(), testCase.name) - } } From 51a8332ee71e8665bc2ba7edd257f146b3866ed2 Mon Sep 17 00:00:00 2001 From: Nathan Villaescusa Date: Wed, 24 Apr 2024 15:41:08 -0700 Subject: [PATCH 2/2] Comment --- lib/typing/columns/columns.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/typing/columns/columns.go b/lib/typing/columns/columns.go index cf5fa6c20..abd6ea4a7 100644 --- a/lib/typing/columns/columns.go +++ b/lib/typing/columns/columns.go @@ -88,7 +88,7 @@ type NameArgs struct { } // Name will give you c.name -// However, if you pass in escape, we will escape if the column name is part of the reserved words from destinations. +// Plus we will escape it if the column name is part of the reserved words from destinations. // If so, it'll change from `start` => `"start"` as suggested by Snowflake. func (c *Column) Name(uppercaseEscNames bool, destKind constants.DestinationKind) string { return sql.EscapeNameIfNecessary(c.name, uppercaseEscNames, destKind)