Skip to content

Commit

Permalink
[typing] Kill Columns.GetEscapedColumnsToUpdate (#544)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathan-artie authored May 2, 2024
1 parent 2291477 commit 523f5bd
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 120 deletions.
3 changes: 2 additions & 1 deletion clients/snowflake/staging.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/artie-labs/transfer/lib/destination/ddl"
"github.com/artie-labs/transfer/lib/destination/types"
"github.com/artie-labs/transfer/lib/optimization"
"github.com/artie-labs/transfer/lib/sql"
"github.com/artie-labs/transfer/lib/typing"
"github.com/artie-labs/transfer/lib/typing/columns"
"github.com/artie-labs/transfer/lib/typing/values"
Expand Down Expand Up @@ -83,7 +84,7 @@ func (s *Store) PrepareTemporaryTable(tableData *optimization.TableData, tableCo
// COPY the CSV file (in Snowflake) into a table
copyCommand := fmt.Sprintf("COPY INTO %s (%s) FROM (SELECT %s FROM @%s)",
tempTableID.FullyQualifiedName(),
strings.Join(tableData.ReadOnlyInMemoryCols().GetEscapedColumnsToUpdate(s.Dialect()), ","),
strings.Join(sql.QuoteIdentifiers(tableData.ReadOnlyInMemoryCols().GetColumnsToUpdate(), s.Dialect()), ","),
escapeColumns(tableData.ReadOnlyInMemoryCols(), ","), addPrefixToTableName(tempTableID, "%"))

if additionalSettings.AdditionalCopyClause != "" {
Expand Down
64 changes: 25 additions & 39 deletions lib/destination/dml/merge.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package dml
import (
"errors"
"fmt"
"slices"
"strings"

"github.com/artie-labs/transfer/lib/array"
Expand Down Expand Up @@ -65,6 +66,12 @@ func (m *MergeArgument) Valid() error {
return nil
}

func removeDeleteColumnMarker(columns []string) ([]string, bool) {
origLength := len(columns)
columns = slices.DeleteFunc(columns, func(col string) bool { return col == constants.DeleteColumnMarker })
return columns, len(columns) != origLength
}

func (m *MergeArgument) GetParts() ([]string, error) {
if err := m.Valid(); err != nil {
return nil, err
Expand Down Expand Up @@ -98,17 +105,17 @@ func (m *MergeArgument) GetParts() ([]string, error) {
equalitySQLParts = append(equalitySQLParts, equalitySQL)
}

cols := m.Columns.GetEscapedColumnsToUpdate(m.Dialect)
columns := m.Columns.GetColumnsToUpdate()

if m.SoftDelete {
return []string{
// INSERT
fmt.Sprintf(`INSERT INTO %s (%s) SELECT %s FROM %s as cc LEFT JOIN %s as c on %s WHERE c.%s IS NULL;`,
// insert into target (col1, col2, col3)
m.TableID.FullyQualifiedName(), strings.Join(cols, ","),
m.TableID.FullyQualifiedName(), strings.Join(sql.QuoteIdentifiers(columns, m.Dialect), ","),
// SELECT cc.col1, cc.col2, ... FROM staging as CC
array.StringsJoinAddPrefix(array.StringsJoinAddPrefixArgs{
Vals: cols,
Vals: sql.QuoteIdentifiers(columns, m.Dialect),
Separator: ",",
Prefix: "cc.",
}), m.SubQuery,
Expand All @@ -128,14 +135,7 @@ func (m *MergeArgument) GetParts() ([]string, error) {

// We also need to remove __artie flags since it does not exist in the destination table
var removed bool
for idx, col := range cols {
if col == m.Dialect.QuoteIdentifier(constants.DeleteColumnMarker) {
cols = append(cols[:idx], cols[idx+1:]...)
removed = true
break
}
}

columns, removed = removeDeleteColumnMarker(columns)
if !removed {
return nil, errors.New("artie delete flag doesn't exist")
}
Expand All @@ -149,10 +149,10 @@ func (m *MergeArgument) GetParts() ([]string, error) {
// INSERT
fmt.Sprintf(`INSERT INTO %s (%s) SELECT %s FROM %s as cc LEFT JOIN %s as c on %s WHERE c.%s IS NULL;`,
// insert into target (col1, col2, col3)
m.TableID.FullyQualifiedName(), strings.Join(cols, ","),
m.TableID.FullyQualifiedName(), strings.Join(sql.QuoteIdentifiers(columns, m.Dialect), ","),
// SELECT cc.col1, cc.col2, ... FROM staging as CC
array.StringsJoinAddPrefix(array.StringsJoinAddPrefixArgs{
Vals: cols,
Vals: sql.QuoteIdentifiers(columns, m.Dialect),
Separator: ",",
Prefix: "cc.",
}), m.SubQuery,
Expand Down Expand Up @@ -230,7 +230,7 @@ func (m *MergeArgument) GetStatement() (string, error) {
equalitySQLParts = append(equalitySQLParts, m.AdditionalEqualityStrings...)
}

cols := m.Columns.GetEscapedColumnsToUpdate(m.Dialect)
columns := m.Columns.GetColumnsToUpdate()

if m.SoftDelete {
return fmt.Sprintf(`
Expand All @@ -241,24 +241,17 @@ WHEN NOT MATCHED AND IFNULL(cc.%s, false) = false THEN INSERT (%s) VALUES (%s);`
// Update + Soft Deletion
idempotentClause, m.Columns.UpdateQuery(m.Dialect, false),
// Insert
constants.DeleteColumnMarker, strings.Join(cols, ","),
constants.DeleteColumnMarker, strings.Join(sql.QuoteIdentifiers(columns, m.Dialect), ","),
array.StringsJoinAddPrefix(array.StringsJoinAddPrefixArgs{
Vals: cols,
Vals: sql.QuoteIdentifiers(columns, m.Dialect),
Separator: ",",
Prefix: "cc.",
})), nil
}

// We also need to remove __artie flags since it does not exist in the destination table
var removed bool
for idx, col := range cols {
if col == m.Dialect.QuoteIdentifier(constants.DeleteColumnMarker) {
cols = append(cols[:idx], cols[idx+1:]...)
removed = true
break
}
}

columns, removed = removeDeleteColumnMarker(columns)
if !removed {
return "", errors.New("artie delete flag doesn't exist")
}
Expand All @@ -274,9 +267,9 @@ WHEN NOT MATCHED AND IFNULL(cc.%s, false) = false THEN INSERT (%s) VALUES (%s);`
// Update
constants.DeleteColumnMarker, idempotentClause, m.Columns.UpdateQuery(m.Dialect, true),
// Insert
constants.DeleteColumnMarker, strings.Join(cols, ","),
constants.DeleteColumnMarker, strings.Join(sql.QuoteIdentifiers(columns, m.Dialect), ","),
array.StringsJoinAddPrefix(array.StringsJoinAddPrefixArgs{
Vals: cols,
Vals: sql.QuoteIdentifiers(columns, m.Dialect),
Separator: ",",
Prefix: "cc.",
})), nil
Expand All @@ -299,7 +292,7 @@ func (m *MergeArgument) GetMSSQLStatement() (string, error) {
equalitySQLParts = append(equalitySQLParts, equalitySQL)
}

cols := m.Columns.GetEscapedColumnsToUpdate(m.Dialect)
columns := m.Columns.GetColumnsToUpdate()

if m.SoftDelete {
return fmt.Sprintf(`
Expand All @@ -311,24 +304,17 @@ WHEN NOT MATCHED AND COALESCE(cc.%s, 0) = 0 THEN INSERT (%s) VALUES (%s);`,
// Update + Soft Deletion
idempotentClause, m.Columns.UpdateQuery(m.Dialect, false),
// Insert
constants.DeleteColumnMarker, strings.Join(cols, ","),
constants.DeleteColumnMarker, strings.Join(sql.QuoteIdentifiers(columns, m.Dialect), ","),
array.StringsJoinAddPrefix(array.StringsJoinAddPrefixArgs{
Vals: cols,
Vals: sql.QuoteIdentifiers(columns, m.Dialect),
Separator: ",",
Prefix: "cc.",
})), nil
}

// We also need to remove __artie flags since it does not exist in the destination table
var removed bool
for idx, col := range cols {
if col == m.Dialect.QuoteIdentifier(constants.DeleteColumnMarker) {
cols = append(cols[:idx], cols[idx+1:]...)
removed = true
break
}
}

columns, removed = removeDeleteColumnMarker(columns)
if !removed {
return "", errors.New("artie delete flag doesn't exist")
}
Expand All @@ -345,9 +331,9 @@ WHEN NOT MATCHED AND COALESCE(cc.%s, 1) = 0 THEN INSERT (%s) VALUES (%s);`,
// Update
constants.DeleteColumnMarker, idempotentClause, m.Columns.UpdateQuery(m.Dialect, true),
// Insert
constants.DeleteColumnMarker, strings.Join(cols, ","),
constants.DeleteColumnMarker, strings.Join(sql.QuoteIdentifiers(columns, m.Dialect), ","),
array.StringsJoinAddPrefix(array.StringsJoinAddPrefixArgs{
Vals: cols,
Vals: sql.QuoteIdentifiers(columns, m.Dialect),
Separator: ",",
Prefix: "cc.",
})), nil
Expand Down
33 changes: 33 additions & 0 deletions lib/destination/dml/merge_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,39 @@ func (m MockTableIdentifier) FullyQualifiedName() string {
return m.fqName
}

func TestRemoveDeleteColumnMarker(t *testing.T) {
{
columns, removed := removeDeleteColumnMarker([]string{})
assert.Empty(t, columns)
assert.False(t, removed)
}
{
columns, removed := removeDeleteColumnMarker([]string{"a"})
assert.Equal(t, []string{"a"}, columns)
assert.False(t, removed)
}
{
columns, removed := removeDeleteColumnMarker([]string{"a", "b"})
assert.Equal(t, []string{"a", "b"}, columns)
assert.False(t, removed)
}
{
columns, removed := removeDeleteColumnMarker([]string{constants.DeleteColumnMarker})
assert.True(t, removed)
assert.Empty(t, columns)
}
{
columns, removed := removeDeleteColumnMarker([]string{"a", constants.DeleteColumnMarker, "b"})
assert.True(t, removed)
assert.Equal(t, []string{"a", "b"}, columns)
}
{
columns, removed := removeDeleteColumnMarker([]string{"a", constants.DeleteColumnMarker, "b", constants.DeleteColumnMarker, "c"})
assert.True(t, removed)
assert.Equal(t, []string{"a", "b", "c"}, columns)
}
}

func TestMergeStatementSoftDelete(t *testing.T) {
// No idempotent key
fqTable := "database.schema.table"
Expand Down
22 changes: 0 additions & 22 deletions lib/typing/columns/columns.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,28 +194,6 @@ func (c *Columns) GetColumnsToUpdate() []string {
return cols
}

// GetEscapedColumnsToUpdate will filter all the `Invalid` columns so that we do not update it.
// It will escape the returned columns.
func (c *Columns) GetEscapedColumnsToUpdate(dialect sql.Dialect) []string {
if c == nil {
return []string{}
}

c.RLock()
defer c.RUnlock()

var cols []string
for _, col := range c.columns {
if col.KindDetails == typing.Invalid {
continue
}

cols = append(cols, col.Name(dialect))
}

return cols
}

func (c *Columns) GetColumns() []Column {
if c == nil {
return []Column{}
Expand Down
58 changes: 0 additions & 58 deletions lib/typing/columns/columns_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -229,64 +229,6 @@ func TestColumns_GetColumnsToUpdate(t *testing.T) {
}
}

func TestColumns_GetEscapedColumnsToUpdate(t *testing.T) {
type _testCase struct {
name string
cols []Column
expectedColsEsc []string
expectedColsEscBq []string
}

var (
happyPathCols = []Column{
{
name: "hi",
KindDetails: typing.String,
},
{
name: "bye",
KindDetails: typing.String,
},
{
name: "start",
KindDetails: typing.String,
},
}
)

extraCols := happyPathCols
for i := 0; i < 100; i++ {
extraCols = append(extraCols, Column{
name: fmt.Sprintf("hello_%v", i),
KindDetails: typing.Invalid,
})
}

testCases := []_testCase{
{
name: "happy path",
cols: happyPathCols,
expectedColsEsc: []string{`"HI"`, `"BYE"`, `"START"`},
expectedColsEscBq: []string{"`hi`", "`bye`", "`start`"},
},
{
name: "happy path + extra col",
cols: extraCols,
expectedColsEsc: []string{`"HI"`, `"BYE"`, `"START"`},
expectedColsEscBq: []string{"`hi`", "`bye`", "`start`"},
},
}

for _, testCase := range testCases {
cols := &Columns{
columns: testCase.cols,
}

assert.Equal(t, testCase.expectedColsEsc, cols.GetEscapedColumnsToUpdate(sql.SnowflakeDialect{}), testCase.name)
assert.Equal(t, testCase.expectedColsEscBq, cols.GetEscapedColumnsToUpdate(sql.BigQueryDialect{}), testCase.name)
}
}

func TestColumns_UpsertColumns(t *testing.T) {
keys := []string{"a", "b", "c", "d", "e"}
var cols Columns
Expand Down

0 comments on commit 523f5bd

Please sign in to comment.