Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[typing] Pass DestinationKind to Column.Name #503

Merged
merged 2 commits into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion clients/shared/merge.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ func Merge(dwh destination.DataWarehouse, tableData *optimization.TableData, cfg
TableID: tableID,
SubQuery: subQuery,
IdempotentKey: tableData.TopicConfig().IdempotentKey,
PrimaryKeys: tableData.PrimaryKeys(dwh.ShouldUppercaseEscapedNames(), &columns.NameArgs{DestKind: dwh.Label()}),
PrimaryKeys: tableData.PrimaryKeys(dwh.ShouldUppercaseEscapedNames(), dwh.Label()),
Columns: tableData.ReadOnlyInMemoryCols(),
SoftDelete: tableData.TopicConfig().SoftDelete,
DestKind: dwh.Label(),
Expand Down
2 changes: 1 addition & 1 deletion clients/shared/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func BackfillColumn(cfg config.Config, dwh destination.DataWarehouse, column col
return fmt.Errorf("failed to escape default value: %w", err)
}

escapedCol := column.Name(dwh.ShouldUppercaseEscapedNames(), &columns.NameArgs{DestKind: dwh.Label()})
escapedCol := column.Name(dwh.ShouldUppercaseEscapedNames(), dwh.Label())

// TODO: This is added because `default` is not technically a column that requires escaping, but it is required when it's in the where clause.
// Once we escape everything by default, we can remove this patch of code.
Expand Down
2 changes: 1 addition & 1 deletion clients/snowflake/snowflake.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ func (s *Store) generateDedupeQueries(tableID, stagingTableID types.TableIdentif
var primaryKeysEscaped []string
for _, pk := range primaryKeys {
pkCol := columns.NewColumn(pk, typing.Invalid)
primaryKeysEscaped = append(primaryKeysEscaped, pkCol.Name(s.ShouldUppercaseEscapedNames(), &columns.NameArgs{DestKind: s.Label()}))
primaryKeysEscaped = append(primaryKeysEscaped, pkCol.Name(s.ShouldUppercaseEscapedNames(), s.Label()))
}

orderColsToIterate := primaryKeysEscaped
Expand Down
8 changes: 2 additions & 6 deletions lib/destination/ddl/ddl.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,7 @@ func (a AlterTableArgs) AlterTable(cols ...columns.Column) error {
mutateCol = append(mutateCol, col)
switch a.ColumnOp {
case constants.Add:
colName := col.Name(*a.UppercaseEscNames, &columns.NameArgs{
DestKind: a.Dwh.Label(),
})
colName := col.Name(*a.UppercaseEscNames, a.Dwh.Label())

if col.PrimaryKey() && a.Mode != config.History {
// Don't create a PK for history mode because it's append-only, so the primary key should not be enforced.
Expand All @@ -115,9 +113,7 @@ func (a AlterTableArgs) AlterTable(cols ...columns.Column) error {

colSQLParts = append(colSQLParts, fmt.Sprintf(`%s %s`, colName, typing.KindToDWHType(col.KindDetails, a.Dwh.Label(), col.PrimaryKey())))
case constants.Delete:
colSQLParts = append(colSQLParts, col.Name(*a.UppercaseEscNames, &columns.NameArgs{
DestKind: a.Dwh.Label(),
}))
colSQLParts = append(colSQLParts, col.Name(*a.UppercaseEscNames, a.Dwh.Label()))
}
}

Expand Down
14 changes: 5 additions & 9 deletions lib/destination/ddl/ddl_bq_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,7 @@ func (d *DDLTestSuite) TestAlterTableDropColumnsBigQuery() {

assert.NoError(d.T(), alterTableArgs.AlterTable(column))
query, _ := d.fakeBigQueryStore.ExecArgsForCall(callIdx)
assert.Equal(d.T(), fmt.Sprintf("ALTER TABLE %s drop COLUMN %s", fqName, column.Name(false, &columns.NameArgs{
DestKind: d.bigQueryStore.Label(),
})), query)
assert.Equal(d.T(), fmt.Sprintf("ALTER TABLE %s drop COLUMN %s", fqName, column.Name(false, d.bigQueryStore.Label())), query)
callIdx += 1
}

Expand Down Expand Up @@ -150,9 +148,8 @@ func (d *DDLTestSuite) TestAlterTableAddColumns() {

assert.NoError(d.T(), alterTableArgs.AlterTable(col))
query, _ := d.fakeBigQueryStore.ExecArgsForCall(callIdx)
assert.Equal(d.T(), fmt.Sprintf("ALTER TABLE %s %s COLUMN %s %s", fqName, constants.Add, col.Name(false, &columns.NameArgs{
DestKind: d.bigQueryStore.Label(),
}), typing.KindToDWHType(kind, d.bigQueryStore.Label(), false)), query)
assert.Equal(d.T(), fmt.Sprintf("ALTER TABLE %s %s COLUMN %s %s", fqName, constants.Add, col.Name(false, d.bigQueryStore.Label()),
typing.KindToDWHType(kind, d.bigQueryStore.Label(), false)), query)
callIdx += 1
}

Expand Down Expand Up @@ -211,9 +208,8 @@ func (d *DDLTestSuite) TestAlterTableAddColumnsSomeAlreadyExist() {

assert.NoError(d.T(), alterTableArgs.AlterTable(column))
query, _ := d.fakeBigQueryStore.ExecArgsForCall(callIdx)
assert.Equal(d.T(), fmt.Sprintf("ALTER TABLE %s %s COLUMN %s %s", fqName, constants.Add, column.Name(false, &columns.NameArgs{
DestKind: d.bigQueryStore.Label(),
}), typing.KindToDWHType(column.KindDetails, d.bigQueryStore.Label(), false)), query)
assert.Equal(d.T(), fmt.Sprintf("ALTER TABLE %s %s COLUMN %s %s", fqName, constants.Add, column.Name(false, d.bigQueryStore.Label()),
typing.KindToDWHType(column.KindDetails, d.bigQueryStore.Label(), false)), query)
callIdx += 1
}

Expand Down
4 changes: 2 additions & 2 deletions lib/destination/ddl/ddl_sflk_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func (d *DDLTestSuite) TestAlterComplexObjects() {
for i := 0; i < len(cols); i++ {
execQuery, _ := d.fakeSnowflakeStagesStore.ExecArgsForCall(i)
assert.Equal(d.T(), fmt.Sprintf("ALTER TABLE %s add COLUMN %s %s", `shop.public."COMPLEX_COLUMNS"`,
cols[i].Name(false, &columns.NameArgs{DestKind: d.snowflakeStagesStore.Label()}),
cols[i].Name(false, d.snowflakeStagesStore.Label()),
typing.KindToDWHType(cols[i].KindDetails, d.snowflakeStagesStore.Label(), false)), execQuery)
}

Expand Down Expand Up @@ -180,7 +180,7 @@ func (d *DDLTestSuite) TestAlterTableDeleteDryRun() {

execArg, _ := d.fakeSnowflakeStagesStore.ExecArgsForCall(i)
assert.Equal(d.T(), execArg, fmt.Sprintf("ALTER TABLE %s %s COLUMN %s", `shop.public."USERS"`, constants.Delete,
cols[i].Name(false, &columns.NameArgs{DestKind: d.snowflakeStagesStore.Label()})))
cols[i].Name(false, d.snowflakeStagesStore.Label())))
}
}

Expand Down
4 changes: 2 additions & 2 deletions lib/destination/dml/merge_bigquery_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ func TestMergeStatement_TempTable(t *testing.T) {
mergeArg := &MergeArgument{
TableID: MockTableIdentifier{"customers.orders"},
SubQuery: "customers.orders_tmp",
PrimaryKeys: []columns.Wrapper{columns.NewWrapper(columns.NewColumn("order_id", typing.Invalid), false, nil)},
PrimaryKeys: []columns.Wrapper{columns.NewWrapper(columns.NewColumn("order_id", typing.Invalid), false, constants.BigQuery)},
Columns: &cols,
DestKind: constants.BigQuery,
SoftDelete: false,
Expand All @@ -41,7 +41,7 @@ func TestMergeStatement_JSONKey(t *testing.T) {
mergeArg := &MergeArgument{
TableID: MockTableIdentifier{"customers.orders"},
SubQuery: "customers.orders_tmp",
PrimaryKeys: []columns.Wrapper{columns.NewWrapper(columns.NewColumn("order_oid", typing.Invalid), false, nil)},
PrimaryKeys: []columns.Wrapper{columns.NewWrapper(columns.NewColumn("order_oid", typing.Invalid), false, constants.BigQuery)},
Columns: &cols,
DestKind: constants.BigQuery,
SoftDelete: false,
Expand Down
2 changes: 1 addition & 1 deletion lib/destination/dml/merge_mssql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func Test_GetMSSQLStatement(t *testing.T) {
TableID: MockTableIdentifier{fqTable},
SubQuery: subQuery,
IdempotentKey: "",
PrimaryKeys: []columns.Wrapper{columns.NewWrapper(columns.NewColumn("id", typing.Invalid), false, nil)},
PrimaryKeys: []columns.Wrapper{columns.NewWrapper(columns.NewColumn("id", typing.Invalid), false, constants.MSSQL)},
Columns: &_cols,
DestKind: constants.MSSQL,
SoftDelete: false,
Expand Down
8 changes: 2 additions & 6 deletions lib/destination/dml/merge_parts_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,10 @@ func getBasicColumnsForTest(compositeKey bool, uppercaseEscNames bool) result {
cols.AddColumn(columns.NewColumn(constants.DeleteColumnMarker, typing.Boolean))

var pks []columns.Wrapper
pks = append(pks, columns.NewWrapper(idCol, uppercaseEscNames, &columns.NameArgs{
DestKind: constants.Redshift,
}))
pks = append(pks, columns.NewWrapper(idCol, uppercaseEscNames, constants.Redshift))

if compositeKey {
pks = append(pks, columns.NewWrapper(emailCol, uppercaseEscNames, &columns.NameArgs{
DestKind: constants.Redshift,
}))
pks = append(pks, columns.NewWrapper(emailCol, uppercaseEscNames, constants.Redshift))
}

return result{
Expand Down
20 changes: 9 additions & 11 deletions lib/destination/dml/merge_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ func TestMergeStatementSoftDelete(t *testing.T) {
TableID: MockTableIdentifier{fqTable},
SubQuery: subQuery,
IdempotentKey: idempotentKey,
PrimaryKeys: []columns.Wrapper{columns.NewWrapper(columns.NewColumn("id", typing.Invalid), false, nil)},
PrimaryKeys: []columns.Wrapper{columns.NewWrapper(columns.NewColumn("id", typing.Invalid), false, constants.Snowflake)},
Columns: &_cols,
DestKind: constants.Snowflake,
SoftDelete: true,
Expand Down Expand Up @@ -110,7 +110,7 @@ func TestMergeStatement(t *testing.T) {
TableID: MockTableIdentifier{fqTable},
SubQuery: subQuery,
IdempotentKey: "",
PrimaryKeys: []columns.Wrapper{columns.NewWrapper(columns.NewColumn("id", typing.Invalid), false, nil)},
PrimaryKeys: []columns.Wrapper{columns.NewWrapper(columns.NewColumn("id", typing.Invalid), false, constants.Snowflake)},
Columns: &_cols,
DestKind: constants.Snowflake,
SoftDelete: false,
Expand Down Expand Up @@ -158,7 +158,7 @@ func TestMergeStatementIdempotentKey(t *testing.T) {
TableID: MockTableIdentifier{fqTable},
SubQuery: subQuery,
IdempotentKey: "updated_at",
PrimaryKeys: []columns.Wrapper{columns.NewWrapper(columns.NewColumn("id", typing.Invalid), false, nil)},
PrimaryKeys: []columns.Wrapper{columns.NewWrapper(columns.NewColumn("id", typing.Invalid), false, constants.Snowflake)},
Columns: &_cols,
DestKind: constants.Snowflake,
SoftDelete: false,
Expand Down Expand Up @@ -200,8 +200,10 @@ func TestMergeStatementCompositeKey(t *testing.T) {
TableID: MockTableIdentifier{fqTable},
SubQuery: subQuery,
IdempotentKey: "updated_at",
PrimaryKeys: []columns.Wrapper{columns.NewWrapper(columns.NewColumn("id", typing.Invalid), false, nil),
columns.NewWrapper(columns.NewColumn("another_id", typing.Invalid), false, nil)},
PrimaryKeys: []columns.Wrapper{
columns.NewWrapper(columns.NewColumn("id", typing.Invalid), false, constants.Snowflake),
columns.NewWrapper(columns.NewColumn("another_id", typing.Invalid), false, constants.Snowflake),
},
Columns: &_cols,
DestKind: constants.Snowflake,
SoftDelete: false,
Expand Down Expand Up @@ -248,12 +250,8 @@ func TestMergeStatementEscapePrimaryKeys(t *testing.T) {
SubQuery: subQuery,
IdempotentKey: "",
PrimaryKeys: []columns.Wrapper{
columns.NewWrapper(columns.NewColumn("id", typing.Invalid), false, &columns.NameArgs{
DestKind: constants.Snowflake,
}),
columns.NewWrapper(columns.NewColumn("group", typing.Invalid), false, &columns.NameArgs{
DestKind: constants.Snowflake,
}),
columns.NewWrapper(columns.NewColumn("id", typing.Invalid), false, constants.Snowflake),
columns.NewWrapper(columns.NewColumn("group", typing.Invalid), false, constants.Snowflake),
},
Columns: &_cols,
DestKind: constants.Snowflake,
Expand Down
2 changes: 1 addition & 1 deletion lib/destination/dml/merge_valid_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (

func TestMergeArgument_Valid(t *testing.T) {
primaryKeys := []columns.Wrapper{
columns.NewWrapper(columns.NewColumn("id", typing.Integer), false, nil),
columns.NewWrapper(columns.NewColumn("id", typing.Integer), false, constants.Snowflake),
}

var cols columns.Columns
Expand Down
4 changes: 2 additions & 2 deletions lib/optimization/table_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,10 @@ func (t *TableData) ContainOtherOperations() bool {
return t.containOtherOperations
}

func (t *TableData) PrimaryKeys(uppercaseEscNames bool, args *columns.NameArgs) []columns.Wrapper {
func (t *TableData) PrimaryKeys(uppercaseEscNames bool, destKind constants.DestinationKind) []columns.Wrapper {
var pks []columns.Wrapper
for _, pk := range t.primaryKeys {
pks = append(pks, columns.NewWrapper(columns.NewColumn(pk, typing.Invalid), uppercaseEscNames, args))
pks = append(pks, columns.NewWrapper(columns.NewColumn(pk, typing.Invalid), uppercaseEscNames, destKind))
}

return pks
Expand Down
18 changes: 9 additions & 9 deletions lib/typing/columns/columns.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,14 +88,10 @@ type NameArgs struct {
}

// Name will give you c.name
// However, if you pass in escape, we will escape if the column name is part of the reserved words from destinations.
// Plus we will escape it if the column name is part of the reserved words from destinations.
// If so, it'll change from `start` => `"start"` as suggested by Snowflake.
func (c *Column) Name(uppercaseEscNames bool, args *NameArgs) string {
// TODO: Kill [NameArgs] and just pass a [DestinationKind].
if args == nil {
return c.name
}
return sql.EscapeNameIfNecessary(c.name, uppercaseEscNames, args.DestKind)
func (c *Column) Name(uppercaseEscNames bool, destKind constants.DestinationKind) string {
return sql.EscapeNameIfNecessary(c.name, uppercaseEscNames, destKind)
}

type Columns struct {
Expand Down Expand Up @@ -198,7 +194,11 @@ func (c *Columns) GetColumnsToUpdate(uppercaseEscNames bool, args *NameArgs) []s
continue
}

cols = append(cols, col.Name(uppercaseEscNames, args))
if args == nil {
Copy link
Contributor Author

@nathan-artie nathan-artie Apr 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a stopgap, will look at passing a *constants.DestinationKind pointer to this method.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we create another function instead?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GetColumnsToUpdateRaw or something?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe, I'll look into it in the next PR, I don't want to grow the scope of this PR.

cols = append(cols, col.RawName())
} else {
cols = append(cols, col.Name(uppercaseEscNames, args.DestKind))
}
}

return cols
Expand Down Expand Up @@ -255,7 +255,7 @@ func (c *Columns) UpdateQuery(destKind constants.DestinationKind, uppercaseEscNa
continue
}

colName := column.Name(uppercaseEscNames, &NameArgs{DestKind: destKind})
colName := column.Name(uppercaseEscNames, destKind)
if column.ToastColumn {
if column.KindDetails == typing.Struct {
cols = append(cols, processToastStructCol(colName, destKind))
Expand Down
8 changes: 2 additions & 6 deletions lib/typing/columns/columns_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,12 +169,8 @@ func TestColumn_Name(t *testing.T) {

assert.Equal(t, testCase.expectedName, col.RawName(), testCase.colName)

assert.Equal(t, testCase.expectedNameEsc, col.Name(false, &NameArgs{
DestKind: constants.Snowflake,
}), testCase.colName)
assert.Equal(t, testCase.expectedNameEscBq, col.Name(false, &NameArgs{
DestKind: constants.BigQuery,
}), testCase.colName)
assert.Equal(t, testCase.expectedNameEsc, col.Name(false, constants.Snowflake), testCase.colName)
assert.Equal(t, testCase.expectedNameEscBq, col.Name(false, constants.BigQuery), testCase.colName)
}
}

Expand Down
6 changes: 4 additions & 2 deletions lib/typing/columns/wrapper.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
package columns

import "github.com/artie-labs/transfer/lib/config/constants"

type Wrapper struct {
name string
escapedName string
}

func NewWrapper(col Column, uppercaseEscNames bool, args *NameArgs) Wrapper {
func NewWrapper(col Column, uppercaseEscNames bool, destKind constants.DestinationKind) Wrapper {
return Wrapper{
name: col.name,
escapedName: col.Name(uppercaseEscNames, args),
escapedName: col.Name(uppercaseEscNames, destKind),
}
}

Expand Down
19 changes: 3 additions & 16 deletions lib/typing/columns/wrapper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,33 +41,20 @@ func TestWrapper_Complete(t *testing.T) {

for _, testCase := range testCases {
// Snowflake escape
w := NewWrapper(NewColumn(testCase.name, typing.Invalid), false, &NameArgs{
DestKind: constants.Snowflake,
})
w := NewWrapper(NewColumn(testCase.name, typing.Invalid), false, constants.Snowflake)

assert.Equal(t, testCase.expectedEscapedName, w.EscapedName(), testCase.name)
assert.Equal(t, testCase.expectedRawName, w.RawName(), testCase.name)

// BigQuery escape
w = NewWrapper(NewColumn(testCase.name, typing.Invalid), false, &NameArgs{
DestKind: constants.BigQuery,
})
w = NewWrapper(NewColumn(testCase.name, typing.Invalid), false, constants.BigQuery)

assert.Equal(t, testCase.expectedEscapedNameBQ, w.EscapedName(), testCase.name)
assert.Equal(t, testCase.expectedRawName, w.RawName(), testCase.name)

for _, destKind := range []constants.DestinationKind{constants.Snowflake, constants.BigQuery} {
w = NewWrapper(NewColumn(testCase.name, typing.Invalid), false, &NameArgs{
DestKind: destKind,
})
w = NewWrapper(NewColumn(testCase.name, typing.Invalid), false, destKind)
assert.Equal(t, testCase.expectedRawName, w.RawName(), testCase.name)
}

// Same if nil
w = NewWrapper(NewColumn(testCase.name, typing.Invalid), false, nil)

assert.Equal(t, testCase.expectedRawName, w.EscapedName(), testCase.name)
assert.Equal(t, testCase.expectedRawName, w.RawName(), testCase.name)

}
}