Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[typing] Rename Column.Name #546

Merged
merged 4 commits into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion clients/shared/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func BackfillColumn(cfg config.Config, dwh destination.DataWarehouse, column col
return fmt.Errorf("failed to escape default value: %w", err)
}

escapedCol := column.Name(dwh.Dialect())
escapedCol := column.EscapedName(dwh.Dialect())

// TODO: This is added because `default` is not technically a column that requires escaping, but it is required when it's in the where clause.
// Once we escape everything by default, we can remove this patch of code.
Expand Down
4 changes: 2 additions & 2 deletions lib/destination/ddl/ddl.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ func (a AlterTableArgs) AlterTable(cols ...columns.Column) error {
mutateCol = append(mutateCol, col)
switch a.ColumnOp {
case constants.Add:
colName := col.Name(a.Dwh.Dialect())
colName := col.EscapedName(a.Dwh.Dialect())

if col.PrimaryKey() && a.Mode != config.History {
// Don't create a PK for history mode because it's append-only, so the primary key should not be enforced.
Expand All @@ -108,7 +108,7 @@ func (a AlterTableArgs) AlterTable(cols ...columns.Column) error {

colSQLParts = append(colSQLParts, fmt.Sprintf(`%s %s`, colName, typing.KindToDWHType(col.KindDetails, a.Dwh.Label(), col.PrimaryKey())))
case constants.Delete:
colSQLParts = append(colSQLParts, col.Name(a.Dwh.Dialect()))
colSQLParts = append(colSQLParts, col.EscapedName(a.Dwh.Dialect()))
}
}

Expand Down
6 changes: 3 additions & 3 deletions lib/destination/ddl/ddl_bq_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ func (d *DDLTestSuite) TestAlterTableDropColumnsBigQuery() {

assert.NoError(d.T(), alterTableArgs.AlterTable(column))
query, _ := d.fakeBigQueryStore.ExecArgsForCall(callIdx)
assert.Equal(d.T(), fmt.Sprintf("ALTER TABLE %s drop COLUMN %s", fqName, column.Name(d.bigQueryStore.Dialect())), query)
assert.Equal(d.T(), fmt.Sprintf("ALTER TABLE %s drop COLUMN %s", fqName, column.EscapedName(d.bigQueryStore.Dialect())), query)
callIdx += 1
}

Expand Down Expand Up @@ -143,7 +143,7 @@ func (d *DDLTestSuite) TestAlterTableAddColumns() {

assert.NoError(d.T(), alterTableArgs.AlterTable(col))
query, _ := d.fakeBigQueryStore.ExecArgsForCall(callIdx)
assert.Equal(d.T(), fmt.Sprintf("ALTER TABLE %s %s COLUMN %s %s", fqName, constants.Add, col.Name(d.bigQueryStore.Dialect()),
assert.Equal(d.T(), fmt.Sprintf("ALTER TABLE %s %s COLUMN %s %s", fqName, constants.Add, col.EscapedName(d.bigQueryStore.Dialect()),
typing.KindToDWHType(kind, d.bigQueryStore.Label(), false)), query)
callIdx += 1
}
Expand Down Expand Up @@ -202,7 +202,7 @@ func (d *DDLTestSuite) TestAlterTableAddColumnsSomeAlreadyExist() {

assert.NoError(d.T(), alterTableArgs.AlterTable(column))
query, _ := d.fakeBigQueryStore.ExecArgsForCall(callIdx)
assert.Equal(d.T(), fmt.Sprintf("ALTER TABLE %s %s COLUMN %s %s", fqName, constants.Add, column.Name(d.bigQueryStore.Dialect()),
assert.Equal(d.T(), fmt.Sprintf("ALTER TABLE %s %s COLUMN %s %s", fqName, constants.Add, column.EscapedName(d.bigQueryStore.Dialect()),
typing.KindToDWHType(column.KindDetails, d.bigQueryStore.Label(), false)), query)
callIdx += 1
}
Expand Down
4 changes: 2 additions & 2 deletions lib/destination/ddl/ddl_sflk_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func (d *DDLTestSuite) TestAlterComplexObjects() {
for i := 0; i < len(cols); i++ {
execQuery, _ := d.fakeSnowflakeStagesStore.ExecArgsForCall(i)
assert.Equal(d.T(), fmt.Sprintf("ALTER TABLE %s add COLUMN %s %s", `shop.public."COMPLEX_COLUMNS"`,
cols[i].Name(d.snowflakeStagesStore.Dialect()),
cols[i].EscapedName(d.snowflakeStagesStore.Dialect()),
typing.KindToDWHType(cols[i].KindDetails, d.snowflakeStagesStore.Label(), false)), execQuery)
}

Expand Down Expand Up @@ -172,7 +172,7 @@ func (d *DDLTestSuite) TestAlterTableDeleteDryRun() {

execArg, _ := d.fakeSnowflakeStagesStore.ExecArgsForCall(i)
assert.Equal(d.T(), execArg, fmt.Sprintf("ALTER TABLE %s %s COLUMN %s", `shop.public."USERS"`, constants.Delete,
cols[i].Name(d.snowflakeStagesStore.Dialect())))
cols[i].EscapedName(d.snowflakeStagesStore.Dialect())))
}
}

Expand Down
14 changes: 7 additions & 7 deletions lib/destination/dml/merge.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ func (m *MergeArgument) GetParts() ([]string, error) {
var equalitySQLParts []string
for _, primaryKey := range m.PrimaryKeys {
// We'll need to escape the primary key as well.
equalitySQL := fmt.Sprintf("c.%s = cc.%s", primaryKey.Name(m.Dialect), primaryKey.Name(m.Dialect))
equalitySQL := fmt.Sprintf("c.%s = cc.%s", primaryKey.EscapedName(m.Dialect), primaryKey.EscapedName(m.Dialect))
equalitySQLParts = append(equalitySQLParts, equalitySQL)
}

Expand All @@ -122,7 +122,7 @@ func (m *MergeArgument) GetParts() ([]string, error) {
// LEFT JOIN table on pk(s)
m.TableID.FullyQualifiedName(), strings.Join(equalitySQLParts, " and "),
// Where PK is NULL (we only need to specify one primary key since it's covered with equalitySQL parts)
m.PrimaryKeys[0].Name(m.Dialect)),
m.PrimaryKeys[0].EscapedName(m.Dialect)),
// UPDATE
fmt.Sprintf(`UPDATE %s as c SET %s FROM %s as cc WHERE %s%s;`,
// UPDATE table set col1 = cc. col1
Expand All @@ -142,7 +142,7 @@ func (m *MergeArgument) GetParts() ([]string, error) {

var pks []string
for _, pk := range m.PrimaryKeys {
pks = append(pks, pk.Name(m.Dialect))
pks = append(pks, pk.EscapedName(m.Dialect))
}

parts := []string{
Expand All @@ -159,7 +159,7 @@ func (m *MergeArgument) GetParts() ([]string, error) {
// LEFT JOIN table on pk(s)
m.TableID.FullyQualifiedName(), strings.Join(equalitySQLParts, " and "),
// Where PK is NULL (we only need to specify one primary key since it's covered with equalitySQL parts)
m.PrimaryKeys[0].Name(m.Dialect)),
m.PrimaryKeys[0].EscapedName(m.Dialect)),
// UPDATE
fmt.Sprintf(`UPDATE %s as c SET %s FROM %s as cc WHERE %s%s AND COALESCE(cc.%s, false) = false;`,
// UPDATE table set col1 = cc. col1
Expand Down Expand Up @@ -207,15 +207,15 @@ func (m *MergeArgument) GetStatement() (string, error) {
var equalitySQLParts []string
for _, primaryKey := range m.PrimaryKeys {
// We'll need to escape the primary key as well.
equalitySQL := fmt.Sprintf("c.%s = cc.%s", primaryKey.Name(m.Dialect), primaryKey.Name(m.Dialect))
equalitySQL := fmt.Sprintf("c.%s = cc.%s", primaryKey.EscapedName(m.Dialect), primaryKey.EscapedName(m.Dialect))
pkCol, isOk := m.Columns.GetColumn(primaryKey.RawName())
if !isOk {
return "", fmt.Errorf("column: %s does not exist in columnToType: %v", primaryKey.RawName(), m.Columns)
}

if m.DestKind == constants.BigQuery && pkCol.KindDetails.Kind == typing.Struct.Kind {
// BigQuery requires special casting to compare two JSON objects.
equalitySQL = fmt.Sprintf("TO_JSON_STRING(c.%s) = TO_JSON_STRING(cc.%s)", primaryKey.Name(m.Dialect), primaryKey.Name(m.Dialect))
equalitySQL = fmt.Sprintf("TO_JSON_STRING(c.%s) = TO_JSON_STRING(cc.%s)", primaryKey.EscapedName(m.Dialect), primaryKey.EscapedName(m.Dialect))
}

equalitySQLParts = append(equalitySQLParts, equalitySQL)
Expand Down Expand Up @@ -288,7 +288,7 @@ func (m *MergeArgument) GetMSSQLStatement() (string, error) {
var equalitySQLParts []string
for _, primaryKey := range m.PrimaryKeys {
// We'll need to escape the primary key as well.
equalitySQL := fmt.Sprintf("c.%s = cc.%s", primaryKey.Name(m.Dialect), primaryKey.Name(m.Dialect))
equalitySQL := fmt.Sprintf("c.%s = cc.%s", primaryKey.EscapedName(m.Dialect), primaryKey.EscapedName(m.Dialect))
equalitySQLParts = append(equalitySQLParts, equalitySQL)
}

Expand Down
5 changes: 2 additions & 3 deletions lib/typing/columns/columns.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,7 @@ func (c *Column) RawName() string {
return c.name
}

// Name will give you c.name and escape it if necessary.
func (c *Column) Name(dialect sql.Dialect) string {
func (c *Column) EscapedName(dialect sql.Dialect) string {
return dialect.QuoteIdentifier(c.name)
}

Expand Down Expand Up @@ -245,7 +244,7 @@ func (c *Columns) UpdateQuery(dialect sql.Dialect, skipDeleteCol bool) string {
continue
}

colName := column.Name(dialect)
colName := column.EscapedName(dialect)
if column.ToastColumn {
if column.KindDetails == typing.Struct {
cols = append(cols, processToastStructCol(colName, dialect))
Expand Down
4 changes: 2 additions & 2 deletions lib/typing/columns/columns_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,8 @@ func TestColumn_Name(t *testing.T) {

assert.Equal(t, testCase.expectedName, col.RawName(), testCase.colName)

assert.Equal(t, testCase.expectedNameEsc, col.Name(sql.SnowflakeDialect{}), testCase.colName)
assert.Equal(t, testCase.expectedNameEscBq, col.Name(sql.BigQueryDialect{}), testCase.colName)
assert.Equal(t, testCase.expectedNameEsc, col.EscapedName(sql.SnowflakeDialect{}), testCase.colName)
assert.Equal(t, testCase.expectedNameEscBq, col.EscapedName(sql.BigQueryDialect{}), testCase.colName)
}
}

Expand Down