Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move typing.UpdateQuery to dml #550

Merged
merged 3 commits into from
May 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions lib/destination/dml/columns.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package dml

import (
"fmt"
"strings"

"github.com/artie-labs/transfer/lib/config/constants"
"github.com/artie-labs/transfer/lib/sql"
"github.com/artie-labs/transfer/lib/typing"
"github.com/artie-labs/transfer/lib/typing/columns"
)

// buildColumnsUpdateFragment will parse the columns and then returns a list of strings like: cc.first_name=c.first_name,cc.last_name=c.last_name,cc.email=c.email
func buildColumnsUpdateFragment(c *columns.Columns, dialect sql.Dialect, skipDeleteCol bool) string {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Everything below this line was copied directly. The only change was to the function definition.

var cols []string
for _, column := range c.GetColumns() {
if column.ShouldSkip() {
continue
}

// skipDeleteCol is useful because we don't want to copy the deleted column over to the source table if we're doing a hard row delete.
if skipDeleteCol && column.Name() == constants.DeleteColumnMarker {
continue
}

colName := dialect.QuoteIdentifier(column.Name())
if column.ToastColumn {
if column.KindDetails == typing.Struct {
cols = append(cols, processToastStructCol(colName, dialect))
} else {
cols = append(cols, processToastCol(colName, dialect))
}

} else {
// This is to make it look like: objCol = cc.objCol
cols = append(cols, fmt.Sprintf("%s=cc.%s", colName, colName))
}
}

return strings.Join(cols, ",")
}

func processToastStructCol(colName string, dialect sql.Dialect) string {
switch dialect.(type) {
case sql.BigQueryDialect:
return fmt.Sprintf(`%s= CASE WHEN COALESCE(TO_JSON_STRING(cc.%s) != '{"key":"%s"}', true) THEN cc.%s ELSE c.%s END`,
colName, colName, constants.ToastUnavailableValuePlaceholder,
colName, colName)
case sql.RedshiftDialect:
return fmt.Sprintf(`%s= CASE WHEN COALESCE(cc.%s != JSON_PARSE('{"key":"%s"}'), true) THEN cc.%s ELSE c.%s END`,
colName, colName, constants.ToastUnavailableValuePlaceholder, colName, colName)
case sql.MSSQLDialect:
// Microsoft SQL Server doesn't allow boolean expressions to be in the COALESCE statement.
return fmt.Sprintf("%s= CASE WHEN COALESCE(cc.%s, {}) != {'key': '%s'} THEN cc.%s ELSE c.%s END",
colName, colName, constants.ToastUnavailableValuePlaceholder, colName, colName)
default:
// TODO: Change this to Snowflake and error out if the destKind isn't supported so we're explicit.
return fmt.Sprintf("%s= CASE WHEN COALESCE(cc.%s != {'key': '%s'}, true) THEN cc.%s ELSE c.%s END",
colName, colName, constants.ToastUnavailableValuePlaceholder, colName, colName)
}
}

func processToastCol(colName string, dialect sql.Dialect) string {
if _, ok := dialect.(sql.MSSQLDialect); ok {
// Microsoft SQL Server doesn't allow boolean expressions to be in the COALESCE statement.
return fmt.Sprintf("%s= CASE WHEN COALESCE(cc.%s, '') != '%s' THEN cc.%s ELSE c.%s END", colName, colName,
constants.ToastUnavailableValuePlaceholder, colName, colName)
} else {
return fmt.Sprintf("%s= CASE WHEN COALESCE(cc.%s != '%s', true) THEN cc.%s ELSE c.%s END",
colName, colName, constants.ToastUnavailableValuePlaceholder, colName, colName)
}
}
134 changes: 134 additions & 0 deletions lib/destination/dml/columns_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
package dml

import (
"fmt"
"testing"

"github.com/artie-labs/transfer/lib/config/constants"
"github.com/artie-labs/transfer/lib/sql"
"github.com/artie-labs/transfer/lib/typing"
"github.com/artie-labs/transfer/lib/typing/columns"
"github.com/stretchr/testify/assert"
)

