Skip to content

Commit

Permalink
[mssql] Move merge query building to MSSQLDialect (#633)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathan-artie authored May 13, 2024
1 parent 5fdca45 commit cd9e195
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 153 deletions.
70 changes: 70 additions & 0 deletions clients/mssql/dialect/dialect.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
package dialect

import (
"errors"
"fmt"
"strconv"
"strings"

"github.com/artie-labs/transfer/lib/array"
"github.com/artie-labs/transfer/lib/config/constants"
"github.com/artie-labs/transfer/lib/kafkalib"
"github.com/artie-labs/transfer/lib/sql"
"github.com/artie-labs/transfer/lib/typing"
"github.com/artie-labs/transfer/lib/typing/columns"
"github.com/artie-labs/transfer/lib/typing/ext"
)

Expand Down Expand Up @@ -173,3 +176,70 @@ func (MSSQLDialect) BuildProcessToastStructColExpression(colName string) string
func (MSSQLDialect) BuildDedupeQueries(tableID, stagingTableID sql.TableIdentifier, primaryKeys []string, topicConfig kafkalib.TopicConfig) []string {
panic("not implemented") // We don't currently support deduping for MS SQL.
}

func (md MSSQLDialect) BuildMergeQueries(
tableID sql.TableIdentifier,
subQuery string,
idempotentKey string,
primaryKeys []columns.Column,
_ []string,
cols []columns.Column,
softDelete bool,
_ *bool,
) ([]string, error) {
var idempotentClause string
if idempotentKey != "" {
idempotentClause = fmt.Sprintf("AND cc.%s >= c.%s ", idempotentKey, idempotentKey)
}

var equalitySQLParts []string
for _, primaryKey := range primaryKeys {
// We'll need to escape the primary key as well.
quotedPrimaryKey := md.QuoteIdentifier(primaryKey.Name())
equalitySQL := fmt.Sprintf("c.%s = cc.%s", quotedPrimaryKey, quotedPrimaryKey)
equalitySQLParts = append(equalitySQLParts, equalitySQL)
}

if softDelete {
return []string{fmt.Sprintf(`
MERGE INTO %s c
USING %s AS cc ON %s
WHEN MATCHED %sTHEN UPDATE SET %s
WHEN NOT MATCHED AND COALESCE(cc.%s, 0) = 0 THEN INSERT (%s) VALUES (%s);`,
tableID.FullyQualifiedName(), subQuery, strings.Join(equalitySQLParts, " and "),
// Update + Soft Deletion
idempotentClause, columns.BuildColumnsUpdateFragment(cols, md),
// Insert
md.QuoteIdentifier(constants.DeleteColumnMarker), strings.Join(columns.QuoteColumns(cols, md), ","),
array.StringsJoinAddPrefix(array.StringsJoinAddPrefixArgs{
Vals: columns.QuoteColumns(cols, md),
Separator: ",",
Prefix: "cc.",
}))}, nil
}

// We also need to remove __artie flags since it does not exist in the destination table
cols, removed := columns.RemoveDeleteColumnMarker(cols)
if !removed {
return nil, errors.New("artie delete flag doesn't exist")
}

return []string{fmt.Sprintf(`
MERGE INTO %s c
USING %s AS cc ON %s
WHEN MATCHED AND cc.%s = 1 THEN DELETE
WHEN MATCHED AND COALESCE(cc.%s, 0) = 0 %sTHEN UPDATE SET %s
WHEN NOT MATCHED AND COALESCE(cc.%s, 1) = 0 THEN INSERT (%s) VALUES (%s);`,
tableID.FullyQualifiedName(), subQuery, strings.Join(equalitySQLParts, " and "),
// Delete
md.QuoteIdentifier(constants.DeleteColumnMarker),
// Update
md.QuoteIdentifier(constants.DeleteColumnMarker), idempotentClause, columns.BuildColumnsUpdateFragment(cols, md),
// Insert
md.QuoteIdentifier(constants.DeleteColumnMarker), strings.Join(columns.QuoteColumns(cols, md), ","),
array.StringsJoinAddPrefix(array.StringsJoinAddPrefixArgs{
Vals: columns.QuoteColumns(cols, md),
Separator: ",",
Prefix: "cc.",
}))}, nil
}
57 changes: 55 additions & 2 deletions clients/mssql/dialect/dialect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@ package dialect

import (
"fmt"
"strings"
"testing"
"time"

"github.com/artie-labs/transfer/lib/config/constants"
"github.com/artie-labs/transfer/lib/mocks"
"github.com/artie-labs/transfer/lib/ptr"
"github.com/artie-labs/transfer/lib/typing"
"github.com/artie-labs/transfer/lib/typing/columns"
"github.com/artie-labs/transfer/lib/typing/ext"
"github.com/stretchr/testify/assert"
)
Expand Down Expand Up @@ -146,10 +149,60 @@ func TestMSSQLDialect_BuildAlterColumnQuery(t *testing.T) {
)
}

