From 7b8b7ea8720f066005912b50d71ca06278c1150b Mon Sep 17 00:00:00 2001 From: Nathan Villaescusa Date: Wed, 1 May 2024 21:48:30 -0700 Subject: [PATCH] func --- lib/destination/dml/merge.go | 14 +++----------- lib/typing/columns/columns.go | 8 ++++++++ 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/lib/destination/dml/merge.go b/lib/destination/dml/merge.go index 5167fc932..3b715ac6c 100644 --- a/lib/destination/dml/merge.go +++ b/lib/destination/dml/merge.go @@ -92,9 +92,7 @@ func (m *MergeArgument) GetParts() ([]string, error) { } var equalitySQLParts []string - for _, primaryKey := range m.PrimaryKeys { - // We'll need to escape the primary key as well. - escapedPrimaryKey := m.Dialect.QuoteIdentifier(primaryKey.RawName()) + for _, escapedPrimaryKey := range columns.QuoteColumns(m.PrimaryKeys, m.Dialect) { equalitySQL := fmt.Sprintf("c.%s = cc.%s", escapedPrimaryKey, escapedPrimaryKey) equalitySQLParts = append(equalitySQLParts, equalitySQL) } @@ -141,11 +139,6 @@ func (m *MergeArgument) GetParts() ([]string, error) { return nil, errors.New("artie delete flag doesn't exist") } - var pks []string - for _, pk := range m.PrimaryKeys { - pks = append(pks, m.Dialect.QuoteIdentifier(pk.RawName())) - } - parts := []string{ // INSERT fmt.Sprintf(`INSERT INTO %s (%s) SELECT %s FROM %s as cc LEFT JOIN %s as c on %s WHERE c.%s IS NULL;`, @@ -171,6 +164,7 @@ func (m *MergeArgument) GetParts() ([]string, error) { } if *m.ContainsHardDeletes { + var pks []string = columns.QuoteColumns(m.PrimaryKeys, m.Dialect) parts = append(parts, // DELETE fmt.Sprintf(`DELETE FROM %s WHERE (%s) IN (SELECT %s FROM %s as cc WHERE cc.%s = true);`, @@ -296,9 +290,7 @@ func (m *MergeArgument) GetMSSQLStatement() (string, error) { } var equalitySQLParts []string - for _, primaryKey := range m.PrimaryKeys { - // We'll need to escape the primary key as well. - escapedPrimaryKey := m.Dialect.QuoteIdentifier(primaryKey.RawName()) + for _, escapedPrimaryKey := range columns.QuoteColumns(m.PrimaryKeys, m.Dialect) { equalitySQL := fmt.Sprintf("c.%s = cc.%s", escapedPrimaryKey, escapedPrimaryKey) equalitySQLParts = append(equalitySQLParts, equalitySQL) } diff --git a/lib/typing/columns/columns.go b/lib/typing/columns/columns.go index d41ab3bf8..402818a93 100644 --- a/lib/typing/columns/columns.go +++ b/lib/typing/columns/columns.go @@ -309,3 +309,11 @@ func processToastCol(colName string, dialect sql.Dialect) string { colName, colName, constants.ToastUnavailableValuePlaceholder, colName, colName) } } + +func QuoteColumns(columns []Column, dialect sql.Dialect) []string { + result := make([]string, len(columns)) + for i, columns := range columns { + result[i] = dialect.QuoteIdentifier(columns.RawName()) + } + return result +}