Skip to content

Commit

Permalink
Move typing.UpdateQuery to dml (#550)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathan-artie authored May 3, 2024
1 parent 098fce8 commit 6f690e5
Show file tree
Hide file tree
Showing 6 changed files with 213 additions and 206 deletions.
72 changes: 72 additions & 0 deletions lib/destination/dml/columns.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package dml

import (
"fmt"
"strings"

"github.com/artie-labs/transfer/lib/config/constants"
"github.com/artie-labs/transfer/lib/sql"
"github.com/artie-labs/transfer/lib/typing"
"github.com/artie-labs/transfer/lib/typing/columns"
)

// buildColumnsUpdateFragment will parse the columns and then returns a list of strings like: cc.first_name=c.first_name,cc.last_name=c.last_name,cc.email=c.email
func buildColumnsUpdateFragment(c *columns.Columns, dialect sql.Dialect, skipDeleteCol bool) string {
var cols []string
for _, column := range c.GetColumns() {
if column.ShouldSkip() {
continue
}

// skipDeleteCol is useful because we don't want to copy the deleted column over to the source table if we're doing a hard row delete.
if skipDeleteCol && column.Name() == constants.DeleteColumnMarker {
continue
}

colName := dialect.QuoteIdentifier(column.Name())
if column.ToastColumn {
if column.KindDetails == typing.Struct {
cols = append(cols, processToastStructCol(colName, dialect))
} else {
cols = append(cols, processToastCol(colName, dialect))
}

} else {
// This is to make it look like: objCol = cc.objCol
cols = append(cols, fmt.Sprintf("%s=cc.%s", colName, colName))
}
}

return strings.Join(cols, ",")
}

func processToastStructCol(colName string, dialect sql.Dialect) string {
switch dialect.(type) {
case sql.BigQueryDialect:
return fmt.Sprintf(`%s= CASE WHEN COALESCE(TO_JSON_STRING(cc.%s) != '{"key":"%s"}', true) THEN cc.%s ELSE c.%s END`,
colName, colName, constants.ToastUnavailableValuePlaceholder,
colName, colName)
case sql.RedshiftDialect:
return fmt.Sprintf(`%s= CASE WHEN COALESCE(cc.%s != JSON_PARSE('{"key":"%s"}'), true) THEN cc.%s ELSE c.%s END`,
colName, colName, constants.ToastUnavailableValuePlaceholder, colName, colName)
case sql.MSSQLDialect:
// Microsoft SQL Server doesn't allow boolean expressions to be in the COALESCE statement.
return fmt.Sprintf("%s= CASE WHEN COALESCE(cc.%s, {}) != {'key': '%s'} THEN cc.%s ELSE c.%s END",
colName, colName, constants.ToastUnavailableValuePlaceholder, colName, colName)
default:
// TODO: Change this to Snowflake and error out if the destKind isn't supported so we're explicit.
return fmt.Sprintf("%s= CASE WHEN COALESCE(cc.%s != {'key': '%s'}, true) THEN cc.%s ELSE c.%s END",
colName, colName, constants.ToastUnavailableValuePlaceholder, colName, colName)
}
}

func processToastCol(colName string, dialect sql.Dialect) string {
if _, ok := dialect.(sql.MSSQLDialect); ok {
// Microsoft SQL Server doesn't allow boolean expressions to be in the COALESCE statement.
return fmt.Sprintf("%s= CASE WHEN COALESCE(cc.%s, '') != '%s' THEN cc.%s ELSE c.%s END", colName, colName,
constants.ToastUnavailableValuePlaceholder, colName, colName)
} else {
return fmt.Sprintf("%s= CASE WHEN COALESCE(cc.%s != '%s', true) THEN cc.%s ELSE c.%s END",
colName, colName, constants.ToastUnavailableValuePlaceholder, colName, colName)
}
}
134 changes: 134 additions & 0 deletions lib/destination/dml/columns_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
package dml

import (
"fmt"
"testing"

"github.com/artie-labs/transfer/lib/config/constants"
"github.com/artie-labs/transfer/lib/sql"
"github.com/artie-labs/transfer/lib/typing"
"github.com/artie-labs/transfer/lib/typing/columns"
"github.com/stretchr/testify/assert"
)

func TestBuildColumnsUpdateFragment(t *testing.T) {
type testCase struct {
name string
columns columns.Columns
expectedString string
dialect sql.Dialect
skipDeleteCol bool
}

fooBarCols := []string{"foo", "bar"}

var (
happyPathCols columns.Columns
stringAndToastCols columns.Columns
lastCaseColTypes columns.Columns
lastCaseEscapeTypes columns.Columns
)
for _, col := range fooBarCols {
column := columns.NewColumn(col, typing.String)
column.ToastColumn = false
happyPathCols.AddColumn(column)
}
for _, col := range fooBarCols {
var toastCol bool
if col == "foo" {
toastCol = true
}

column := columns.NewColumn(col, typing.String)
column.ToastColumn = toastCol
stringAndToastCols.AddColumn(column)
}

lastCaseCols := []string{"a1", "b2", "c3"}
for _, lastCaseCol := range lastCaseCols {
kd := typing.String
var toast bool
// a1 - struct + toast, b2 - string + toast, c3 = regular string.
if lastCaseCol == "a1" {
kd = typing.Struct
toast = true
} else if lastCaseCol == "b2" {
toast = true
}

column := columns.NewColumn(lastCaseCol, kd)
column.ToastColumn = toast
lastCaseColTypes.AddColumn(column)
}

lastCaseColsEsc := []string{"a1", "b2", "c3", "start", "select"}
for _, lastCaseColEsc := range lastCaseColsEsc {
kd := typing.String
var toast bool
// a1 - struct + toast, b2 - string + toast, c3 = regular string.
if lastCaseColEsc == "a1" {
kd = typing.Struct
toast = true
} else if lastCaseColEsc == "b2" {
toast = true
} else if lastCaseColEsc == "start" {
kd = typing.Struct
toast = true
}

column := columns.NewColumn(lastCaseColEsc, kd)
column.ToastColumn = toast
lastCaseEscapeTypes.AddColumn(column)
}

lastCaseEscapeTypes.AddColumn(columns.NewColumn(constants.DeleteColumnMarker, typing.Boolean))

key := `{"key":"__debezium_unavailable_value"}`
testCases := []testCase{
{
name: "happy path",
columns: happyPathCols,
dialect: sql.RedshiftDialect{},
expectedString: `"foo"=cc."foo","bar"=cc."bar"`,
},
{
name: "string and toast",
columns: stringAndToastCols,
dialect: sql.SnowflakeDialect{},
expectedString: `"FOO"= CASE WHEN COALESCE(cc."FOO" != '__debezium_unavailable_value', true) THEN cc."FOO" ELSE c."FOO" END,"BAR"=cc."BAR"`,
},
{
name: "struct, string and toast string",
columns: lastCaseColTypes,
dialect: sql.RedshiftDialect{},
expectedString: `"a1"= CASE WHEN COALESCE(cc."a1" != JSON_PARSE('{"key":"__debezium_unavailable_value"}'), true) THEN cc."a1" ELSE c."a1" END,"b2"= CASE WHEN COALESCE(cc."b2" != '__debezium_unavailable_value', true) THEN cc."b2" ELSE c."b2" END,"c3"=cc."c3"`,
},
{
name: "struct, string and toast string (bigquery)",
columns: lastCaseColTypes,
dialect: sql.BigQueryDialect{},
expectedString: "`a1`= CASE WHEN COALESCE(TO_JSON_STRING(cc.`a1`) != '{\"key\":\"__debezium_unavailable_value\"}', true) THEN cc.`a1` ELSE c.`a1` END,`b2`= CASE WHEN COALESCE(cc.`b2` != '__debezium_unavailable_value', true) THEN cc.`b2` ELSE c.`b2` END,`c3`=cc.`c3`",
},
{
name: "struct, string and toast string (bigquery) w/ reserved keywords",
columns: lastCaseEscapeTypes,
dialect: sql.BigQueryDialect{},
expectedString: fmt.Sprintf("`a1`= CASE WHEN COALESCE(TO_JSON_STRING(cc.`a1`) != '%s', true) THEN cc.`a1` ELSE c.`a1` END,`b2`= CASE WHEN COALESCE(cc.`b2` != '__debezium_unavailable_value', true) THEN cc.`b2` ELSE c.`b2` END,`c3`=cc.`c3`,%s,%s",
key, fmt.Sprintf("`start`= CASE WHEN COALESCE(TO_JSON_STRING(cc.`start`) != '%s', true) THEN cc.`start` ELSE c.`start` END", key), "`select`=cc.`select`"),
skipDeleteCol: true,
},
{
name: "struct, string and toast string (bigquery) w/ reserved keywords",
columns: lastCaseEscapeTypes,
dialect: sql.BigQueryDialect{},
expectedString: fmt.Sprintf("`a1`= CASE WHEN COALESCE(TO_JSON_STRING(cc.`a1`) != '%s', true) THEN cc.`a1` ELSE c.`a1` END,`b2`= CASE WHEN COALESCE(cc.`b2` != '__debezium_unavailable_value', true) THEN cc.`b2` ELSE c.`b2` END,`c3`=cc.`c3`,%s,%s",
key, fmt.Sprintf("`start`= CASE WHEN COALESCE(TO_JSON_STRING(cc.`start`) != '%s', true) THEN cc.`start` ELSE c.`start` END", key), "`select`=cc.`select`,`__artie_delete`=cc.`__artie_delete`"),
skipDeleteCol: false,
},
}

for _, _testCase := range testCases {
actualQuery := buildColumnsUpdateFragment(&_testCase.columns, _testCase.dialect, _testCase.skipDeleteCol)
assert.Equal(t, _testCase.expectedString, actualQuery, _testCase.name)
}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package columns
package dml

import (
"testing"
Expand Down
12 changes: 6 additions & 6 deletions lib/destination/dml/merge.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ func (m *MergeArgument) GetParts() ([]string, error) {
// UPDATE
fmt.Sprintf(`UPDATE %s as c SET %s FROM %s as cc WHERE %s%s;`,
// UPDATE table set col1 = cc. col1
m.TableID.FullyQualifiedName(), m.Columns.UpdateQuery(m.Dialect, false),
m.TableID.FullyQualifiedName(), buildColumnsUpdateFragment(m.Columns, m.Dialect, false),
// FROM table (temp) WHERE join on PK(s)
m.SubQuery, strings.Join(equalitySQLParts, " and "), idempotentClause,
),
Expand Down Expand Up @@ -166,7 +166,7 @@ func (m *MergeArgument) GetParts() ([]string, error) {
// UPDATE
fmt.Sprintf(`UPDATE %s as c SET %s FROM %s as cc WHERE %s%s AND COALESCE(cc.%s, false) = false;`,
// UPDATE table set col1 = cc. col1
m.TableID.FullyQualifiedName(), m.Columns.UpdateQuery(m.Dialect, true),
m.TableID.FullyQualifiedName(), buildColumnsUpdateFragment(m.Columns, m.Dialect, true),
// FROM staging WHERE join on PK(s)
m.SubQuery, strings.Join(equalitySQLParts, " and "), idempotentClause, m.Dialect.QuoteIdentifier(constants.DeleteColumnMarker),
),
Expand Down Expand Up @@ -244,7 +244,7 @@ WHEN MATCHED %sTHEN UPDATE SET %s
WHEN NOT MATCHED AND IFNULL(cc.%s, false) = false THEN INSERT (%s) VALUES (%s);`,
m.TableID.FullyQualifiedName(), subQuery, strings.Join(equalitySQLParts, " and "),
// Update + Soft Deletion
idempotentClause, m.Columns.UpdateQuery(m.Dialect, false),
idempotentClause, buildColumnsUpdateFragment(m.Columns, m.Dialect, false),
// Insert
m.Dialect.QuoteIdentifier(constants.DeleteColumnMarker), strings.Join(sql.QuoteIdentifiers(columns, m.Dialect), ","),
array.StringsJoinAddPrefix(array.StringsJoinAddPrefixArgs{
Expand All @@ -270,7 +270,7 @@ WHEN NOT MATCHED AND IFNULL(cc.%s, false) = false THEN INSERT (%s) VALUES (%s);`
// Delete
m.Dialect.QuoteIdentifier(constants.DeleteColumnMarker),
// Update
m.Dialect.QuoteIdentifier(constants.DeleteColumnMarker), idempotentClause, m.Columns.UpdateQuery(m.Dialect, true),
m.Dialect.QuoteIdentifier(constants.DeleteColumnMarker), idempotentClause, buildColumnsUpdateFragment(m.Columns, m.Dialect, true),
// Insert
m.Dialect.QuoteIdentifier(constants.DeleteColumnMarker), strings.Join(sql.QuoteIdentifiers(columns, m.Dialect), ","),
array.StringsJoinAddPrefix(array.StringsJoinAddPrefixArgs{
Expand Down Expand Up @@ -308,7 +308,7 @@ WHEN MATCHED %sTHEN UPDATE SET %s
WHEN NOT MATCHED AND COALESCE(cc.%s, 0) = 0 THEN INSERT (%s) VALUES (%s);`,
m.TableID.FullyQualifiedName(), m.SubQuery, strings.Join(equalitySQLParts, " and "),
// Update + Soft Deletion
idempotentClause, m.Columns.UpdateQuery(m.Dialect, false),
idempotentClause, buildColumnsUpdateFragment(m.Columns, m.Dialect, false),
// Insert
m.Dialect.QuoteIdentifier(constants.DeleteColumnMarker), strings.Join(sql.QuoteIdentifiers(columns, m.Dialect), ","),
array.StringsJoinAddPrefix(array.StringsJoinAddPrefixArgs{
Expand All @@ -335,7 +335,7 @@ WHEN NOT MATCHED AND COALESCE(cc.%s, 1) = 0 THEN INSERT (%s) VALUES (%s);`,
// Delete
m.Dialect.QuoteIdentifier(constants.DeleteColumnMarker),
// Update
m.Dialect.QuoteIdentifier(constants.DeleteColumnMarker), idempotentClause, m.Columns.UpdateQuery(m.Dialect, true),
m.Dialect.QuoteIdentifier(constants.DeleteColumnMarker), idempotentClause, buildColumnsUpdateFragment(m.Columns, m.Dialect, true),
// Insert
m.Dialect.QuoteIdentifier(constants.DeleteColumnMarker), strings.Join(sql.QuoteIdentifiers(columns, m.Dialect), ","),
array.StringsJoinAddPrefix(array.StringsJoinAddPrefixArgs{
Expand Down
64 changes: 0 additions & 64 deletions lib/typing/columns/columns.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
package columns

import (
"fmt"
"strings"
"sync"

"github.com/artie-labs/transfer/lib/config/constants"
"github.com/artie-labs/transfer/lib/sql"
"github.com/artie-labs/transfer/lib/stringutil"
"github.com/artie-labs/transfer/lib/typing"
)
Expand Down Expand Up @@ -226,64 +223,3 @@ func (c *Columns) DeleteColumn(name string) {
}
}
}

// UpdateQuery will parse the columns and then returns a list of strings like: cc.first_name=c.first_name,cc.last_name=c.last_name,cc.email=c.email
func (c *Columns) UpdateQuery(dialect sql.Dialect, skipDeleteCol bool) string {
var cols []string
for _, column := range c.GetColumns() {
if column.ShouldSkip() {
continue
}

// skipDeleteCol is useful because we don't want to copy the deleted column over to the source table if we're doing a hard row delete.
if skipDeleteCol && column.Name() == constants.DeleteColumnMarker {
continue
}

colName := dialect.QuoteIdentifier(column.Name())
if column.ToastColumn {
if column.KindDetails == typing.Struct {
cols = append(cols, processToastStructCol(colName, dialect))
} else {
cols = append(cols, processToastCol(colName, dialect))
}

} else {
// This is to make it look like: objCol = cc.objCol
cols = append(cols, fmt.Sprintf("%s=cc.%s", colName, colName))
}
}

return strings.Join(cols, ",")
}

func processToastStructCol(colName string, dialect sql.Dialect) string {
switch dialect.(type) {
case sql.BigQueryDialect:
return fmt.Sprintf(`%s= CASE WHEN COALESCE(TO_JSON_STRING(cc.%s) != '{"key":"%s"}', true) THEN cc.%s ELSE c.%s END`,
colName, colName, constants.ToastUnavailableValuePlaceholder,
colName, colName)
case sql.RedshiftDialect:
return fmt.Sprintf(`%s= CASE WHEN COALESCE(cc.%s != JSON_PARSE('{"key":"%s"}'), true) THEN cc.%s ELSE c.%s END`,
colName, colName, constants.ToastUnavailableValuePlaceholder, colName, colName)
case sql.MSSQLDialect:
// Microsoft SQL Server doesn't allow boolean expressions to be in the COALESCE statement.
return fmt.Sprintf("%s= CASE WHEN COALESCE(cc.%s, {}) != {'key': '%s'} THEN cc.%s ELSE c.%s END",
colName, colName, constants.ToastUnavailableValuePlaceholder, colName, colName)
default:
// TODO: Change this to Snowflake and error out if the destKind isn't supported so we're explicit.
return fmt.Sprintf("%s= CASE WHEN COALESCE(cc.%s != {'key': '%s'}, true) THEN cc.%s ELSE c.%s END",
colName, colName, constants.ToastUnavailableValuePlaceholder, colName, colName)
}
}

func processToastCol(colName string, dialect sql.Dialect) string {
if _, ok := dialect.(sql.MSSQLDialect); ok {
// Microsoft SQL Server doesn't allow boolean expressions to be in the COALESCE statement.
return fmt.Sprintf("%s= CASE WHEN COALESCE(cc.%s, '') != '%s' THEN cc.%s ELSE c.%s END", colName, colName,
constants.ToastUnavailableValuePlaceholder, colName, colName)
} else {
return fmt.Sprintf("%s= CASE WHEN COALESCE(cc.%s != '%s', true) THEN cc.%s ELSE c.%s END",
colName, colName, constants.ToastUnavailableValuePlaceholder, colName, colName)
}
}
Loading

0 comments on commit 6f690e5

Please sign in to comment.