diff --git a/clients/shared/utils.go b/clients/shared/utils.go index 64f9faa82..45e971ccf 100644 --- a/clients/shared/utils.go +++ b/clients/shared/utils.go @@ -30,7 +30,7 @@ func BackfillColumn(cfg config.Config, dwh destination.DataWarehouse, column col return fmt.Errorf("failed to escape default value: %w", err) } - escapedCol := column.EscapedName(dwh.Dialect()) + escapedCol := dwh.Dialect().QuoteIdentifier(column.Name()) // TODO: This is added because `default` is not technically a column that requires escaping, but it is required when it's in the where clause. // Once we escape everything by default, we can remove this patch of code. diff --git a/lib/destination/ddl/ddl.go b/lib/destination/ddl/ddl.go index c8f0b864b..675dc5941 100644 --- a/lib/destination/ddl/ddl.go +++ b/lib/destination/ddl/ddl.go @@ -99,7 +99,7 @@ func (a AlterTableArgs) AlterTable(cols ...columns.Column) error { mutateCol = append(mutateCol, col) switch a.ColumnOp { case constants.Add: - colName := col.EscapedName(a.Dwh.Dialect()) + colName := a.Dwh.Dialect().QuoteIdentifier(col.Name()) if col.PrimaryKey() && a.Mode != config.History { // Don't create a PK for history mode because it's append-only, so the primary key should not be enforced. @@ -108,7 +108,7 @@ func (a AlterTableArgs) AlterTable(cols ...columns.Column) error { colSQLParts = append(colSQLParts, fmt.Sprintf(`%s %s`, colName, typing.KindToDWHType(col.KindDetails, a.Dwh.Label(), col.PrimaryKey()))) case constants.Delete: - colSQLParts = append(colSQLParts, col.EscapedName(a.Dwh.Dialect())) + colSQLParts = append(colSQLParts, a.Dwh.Dialect().QuoteIdentifier(col.Name())) } } diff --git a/lib/destination/ddl/ddl_bq_test.go b/lib/destination/ddl/ddl_bq_test.go index 123fcc86c..babea1977 100644 --- a/lib/destination/ddl/ddl_bq_test.go +++ b/lib/destination/ddl/ddl_bq_test.go @@ -86,7 +86,7 @@ func (d *DDLTestSuite) TestAlterTableDropColumnsBigQuery() { assert.NoError(d.T(), alterTableArgs.AlterTable(column)) query, _ := d.fakeBigQueryStore.ExecArgsForCall(callIdx) - assert.Equal(d.T(), fmt.Sprintf("ALTER TABLE %s drop COLUMN %s", fqName, column.EscapedName(d.bigQueryStore.Dialect())), query) + assert.Equal(d.T(), fmt.Sprintf("ALTER TABLE %s drop COLUMN %s", fqName, d.bigQueryStore.Dialect().QuoteIdentifier(column.Name())), query) callIdx += 1 } @@ -143,7 +143,7 @@ func (d *DDLTestSuite) TestAlterTableAddColumns() { assert.NoError(d.T(), alterTableArgs.AlterTable(col)) query, _ := d.fakeBigQueryStore.ExecArgsForCall(callIdx) - assert.Equal(d.T(), fmt.Sprintf("ALTER TABLE %s %s COLUMN %s %s", fqName, constants.Add, col.EscapedName(d.bigQueryStore.Dialect()), + assert.Equal(d.T(), fmt.Sprintf("ALTER TABLE %s %s COLUMN %s %s", fqName, constants.Add, d.bigQueryStore.Dialect().QuoteIdentifier(col.Name()), typing.KindToDWHType(kind, d.bigQueryStore.Label(), false)), query) callIdx += 1 } @@ -202,7 +202,7 @@ func (d *DDLTestSuite) TestAlterTableAddColumnsSomeAlreadyExist() { assert.NoError(d.T(), alterTableArgs.AlterTable(column)) query, _ := d.fakeBigQueryStore.ExecArgsForCall(callIdx) - assert.Equal(d.T(), fmt.Sprintf("ALTER TABLE %s %s COLUMN %s %s", fqName, constants.Add, column.EscapedName(d.bigQueryStore.Dialect()), + assert.Equal(d.T(), fmt.Sprintf("ALTER TABLE %s %s COLUMN %s %s", fqName, constants.Add, d.bigQueryStore.Dialect().QuoteIdentifier(column.Name()), typing.KindToDWHType(column.KindDetails, d.bigQueryStore.Label(), false)), query) callIdx += 1 } diff --git a/lib/destination/ddl/ddl_sflk_test.go b/lib/destination/ddl/ddl_sflk_test.go index 352efba2a..deae9dfd8 100644 --- a/lib/destination/ddl/ddl_sflk_test.go +++ b/lib/destination/ddl/ddl_sflk_test.go @@ -42,7 +42,7 @@ func (d *DDLTestSuite) TestAlterComplexObjects() { for i := 0; i < len(cols); i++ { execQuery, _ := d.fakeSnowflakeStagesStore.ExecArgsForCall(i) assert.Equal(d.T(), fmt.Sprintf("ALTER TABLE %s add COLUMN %s %s", `shop.public."COMPLEX_COLUMNS"`, - cols[i].EscapedName(d.snowflakeStagesStore.Dialect()), + d.snowflakeStagesStore.Dialect().QuoteIdentifier(cols[i].Name()), typing.KindToDWHType(cols[i].KindDetails, d.snowflakeStagesStore.Label(), false)), execQuery) } @@ -172,7 +172,8 @@ func (d *DDLTestSuite) TestAlterTableDeleteDryRun() { execArg, _ := d.fakeSnowflakeStagesStore.ExecArgsForCall(i) assert.Equal(d.T(), execArg, fmt.Sprintf("ALTER TABLE %s %s COLUMN %s", `shop.public."USERS"`, constants.Delete, - cols[i].EscapedName(d.snowflakeStagesStore.Dialect()))) + d.snowflakeStagesStore.Dialect().QuoteIdentifier(cols[i].Name()), + )) } } diff --git a/lib/destination/dml/merge.go b/lib/destination/dml/merge.go index e40f1d890..2c5b67eb2 100644 --- a/lib/destination/dml/merge.go +++ b/lib/destination/dml/merge.go @@ -101,7 +101,8 @@ func (m *MergeArgument) GetParts() ([]string, error) { var equalitySQLParts []string for _, primaryKey := range m.PrimaryKeys { // We'll need to escape the primary key as well. - equalitySQL := fmt.Sprintf("c.%s = cc.%s", primaryKey.EscapedName(m.Dialect), primaryKey.EscapedName(m.Dialect)) + quotedPrimaryKey := m.Dialect.QuoteIdentifier(primaryKey.Name()) + equalitySQL := fmt.Sprintf("c.%s = cc.%s", quotedPrimaryKey, quotedPrimaryKey) equalitySQLParts = append(equalitySQLParts, equalitySQL) } @@ -122,7 +123,8 @@ func (m *MergeArgument) GetParts() ([]string, error) { // LEFT JOIN table on pk(s) m.TableID.FullyQualifiedName(), strings.Join(equalitySQLParts, " and "), // Where PK is NULL (we only need to specify one primary key since it's covered with equalitySQL parts) - m.PrimaryKeys[0].EscapedName(m.Dialect)), + m.Dialect.QuoteIdentifier(m.PrimaryKeys[0].Name()), + ), // UPDATE fmt.Sprintf(`UPDATE %s as c SET %s FROM %s as cc WHERE %s%s;`, // UPDATE table set col1 = cc. col1 @@ -142,7 +144,7 @@ func (m *MergeArgument) GetParts() ([]string, error) { var pks []string for _, pk := range m.PrimaryKeys { - pks = append(pks, pk.EscapedName(m.Dialect)) + pks = append(pks, m.Dialect.QuoteIdentifier(pk.Name())) } parts := []string{ @@ -159,7 +161,8 @@ func (m *MergeArgument) GetParts() ([]string, error) { // LEFT JOIN table on pk(s) m.TableID.FullyQualifiedName(), strings.Join(equalitySQLParts, " and "), // Where PK is NULL (we only need to specify one primary key since it's covered with equalitySQL parts) - m.PrimaryKeys[0].EscapedName(m.Dialect)), + m.Dialect.QuoteIdentifier(m.PrimaryKeys[0].Name()), + ), // UPDATE fmt.Sprintf(`UPDATE %s as c SET %s FROM %s as cc WHERE %s%s AND COALESCE(cc.%s, false) = false;`, // UPDATE table set col1 = cc. col1 @@ -207,7 +210,9 @@ func (m *MergeArgument) GetStatement() (string, error) { var equalitySQLParts []string for _, primaryKey := range m.PrimaryKeys { // We'll need to escape the primary key as well. - equalitySQL := fmt.Sprintf("c.%s = cc.%s", primaryKey.EscapedName(m.Dialect), primaryKey.EscapedName(m.Dialect)) + quotedPrimaryKey := m.Dialect.QuoteIdentifier(primaryKey.Name()) + + equalitySQL := fmt.Sprintf("c.%s = cc.%s", quotedPrimaryKey, quotedPrimaryKey) pkCol, isOk := m.Columns.GetColumn(primaryKey.Name()) if !isOk { return "", fmt.Errorf("column: %s does not exist in columnToType: %v", primaryKey.Name(), m.Columns) @@ -215,7 +220,7 @@ func (m *MergeArgument) GetStatement() (string, error) { if m.DestKind == constants.BigQuery && pkCol.KindDetails.Kind == typing.Struct.Kind { // BigQuery requires special casting to compare two JSON objects. - equalitySQL = fmt.Sprintf("TO_JSON_STRING(c.%s) = TO_JSON_STRING(cc.%s)", primaryKey.EscapedName(m.Dialect), primaryKey.EscapedName(m.Dialect)) + equalitySQL = fmt.Sprintf("TO_JSON_STRING(c.%s) = TO_JSON_STRING(cc.%s)", quotedPrimaryKey, quotedPrimaryKey) } equalitySQLParts = append(equalitySQLParts, equalitySQL) @@ -288,7 +293,8 @@ func (m *MergeArgument) GetMSSQLStatement() (string, error) { var equalitySQLParts []string for _, primaryKey := range m.PrimaryKeys { // We'll need to escape the primary key as well. - equalitySQL := fmt.Sprintf("c.%s = cc.%s", primaryKey.EscapedName(m.Dialect), primaryKey.EscapedName(m.Dialect)) + quotedPrimaryKey := m.Dialect.QuoteIdentifier(primaryKey.Name()) + equalitySQL := fmt.Sprintf("c.%s = cc.%s", quotedPrimaryKey, quotedPrimaryKey) equalitySQLParts = append(equalitySQLParts, equalitySQL) } diff --git a/lib/typing/columns/columns.go b/lib/typing/columns/columns.go index c7c71053a..19ffef0c3 100644 --- a/lib/typing/columns/columns.go +++ b/lib/typing/columns/columns.go @@ -83,10 +83,6 @@ func (c *Column) Name() string { return c.name } -func (c *Column) EscapedName(dialect sql.Dialect) string { - return dialect.QuoteIdentifier(c.name) -} - type Columns struct { columns []Column sync.RWMutex @@ -244,7 +240,7 @@ func (c *Columns) UpdateQuery(dialect sql.Dialect, skipDeleteCol bool) string { continue } - colName := column.EscapedName(dialect) + colName := dialect.QuoteIdentifier(column.Name()) if column.ToastColumn { if column.KindDetails == typing.Struct { cols = append(cols, processToastStructCol(colName, dialect)) diff --git a/lib/typing/columns/columns_test.go b/lib/typing/columns/columns_test.go index 525ee12e3..3d7b0e58e 100644 --- a/lib/typing/columns/columns_test.go +++ b/lib/typing/columns/columns_test.go @@ -163,6 +163,8 @@ func TestColumn_Name(t *testing.T) { }, } + snowflakeDialect := sql.SnowflakeDialect{} + bqDialect := sql.BigQueryDialect{} for _, testCase := range testCases { col := &Column{ name: testCase.colName, @@ -170,8 +172,8 @@ func TestColumn_Name(t *testing.T) { assert.Equal(t, testCase.expectedName, col.Name(), testCase.colName) - assert.Equal(t, testCase.expectedNameEsc, col.EscapedName(sql.SnowflakeDialect{}), testCase.colName) - assert.Equal(t, testCase.expectedNameEscBq, col.EscapedName(sql.BigQueryDialect{}), testCase.colName) + assert.Equal(t, testCase.expectedNameEsc, snowflakeDialect.QuoteIdentifier(col.Name()), testCase.colName) + assert.Equal(t, testCase.expectedNameEscBq, bqDialect.QuoteIdentifier(col.Name()), testCase.colName) } }