func TestBuildColumnsUpdateFragment(t *testing.T) {
type testCase struct {
name string
columns columns.Columns
expectedString string
dialect sql.Dialect
skipDeleteCol bool
}

fooBarCols := []string{"foo", "bar"}

var (
happyPathCols columns.Columns
stringAndToastCols columns.Columns
lastCaseColTypes columns.Columns
lastCaseEscapeTypes columns.Columns
)
for _, col := range fooBarCols {
column := columns.NewColumn(col, typing.String)
column.ToastColumn = false
happyPathCols.AddColumn(column)
}
for _, col := range fooBarCols {
var toastCol bool
if col == "foo" {
toastCol = true
}

column := columns.NewColumn(col, typing.String)
column.ToastColumn = toastCol
stringAndToastCols.AddColumn(column)
}

lastCaseCols := []string{"a1", "b2", "c3"}
for _, lastCaseCol := range lastCaseCols {
kd := typing.String
var toast bool
// a1 - struct + toast, b2 - string + toast, c3 = regular string.
if lastCaseCol == "a1" {
kd = typing.Struct
toast = true
} else if lastCaseCol == "b2" {
toast = true
}

column := columns.NewColumn(lastCaseCol, kd)
column.ToastColumn = toast
lastCaseColTypes.AddColumn(column)
}

lastCaseColsEsc := []string{"a1", "b2", "c3", "start", "select"}
for _, lastCaseColEsc := range lastCaseColsEsc {
kd := typing.String
var toast bool
// a1 - struct + toast, b2 - string + toast, c3 = regular string.
if lastCaseColEsc == "a1" {
kd = typing.Struct
toast = true
} else if lastCaseColEsc == "b2" {
toast = true
} else if lastCaseColEsc == "start" {
kd = typing.Struct
toast = true
}

column := columns.NewColumn(lastCaseColEsc, kd)
column.ToastColumn = toast
lastCaseEscapeTypes.AddColumn(column)
}

lastCaseEscapeTypes.AddColumn(columns.NewColumn(constants.DeleteColumnMarker, typing.Boolean))

key := `{"key":"__debezium_unavailable_value"}`
testCases := []testCase{
{
name: "happy path",
columns: happyPathCols,
dialect: sql.RedshiftDialect{},
expectedString: `"foo"=cc."foo","bar"=cc."bar"`,
},
{
name: "string and toast",
columns: stringAndToastCols,
dialect: sql.SnowflakeDialect{},
expectedString: `"FOO"= CASE WHEN COALESCE(cc."FOO" != '__debezium_unavailable_value', true) THEN cc."FOO" ELSE c."FOO" END,"BAR"=cc."BAR"`,
},
{
name: "struct, string and toast string",
columns: lastCaseColTypes,
dialect: sql.RedshiftDialect{},
expectedString: `"a1"= CASE WHEN COALESCE(cc."a1" != JSON_PARSE('{"key":"__debezium_unavailable_value"}'), true) THEN cc."a1" ELSE c."a1" END,"b2"= CASE WHEN COALESCE(cc."b2" != '__debezium_unavailable_value', true) THEN cc."b2" ELSE c."b2" END,"c3"=cc."c3"`,
},
{
name: "struct, string and toast string (bigquery)",
columns: lastCaseColTypes,
dialect: sql.BigQueryDialect{},
expectedString: "`a1`= CASE WHEN COALESCE(TO_JSON_STRING(cc.`a1`) != '{\"key\":\"__debezium_unavailable_value\"}', true) THEN cc.`a1` ELSE c.`a1` END,`b2`= CASE WHEN COALESCE(cc.`b2` != '__debezium_unavailable_value', true) THEN cc.`b2` ELSE c.`b2` END,`c3`=cc.`c3`",
},
{
name: "struct, string and toast string (bigquery) w/ reserved keywords",
columns: lastCaseEscapeTypes,
dialect: sql.BigQueryDialect{},
expectedString: fmt.Sprintf("`a1`= CASE WHEN COALESCE(TO_JSON_STRING(cc.`a1`) != '%s', true) THEN cc.`a1` ELSE c.`a1` END,`b2`= CASE WHEN COALESCE(cc.`b2` != '__debezium_unavailable_value', true) THEN cc.`b2` ELSE c.`b2` END,`c3`=cc.`c3`,%s,%s",
key, fmt.Sprintf("`start`= CASE WHEN COALESCE(TO_JSON_STRING(cc.`start`) != '%s', true) THEN cc.`start` ELSE c.`start` END", key), "`select`=cc.`select`"),
skipDeleteCol: true,
},
{
name: "struct, string and toast string (bigquery) w/ reserved keywords",
columns: lastCaseEscapeTypes,
dialect: sql.BigQueryDialect{},
expectedString: fmt.Sprintf("`a1`= CASE WHEN COALESCE(TO_JSON_STRING(cc.`a1`) != '%s', true) THEN cc.`a1` ELSE c.`a1` END,`b2`= CASE WHEN COALESCE(cc.`b2` != '__debezium_unavailable_value', true) THEN cc.`b2` ELSE c.`b2` END,`c3`=cc.`c3`,%s,%s",
key, fmt.Sprintf("`start`= CASE WHEN COALESCE(TO_JSON_STRING(cc.`start`) != '%s', true) THEN cc.`start` ELSE c.`start` END", key), "`select`=cc.`select`,`__artie_delete`=cc.`__artie_delete`"),
skipDeleteCol: false,
},
}

for _, _testCase := range testCases {
actualQuery := buildColumnsUpdateFragment(&_testCase.columns, _testCase.dialect, _testCase.skipDeleteCol)
assert.Equal(t, _testCase.expectedString, actualQuery, _testCase.name)
}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package columns
package dml

import (
"testing"
Expand Down
12 changes: 6 additions & 6 deletions lib/destination/dml/merge.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ func (m *MergeArgument) GetParts() ([]string, error) {
// UPDATE
fmt.Sprintf(`UPDATE %s as c SET %s FROM %s as cc WHERE %s%s;`,
// UPDATE table set col1 = cc. col1
m.TableID.FullyQualifiedName(), m.Columns.UpdateQuery(m.Dialect, false),
m.TableID.FullyQualifiedName(), buildColumnsUpdateFragment(m.Columns, m.Dialect, false),
// FROM table (temp) WHERE join on PK(s)
m.SubQuery, strings.Join(equalitySQLParts, " and "), idempotentClause,
),
Expand Down Expand Up @@ -166,7 +166,7 @@ func (m *MergeArgument) GetParts() ([]string, error) {
// UPDATE
fmt.Sprintf(`UPDATE %s as c SET %s FROM %s as cc WHERE %s%s AND COALESCE(cc.%s, false) = false;`,
// UPDATE table set col1 = cc. col1
m.TableID.FullyQualifiedName(), m.Columns.UpdateQuery(m.Dialect, true),
m.TableID.FullyQualifiedName(), buildColumnsUpdateFragment(m.Columns, m.Dialect, true),
// FROM staging WHERE join on PK(s)
m.SubQuery, strings.Join(equalitySQLParts, " and "), idempotentClause, m.Dialect.QuoteIdentifier(constants.DeleteColumnMarker),
),
Expand Down Expand Up @@ -244,7 +244,7 @@ WHEN MATCHED %sTHEN UPDATE SET %s
WHEN NOT MATCHED AND IFNULL(cc.%s, false) = false THEN INSERT (%s) VALUES (%s);`,
m.TableID.FullyQualifiedName(), subQuery, strings.Join(equalitySQLParts, " and "),
// Update + Soft Deletion
idempotentClause, m.Columns.UpdateQuery(m.Dialect, false),
idempotentClause, buildColumnsUpdateFragment(m.Columns, m.Dialect, false),
// Insert
m.Dialect.QuoteIdentifier(constants.DeleteColumnMarker), strings.Join(sql.QuoteIdentifiers(columns, m.Dialect), ","),
array.StringsJoinAddPrefix(array.StringsJoinAddPrefixArgs{
Expand All @@ -270,7 +270,7 @@ WHEN NOT MATCHED AND IFNULL(cc.%s, false) = false THEN INSERT (%s) VALUES (%s);`
// Delete
m.Dialect.QuoteIdentifier(constants.DeleteColumnMarker),
// Update
m.Dialect.QuoteIdentifier(constants.DeleteColumnMarker), idempotentClause, m.Columns.UpdateQuery(m.Dialect, true),
m.Dialect.QuoteIdentifier(constants.DeleteColumnMarker), idempotentClause, buildColumnsUpdateFragment(m.Columns, m.Dialect, true),
// Insert
m.Dialect.QuoteIdentifier(constants.DeleteColumnMarker), strings.Join(sql.QuoteIdentifiers(columns, m.Dialect), ","),
array.StringsJoinAddPrefix(array.StringsJoinAddPrefixArgs{
Expand Down Expand Up @@ -308,7 +308,7 @@ WHEN MATCHED %sTHEN UPDATE SET %s
WHEN NOT MATCHED AND COALESCE(cc.%s, 0) = 0 THEN INSERT (%s) VALUES (%s);`,
m.TableID.FullyQualifiedName(), m.SubQuery, strings.Join(equalitySQLParts, " and "),
// Update + Soft Deletion
idempotentClause, m.Columns.UpdateQuery(m.Dialect, false),
idempotentClause, buildColumnsUpdateFragment(m.Columns, m.Dialect, false),
// Insert
m.Dialect.QuoteIdentifier(constants.DeleteColumnMarker), strings.Join(sql.QuoteIdentifiers(columns, m.Dialect), ","),
array.StringsJoinAddPrefix(array.StringsJoinAddPrefixArgs{
Expand All @@ -335,7 +335,7 @@ WHEN NOT MATCHED AND COALESCE(cc.%s, 1) = 0 THEN INSERT (%s) VALUES (%s);`,
// Delete
m.Dialect.QuoteIdentifier(constants.DeleteColumnMarker),
// Update
m.Dialect.QuoteIdentifier(constants.DeleteColumnMarker), idempotentClause, m.Columns.UpdateQuery(m.Dialect, true),
m.Dialect.QuoteIdentifier(constants.DeleteColumnMarker), idempotentClause, buildColumnsUpdateFragment(m.Columns, m.Dialect, true),
// Insert
m.Dialect.QuoteIdentifier(constants.DeleteColumnMarker), strings.Join(sql.QuoteIdentifiers(columns, m.Dialect), ","),
array.StringsJoinAddPrefix(array.StringsJoinAddPrefixArgs{
Expand Down
64 changes: 0 additions & 64 deletions lib/typing/columns/columns.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
package columns

import (
"fmt"
"strings"
"sync"

"github.com/artie-labs/transfer/lib/config/constants"
"github.com/artie-labs/transfer/lib/sql"
"github.com/artie-labs/transfer/lib/stringutil"
"github.com/artie-labs/transfer/lib/typing"
)
Expand Down Expand Up @@ -226,64 +223,3 @@ func (c *Columns) DeleteColumn(name string) {
}
}
}

// UpdateQuery will parse the columns and then returns a list of strings like: cc.first_name=c.first_name,cc.last_name=c.last_name,cc.email=c.email
func (c *Columns) UpdateQuery(dialect sql.Dialect, skipDeleteCol bool) string {
var cols []string
for _, column := range c.GetColumns() {
if column.ShouldSkip() {
continue
}

// skipDeleteCol is useful because we don't want to copy the deleted column over to the source table if we're doing a hard row delete.
if skipDeleteCol && column.Name() == constants.DeleteColumnMarker {
continue
}

colName := dialect.QuoteIdentifier(column.Name())
if column.ToastColumn {
if column.KindDetails == typing.Struct {
cols = append(cols, processToastStructCol(colName, dialect))
} else {
cols = append(cols, processToastCol(colName, dialect))
}

} else {
// This is to make it look like: objCol = cc.objCol
cols = append(cols, fmt.Sprintf("%s=cc.%s", colName, colName))
}
}

return strings.Join(cols, ",")
}

func processToastStructCol(colName string, dialect sql.Dialect) string {
switch dialect.(type) {
case sql.BigQueryDialect:
return fmt.Sprintf(`%s= CASE WHEN COALESCE(TO_JSON_STRING(cc.%s) != '{"key":"%s"}', true) THEN cc.%s ELSE c.%s END`,
colName, colName, constants.ToastUnavailableValuePlaceholder,
colName, colName)
case sql.RedshiftDialect:
return fmt.Sprintf(`%s= CASE WHEN COALESCE(cc.%s != JSON_PARSE('{"key":"%s"}'), true) THEN cc.%s ELSE c.%s END`,
colName, colName, constants.ToastUnavailableValuePlaceholder, colName, colName)
case sql.MSSQLDialect:
// Microsoft SQL Server doesn't allow boolean expressions to be in the COALESCE statement.
return fmt.Sprintf("%s= CASE WHEN COALESCE(cc.%s, {}) != {'key': '%s'} THEN cc.%s ELSE c.%s END",
colName, colName, constants.ToastUnavailableValuePlaceholder, colName, colName)
default:
// TODO: Change this to Snowflake and error out if the destKind isn't supported so we're explicit.
return fmt.Sprintf("%s= CASE WHEN COALESCE(cc.%s != {'key': '%s'}, true) THEN cc.%s ELSE c.%s END",
colName, colName, constants.ToastUnavailableValuePlaceholder, colName, colName)
}
}

func processToastCol(colName string, dialect sql.Dialect) string {
if _, ok := dialect.(sql.MSSQLDialect); ok {
// Microsoft SQL Server doesn't allow boolean expressions to be in the COALESCE statement.
return fmt.Sprintf("%s= CASE WHEN COALESCE(cc.%s, '') != '%s' THEN cc.%s ELSE c.%s END", colName, colName,
constants.ToastUnavailableValuePlaceholder, colName, colName)
} else {
return fmt.Sprintf("%s= CASE WHEN COALESCE(cc.%s != '%s', true) THEN cc.%s ELSE c.%s END",
colName, colName, constants.ToastUnavailableValuePlaceholder, colName, colName)
}
}
Loading