Skip to content

Commit

Permalink
Move QuoteColumns function to columns (#629)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathan-artie authored May 13, 2024
1 parent 6098330 commit cdc050d
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 31 deletions.
8 changes: 0 additions & 8 deletions lib/destination/dml/columns.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,6 @@ import (
"github.com/artie-labs/transfer/lib/typing/columns"
)

func quoteColumns(cols []columns.Column, dialect sql.Dialect) []string {
result := make([]string, len(cols))
for i, col := range cols {
result[i] = dialect.QuoteIdentifier(col.Name())
}
return result
}

// buildColumnsUpdateFragment will parse the columns and then returns a list of strings like: cc.first_name=c.first_name,cc.last_name=c.last_name,cc.email=c.email
// NOTE: This should only be used with valid columns.
func buildColumnsUpdateFragment(columns []columns.Column, dialect sql.Dialect) string {
Expand Down
8 changes: 4 additions & 4 deletions lib/destination/dml/columns_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@ import (
)

func TestQuoteColumns(t *testing.T) {
assert.Equal(t, []string{}, quoteColumns(nil, bigQueryDialect.BigQueryDialect{}))
assert.Equal(t, []string{}, quoteColumns(nil, snowflakeDialect.SnowflakeDialect{}))
assert.Equal(t, []string{}, columns.QuoteColumns(nil, bigQueryDialect.BigQueryDialect{}))
assert.Equal(t, []string{}, columns.QuoteColumns(nil, snowflakeDialect.SnowflakeDialect{}))

cols := []columns.Column{columns.NewColumn("a", typing.Invalid), columns.NewColumn("b", typing.Invalid)}
assert.Equal(t, []string{"`a`", "`b`"}, quoteColumns(cols, bigQueryDialect.BigQueryDialect{}))
assert.Equal(t, []string{`"A"`, `"B"`}, quoteColumns(cols, snowflakeDialect.SnowflakeDialect{}))
assert.Equal(t, []string{"`a`", "`b`"}, columns.QuoteColumns(cols, bigQueryDialect.BigQueryDialect{}))
assert.Equal(t, []string{`"A"`, `"B"`}, columns.QuoteColumns(cols, snowflakeDialect.SnowflakeDialect{}))
}

func TestBuildColumnsUpdateFragment(t *testing.T) {
Expand Down
38 changes: 19 additions & 19 deletions lib/destination/dml/merge.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,13 @@ func (m *MergeArgument) redshiftEqualitySQLParts() []string {
return equalitySQLParts
}

func (m *MergeArgument) buildRedshiftInsertQuery(columns []columns.Column) string {
func (m *MergeArgument) buildRedshiftInsertQuery(cols []columns.Column) string {
return fmt.Sprintf(`INSERT INTO %s (%s) SELECT %s FROM %s AS cc LEFT JOIN %s AS c ON %s WHERE c.%s IS NULL;`,
// insert into target (col1, col2, col3)
m.TableID.FullyQualifiedName(), strings.Join(quoteColumns(columns, m.Dialect), ","),
m.TableID.FullyQualifiedName(), strings.Join(columns.QuoteColumns(cols, m.Dialect), ","),
// SELECT cc.col1, cc.col2, ... FROM staging as CC
array.StringsJoinAddPrefix(array.StringsJoinAddPrefixArgs{
Vals: quoteColumns(columns, m.Dialect),
Vals: columns.QuoteColumns(cols, m.Dialect),
Separator: ",",
Prefix: "cc.",
}), m.SubQuery,
Expand All @@ -95,7 +95,7 @@ func (m *MergeArgument) buildRedshiftInsertQuery(columns []columns.Column) strin
)
}

func (m *MergeArgument) buildRedshiftUpdateQuery(columns []columns.Column) string {
func (m *MergeArgument) buildRedshiftUpdateQuery(cols []columns.Column) string {
clauses := m.redshiftEqualitySQLParts()

if m.IdempotentKey != "" {
Expand All @@ -108,7 +108,7 @@ func (m *MergeArgument) buildRedshiftUpdateQuery(columns []columns.Column) strin

return fmt.Sprintf(`UPDATE %s AS c SET %s FROM %s AS cc WHERE %s;`,
// UPDATE table set col1 = cc. col1
m.TableID.FullyQualifiedName(), buildColumnsUpdateFragment(columns, m.Dialect),
m.TableID.FullyQualifiedName(), buildColumnsUpdateFragment(cols, m.Dialect),
// FROM staging WHERE join on PK(s)
m.SubQuery, strings.Join(clauses, " AND "),
)
Expand All @@ -117,10 +117,10 @@ func (m *MergeArgument) buildRedshiftUpdateQuery(columns []columns.Column) strin
func (m *MergeArgument) buildRedshiftDeleteQuery() string {
return fmt.Sprintf(`DELETE FROM %s WHERE (%s) IN (SELECT %s FROM %s AS cc WHERE cc.%s = true);`,
// DELETE from table where (pk_1, pk_2)
m.TableID.FullyQualifiedName(), strings.Join(quoteColumns(m.PrimaryKeys, m.Dialect), ","),
m.TableID.FullyQualifiedName(), strings.Join(columns.QuoteColumns(m.PrimaryKeys, m.Dialect), ","),
// IN (cc.pk_1, cc.pk_2) FROM staging
array.StringsJoinAddPrefix(array.StringsJoinAddPrefixArgs{
Vals: quoteColumns(m.PrimaryKeys, m.Dialect),
Vals: columns.QuoteColumns(m.PrimaryKeys, m.Dialect),
Separator: ",",
Prefix: "cc.",
}), m.SubQuery, m.Dialect.QuoteIdentifier(constants.DeleteColumnMarker),
Expand Down Expand Up @@ -211,16 +211,16 @@ WHEN NOT MATCHED AND IFNULL(cc.%s, false) = false THEN INSERT (%s) VALUES (%s);`
// Update + Soft Deletion
idempotentClause, buildColumnsUpdateFragment(m.Columns, m.Dialect),
// Insert
m.Dialect.QuoteIdentifier(constants.DeleteColumnMarker), strings.Join(quoteColumns(m.Columns, m.Dialect), ","),
m.Dialect.QuoteIdentifier(constants.DeleteColumnMarker), strings.Join(columns.QuoteColumns(m.Columns, m.Dialect), ","),
array.StringsJoinAddPrefix(array.StringsJoinAddPrefixArgs{
Vals: quoteColumns(m.Columns, m.Dialect),
Vals: columns.QuoteColumns(m.Columns, m.Dialect),
Separator: ",",
Prefix: "cc.",
})), nil
}

// We also need to remove __artie flags since it does not exist in the destination table
columns, removed := columns.RemoveDeleteColumnMarker(m.Columns)
cols, removed := columns.RemoveDeleteColumnMarker(m.Columns)
if !removed {
return "", errors.New("artie delete flag doesn't exist")
}
Expand All @@ -234,11 +234,11 @@ WHEN NOT MATCHED AND IFNULL(cc.%s, false) = false THEN INSERT (%s) VALUES (%s);`
// Delete
m.Dialect.QuoteIdentifier(constants.DeleteColumnMarker),
// Update
m.Dialect.QuoteIdentifier(constants.DeleteColumnMarker), idempotentClause, buildColumnsUpdateFragment(columns, m.Dialect),
m.Dialect.QuoteIdentifier(constants.DeleteColumnMarker), idempotentClause, buildColumnsUpdateFragment(cols, m.Dialect),
// Insert
m.Dialect.QuoteIdentifier(constants.DeleteColumnMarker), strings.Join(quoteColumns(columns, m.Dialect), ","),
m.Dialect.QuoteIdentifier(constants.DeleteColumnMarker), strings.Join(columns.QuoteColumns(cols, m.Dialect), ","),
array.StringsJoinAddPrefix(array.StringsJoinAddPrefixArgs{
Vals: quoteColumns(columns, m.Dialect),
Vals: columns.QuoteColumns(cols, m.Dialect),
Separator: ",",
Prefix: "cc.",
})), nil
Expand Down Expand Up @@ -268,16 +268,16 @@ WHEN NOT MATCHED AND COALESCE(cc.%s, 0) = 0 THEN INSERT (%s) VALUES (%s);`,
// Update + Soft Deletion
idempotentClause, buildColumnsUpdateFragment(m.Columns, m.Dialect),
// Insert
m.Dialect.QuoteIdentifier(constants.DeleteColumnMarker), strings.Join(quoteColumns(m.Columns, m.Dialect), ","),
m.Dialect.QuoteIdentifier(constants.DeleteColumnMarker), strings.Join(columns.QuoteColumns(m.Columns, m.Dialect), ","),
array.StringsJoinAddPrefix(array.StringsJoinAddPrefixArgs{
Vals: quoteColumns(m.Columns, m.Dialect),
Vals: columns.QuoteColumns(m.Columns, m.Dialect),
Separator: ",",
Prefix: "cc.",
})), nil
}

// We also need to remove __artie flags since it does not exist in the destination table
columns, removed := columns.RemoveDeleteColumnMarker(m.Columns)
cols, removed := columns.RemoveDeleteColumnMarker(m.Columns)
if !removed {
return "", errors.New("artie delete flag doesn't exist")
}
Expand All @@ -292,11 +292,11 @@ WHEN NOT MATCHED AND COALESCE(cc.%s, 1) = 0 THEN INSERT (%s) VALUES (%s);`,
// Delete
m.Dialect.QuoteIdentifier(constants.DeleteColumnMarker),
// Update
m.Dialect.QuoteIdentifier(constants.DeleteColumnMarker), idempotentClause, buildColumnsUpdateFragment(columns, m.Dialect),
m.Dialect.QuoteIdentifier(constants.DeleteColumnMarker), idempotentClause, buildColumnsUpdateFragment(cols, m.Dialect),
// Insert
m.Dialect.QuoteIdentifier(constants.DeleteColumnMarker), strings.Join(quoteColumns(columns, m.Dialect), ","),
m.Dialect.QuoteIdentifier(constants.DeleteColumnMarker), strings.Join(columns.QuoteColumns(cols, m.Dialect), ","),
array.StringsJoinAddPrefix(array.StringsJoinAddPrefixArgs{
Vals: quoteColumns(columns, m.Dialect),
Vals: columns.QuoteColumns(cols, m.Dialect),
Separator: ",",
Prefix: "cc.",
})), nil
Expand Down
9 changes: 9 additions & 0 deletions lib/typing/columns/columns.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"sync"

"github.com/artie-labs/transfer/lib/config/constants"
"github.com/artie-labs/transfer/lib/sql"
"github.com/artie-labs/transfer/lib/stringutil"
"github.com/artie-labs/transfer/lib/typing"
)
Expand Down Expand Up @@ -248,6 +249,14 @@ func (c *Columns) DeleteColumn(name string) {
}
}

func QuoteColumns(cols []Column, dialect sql.Dialect) []string {
result := make([]string, len(cols))
for i, col := range cols {
result[i] = dialect.QuoteIdentifier(col.Name())
}
return result
}

// RemoveDeleteColumnMarker removes the deleted column marker from a slice (if present) returning a new slice and whether or not it was removed.
func RemoveDeleteColumnMarker(cols []Column) ([]Column, bool) {
origLength := len(cols)
Expand Down

0 comments on commit cdc050d

Please sign in to comment.