From cf3e96014cae14cf68afd59421af4508a37dc7e9 Mon Sep 17 00:00:00 2001 From: Robin Tang Date: Mon, 29 Apr 2024 11:54:23 -0700 Subject: [PATCH] Snowflake - Fix Dedupe Query (#513) --- clients/snowflake/snowflake.go | 10 ++++----- clients/snowflake/snowflake_dedupe_test.go | 24 +++++++++++----------- 2 files changed, 16 insertions(+), 18 deletions(-) diff --git a/clients/snowflake/snowflake.go b/clients/snowflake/snowflake.go index fff92d3c3..1f3f2c6fe 100644 --- a/clients/snowflake/snowflake.go +++ b/clients/snowflake/snowflake.go @@ -15,7 +15,6 @@ import ( "github.com/artie-labs/transfer/lib/kafkalib" "github.com/artie-labs/transfer/lib/optimization" "github.com/artie-labs/transfer/lib/ptr" - "github.com/artie-labs/transfer/lib/sql" "github.com/artie-labs/transfer/lib/stringutil" "github.com/artie-labs/transfer/lib/typing" "github.com/artie-labs/transfer/lib/typing/columns" @@ -146,10 +145,9 @@ func (s *Store) generateDedupeQueries(tableID, stagingTableID types.TableIdentif orderByCols = append(orderByCols, fmt.Sprintf("%s ASC", pk)) } - temporaryTableName := sql.EscapeName(stagingTableID.Table(), s.ShouldUppercaseEscapedNames(), s.Label()) var parts []string - parts = append(parts, fmt.Sprintf("CREATE OR REPLACE TRANSIENT TABLE %s AS (SELECT * FROM %s QUALIFY ROW_NUMBER() OVER (PARTITION BY by %s ORDER BY %s) = 2)", - temporaryTableName, + parts = append(parts, fmt.Sprintf("CREATE OR REPLACE TRANSIENT TABLE %s AS (SELECT * FROM %s QUALIFY ROW_NUMBER() OVER (PARTITION BY %s ORDER BY %s) = 2)", + stagingTableID.FullyQualifiedName(), tableID.FullyQualifiedName(), strings.Join(primaryKeysEscaped, ", "), strings.Join(orderByCols, ", "), @@ -162,11 +160,11 @@ func (s *Store) generateDedupeQueries(tableID, stagingTableID types.TableIdentif parts = append(parts, fmt.Sprintf("DELETE FROM %s t1 USING %s t2 WHERE %s", tableID.FullyQualifiedName(), - temporaryTableName, + stagingTableID.FullyQualifiedName(), strings.Join(whereClauses, " AND "), )) - parts = append(parts, fmt.Sprintf("INSERT INTO %s SELECT * FROM %s", tableID.FullyQualifiedName(), temporaryTableName)) + parts = append(parts, fmt.Sprintf("INSERT INTO %s SELECT * FROM %s", tableID.FullyQualifiedName(), stagingTableID.FullyQualifiedName())) return parts } diff --git a/clients/snowflake/snowflake_dedupe_test.go b/clients/snowflake/snowflake_dedupe_test.go index d600f072a..b9b5aed83 100644 --- a/clients/snowflake/snowflake_dedupe_test.go +++ b/clients/snowflake/snowflake_dedupe_test.go @@ -20,11 +20,11 @@ func (s *SnowflakeTestSuite) TestGenerateDedupeQueries() { assert.Len(s.T(), parts, 3) assert.Equal( s.T(), - fmt.Sprintf(`CREATE OR REPLACE TRANSIENT TABLE "%s" AS (SELECT * FROM db.public."CUSTOMERS" QUALIFY ROW_NUMBER() OVER (PARTITION BY by id ORDER BY id ASC) = 2)`, stagingTableID.Table()), + fmt.Sprintf(`CREATE OR REPLACE TRANSIENT TABLE %s AS (SELECT * FROM db.public."CUSTOMERS" QUALIFY ROW_NUMBER() OVER (PARTITION BY id ORDER BY id ASC) = 2)`, stagingTableID.FullyQualifiedName()), parts[0], ) - assert.Equal(s.T(), fmt.Sprintf(`DELETE FROM db.public."CUSTOMERS" t1 USING "%s" t2 WHERE t1.id = t2.id`, stagingTableID.Table()), parts[1]) - assert.Equal(s.T(), fmt.Sprintf(`INSERT INTO db.public."CUSTOMERS" SELECT * FROM "%s"`, stagingTableID.Table()), parts[2]) + assert.Equal(s.T(), fmt.Sprintf(`DELETE FROM db.public."CUSTOMERS" t1 USING %s t2 WHERE t1.id = t2.id`, stagingTableID.FullyQualifiedName()), parts[1]) + assert.Equal(s.T(), fmt.Sprintf(`INSERT INTO db.public."CUSTOMERS" SELECT * FROM %s`, stagingTableID.FullyQualifiedName()), parts[2]) } { // Dedupe with one primary key + `__artie_updated_at` flag. @@ -35,11 +35,11 @@ func (s *SnowflakeTestSuite) TestGenerateDedupeQueries() { assert.Len(s.T(), parts, 3) assert.Equal( s.T(), - fmt.Sprintf(`CREATE OR REPLACE TRANSIENT TABLE "%s" AS (SELECT * FROM db.public."CUSTOMERS" QUALIFY ROW_NUMBER() OVER (PARTITION BY by id ORDER BY id ASC, __artie_updated_at ASC) = 2)`, stagingTableID.Table()), + fmt.Sprintf(`CREATE OR REPLACE TRANSIENT TABLE %s AS (SELECT * FROM db.public."CUSTOMERS" QUALIFY ROW_NUMBER() OVER (PARTITION BY id ORDER BY id ASC, __artie_updated_at ASC) = 2)`, stagingTableID.FullyQualifiedName()), parts[0], ) - assert.Equal(s.T(), fmt.Sprintf(`DELETE FROM db.public."CUSTOMERS" t1 USING "%s" t2 WHERE t1.id = t2.id`, stagingTableID.Table()), parts[1]) - assert.Equal(s.T(), fmt.Sprintf(`INSERT INTO db.public."CUSTOMERS" SELECT * FROM "%s"`, stagingTableID.Table()), parts[2]) + assert.Equal(s.T(), fmt.Sprintf(`DELETE FROM db.public."CUSTOMERS" t1 USING %s t2 WHERE t1.id = t2.id`, stagingTableID.FullyQualifiedName()), parts[1]) + assert.Equal(s.T(), fmt.Sprintf(`INSERT INTO db.public."CUSTOMERS" SELECT * FROM %s`, stagingTableID.FullyQualifiedName()), parts[2]) } { // Dedupe with composite keys + no `__artie_updated_at` flag. @@ -50,11 +50,11 @@ func (s *SnowflakeTestSuite) TestGenerateDedupeQueries() { assert.Len(s.T(), parts, 3) assert.Equal( s.T(), - fmt.Sprintf(`CREATE OR REPLACE TRANSIENT TABLE "%s" AS (SELECT * FROM db.public."USER_SETTINGS" QUALIFY ROW_NUMBER() OVER (PARTITION BY by user_id, settings ORDER BY user_id ASC, settings ASC) = 2)`, stagingTableID.Table()), + fmt.Sprintf(`CREATE OR REPLACE TRANSIENT TABLE %s AS (SELECT * FROM db.public."USER_SETTINGS" QUALIFY ROW_NUMBER() OVER (PARTITION BY user_id, settings ORDER BY user_id ASC, settings ASC) = 2)`, stagingTableID.FullyQualifiedName()), parts[0], ) - assert.Equal(s.T(), fmt.Sprintf(`DELETE FROM db.public."USER_SETTINGS" t1 USING "%s" t2 WHERE t1.user_id = t2.user_id AND t1.settings = t2.settings`, stagingTableID.Table()), parts[1]) - assert.Equal(s.T(), fmt.Sprintf(`INSERT INTO db.public."USER_SETTINGS" SELECT * FROM "%s"`, stagingTableID.Table()), parts[2]) + assert.Equal(s.T(), fmt.Sprintf(`DELETE FROM db.public."USER_SETTINGS" t1 USING %s t2 WHERE t1.user_id = t2.user_id AND t1.settings = t2.settings`, stagingTableID.FullyQualifiedName()), parts[1]) + assert.Equal(s.T(), fmt.Sprintf(`INSERT INTO db.public."USER_SETTINGS" SELECT * FROM %s`, stagingTableID.FullyQualifiedName()), parts[2]) } { // Dedupe with composite keys + `__artie_updated_at` flag. @@ -65,10 +65,10 @@ func (s *SnowflakeTestSuite) TestGenerateDedupeQueries() { assert.Len(s.T(), parts, 3) assert.Equal( s.T(), - fmt.Sprintf(`CREATE OR REPLACE TRANSIENT TABLE "%s" AS (SELECT * FROM db.public."USER_SETTINGS" QUALIFY ROW_NUMBER() OVER (PARTITION BY by user_id, settings ORDER BY user_id ASC, settings ASC, __artie_updated_at ASC) = 2)`, stagingTableID.Table()), + fmt.Sprintf(`CREATE OR REPLACE TRANSIENT TABLE %s AS (SELECT * FROM db.public."USER_SETTINGS" QUALIFY ROW_NUMBER() OVER (PARTITION BY user_id, settings ORDER BY user_id ASC, settings ASC, __artie_updated_at ASC) = 2)`, stagingTableID.FullyQualifiedName()), parts[0], ) - assert.Equal(s.T(), fmt.Sprintf(`DELETE FROM db.public."USER_SETTINGS" t1 USING "%s" t2 WHERE t1.user_id = t2.user_id AND t1.settings = t2.settings`, stagingTableID.Table()), parts[1]) - assert.Equal(s.T(), fmt.Sprintf(`INSERT INTO db.public."USER_SETTINGS" SELECT * FROM "%s"`, stagingTableID.Table()), parts[2]) + assert.Equal(s.T(), fmt.Sprintf(`DELETE FROM db.public."USER_SETTINGS" t1 USING %s t2 WHERE t1.user_id = t2.user_id AND t1.settings = t2.settings`, stagingTableID.FullyQualifiedName()), parts[1]) + assert.Equal(s.T(), fmt.Sprintf(`INSERT INTO db.public."USER_SETTINGS" SELECT * FROM %s`, stagingTableID.FullyQualifiedName()), parts[2]) } }