diff --git a/clients/shared/merge.go b/clients/shared/merge.go index 970fc3652..88c560efa 100644 --- a/clients/shared/merge.go +++ b/clients/shared/merge.go @@ -125,7 +125,7 @@ func Merge(dwh destination.DataWarehouse, tableData *optimization.TableData, cfg TableID: tableID, SubQuery: subQuery, IdempotentKey: tableData.TopicConfig().IdempotentKey, - PrimaryKeys: tableData.PrimaryKeys(dwh.ShouldUppercaseEscapedNames(), dwh.Label()), + PrimaryKeys: tableData.PrimaryKeys(dwh.Dialect()), Columns: tableData.ReadOnlyInMemoryCols(), SoftDelete: tableData.TopicConfig().SoftDelete, DestKind: dwh.Label(), diff --git a/clients/shared/utils.go b/clients/shared/utils.go index b3b236604..1382abf0e 100644 --- a/clients/shared/utils.go +++ b/clients/shared/utils.go @@ -30,7 +30,7 @@ func BackfillColumn(cfg config.Config, dwh destination.DataWarehouse, column col return fmt.Errorf("failed to escape default value: %w", err) } - escapedCol := column.Name(dwh.ShouldUppercaseEscapedNames(), dwh.Label()) + escapedCol := column.Name(dwh.Dialect()) // TODO: This is added because `default` is not technically a column that requires escaping, but it is required when it's in the where clause. // Once we escape everything by default, we can remove this patch of code. diff --git a/clients/snowflake/staging.go b/clients/snowflake/staging.go index 91e07e286..2c761ebf9 100644 --- a/clients/snowflake/staging.go +++ b/clients/snowflake/staging.go @@ -85,7 +85,7 @@ func (s *Store) PrepareTemporaryTable(tableData *optimization.TableData, tableCo // COPY the CSV file (in Snowflake) into a table copyCommand := fmt.Sprintf("COPY INTO %s (%s) FROM (SELECT %s FROM @%s)", tempTableID.FullyQualifiedName(), - strings.Join(tableData.ReadOnlyInMemoryCols().GetEscapedColumnsToUpdate(s.ShouldUppercaseEscapedNames(), s.Label()), ","), + strings.Join(tableData.ReadOnlyInMemoryCols().GetEscapedColumnsToUpdate(s.Dialect()), ","), escapeColumns(tableData.ReadOnlyInMemoryCols(), ","), addPrefixToTableName(tempTableID, "%")) if additionalSettings.AdditionalCopyClause != "" { diff --git a/lib/destination/ddl/ddl.go b/lib/destination/ddl/ddl.go index 723238ac6..ac2803fe6 100644 --- a/lib/destination/ddl/ddl.go +++ b/lib/destination/ddl/ddl.go @@ -104,7 +104,7 @@ func (a AlterTableArgs) AlterTable(cols ...columns.Column) error { mutateCol = append(mutateCol, col) switch a.ColumnOp { case constants.Add: - colName := col.Name(*a.UppercaseEscNames, a.Dwh.Label()) + colName := col.Name(a.Dwh.Dialect()) if col.PrimaryKey() && a.Mode != config.History { // Don't create a PK for history mode because it's append-only, so the primary key should not be enforced. @@ -113,7 +113,7 @@ func (a AlterTableArgs) AlterTable(cols ...columns.Column) error { colSQLParts = append(colSQLParts, fmt.Sprintf(`%s %s`, colName, typing.KindToDWHType(col.KindDetails, a.Dwh.Label(), col.PrimaryKey()))) case constants.Delete: - colSQLParts = append(colSQLParts, col.Name(*a.UppercaseEscNames, a.Dwh.Label())) + colSQLParts = append(colSQLParts, col.Name(a.Dwh.Dialect())) } } diff --git a/lib/destination/ddl/ddl_bq_test.go b/lib/destination/ddl/ddl_bq_test.go index 62c6a19bc..5985db3cb 100644 --- a/lib/destination/ddl/ddl_bq_test.go +++ b/lib/destination/ddl/ddl_bq_test.go @@ -90,7 +90,7 @@ func (d *DDLTestSuite) TestAlterTableDropColumnsBigQuery() { assert.NoError(d.T(), alterTableArgs.AlterTable(column)) query, _ := d.fakeBigQueryStore.ExecArgsForCall(callIdx) - assert.Equal(d.T(), fmt.Sprintf("ALTER TABLE %s drop COLUMN %s", fqName, column.Name(false, d.bigQueryStore.Label())), query) + assert.Equal(d.T(), fmt.Sprintf("ALTER TABLE %s drop COLUMN %s", fqName, column.Name(d.bigQueryStore.Dialect())), query) callIdx += 1 } @@ -148,7 +148,7 @@ func (d *DDLTestSuite) TestAlterTableAddColumns() { assert.NoError(d.T(), alterTableArgs.AlterTable(col)) query, _ := d.fakeBigQueryStore.ExecArgsForCall(callIdx) - assert.Equal(d.T(), fmt.Sprintf("ALTER TABLE %s %s COLUMN %s %s", fqName, constants.Add, col.Name(false, d.bigQueryStore.Label()), + assert.Equal(d.T(), fmt.Sprintf("ALTER TABLE %s %s COLUMN %s %s", fqName, constants.Add, col.Name(d.bigQueryStore.Dialect()), typing.KindToDWHType(kind, d.bigQueryStore.Label(), false)), query) callIdx += 1 } @@ -208,7 +208,7 @@ func (d *DDLTestSuite) TestAlterTableAddColumnsSomeAlreadyExist() { assert.NoError(d.T(), alterTableArgs.AlterTable(column)) query, _ := d.fakeBigQueryStore.ExecArgsForCall(callIdx) - assert.Equal(d.T(), fmt.Sprintf("ALTER TABLE %s %s COLUMN %s %s", fqName, constants.Add, column.Name(false, d.bigQueryStore.Label()), + assert.Equal(d.T(), fmt.Sprintf("ALTER TABLE %s %s COLUMN %s %s", fqName, constants.Add, column.Name(d.bigQueryStore.Dialect()), typing.KindToDWHType(column.KindDetails, d.bigQueryStore.Label(), false)), query) callIdx += 1 } diff --git a/lib/destination/ddl/ddl_sflk_test.go b/lib/destination/ddl/ddl_sflk_test.go index 2305929d2..de84b2dda 100644 --- a/lib/destination/ddl/ddl_sflk_test.go +++ b/lib/destination/ddl/ddl_sflk_test.go @@ -47,7 +47,7 @@ func (d *DDLTestSuite) TestAlterComplexObjects() { for i := 0; i < len(cols); i++ { execQuery, _ := d.fakeSnowflakeStagesStore.ExecArgsForCall(i) assert.Equal(d.T(), fmt.Sprintf("ALTER TABLE %s add COLUMN %s %s", `shop.public."COMPLEX_COLUMNS"`, - cols[i].Name(false, d.snowflakeStagesStore.Label()), + cols[i].Name(d.snowflakeStagesStore.Dialect()), typing.KindToDWHType(cols[i].KindDetails, d.snowflakeStagesStore.Label(), false)), execQuery) } @@ -180,7 +180,7 @@ func (d *DDLTestSuite) TestAlterTableDeleteDryRun() { execArg, _ := d.fakeSnowflakeStagesStore.ExecArgsForCall(i) assert.Equal(d.T(), execArg, fmt.Sprintf("ALTER TABLE %s %s COLUMN %s", `shop.public."USERS"`, constants.Delete, - cols[i].Name(false, d.snowflakeStagesStore.Label()))) + cols[i].Name(d.snowflakeStagesStore.Dialect()))) } } diff --git a/lib/destination/dml/merge.go b/lib/destination/dml/merge.go index 715f7ab5a..0c69a0d82 100644 --- a/lib/destination/dml/merge.go +++ b/lib/destination/dml/merge.go @@ -103,7 +103,7 @@ func (m *MergeArgument) GetParts() ([]string, error) { equalitySQLParts = append(equalitySQLParts, equalitySQL) } - cols := m.Columns.GetEscapedColumnsToUpdate(*m.UppercaseEscNames, m.DestKind) + cols := m.Columns.GetEscapedColumnsToUpdate(m.Dialect) if m.SoftDelete { return []string{ @@ -124,7 +124,7 @@ func (m *MergeArgument) GetParts() ([]string, error) { // UPDATE fmt.Sprintf(`UPDATE %s as c SET %s FROM %s as cc WHERE %s%s;`, // UPDATE table set col1 = cc. col1 - m.TableID.FullyQualifiedName(), m.Columns.UpdateQuery(m.DestKind, *m.UppercaseEscNames, false), + m.TableID.FullyQualifiedName(), m.Columns.UpdateQuery(m.DestKind, m.Dialect, false), // FROM table (temp) WHERE join on PK(s) m.SubQuery, strings.Join(equalitySQLParts, " and "), idempotentClause, ), @@ -168,7 +168,7 @@ func (m *MergeArgument) GetParts() ([]string, error) { // UPDATE fmt.Sprintf(`UPDATE %s as c SET %s FROM %s as cc WHERE %s%s AND COALESCE(cc.%s, false) = false;`, // UPDATE table set col1 = cc. col1 - m.TableID.FullyQualifiedName(), m.Columns.UpdateQuery(m.DestKind, *m.UppercaseEscNames, true), + m.TableID.FullyQualifiedName(), m.Columns.UpdateQuery(m.DestKind, m.Dialect, true), // FROM staging WHERE join on PK(s) m.SubQuery, strings.Join(equalitySQLParts, " and "), idempotentClause, constants.DeleteColumnMarker, ), @@ -235,7 +235,7 @@ func (m *MergeArgument) GetStatement() (string, error) { equalitySQLParts = append(equalitySQLParts, m.AdditionalEqualityStrings...) } - cols := m.Columns.GetEscapedColumnsToUpdate(*m.UppercaseEscNames, m.DestKind) + cols := m.Columns.GetEscapedColumnsToUpdate(m.Dialect) if m.SoftDelete { return fmt.Sprintf(` @@ -244,7 +244,7 @@ WHEN MATCHED %sTHEN UPDATE SET %s WHEN NOT MATCHED AND IFNULL(cc.%s, false) = false THEN INSERT (%s) VALUES (%s);`, m.TableID.FullyQualifiedName(), subQuery, strings.Join(equalitySQLParts, " and "), // Update + Soft Deletion - idempotentClause, m.Columns.UpdateQuery(m.DestKind, *m.UppercaseEscNames, false), + idempotentClause, m.Columns.UpdateQuery(m.DestKind, m.Dialect, false), // Insert constants.DeleteColumnMarker, strings.Join(cols, ","), array.StringsJoinAddPrefix(array.StringsJoinAddPrefixArgs{ @@ -277,7 +277,7 @@ WHEN NOT MATCHED AND IFNULL(cc.%s, false) = false THEN INSERT (%s) VALUES (%s);` // Delete constants.DeleteColumnMarker, // Update - constants.DeleteColumnMarker, idempotentClause, m.Columns.UpdateQuery(m.DestKind, *m.UppercaseEscNames, true), + constants.DeleteColumnMarker, idempotentClause, m.Columns.UpdateQuery(m.DestKind, m.Dialect, true), // Insert constants.DeleteColumnMarker, strings.Join(cols, ","), array.StringsJoinAddPrefix(array.StringsJoinAddPrefixArgs{ @@ -304,7 +304,7 @@ func (m *MergeArgument) GetMSSQLStatement() (string, error) { equalitySQLParts = append(equalitySQLParts, equalitySQL) } - cols := m.Columns.GetEscapedColumnsToUpdate(*m.UppercaseEscNames, m.DestKind) + cols := m.Columns.GetEscapedColumnsToUpdate(m.Dialect) if m.SoftDelete { return fmt.Sprintf(` @@ -314,7 +314,7 @@ WHEN MATCHED %sTHEN UPDATE SET %s WHEN NOT MATCHED AND COALESCE(cc.%s, 0) = 0 THEN INSERT (%s) VALUES (%s);`, m.TableID.FullyQualifiedName(), m.SubQuery, strings.Join(equalitySQLParts, " and "), // Update + Soft Deletion - idempotentClause, m.Columns.UpdateQuery(m.DestKind, *m.UppercaseEscNames, false), + idempotentClause, m.Columns.UpdateQuery(m.DestKind, m.Dialect, false), // Insert constants.DeleteColumnMarker, strings.Join(cols, ","), array.StringsJoinAddPrefix(array.StringsJoinAddPrefixArgs{ @@ -348,7 +348,7 @@ WHEN NOT MATCHED AND COALESCE(cc.%s, 1) = 0 THEN INSERT (%s) VALUES (%s);`, // Delete constants.DeleteColumnMarker, // Update - constants.DeleteColumnMarker, idempotentClause, m.Columns.UpdateQuery(m.DestKind, *m.UppercaseEscNames, true), + constants.DeleteColumnMarker, idempotentClause, m.Columns.UpdateQuery(m.DestKind, m.Dialect, true), // Insert constants.DeleteColumnMarker, strings.Join(cols, ","), array.StringsJoinAddPrefix(array.StringsJoinAddPrefixArgs{ diff --git a/lib/destination/dml/merge_bigquery_test.go b/lib/destination/dml/merge_bigquery_test.go index 38524abb9..0f383d35c 100644 --- a/lib/destination/dml/merge_bigquery_test.go +++ b/lib/destination/dml/merge_bigquery_test.go @@ -20,7 +20,7 @@ func TestMergeStatement_TempTable(t *testing.T) { mergeArg := &MergeArgument{ TableID: MockTableIdentifier{"customers.orders"}, SubQuery: "customers.orders_tmp", - PrimaryKeys: []columns.Wrapper{columns.NewWrapper(columns.NewColumn("order_id", typing.Invalid), false, constants.BigQuery)}, + PrimaryKeys: []columns.Wrapper{columns.NewWrapper(columns.NewColumn("order_id", typing.Invalid), sql.BigQueryDialect{})}, Columns: &cols, DestKind: constants.BigQuery, Dialect: sql.BigQueryDialect{}, @@ -43,7 +43,7 @@ func TestMergeStatement_JSONKey(t *testing.T) { mergeArg := &MergeArgument{ TableID: MockTableIdentifier{"customers.orders"}, SubQuery: "customers.orders_tmp", - PrimaryKeys: []columns.Wrapper{columns.NewWrapper(columns.NewColumn("order_oid", typing.Invalid), false, constants.BigQuery)}, + PrimaryKeys: []columns.Wrapper{columns.NewWrapper(columns.NewColumn("order_oid", typing.Invalid), sql.BigQueryDialect{})}, Columns: &cols, DestKind: constants.BigQuery, Dialect: sql.BigQueryDialect{}, diff --git a/lib/destination/dml/merge_mssql_test.go b/lib/destination/dml/merge_mssql_test.go index b64455140..163eab3a1 100644 --- a/lib/destination/dml/merge_mssql_test.go +++ b/lib/destination/dml/merge_mssql_test.go @@ -45,7 +45,7 @@ func Test_GetMSSQLStatement(t *testing.T) { TableID: MockTableIdentifier{fqTable}, SubQuery: subQuery, IdempotentKey: "", - PrimaryKeys: []columns.Wrapper{columns.NewWrapper(columns.NewColumn("id", typing.Invalid), false, constants.MSSQL)}, + PrimaryKeys: []columns.Wrapper{columns.NewWrapper(columns.NewColumn("id", typing.Invalid), sql.DefaultDialect{})}, Columns: &_cols, DestKind: constants.MSSQL, Dialect: sql.DefaultDialect{}, diff --git a/lib/destination/dml/merge_parts_test.go b/lib/destination/dml/merge_parts_test.go index d6be3ce5c..0be37e169 100644 --- a/lib/destination/dml/merge_parts_test.go +++ b/lib/destination/dml/merge_parts_test.go @@ -48,10 +48,10 @@ func getBasicColumnsForTest(compositeKey bool) result { cols.AddColumn(columns.NewColumn(constants.DeleteColumnMarker, typing.Boolean)) var pks []columns.Wrapper - pks = append(pks, columns.NewWrapper(idCol, false, constants.Redshift)) + pks = append(pks, columns.NewWrapper(idCol, sql.DefaultDialect{})) if compositeKey { - pks = append(pks, columns.NewWrapper(emailCol, false, constants.Redshift)) + pks = append(pks, columns.NewWrapper(emailCol, sql.DefaultDialect{})) } return result{ diff --git a/lib/destination/dml/merge_test.go b/lib/destination/dml/merge_test.go index 37ed53257..5ec119f58 100644 --- a/lib/destination/dml/merge_test.go +++ b/lib/destination/dml/merge_test.go @@ -57,15 +57,16 @@ func TestMergeStatementSoftDelete(t *testing.T) { _cols.AddColumn(columns.NewColumn("id", typing.String)) _cols.AddColumn(columns.NewColumn(constants.DeleteColumnMarker, typing.Boolean)) + dialect := sql.SnowflakeDialect{UppercaseEscNames: true} for _, idempotentKey := range []string{"", "updated_at"} { mergeArg := MergeArgument{ TableID: MockTableIdentifier{fqTable}, SubQuery: subQuery, IdempotentKey: idempotentKey, - PrimaryKeys: []columns.Wrapper{columns.NewWrapper(columns.NewColumn("id", typing.Invalid), true, constants.Snowflake)}, + PrimaryKeys: []columns.Wrapper{columns.NewWrapper(columns.NewColumn("id", typing.Invalid), dialect)}, Columns: &_cols, DestKind: constants.Snowflake, - Dialect: sql.SnowflakeDialect{UppercaseEscNames: true}, + Dialect: dialect, SoftDelete: true, UppercaseEscNames: ptr.ToBool(true), } @@ -108,14 +109,15 @@ func TestMergeStatement(t *testing.T) { subQuery := fmt.Sprintf("SELECT %s from (values %s) as %s(%s)", strings.Join(cols, ","), strings.Join(tableValues, ","), "_tbl", strings.Join(cols, ",")) + dialect := sql.SnowflakeDialect{UppercaseEscNames: true} mergeArg := MergeArgument{ TableID: MockTableIdentifier{fqTable}, SubQuery: subQuery, IdempotentKey: "", - PrimaryKeys: []columns.Wrapper{columns.NewWrapper(columns.NewColumn("id", typing.Invalid), true, constants.Snowflake)}, + PrimaryKeys: []columns.Wrapper{columns.NewWrapper(columns.NewColumn("id", typing.Invalid), dialect)}, Columns: &_cols, DestKind: constants.Snowflake, - Dialect: sql.SnowflakeDialect{UppercaseEscNames: true}, + Dialect: dialect, SoftDelete: false, UppercaseEscNames: ptr.ToBool(true), } @@ -157,14 +159,15 @@ func TestMergeStatementIdempotentKey(t *testing.T) { _cols.AddColumn(columns.NewColumn("id", typing.String)) _cols.AddColumn(columns.NewColumn(constants.DeleteColumnMarker, typing.Boolean)) + dialect := sql.SnowflakeDialect{UppercaseEscNames: true} mergeArg := MergeArgument{ TableID: MockTableIdentifier{fqTable}, SubQuery: subQuery, IdempotentKey: "updated_at", - PrimaryKeys: []columns.Wrapper{columns.NewWrapper(columns.NewColumn("id", typing.Invalid), true, constants.Snowflake)}, + PrimaryKeys: []columns.Wrapper{columns.NewWrapper(columns.NewColumn("id", typing.Invalid), dialect)}, Columns: &_cols, DestKind: constants.Snowflake, - Dialect: sql.SnowflakeDialect{UppercaseEscNames: true}, + Dialect: dialect, SoftDelete: false, UppercaseEscNames: ptr.ToBool(true), } @@ -200,17 +203,18 @@ func TestMergeStatementCompositeKey(t *testing.T) { _cols.AddColumn(columns.NewColumn("another_id", typing.String)) _cols.AddColumn(columns.NewColumn(constants.DeleteColumnMarker, typing.Boolean)) + dialect := sql.SnowflakeDialect{UppercaseEscNames: true} mergeArg := MergeArgument{ TableID: MockTableIdentifier{fqTable}, SubQuery: subQuery, IdempotentKey: "updated_at", PrimaryKeys: []columns.Wrapper{ - columns.NewWrapper(columns.NewColumn("id", typing.Invalid), true, constants.Snowflake), - columns.NewWrapper(columns.NewColumn("another_id", typing.Invalid), true, constants.Snowflake), + columns.NewWrapper(columns.NewColumn("id", typing.Invalid), dialect), + columns.NewWrapper(columns.NewColumn("another_id", typing.Invalid), dialect), }, Columns: &_cols, DestKind: constants.Snowflake, - Dialect: sql.SnowflakeDialect{UppercaseEscNames: true}, + Dialect: dialect, SoftDelete: false, UppercaseEscNames: ptr.ToBool(true), } @@ -250,17 +254,18 @@ func TestMergeStatementEscapePrimaryKeys(t *testing.T) { subQuery := fmt.Sprintf("SELECT %s from (values %s) as %s(%s)", strings.Join(cols, ","), strings.Join(tableValues, ","), "_tbl", strings.Join(cols, ",")) + dialect := sql.SnowflakeDialect{UppercaseEscNames: true} mergeArg := MergeArgument{ TableID: MockTableIdentifier{fqTable}, SubQuery: subQuery, IdempotentKey: "", PrimaryKeys: []columns.Wrapper{ - columns.NewWrapper(columns.NewColumn("id", typing.Invalid), true, constants.Snowflake), - columns.NewWrapper(columns.NewColumn("group", typing.Invalid), true, constants.Snowflake), + columns.NewWrapper(columns.NewColumn("id", typing.Invalid), dialect), + columns.NewWrapper(columns.NewColumn("group", typing.Invalid), dialect), }, Columns: &_cols, DestKind: constants.Snowflake, - Dialect: sql.SnowflakeDialect{UppercaseEscNames: true}, + Dialect: dialect, SoftDelete: false, UppercaseEscNames: ptr.ToBool(true), } diff --git a/lib/destination/dml/merge_valid_test.go b/lib/destination/dml/merge_valid_test.go index 93760e089..9a3e109a0 100644 --- a/lib/destination/dml/merge_valid_test.go +++ b/lib/destination/dml/merge_valid_test.go @@ -15,7 +15,7 @@ import ( func TestMergeArgument_Valid(t *testing.T) { primaryKeys := []columns.Wrapper{ - columns.NewWrapper(columns.NewColumn("id", typing.Integer), true, constants.Snowflake), + columns.NewWrapper(columns.NewColumn("id", typing.Integer), sql.SnowflakeDialect{UppercaseEscNames: true}), } var cols columns.Columns diff --git a/lib/optimization/table_data.go b/lib/optimization/table_data.go index e47839528..39402b00e 100644 --- a/lib/optimization/table_data.go +++ b/lib/optimization/table_data.go @@ -10,6 +10,7 @@ import ( "github.com/artie-labs/transfer/lib/config/constants" "github.com/artie-labs/transfer/lib/kafkalib" "github.com/artie-labs/transfer/lib/size" + "github.com/artie-labs/transfer/lib/sql" "github.com/artie-labs/transfer/lib/stringutil" "github.com/artie-labs/transfer/lib/typing" "github.com/artie-labs/transfer/lib/typing/columns" @@ -66,10 +67,10 @@ func (t *TableData) ContainOtherOperations() bool { return t.containOtherOperations } -func (t *TableData) PrimaryKeys(uppercaseEscNames bool, destKind constants.DestinationKind) []columns.Wrapper { +func (t *TableData) PrimaryKeys(dialect sql.Dialect) []columns.Wrapper { var pks []columns.Wrapper for _, pk := range t.primaryKeys { - pks = append(pks, columns.NewWrapper(columns.NewColumn(pk, typing.Invalid), uppercaseEscNames, destKind)) + pks = append(pks, columns.NewWrapper(columns.NewColumn(pk, typing.Invalid), dialect)) } return pks diff --git a/lib/sql/escape.go b/lib/sql/escape.go index 8ba1bf847..60e330f41 100644 --- a/lib/sql/escape.go +++ b/lib/sql/escape.go @@ -1,35 +1,8 @@ package sql -import ( - "github.com/artie-labs/transfer/lib/config/constants" -) - -func EscapeNameIfNecessary(name string, uppercaseEscNames bool, destKind constants.DestinationKind) string { - // TODO: Switch all calls of [EscapeNameIfNecessary] to [EscapeNameIfNecessaryUsingDialect] and kill this. - var dialect = dialectFor(destKind, uppercaseEscNames) - - if destKind != constants.S3 && dialect.NeedsEscaping(name) { - return dialect.QuoteIdentifier(name) - } - return name -} - func EscapeNameIfNecessaryUsingDialect(name string, dialect Dialect) string { if dialect.NeedsEscaping(name) { return dialect.QuoteIdentifier(name) } return name } - -func dialectFor(destKind constants.DestinationKind, uppercaseEscNames bool) Dialect { - switch destKind { - case constants.BigQuery: - return BigQueryDialect{} - case constants.Snowflake: - return SnowflakeDialect{UppercaseEscNames: uppercaseEscNames} - case constants.Redshift: - return RedshiftDialect{} - default: - return DefaultDialect{} - } -} diff --git a/lib/sql/escape_test.go b/lib/sql/escape_test.go deleted file mode 100644 index 1d5a83911..000000000 --- a/lib/sql/escape_test.go +++ /dev/null @@ -1,99 +0,0 @@ -package sql - -import ( - "testing" - - "github.com/artie-labs/transfer/lib/config/constants" - "github.com/stretchr/testify/assert" -) - -func TestEscapeNameIfNecessary(t *testing.T) { - type _testCase struct { - name string - nameToEscape string - destKind constants.DestinationKind - expectedName string - expectedNameWhenUpperCfg string - } - - testCases := []_testCase{ - { - name: "snowflake", - destKind: constants.Snowflake, - nameToEscape: "order", - expectedName: `"order"`, - expectedNameWhenUpperCfg: `"ORDER"`, - }, - { - name: "snowflake #2", - destKind: constants.Snowflake, - nameToEscape: "hello", - expectedName: `hello`, - expectedNameWhenUpperCfg: `"HELLO"`, - }, - { - name: "redshift", - destKind: constants.Redshift, - nameToEscape: "order", - expectedName: `"order"`, - expectedNameWhenUpperCfg: `"order"`, - }, - { - name: "redshift #2", - destKind: constants.Redshift, - nameToEscape: "hello", - expectedName: `"hello"`, - expectedNameWhenUpperCfg: `"hello"`, - }, - { - name: "bigquery", - destKind: constants.BigQuery, - nameToEscape: "order", - expectedName: "`order`", - expectedNameWhenUpperCfg: "`order`", - }, - { - name: "bigquery, #2", - destKind: constants.BigQuery, - nameToEscape: "hello", - expectedName: "`hello`", - expectedNameWhenUpperCfg: "`hello`", - }, - { - name: "redshift, #1 (delta)", - destKind: constants.Redshift, - nameToEscape: "delta", - expectedName: `"delta"`, - expectedNameWhenUpperCfg: `"delta"`, - }, - { - name: "snowflake, #1 (delta)", - destKind: constants.Snowflake, - nameToEscape: "delta", - expectedName: `delta`, - expectedNameWhenUpperCfg: `"DELTA"`, - }, - { - name: "redshift, symbols", - destKind: constants.Redshift, - nameToEscape: "receivedat:__", - expectedName: `"receivedat:__"`, - expectedNameWhenUpperCfg: `"receivedat:__"`, - }, - { - name: "redshift, numbers", - destKind: constants.Redshift, - nameToEscape: "0", - expectedName: `"0"`, - expectedNameWhenUpperCfg: `"0"`, - }, - } - - for _, testCase := range testCases { - actualName := EscapeNameIfNecessary(testCase.nameToEscape, false, testCase.destKind) - assert.Equal(t, testCase.expectedName, actualName, testCase.name) - - actualUpperName := EscapeNameIfNecessary(testCase.nameToEscape, true, testCase.destKind) - assert.Equal(t, testCase.expectedNameWhenUpperCfg, actualUpperName, testCase.name) - } -} diff --git a/lib/typing/columns/columns.go b/lib/typing/columns/columns.go index 2ed85ff2f..78fad212f 100644 --- a/lib/typing/columns/columns.go +++ b/lib/typing/columns/columns.go @@ -83,11 +83,11 @@ func (c *Column) RawName() string { return c.name } -// Name will give you c.name +// Name will give you c.name and escape it if necessary. // Plus we will escape it if the column name is part of the reserved words from destinations. // If so, it'll change from `start` => `"start"` as suggested by Snowflake. -func (c *Column) Name(uppercaseEscNames bool, destKind constants.DestinationKind) string { - return sql.EscapeNameIfNecessary(c.name, uppercaseEscNames, destKind) +func (c *Column) Name(dialect sql.Dialect) string { + return sql.EscapeNameIfNecessaryUsingDialect(c.name, dialect) } type Columns struct { @@ -198,7 +198,7 @@ func (c *Columns) GetColumnsToUpdate() []string { // GetEscapedColumnsToUpdate will filter all the `Invalid` columns so that we do not update it. // It will escape the returned columns. -func (c *Columns) GetEscapedColumnsToUpdate(uppercaseEscNames bool, destKind constants.DestinationKind) []string { +func (c *Columns) GetEscapedColumnsToUpdate(dialect sql.Dialect) []string { if c == nil { return []string{} } @@ -212,7 +212,7 @@ func (c *Columns) GetEscapedColumnsToUpdate(uppercaseEscNames bool, destKind con continue } - cols = append(cols, col.Name(uppercaseEscNames, destKind)) + cols = append(cols, col.Name(dialect)) } return cols @@ -257,7 +257,7 @@ func (c *Columns) DeleteColumn(name string) { } // UpdateQuery will parse the columns and then returns a list of strings like: cc.first_name=c.first_name,cc.last_name=c.last_name,cc.email=c.email -func (c *Columns) UpdateQuery(destKind constants.DestinationKind, uppercaseEscNames bool, skipDeleteCol bool) string { +func (c *Columns) UpdateQuery(destKind constants.DestinationKind, dialect sql.Dialect, skipDeleteCol bool) string { var cols []string for _, column := range c.GetColumns() { if column.ShouldSkip() { @@ -269,7 +269,7 @@ func (c *Columns) UpdateQuery(destKind constants.DestinationKind, uppercaseEscNa continue } - colName := column.Name(uppercaseEscNames, destKind) + colName := column.Name(dialect) if column.ToastColumn { if column.KindDetails == typing.Struct { cols = append(cols, processToastStructCol(colName, destKind)) diff --git a/lib/typing/columns/columns_test.go b/lib/typing/columns/columns_test.go index 98af9f113..a4d162687 100644 --- a/lib/typing/columns/columns_test.go +++ b/lib/typing/columns/columns_test.go @@ -6,6 +6,7 @@ import ( "github.com/artie-labs/transfer/lib/config/constants" "github.com/artie-labs/transfer/lib/ptr" + "github.com/artie-labs/transfer/lib/sql" "github.com/artie-labs/transfer/lib/typing" "github.com/stretchr/testify/assert" ) @@ -169,8 +170,8 @@ func TestColumn_Name(t *testing.T) { assert.Equal(t, testCase.expectedName, col.RawName(), testCase.colName) - assert.Equal(t, testCase.expectedNameEsc, col.Name(true, constants.Snowflake), testCase.colName) - assert.Equal(t, testCase.expectedNameEscBq, col.Name(false, constants.BigQuery), testCase.colName) + assert.Equal(t, testCase.expectedNameEsc, col.Name(sql.SnowflakeDialect{UppercaseEscNames: true}), testCase.colName) + assert.Equal(t, testCase.expectedNameEscBq, col.Name(sql.BigQueryDialect{}), testCase.colName) } } @@ -281,8 +282,8 @@ func TestColumns_GetEscapedColumnsToUpdate(t *testing.T) { columns: testCase.cols, } - assert.Equal(t, testCase.expectedColsEsc, cols.GetEscapedColumnsToUpdate(true, constants.Snowflake), testCase.name) - assert.Equal(t, testCase.expectedColsEscBq, cols.GetEscapedColumnsToUpdate(false, constants.BigQuery), testCase.name) + assert.Equal(t, testCase.expectedColsEsc, cols.GetEscapedColumnsToUpdate(sql.SnowflakeDialect{UppercaseEscNames: true}), testCase.name) + assert.Equal(t, testCase.expectedColsEscBq, cols.GetEscapedColumnsToUpdate(sql.BigQueryDialect{}), testCase.name) } } @@ -397,6 +398,7 @@ func TestColumnsUpdateQuery(t *testing.T) { columns Columns expectedString string destKind constants.DestinationKind + dialect sql.Dialect skipDeleteCol bool } @@ -480,30 +482,35 @@ func TestColumnsUpdateQuery(t *testing.T) { name: "happy path", columns: happyPathCols, destKind: constants.Redshift, + dialect: sql.RedshiftDialect{}, expectedString: `"foo"=cc."foo","bar"=cc."bar"`, }, { name: "string and toast", columns: stringAndToastCols, destKind: constants.Snowflake, + dialect: sql.SnowflakeDialect{UppercaseEscNames: true}, expectedString: `"FOO"= CASE WHEN COALESCE(cc."FOO" != '__debezium_unavailable_value', true) THEN cc."FOO" ELSE c."FOO" END,"BAR"=cc."BAR"`, }, { name: "struct, string and toast string", columns: lastCaseColTypes, destKind: constants.Redshift, + dialect: sql.RedshiftDialect{}, expectedString: `"a1"= CASE WHEN COALESCE(cc."a1" != JSON_PARSE('{"key":"__debezium_unavailable_value"}'), true) THEN cc."a1" ELSE c."a1" END,"b2"= CASE WHEN COALESCE(cc."b2" != '__debezium_unavailable_value', true) THEN cc."b2" ELSE c."b2" END,"c3"=cc."c3"`, }, { name: "struct, string and toast string (bigquery)", columns: lastCaseColTypes, destKind: constants.BigQuery, + dialect: sql.BigQueryDialect{}, expectedString: "`a1`= CASE WHEN COALESCE(TO_JSON_STRING(cc.`a1`) != '{\"key\":\"__debezium_unavailable_value\"}', true) THEN cc.`a1` ELSE c.`a1` END,`b2`= CASE WHEN COALESCE(cc.`b2` != '__debezium_unavailable_value', true) THEN cc.`b2` ELSE c.`b2` END,`c3`=cc.`c3`", }, { name: "struct, string and toast string (bigquery) w/ reserved keywords", columns: lastCaseEscapeTypes, destKind: constants.BigQuery, + dialect: sql.BigQueryDialect{}, expectedString: fmt.Sprintf("`a1`= CASE WHEN COALESCE(TO_JSON_STRING(cc.`a1`) != '%s', true) THEN cc.`a1` ELSE c.`a1` END,`b2`= CASE WHEN COALESCE(cc.`b2` != '__debezium_unavailable_value', true) THEN cc.`b2` ELSE c.`b2` END,`c3`=cc.`c3`,%s,%s", key, fmt.Sprintf("`start`= CASE WHEN COALESCE(TO_JSON_STRING(cc.`start`) != '%s', true) THEN cc.`start` ELSE c.`start` END", key), "`select`=cc.`select`"), skipDeleteCol: true, @@ -512,6 +519,7 @@ func TestColumnsUpdateQuery(t *testing.T) { name: "struct, string and toast string (bigquery) w/ reserved keywords", columns: lastCaseEscapeTypes, destKind: constants.BigQuery, + dialect: sql.BigQueryDialect{}, expectedString: fmt.Sprintf("`a1`= CASE WHEN COALESCE(TO_JSON_STRING(cc.`a1`) != '%s', true) THEN cc.`a1` ELSE c.`a1` END,`b2`= CASE WHEN COALESCE(cc.`b2` != '__debezium_unavailable_value', true) THEN cc.`b2` ELSE c.`b2` END,`c3`=cc.`c3`,%s,%s", key, fmt.Sprintf("`start`= CASE WHEN COALESCE(TO_JSON_STRING(cc.`start`) != '%s', true) THEN cc.`start` ELSE c.`start` END", key), "`select`=cc.`select`,`__artie_delete`=cc.`__artie_delete`"), skipDeleteCol: false, @@ -519,7 +527,7 @@ func TestColumnsUpdateQuery(t *testing.T) { } for _, _testCase := range testCases { - actualQuery := _testCase.columns.UpdateQuery(_testCase.destKind, _testCase.destKind == constants.Snowflake, _testCase.skipDeleteCol) + actualQuery := _testCase.columns.UpdateQuery(_testCase.destKind, _testCase.dialect, _testCase.skipDeleteCol) assert.Equal(t, _testCase.expectedString, actualQuery, _testCase.name) } } diff --git a/lib/typing/columns/wrapper.go b/lib/typing/columns/wrapper.go index cb5f37643..2d79d4845 100644 --- a/lib/typing/columns/wrapper.go +++ b/lib/typing/columns/wrapper.go @@ -1,16 +1,18 @@ package columns -import "github.com/artie-labs/transfer/lib/config/constants" +import ( + "github.com/artie-labs/transfer/lib/sql" +) type Wrapper struct { name string escapedName string } -func NewWrapper(col Column, uppercaseEscNames bool, destKind constants.DestinationKind) Wrapper { +func NewWrapper(col Column, dialect sql.Dialect) Wrapper { return Wrapper{ name: col.name, - escapedName: col.Name(uppercaseEscNames, destKind), + escapedName: col.Name(dialect), } } diff --git a/lib/typing/columns/wrapper_test.go b/lib/typing/columns/wrapper_test.go index 38da079fb..8f7e1f317 100644 --- a/lib/typing/columns/wrapper_test.go +++ b/lib/typing/columns/wrapper_test.go @@ -5,7 +5,7 @@ import ( "github.com/stretchr/testify/assert" - "github.com/artie-labs/transfer/lib/config/constants" + "github.com/artie-labs/transfer/lib/sql" "github.com/artie-labs/transfer/lib/typing" ) @@ -41,20 +41,25 @@ func TestWrapper_Complete(t *testing.T) { for _, testCase := range testCases { // Snowflake escape - w := NewWrapper(NewColumn(testCase.name, typing.Invalid), true, constants.Snowflake) + w := NewWrapper(NewColumn(testCase.name, typing.Invalid), sql.SnowflakeDialect{UppercaseEscNames: true}) assert.Equal(t, testCase.expectedEscapedName, w.EscapedName(), testCase.name) assert.Equal(t, testCase.expectedRawName, w.RawName(), testCase.name) // BigQuery escape - w = NewWrapper(NewColumn(testCase.name, typing.Invalid), false, constants.BigQuery) + w = NewWrapper(NewColumn(testCase.name, typing.Invalid), sql.BigQueryDialect{}) assert.Equal(t, testCase.expectedEscapedNameBQ, w.EscapedName(), testCase.name) assert.Equal(t, testCase.expectedRawName, w.RawName(), testCase.name) - for _, destKind := range []constants.DestinationKind{constants.Snowflake, constants.BigQuery} { - w = NewWrapper(NewColumn(testCase.name, typing.Invalid), false, destKind) + { + w = NewWrapper(NewColumn(testCase.name, typing.Invalid), sql.SnowflakeDialect{UppercaseEscNames: true}) assert.Equal(t, testCase.expectedRawName, w.RawName(), testCase.name) } + { + w = NewWrapper(NewColumn(testCase.name, typing.Invalid), sql.BigQueryDialect{}) + assert.Equal(t, testCase.expectedRawName, w.RawName(), testCase.name) + } + } }