func TestBuildProcessToastColExpression(t *testing.T) {
func TestMSSQLDialect_BuildProcessToastColExpression(t *testing.T) {
assert.Equal(t, `CASE WHEN COALESCE(cc.bar, '') != '__debezium_unavailable_value' THEN cc.bar ELSE c.bar END`, MSSQLDialect{}.BuildProcessToastColExpression("bar"))
}

func TestBuildProcessToastStructColExpression(t *testing.T) {
func TestMSSQLDialect_BuildProcessToastStructColExpression(t *testing.T) {
assert.Equal(t, `CASE WHEN COALESCE(cc.foo, {}) != {'key': '__debezium_unavailable_value'} THEN cc.foo ELSE c.foo END`, MSSQLDialect{}.BuildProcessToastStructColExpression("foo"))
}

func TestMSSQLDialect_BuildMergeQueries(t *testing.T) {
var _cols = []columns.Column{
columns.NewColumn("id", typing.String),
columns.NewColumn("bar", typing.String),
columns.NewColumn("updated_at", typing.String),
columns.NewColumn("start", typing.String),
columns.NewColumn(constants.DeleteColumnMarker, typing.Boolean),
}
cols := make([]string, len(_cols))
for i, col := range _cols {
cols[i] = col.Name()
}

tableValues := []string{
fmt.Sprintf("('%s', '%s', '%v', '%v', false)", "1", "456", "foo", time.Now().Round(0).UTC()),
fmt.Sprintf("('%s', '%s', '%v', '%v', false)", "2", "bb", "bar", time.Now().Round(0).UTC()),
fmt.Sprintf("('%s', '%s', '%v', '%v', false)", "3", "dd", "world", time.Now().Round(0).UTC()),
}

// select cc.foo, cc.bar from (values (12, 34), (44, 55)) as cc(foo, bar);
subQuery := fmt.Sprintf("SELECT %s from (values %s) as %s(%s)",
strings.Join(cols, ","), strings.Join(tableValues, ","), "_tbl", strings.Join(cols, ","))

fqTable := "database.schema.table"
fakeID := &mocks.FakeTableIdentifier{}
fakeID.FullyQualifiedNameReturns(fqTable)

queries, err := MSSQLDialect{}.BuildMergeQueries(
fakeID,
subQuery,
"",
[]columns.Column{columns.NewColumn("id", typing.Invalid)},
[]string{},
_cols,
false,
nil,
)
assert.Len(t, queries, 1)
mergeSQL := queries[0]
assert.NoError(t, err)
assert.Contains(t, mergeSQL, fmt.Sprintf("MERGE INTO %s", fqTable), mergeSQL)
assert.NotContains(t, mergeSQL, fmt.Sprintf(`cc."%s" >= c."%s"`, "updated_at", "updated_at"), fmt.Sprintf("Idempotency key: %s", mergeSQL))
// Check primary keys clause
assert.Contains(t, mergeSQL, `AS cc ON c."id" = cc."id"`, mergeSQL)

assert.Contains(t, mergeSQL, `SET "id"=cc."id","bar"=cc."bar","updated_at"=cc."updated_at","start"=cc."start"`, mergeSQL)
assert.Contains(t, mergeSQL, `id,bar,updated_at,start`, mergeSQL)
assert.Contains(t, mergeSQL, `cc."id",cc."bar",cc."updated_at",cc."start"`, mergeSQL)
}
75 changes: 11 additions & 64 deletions lib/destination/dml/merge.go
Original file line number Diff line number Diff line change
Expand Up @@ -244,78 +244,25 @@ WHEN NOT MATCHED AND IFNULL(cc.%s, false) = false THEN INSERT (%s) VALUES (%s);`
})), nil
}

func (m *MergeArgument) buildMSSQLStatement() (string, error) {
var idempotentClause string
if m.IdempotentKey != "" {
idempotentClause = fmt.Sprintf("AND cc.%s >= c.%s ", m.IdempotentKey, m.IdempotentKey)
}

var equalitySQLParts []string
for _, primaryKey := range m.PrimaryKeys {
// We'll need to escape the primary key as well.
quotedPrimaryKey := m.Dialect.QuoteIdentifier(primaryKey.Name())
equalitySQL := fmt.Sprintf("c.%s = cc.%s", quotedPrimaryKey, quotedPrimaryKey)
equalitySQLParts = append(equalitySQLParts, equalitySQL)
}

if m.SoftDelete {
return fmt.Sprintf(`
MERGE INTO %s c
USING %s AS cc ON %s
WHEN MATCHED %sTHEN UPDATE SET %s
WHEN NOT MATCHED AND COALESCE(cc.%s, 0) = 0 THEN INSERT (%s) VALUES (%s);`,
m.TableID.FullyQualifiedName(), m.SubQuery, strings.Join(equalitySQLParts, " and "),
// Update + Soft Deletion
idempotentClause, columns.BuildColumnsUpdateFragment(m.Columns, m.Dialect),
// Insert
m.Dialect.QuoteIdentifier(constants.DeleteColumnMarker), strings.Join(columns.QuoteColumns(m.Columns, m.Dialect), ","),
array.StringsJoinAddPrefix(array.StringsJoinAddPrefixArgs{
Vals: columns.QuoteColumns(m.Columns, m.Dialect),
Separator: ",",
Prefix: "cc.",
})), nil
}

// We also need to remove __artie flags since it does not exist in the destination table
cols, removed := columns.RemoveDeleteColumnMarker(m.Columns)
if !removed {
return "", errors.New("artie delete flag doesn't exist")
}

return fmt.Sprintf(`
MERGE INTO %s c
USING %s AS cc ON %s
WHEN MATCHED AND cc.%s = 1 THEN DELETE
WHEN MATCHED AND COALESCE(cc.%s, 0) = 0 %sTHEN UPDATE SET %s
WHEN NOT MATCHED AND COALESCE(cc.%s, 1) = 0 THEN INSERT (%s) VALUES (%s);`,
m.TableID.FullyQualifiedName(), m.SubQuery, strings.Join(equalitySQLParts, " and "),
// Delete
m.Dialect.QuoteIdentifier(constants.DeleteColumnMarker),
// Update
m.Dialect.QuoteIdentifier(constants.DeleteColumnMarker), idempotentClause, columns.BuildColumnsUpdateFragment(cols, m.Dialect),
// Insert
m.Dialect.QuoteIdentifier(constants.DeleteColumnMarker), strings.Join(columns.QuoteColumns(cols, m.Dialect), ","),
array.StringsJoinAddPrefix(array.StringsJoinAddPrefixArgs{
Vals: columns.QuoteColumns(cols, m.Dialect),
Separator: ",",
Prefix: "cc.",
})), nil
}

func (m *MergeArgument) BuildStatements() ([]string, error) {
if err := m.Valid(); err != nil {
return nil, err
}

switch m.Dialect.(type) {
switch specificDialect := m.Dialect.(type) {
case redshiftDialect.RedshiftDialect:
return m.buildRedshiftStatements()
case mssqlDialect.MSSQLDialect:
mergeQuery, err := m.buildMSSQLStatement()
if err != nil {
return nil, err
}
return []string{mergeQuery}, nil
return specificDialect.BuildMergeQueries(
m.TableID,
m.SubQuery,
m.IdempotentKey,
m.PrimaryKeys,
m.AdditionalEqualityStrings,
m.Columns,
m.SoftDelete,
m.ContainsHardDeletes,
)
default:
mergeQuery, err := m.buildDefaultStatement()
if err != nil {
Expand Down
87 changes: 0 additions & 87 deletions lib/destination/dml/merge_mssql_test.go

This file was deleted.

0 comments on commit cd9e195

Please sign in to comment.