diff --git a/clients/bigquery/bigquery.go b/clients/bigquery/bigquery.go index f2ba0182d..2f983a3a9 100644 --- a/clients/bigquery/bigquery.go +++ b/clients/bigquery/bigquery.go @@ -143,7 +143,7 @@ func (s *Store) putTable(ctx context.Context, tableID types.TableIdentifier, row return nil } -func (s *Store) Dedupe(tableID types.TableIdentifier) error { +func (s *Store) Dedupe(tableID types.TableIdentifier, _ []string, _ kafkalib.TopicConfig) error { fqTableName := tableID.FullyQualifiedName() _, err := s.Exec(fmt.Sprintf("CREATE OR REPLACE TABLE %s AS SELECT DISTINCT * FROM %s", fqTableName, fqTableName)) return err diff --git a/clients/mssql/store.go b/clients/mssql/store.go index e3a66e704..ffd5725c5 100644 --- a/clients/mssql/store.go +++ b/clients/mssql/store.go @@ -70,7 +70,7 @@ func (s *Store) Sweep() error { return shared.Sweep(s, tcs, queryFunc) } -func (s *Store) Dedupe(_ types.TableIdentifier) error { +func (s *Store) Dedupe(_ types.TableIdentifier, _ []string, _ kafkalib.TopicConfig) error { return nil // dedupe is not necessary for MS SQL } diff --git a/clients/redshift/redshift.go b/clients/redshift/redshift.go index 0914cb23c..34acb73f4 100644 --- a/clients/redshift/redshift.go +++ b/clients/redshift/redshift.go @@ -97,7 +97,7 @@ WHERE return shared.Sweep(s, tcs, queryFunc) } -func (s *Store) Dedupe(tableID types.TableIdentifier) error { +func (s *Store) Dedupe(tableID types.TableIdentifier, _ []string, _ kafkalib.TopicConfig) error { fqTableName := tableID.FullyQualifiedName() stagingTableName := shared.TempTableID(tableID, strings.ToLower(stringutil.Random(5))).FullyQualifiedName() diff --git a/clients/shared/table_config_test.go b/clients/shared/table_config_test.go index b909e6743..1d8b26d81 100644 --- a/clients/shared/table_config_test.go +++ b/clients/shared/table_config_test.go @@ -57,10 +57,12 @@ func TestGetTableCfgArgs_ShouldParseComment(t *testing.T) { type MockDWH struct{} -func (MockDWH) Label() constants.DestinationKind { panic("not implemented") } -func (MockDWH) Merge(tableData *optimization.TableData) error { panic("not implemented") } -func (MockDWH) Append(tableData *optimization.TableData) error { panic("not implemented") } -func (MockDWH) Dedupe(tableID types.TableIdentifier) error { panic("not implemented") } +func (MockDWH) Label() constants.DestinationKind { panic("not implemented") } +func (MockDWH) Merge(tableData *optimization.TableData) error { panic("not implemented") } +func (MockDWH) Append(tableData *optimization.TableData) error { panic("not implemented") } +func (MockDWH) Dedupe(tableID types.TableIdentifier, primaryKeys []string, topicConfig kafkalib.TopicConfig) error { + panic("not implemented") +} func (MockDWH) Exec(query string, args ...any) (sql.Result, error) { panic("not implemented") } func (MockDWH) Query(query string, args ...any) (*sql.Rows, error) { panic("not implemented") } func (MockDWH) Begin() (*sql.Tx, error) { panic("not implemented") } diff --git a/clients/snowflake/snowflake.go b/clients/snowflake/snowflake.go index 9b2820be3..6c4105f84 100644 --- a/clients/snowflake/snowflake.go +++ b/clients/snowflake/snowflake.go @@ -2,6 +2,8 @@ package snowflake import ( "fmt" + "log/slog" + "strings" "github.com/snowflakedb/gosnowflake" @@ -13,6 +15,10 @@ import ( "github.com/artie-labs/transfer/lib/kafkalib" "github.com/artie-labs/transfer/lib/optimization" "github.com/artie-labs/transfer/lib/ptr" + "github.com/artie-labs/transfer/lib/sql" + "github.com/artie-labs/transfer/lib/stringutil" + "github.com/artie-labs/transfer/lib/typing" + "github.com/artie-labs/transfer/lib/typing/columns" ) const maxRetries = 10 @@ -119,10 +125,77 @@ func (s *Store) reestablishConnection() error { return nil } -func (s *Store) Dedupe(tableID types.TableIdentifier) error { - fqTableName := tableID.FullyQualifiedName() - _, err := s.Exec(fmt.Sprintf("CREATE OR REPLACE TABLE %s AS SELECT DISTINCT * FROM %s", fqTableName, fqTableName)) - return err +func (s *Store) generateDedupeQueries(tableID, stagingTableID types.TableIdentifier, primaryKeys []string, topicConfig kafkalib.TopicConfig) []string { + var primaryKeysEscaped []string + for _, pk := range primaryKeys { + pkCol := columns.NewColumn(pk, typing.Invalid) + primaryKeysEscaped = append(primaryKeysEscaped, pkCol.Name(s.ShouldUppercaseEscapedNames(), &columns.NameArgs{DestKind: s.Label()})) + } + + orderColsToIterate := primaryKeysEscaped + if topicConfig.IncludeArtieUpdatedAt { + orderColsToIterate = append(orderColsToIterate, constants.UpdateColumnMarker) + } + + var orderByCols []string + for _, pk := range orderColsToIterate { + orderByCols = append(orderByCols, fmt.Sprintf("%s ASC", pk)) + } + + temporaryTableName := sql.EscapeName(stagingTableID.Table(), s.ShouldUppercaseEscapedNames(), s.Label()) + var parts []string + parts = append(parts, fmt.Sprintf("CREATE OR REPLACE TRANSIENT TABLE %s AS (SELECT * FROM %s QUALIFY ROW_NUMBER() OVER (PARTITION BY by %s ORDER BY %s) = 2)", + temporaryTableName, + tableID.FullyQualifiedName(), + strings.Join(primaryKeysEscaped, ", "), + strings.Join(orderByCols, ", "), + )) + + var whereClauses []string + for _, primaryKeyEscaped := range primaryKeysEscaped { + whereClauses = append(whereClauses, fmt.Sprintf("t1.%s = t2.%s", primaryKeyEscaped, primaryKeyEscaped)) + } + + parts = append(parts, fmt.Sprintf("DELETE FROM %s t1 USING %s t2 WHERE %s", + tableID.FullyQualifiedName(), + temporaryTableName, + strings.Join(whereClauses, " AND "), + )) + + parts = append(parts, fmt.Sprintf("INSERT INTO %s SELECT * FROM %s", tableID.FullyQualifiedName(), temporaryTableName)) + return parts +} + +// Dedupe takes a table and will remove duplicates based on the primary key(s). +// These queries are inspired and modified from: https://stackoverflow.com/a/71515946 +func (s *Store) Dedupe(tableID types.TableIdentifier, primaryKeys []string, topicConfig kafkalib.TopicConfig) error { + var txCommitted bool + tx, err := s.Begin() + if err != nil { + return fmt.Errorf("failed to start a tx: %w", err) + } + + defer func() { + if !txCommitted { + if err = tx.Rollback(); err != nil { + slog.Warn("Failed to rollback tx", slog.Any("err", err)) + } + } + }() + + stagingTableID := shared.TempTableID(tableID, strings.ToLower(stringutil.Random(5))) + for _, part := range s.generateDedupeQueries(tableID, stagingTableID, primaryKeys, topicConfig) { + if _, err = tx.Exec(part); err != nil { + return fmt.Errorf("failed to execute tx, query: %q, err: %w", part, err) + } + } + + if err = tx.Commit(); err != nil { + return fmt.Errorf("failed to commit tx: %w", err) + } + + txCommitted = true + return nil } func LoadSnowflake(cfg config.Config, _store *db.Store) (*Store, error) { diff --git a/clients/snowflake/snowflake_dedupe_test.go b/clients/snowflake/snowflake_dedupe_test.go new file mode 100644 index 000000000..d600f072a --- /dev/null +++ b/clients/snowflake/snowflake_dedupe_test.go @@ -0,0 +1,74 @@ +package snowflake + +import ( + "fmt" + "strings" + + "github.com/artie-labs/transfer/clients/shared" + "github.com/artie-labs/transfer/lib/kafkalib" + "github.com/artie-labs/transfer/lib/stringutil" + "github.com/stretchr/testify/assert" +) + +func (s *SnowflakeTestSuite) TestGenerateDedupeQueries() { + { + // Dedupe with one primary key + no `__artie_updated_at` flag. + tableID := NewTableIdentifier("db", "public", "customers") + stagingTableID := shared.TempTableID(tableID, strings.ToLower(stringutil.Random(5))) + + parts := s.stageStore.generateDedupeQueries(tableID, stagingTableID, []string{"id"}, kafkalib.TopicConfig{}) + assert.Len(s.T(), parts, 3) + assert.Equal( + s.T(), + fmt.Sprintf(`CREATE OR REPLACE TRANSIENT TABLE "%s" AS (SELECT * FROM db.public."CUSTOMERS" QUALIFY ROW_NUMBER() OVER (PARTITION BY by id ORDER BY id ASC) = 2)`, stagingTableID.Table()), + parts[0], + ) + assert.Equal(s.T(), fmt.Sprintf(`DELETE FROM db.public."CUSTOMERS" t1 USING "%s" t2 WHERE t1.id = t2.id`, stagingTableID.Table()), parts[1]) + assert.Equal(s.T(), fmt.Sprintf(`INSERT INTO db.public."CUSTOMERS" SELECT * FROM "%s"`, stagingTableID.Table()), parts[2]) + } + { + // Dedupe with one primary key + `__artie_updated_at` flag. + tableID := NewTableIdentifier("db", "public", "customers") + stagingTableID := shared.TempTableID(tableID, strings.ToLower(stringutil.Random(5))) + + parts := s.stageStore.generateDedupeQueries(tableID, stagingTableID, []string{"id"}, kafkalib.TopicConfig{IncludeArtieUpdatedAt: true}) + assert.Len(s.T(), parts, 3) + assert.Equal( + s.T(), + fmt.Sprintf(`CREATE OR REPLACE TRANSIENT TABLE "%s" AS (SELECT * FROM db.public."CUSTOMERS" QUALIFY ROW_NUMBER() OVER (PARTITION BY by id ORDER BY id ASC, __artie_updated_at ASC) = 2)`, stagingTableID.Table()), + parts[0], + ) + assert.Equal(s.T(), fmt.Sprintf(`DELETE FROM db.public."CUSTOMERS" t1 USING "%s" t2 WHERE t1.id = t2.id`, stagingTableID.Table()), parts[1]) + assert.Equal(s.T(), fmt.Sprintf(`INSERT INTO db.public."CUSTOMERS" SELECT * FROM "%s"`, stagingTableID.Table()), parts[2]) + } + { + // Dedupe with composite keys + no `__artie_updated_at` flag. + tableID := NewTableIdentifier("db", "public", "user_settings") + stagingTableID := shared.TempTableID(tableID, strings.ToLower(stringutil.Random(5))) + + parts := s.stageStore.generateDedupeQueries(tableID, stagingTableID, []string{"user_id", "settings"}, kafkalib.TopicConfig{}) + assert.Len(s.T(), parts, 3) + assert.Equal( + s.T(), + fmt.Sprintf(`CREATE OR REPLACE TRANSIENT TABLE "%s" AS (SELECT * FROM db.public."USER_SETTINGS" QUALIFY ROW_NUMBER() OVER (PARTITION BY by user_id, settings ORDER BY user_id ASC, settings ASC) = 2)`, stagingTableID.Table()), + parts[0], + ) + assert.Equal(s.T(), fmt.Sprintf(`DELETE FROM db.public."USER_SETTINGS" t1 USING "%s" t2 WHERE t1.user_id = t2.user_id AND t1.settings = t2.settings`, stagingTableID.Table()), parts[1]) + assert.Equal(s.T(), fmt.Sprintf(`INSERT INTO db.public."USER_SETTINGS" SELECT * FROM "%s"`, stagingTableID.Table()), parts[2]) + } + { + // Dedupe with composite keys + `__artie_updated_at` flag. + tableID := NewTableIdentifier("db", "public", "user_settings") + stagingTableID := shared.TempTableID(tableID, strings.ToLower(stringutil.Random(5))) + + parts := s.stageStore.generateDedupeQueries(tableID, stagingTableID, []string{"user_id", "settings"}, kafkalib.TopicConfig{IncludeArtieUpdatedAt: true}) + assert.Len(s.T(), parts, 3) + assert.Equal( + s.T(), + fmt.Sprintf(`CREATE OR REPLACE TRANSIENT TABLE "%s" AS (SELECT * FROM db.public."USER_SETTINGS" QUALIFY ROW_NUMBER() OVER (PARTITION BY by user_id, settings ORDER BY user_id ASC, settings ASC, __artie_updated_at ASC) = 2)`, stagingTableID.Table()), + parts[0], + ) + assert.Equal(s.T(), fmt.Sprintf(`DELETE FROM db.public."USER_SETTINGS" t1 USING "%s" t2 WHERE t1.user_id = t2.user_id AND t1.settings = t2.settings`, stagingTableID.Table()), parts[1]) + assert.Equal(s.T(), fmt.Sprintf(`INSERT INTO db.public."USER_SETTINGS" SELECT * FROM "%s"`, stagingTableID.Table()), parts[2]) + } +} diff --git a/lib/destination/dwh.go b/lib/destination/dwh.go index 9a01bf50a..3a0acfa40 100644 --- a/lib/destination/dwh.go +++ b/lib/destination/dwh.go @@ -13,7 +13,7 @@ type DataWarehouse interface { Label() constants.DestinationKind Merge(tableData *optimization.TableData) error Append(tableData *optimization.TableData) error - Dedupe(tableID types.TableIdentifier) error + Dedupe(tableID types.TableIdentifier, primaryKeys []string, topicConfig kafkalib.TopicConfig) error Exec(query string, args ...any) (sql.Result, error) Query(query string, args ...any) (*sql.Rows, error) Begin() (*sql.Tx, error)