diff --git a/clients/snowflake/staging.go b/clients/snowflake/staging.go index 2528dfae4..bac39ef83 100644 --- a/clients/snowflake/staging.go +++ b/clients/snowflake/staging.go @@ -12,6 +12,7 @@ import ( "github.com/artie-labs/transfer/lib/destination/ddl" "github.com/artie-labs/transfer/lib/destination/types" "github.com/artie-labs/transfer/lib/optimization" + "github.com/artie-labs/transfer/lib/sql" "github.com/artie-labs/transfer/lib/typing" "github.com/artie-labs/transfer/lib/typing/columns" "github.com/artie-labs/transfer/lib/typing/values" @@ -83,7 +84,7 @@ func (s *Store) PrepareTemporaryTable(tableData *optimization.TableData, tableCo // COPY the CSV file (in Snowflake) into a table copyCommand := fmt.Sprintf("COPY INTO %s (%s) FROM (SELECT %s FROM @%s)", tempTableID.FullyQualifiedName(), - strings.Join(tableData.ReadOnlyInMemoryCols().GetEscapedColumnsToUpdate(s.Dialect()), ","), + strings.Join(sql.QuoteIdentifiers(tableData.ReadOnlyInMemoryCols().GetColumnsToUpdate(), s.Dialect()), ","), escapeColumns(tableData.ReadOnlyInMemoryCols(), ","), addPrefixToTableName(tempTableID, "%")) if additionalSettings.AdditionalCopyClause != "" { diff --git a/lib/destination/dml/merge.go b/lib/destination/dml/merge.go index 3b715ac6c..0bf610402 100644 --- a/lib/destination/dml/merge.go +++ b/lib/destination/dml/merge.go @@ -3,6 +3,7 @@ package dml import ( "errors" "fmt" + "slices" "strings" "github.com/artie-labs/transfer/lib/array" @@ -65,6 +66,12 @@ func (m *MergeArgument) Valid() error { return nil } +func removeDeleteColumnMarker(columns []string) ([]string, bool) { + origLength := len(columns) + columns = slices.DeleteFunc(columns, func(col string) bool { return col == constants.DeleteColumnMarker }) + return columns, len(columns) != origLength +} + func (m *MergeArgument) GetParts() ([]string, error) { if err := m.Valid(); err != nil { return nil, err @@ -97,17 +104,17 @@ func (m *MergeArgument) GetParts() ([]string, error) { equalitySQLParts = append(equalitySQLParts, equalitySQL) } - cols := m.Columns.GetEscapedColumnsToUpdate(m.Dialect) + columns := m.Columns.GetColumnsToUpdate() if m.SoftDelete { return []string{ // INSERT fmt.Sprintf(`INSERT INTO %s (%s) SELECT %s FROM %s as cc LEFT JOIN %s as c on %s WHERE c.%s IS NULL;`, // insert into target (col1, col2, col3) - m.TableID.FullyQualifiedName(), strings.Join(cols, ","), + m.TableID.FullyQualifiedName(), strings.Join(sql.QuoteIdentifiers(columns, m.Dialect), ","), // SELECT cc.col1, cc.col2, ... FROM staging as CC array.StringsJoinAddPrefix(array.StringsJoinAddPrefixArgs{ - Vals: cols, + Vals: sql.QuoteIdentifiers(columns, m.Dialect), Separator: ",", Prefix: "cc.", }), m.SubQuery, @@ -127,14 +134,7 @@ func (m *MergeArgument) GetParts() ([]string, error) { // We also need to remove __artie flags since it does not exist in the destination table var removed bool - for idx, col := range cols { - if col == m.Dialect.QuoteIdentifier(constants.DeleteColumnMarker) { - cols = append(cols[:idx], cols[idx+1:]...) - removed = true - break - } - } - + columns, removed = removeDeleteColumnMarker(columns) if !removed { return nil, errors.New("artie delete flag doesn't exist") } @@ -143,10 +143,10 @@ func (m *MergeArgument) GetParts() ([]string, error) { // INSERT fmt.Sprintf(`INSERT INTO %s (%s) SELECT %s FROM %s as cc LEFT JOIN %s as c on %s WHERE c.%s IS NULL;`, // insert into target (col1, col2, col3) - m.TableID.FullyQualifiedName(), strings.Join(cols, ","), + m.TableID.FullyQualifiedName(), strings.Join(sql.QuoteIdentifiers(columns, m.Dialect), ","), // SELECT cc.col1, cc.col2, ... FROM staging as CC array.StringsJoinAddPrefix(array.StringsJoinAddPrefixArgs{ - Vals: cols, + Vals: sql.QuoteIdentifiers(columns, m.Dialect), Separator: ",", Prefix: "cc.", }), m.SubQuery, @@ -227,7 +227,7 @@ func (m *MergeArgument) GetStatement() (string, error) { equalitySQLParts = append(equalitySQLParts, m.AdditionalEqualityStrings...) } - cols := m.Columns.GetEscapedColumnsToUpdate(m.Dialect) + columns := m.Columns.GetColumnsToUpdate() if m.SoftDelete { return fmt.Sprintf(` @@ -238,9 +238,9 @@ WHEN NOT MATCHED AND IFNULL(cc.%s, false) = false THEN INSERT (%s) VALUES (%s);` // Update + Soft Deletion idempotentClause, m.Columns.UpdateQuery(m.Dialect, false), // Insert - constants.DeleteColumnMarker, strings.Join(cols, ","), + constants.DeleteColumnMarker, strings.Join(sql.QuoteIdentifiers(columns, m.Dialect), ","), array.StringsJoinAddPrefix(array.StringsJoinAddPrefixArgs{ - Vals: cols, + Vals: sql.QuoteIdentifiers(columns, m.Dialect), Separator: ",", Prefix: "cc.", })), nil @@ -248,14 +248,7 @@ WHEN NOT MATCHED AND IFNULL(cc.%s, false) = false THEN INSERT (%s) VALUES (%s);` // We also need to remove __artie flags since it does not exist in the destination table var removed bool - for idx, col := range cols { - if col == m.Dialect.QuoteIdentifier(constants.DeleteColumnMarker) { - cols = append(cols[:idx], cols[idx+1:]...) - removed = true - break - } - } - + columns, removed = removeDeleteColumnMarker(columns) if !removed { return "", errors.New("artie delete flag doesn't exist") } @@ -271,9 +264,9 @@ WHEN NOT MATCHED AND IFNULL(cc.%s, false) = false THEN INSERT (%s) VALUES (%s);` // Update constants.DeleteColumnMarker, idempotentClause, m.Columns.UpdateQuery(m.Dialect, true), // Insert - constants.DeleteColumnMarker, strings.Join(cols, ","), + constants.DeleteColumnMarker, strings.Join(sql.QuoteIdentifiers(columns, m.Dialect), ","), array.StringsJoinAddPrefix(array.StringsJoinAddPrefixArgs{ - Vals: cols, + Vals: sql.QuoteIdentifiers(columns, m.Dialect), Separator: ",", Prefix: "cc.", })), nil @@ -295,7 +288,7 @@ func (m *MergeArgument) GetMSSQLStatement() (string, error) { equalitySQLParts = append(equalitySQLParts, equalitySQL) } - cols := m.Columns.GetEscapedColumnsToUpdate(m.Dialect) + columns := m.Columns.GetColumnsToUpdate() if m.SoftDelete { return fmt.Sprintf(` @@ -307,9 +300,9 @@ WHEN NOT MATCHED AND COALESCE(cc.%s, 0) = 0 THEN INSERT (%s) VALUES (%s);`, // Update + Soft Deletion idempotentClause, m.Columns.UpdateQuery(m.Dialect, false), // Insert - constants.DeleteColumnMarker, strings.Join(cols, ","), + constants.DeleteColumnMarker, strings.Join(sql.QuoteIdentifiers(columns, m.Dialect), ","), array.StringsJoinAddPrefix(array.StringsJoinAddPrefixArgs{ - Vals: cols, + Vals: sql.QuoteIdentifiers(columns, m.Dialect), Separator: ",", Prefix: "cc.", })), nil @@ -317,14 +310,7 @@ WHEN NOT MATCHED AND COALESCE(cc.%s, 0) = 0 THEN INSERT (%s) VALUES (%s);`, // We also need to remove __artie flags since it does not exist in the destination table var removed bool - for idx, col := range cols { - if col == m.Dialect.QuoteIdentifier(constants.DeleteColumnMarker) { - cols = append(cols[:idx], cols[idx+1:]...) - removed = true - break - } - } - + columns, removed = removeDeleteColumnMarker(columns) if !removed { return "", errors.New("artie delete flag doesn't exist") } @@ -341,9 +327,9 @@ WHEN NOT MATCHED AND COALESCE(cc.%s, 1) = 0 THEN INSERT (%s) VALUES (%s);`, // Update constants.DeleteColumnMarker, idempotentClause, m.Columns.UpdateQuery(m.Dialect, true), // Insert - constants.DeleteColumnMarker, strings.Join(cols, ","), + constants.DeleteColumnMarker, strings.Join(sql.QuoteIdentifiers(columns, m.Dialect), ","), array.StringsJoinAddPrefix(array.StringsJoinAddPrefixArgs{ - Vals: cols, + Vals: sql.QuoteIdentifiers(columns, m.Dialect), Separator: ",", Prefix: "cc.", })), nil diff --git a/lib/destination/dml/merge_test.go b/lib/destination/dml/merge_test.go index 312952ea3..0f0081bf3 100644 --- a/lib/destination/dml/merge_test.go +++ b/lib/destination/dml/merge_test.go @@ -32,6 +32,39 @@ func (m MockTableIdentifier) FullyQualifiedName() string { return m.fqName } +func TestRemoveDeleteColumnMarker(t *testing.T) { + { + columns, removed := removeDeleteColumnMarker([]string{}) + assert.Empty(t, columns) + assert.False(t, removed) + } + { + columns, removed := removeDeleteColumnMarker([]string{"a"}) + assert.Equal(t, []string{"a"}, columns) + assert.False(t, removed) + } + { + columns, removed := removeDeleteColumnMarker([]string{"a", "b"}) + assert.Equal(t, []string{"a", "b"}, columns) + assert.False(t, removed) + } + { + columns, removed := removeDeleteColumnMarker([]string{constants.DeleteColumnMarker}) + assert.True(t, removed) + assert.Empty(t, columns) + } + { + columns, removed := removeDeleteColumnMarker([]string{"a", constants.DeleteColumnMarker, "b"}) + assert.True(t, removed) + assert.Equal(t, []string{"a", "b"}, columns) + } + { + columns, removed := removeDeleteColumnMarker([]string{"a", constants.DeleteColumnMarker, "b", constants.DeleteColumnMarker, "c"}) + assert.True(t, removed) + assert.Equal(t, []string{"a", "b", "c"}, columns) + } +} + func TestMergeStatementSoftDelete(t *testing.T) { // No idempotent key fqTable := "database.schema.table" diff --git a/lib/typing/columns/columns.go b/lib/typing/columns/columns.go index 402818a93..b43e6bdc5 100644 --- a/lib/typing/columns/columns.go +++ b/lib/typing/columns/columns.go @@ -189,28 +189,6 @@ func (c *Columns) GetColumnsToUpdate() []string { return cols } -// GetEscapedColumnsToUpdate will filter all the `Invalid` columns so that we do not update it. -// It will escape the returned columns. -func (c *Columns) GetEscapedColumnsToUpdate(dialect sql.Dialect) []string { - if c == nil { - return []string{} - } - - c.RLock() - defer c.RUnlock() - - var cols []string - for _, col := range c.columns { - if col.KindDetails == typing.Invalid { - continue - } - - cols = append(cols, dialect.QuoteIdentifier(col.RawName())) - } - - return cols -} - func (c *Columns) GetColumns() []Column { if c == nil { return []Column{} diff --git a/lib/typing/columns/columns_test.go b/lib/typing/columns/columns_test.go index c2e32518c..2ecb3c586 100644 --- a/lib/typing/columns/columns_test.go +++ b/lib/typing/columns/columns_test.go @@ -229,64 +229,6 @@ func TestColumns_GetColumnsToUpdate(t *testing.T) { } } -func TestColumns_GetEscapedColumnsToUpdate(t *testing.T) { - type _testCase struct { - name string - cols []Column - expectedColsEsc []string - expectedColsEscBq []string - } - - var ( - happyPathCols = []Column{ - { - name: "hi", - KindDetails: typing.String, - }, - { - name: "bye", - KindDetails: typing.String, - }, - { - name: "start", - KindDetails: typing.String, - }, - } - ) - - extraCols := happyPathCols - for i := 0; i < 100; i++ { - extraCols = append(extraCols, Column{ - name: fmt.Sprintf("hello_%v", i), - KindDetails: typing.Invalid, - }) - } - - testCases := []_testCase{ - { - name: "happy path", - cols: happyPathCols, - expectedColsEsc: []string{`"HI"`, `"BYE"`, `"START"`}, - expectedColsEscBq: []string{"`hi`", "`bye`", "`start`"}, - }, - { - name: "happy path + extra col", - cols: extraCols, - expectedColsEsc: []string{`"HI"`, `"BYE"`, `"START"`}, - expectedColsEscBq: []string{"`hi`", "`bye`", "`start`"}, - }, - } - - for _, testCase := range testCases { - cols := &Columns{ - columns: testCase.cols, - } - - assert.Equal(t, testCase.expectedColsEsc, cols.GetEscapedColumnsToUpdate(sql.SnowflakeDialect{}), testCase.name) - assert.Equal(t, testCase.expectedColsEscBq, cols.GetEscapedColumnsToUpdate(sql.BigQueryDialect{}), testCase.name) - } -} - func TestColumns_UpsertColumns(t *testing.T) { keys := []string{"a", "b", "c", "d", "e"} var cols Columns