Skip to content

Commit

Permalink
Inline
Browse files Browse the repository at this point in the history
  • Loading branch information
nathan-artie committed May 2, 2024
1 parent 11a602e commit 2f5638e
Show file tree
Hide file tree
Showing 7 changed files with 27 additions and 22 deletions.
2 changes: 1 addition & 1 deletion clients/shared/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func BackfillColumn(cfg config.Config, dwh destination.DataWarehouse, column col
return fmt.Errorf("failed to escape default value: %w", err)
}

escapedCol := column.EscapedName(dwh.Dialect())
escapedCol := dwh.Dialect().QuoteIdentifier(column.Name())

// TODO: This is added because `default` is not technically a column that requires escaping, but it is required when it's in the where clause.
// Once we escape everything by default, we can remove this patch of code.
Expand Down
4 changes: 2 additions & 2 deletions lib/destination/ddl/ddl.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ func (a AlterTableArgs) AlterTable(cols ...columns.Column) error {
mutateCol = append(mutateCol, col)
switch a.ColumnOp {
case constants.Add:
colName := col.EscapedName(a.Dwh.Dialect())
colName := a.Dwh.Dialect().QuoteIdentifier(col.Name())

if col.PrimaryKey() && a.Mode != config.History {
// Don't create a PK for history mode because it's append-only, so the primary key should not be enforced.
Expand All @@ -108,7 +108,7 @@ func (a AlterTableArgs) AlterTable(cols ...columns.Column) error {

colSQLParts = append(colSQLParts, fmt.Sprintf(`%s %s`, colName, typing.KindToDWHType(col.KindDetails, a.Dwh.Label(), col.PrimaryKey())))
case constants.Delete:
colSQLParts = append(colSQLParts, col.EscapedName(a.Dwh.Dialect()))
colSQLParts = append(colSQLParts, a.Dwh.Dialect().QuoteIdentifier(col.Name()))
}
}

Expand Down
6 changes: 3 additions & 3 deletions lib/destination/ddl/ddl_bq_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ func (d *DDLTestSuite) TestAlterTableDropColumnsBigQuery() {

assert.NoError(d.T(), alterTableArgs.AlterTable(column))
query, _ := d.fakeBigQueryStore.ExecArgsForCall(callIdx)
assert.Equal(d.T(), fmt.Sprintf("ALTER TABLE %s drop COLUMN %s", fqName, column.EscapedName(d.bigQueryStore.Dialect())), query)
assert.Equal(d.T(), fmt.Sprintf("ALTER TABLE %s drop COLUMN %s", fqName, d.bigQueryStore.Dialect().QuoteIdentifier(column.Name())), query)
callIdx += 1
}

Expand Down Expand Up @@ -143,7 +143,7 @@ func (d *DDLTestSuite) TestAlterTableAddColumns() {

assert.NoError(d.T(), alterTableArgs.AlterTable(col))
query, _ := d.fakeBigQueryStore.ExecArgsForCall(callIdx)
assert.Equal(d.T(), fmt.Sprintf("ALTER TABLE %s %s COLUMN %s %s", fqName, constants.Add, col.EscapedName(d.bigQueryStore.Dialect()),
assert.Equal(d.T(), fmt.Sprintf("ALTER TABLE %s %s COLUMN %s %s", fqName, constants.Add, d.bigQueryStore.Dialect().QuoteIdentifier(col.Name()),
typing.KindToDWHType(kind, d.bigQueryStore.Label(), false)), query)
callIdx += 1
}
Expand Down Expand Up @@ -202,7 +202,7 @@ func (d *DDLTestSuite) TestAlterTableAddColumnsSomeAlreadyExist() {

assert.NoError(d.T(), alterTableArgs.AlterTable(column))
query, _ := d.fakeBigQueryStore.ExecArgsForCall(callIdx)
assert.Equal(d.T(), fmt.Sprintf("ALTER TABLE %s %s COLUMN %s %s", fqName, constants.Add, column.EscapedName(d.bigQueryStore.Dialect()),
assert.Equal(d.T(), fmt.Sprintf("ALTER TABLE %s %s COLUMN %s %s", fqName, constants.Add, d.bigQueryStore.Dialect().QuoteIdentifier(column.Name()),
typing.KindToDWHType(column.KindDetails, d.bigQueryStore.Label(), false)), query)
callIdx += 1
}
Expand Down
5 changes: 3 additions & 2 deletions lib/destination/ddl/ddl_sflk_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func (d *DDLTestSuite) TestAlterComplexObjects() {
for i := 0; i < len(cols); i++ {
execQuery, _ := d.fakeSnowflakeStagesStore.ExecArgsForCall(i)
assert.Equal(d.T(), fmt.Sprintf("ALTER TABLE %s add COLUMN %s %s", `shop.public."COMPLEX_COLUMNS"`,
cols[i].EscapedName(d.snowflakeStagesStore.Dialect()),
d.snowflakeStagesStore.Dialect().QuoteIdentifier(cols[i].Name()),
typing.KindToDWHType(cols[i].KindDetails, d.snowflakeStagesStore.Label(), false)), execQuery)
}

Expand Down Expand Up @@ -172,7 +172,8 @@ func (d *DDLTestSuite) TestAlterTableDeleteDryRun() {

execArg, _ := d.fakeSnowflakeStagesStore.ExecArgsForCall(i)
assert.Equal(d.T(), execArg, fmt.Sprintf("ALTER TABLE %s %s COLUMN %s", `shop.public."USERS"`, constants.Delete,
cols[i].EscapedName(d.snowflakeStagesStore.Dialect())))
d.snowflakeStagesStore.Dialect().QuoteIdentifier(cols[i].Name()),
))
}
}

Expand Down
20 changes: 13 additions & 7 deletions lib/destination/dml/merge.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ func (m *MergeArgument) GetParts() ([]string, error) {
var equalitySQLParts []string
for _, primaryKey := range m.PrimaryKeys {
// We'll need to escape the primary key as well.
equalitySQL := fmt.Sprintf("c.%s = cc.%s", primaryKey.EscapedName(m.Dialect), primaryKey.EscapedName(m.Dialect))
quotedPrimaryKey := m.Dialect.QuoteIdentifier(primaryKey.Name())
equalitySQL := fmt.Sprintf("c.%s = cc.%s", quotedPrimaryKey, quotedPrimaryKey)
equalitySQLParts = append(equalitySQLParts, equalitySQL)
}

Expand All @@ -122,7 +123,8 @@ func (m *MergeArgument) GetParts() ([]string, error) {
// LEFT JOIN table on pk(s)
m.TableID.FullyQualifiedName(), strings.Join(equalitySQLParts, " and "),
// Where PK is NULL (we only need to specify one primary key since it's covered with equalitySQL parts)
m.PrimaryKeys[0].EscapedName(m.Dialect)),
m.Dialect.QuoteIdentifier(m.PrimaryKeys[0].Name()),
),
// UPDATE
fmt.Sprintf(`UPDATE %s as c SET %s FROM %s as cc WHERE %s%s;`,
// UPDATE table set col1 = cc. col1
Expand All @@ -142,7 +144,7 @@ func (m *MergeArgument) GetParts() ([]string, error) {

var pks []string
for _, pk := range m.PrimaryKeys {
pks = append(pks, pk.EscapedName(m.Dialect))
pks = append(pks, m.Dialect.QuoteIdentifier(pk.Name()))
}

parts := []string{
Expand All @@ -159,7 +161,8 @@ func (m *MergeArgument) GetParts() ([]string, error) {
// LEFT JOIN table on pk(s)
m.TableID.FullyQualifiedName(), strings.Join(equalitySQLParts, " and "),
// Where PK is NULL (we only need to specify one primary key since it's covered with equalitySQL parts)
m.PrimaryKeys[0].EscapedName(m.Dialect)),
m.Dialect.QuoteIdentifier(m.PrimaryKeys[0].Name()),
),
// UPDATE
fmt.Sprintf(`UPDATE %s as c SET %s FROM %s as cc WHERE %s%s AND COALESCE(cc.%s, false) = false;`,
// UPDATE table set col1 = cc. col1
Expand Down Expand Up @@ -207,15 +210,17 @@ func (m *MergeArgument) GetStatement() (string, error) {
var equalitySQLParts []string
for _, primaryKey := range m.PrimaryKeys {
// We'll need to escape the primary key as well.
equalitySQL := fmt.Sprintf("c.%s = cc.%s", primaryKey.EscapedName(m.Dialect), primaryKey.EscapedName(m.Dialect))
quotedPrimaryKey := m.Dialect.QuoteIdentifier(primaryKey.Name())

equalitySQL := fmt.Sprintf("c.%s = cc.%s", quotedPrimaryKey, quotedPrimaryKey)
pkCol, isOk := m.Columns.GetColumn(primaryKey.Name())
if !isOk {
return "", fmt.Errorf("column: %s does not exist in columnToType: %v", primaryKey.Name(), m.Columns)
}

if m.DestKind == constants.BigQuery && pkCol.KindDetails.Kind == typing.Struct.Kind {
// BigQuery requires special casting to compare two JSON objects.
equalitySQL = fmt.Sprintf("TO_JSON_STRING(c.%s) = TO_JSON_STRING(cc.%s)", primaryKey.EscapedName(m.Dialect), primaryKey.EscapedName(m.Dialect))
equalitySQL = fmt.Sprintf("TO_JSON_STRING(c.%s) = TO_JSON_STRING(cc.%s)", quotedPrimaryKey, quotedPrimaryKey)
}

equalitySQLParts = append(equalitySQLParts, equalitySQL)
Expand Down Expand Up @@ -288,7 +293,8 @@ func (m *MergeArgument) GetMSSQLStatement() (string, error) {
var equalitySQLParts []string
for _, primaryKey := range m.PrimaryKeys {
// We'll need to escape the primary key as well.
equalitySQL := fmt.Sprintf("c.%s = cc.%s", primaryKey.EscapedName(m.Dialect), primaryKey.EscapedName(m.Dialect))
quotedPrimaryKey := m.Dialect.QuoteIdentifier(primaryKey.Name())
equalitySQL := fmt.Sprintf("c.%s = cc.%s", quotedPrimaryKey, quotedPrimaryKey)
equalitySQLParts = append(equalitySQLParts, equalitySQL)
}

Expand Down
6 changes: 1 addition & 5 deletions lib/typing/columns/columns.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,6 @@ func (c *Column) Name() string {
return c.name
}

func (c *Column) EscapedName(dialect sql.Dialect) string {
return dialect.QuoteIdentifier(c.name)
}

type Columns struct {
columns []Column
sync.RWMutex
Expand Down Expand Up @@ -244,7 +240,7 @@ func (c *Columns) UpdateQuery(dialect sql.Dialect, skipDeleteCol bool) string {
continue
}

colName := column.EscapedName(dialect)
colName := dialect.QuoteIdentifier(column.Name())
if column.ToastColumn {
if column.KindDetails == typing.Struct {
cols = append(cols, processToastStructCol(colName, dialect))
Expand Down
6 changes: 4 additions & 2 deletions lib/typing/columns/columns_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,15 +163,17 @@ func TestColumn_Name(t *testing.T) {
},
}

snowflakeDialect := sql.SnowflakeDialect{}
bqDialect := sql.BigQueryDialect{}
for _, testCase := range testCases {
col := &Column{
name: testCase.colName,
}

assert.Equal(t, testCase.expectedName, col.Name(), testCase.colName)

assert.Equal(t, testCase.expectedNameEsc, col.EscapedName(sql.SnowflakeDialect{}), testCase.colName)
assert.Equal(t, testCase.expectedNameEscBq, col.EscapedName(sql.BigQueryDialect{}), testCase.colName)
assert.Equal(t, testCase.expectedNameEsc, snowflakeDialect.QuoteIdentifier(col.Name()), testCase.colName)
assert.Equal(t, testCase.expectedNameEscBq, bqDialect.QuoteIdentifier(col.Name()), testCase.colName)
}
}

Expand Down

0 comments on commit 2f5638e

Please sign in to comment.