diff --git a/clients/bigquery/bigquery.go b/clients/bigquery/bigquery.go index 3c649f56d..97cdc36e2 100644 --- a/clients/bigquery/bigquery.go +++ b/clients/bigquery/bigquery.go @@ -5,6 +5,8 @@ import ( "fmt" "log/slog" "os" + "strings" + "time" "cloud.google.com/go/bigquery" _ "github.com/viant/bigquery" @@ -19,6 +21,9 @@ import ( "github.com/artie-labs/transfer/lib/logger" "github.com/artie-labs/transfer/lib/optimization" "github.com/artie-labs/transfer/lib/ptr" + "github.com/artie-labs/transfer/lib/sql" + "github.com/artie-labs/transfer/lib/stringutil" + "github.com/artie-labs/transfer/lib/typing" ) const ( @@ -143,10 +148,82 @@ func (s *Store) putTable(ctx context.Context, tableID types.TableIdentifier, row return nil } -func (s *Store) Dedupe(tableID types.TableIdentifier, _ []string, _ kafkalib.TopicConfig) error { - fqTableName := tableID.FullyQualifiedName() - _, err := s.Exec(fmt.Sprintf("CREATE OR REPLACE TABLE %s AS SELECT DISTINCT * FROM %s", fqTableName, fqTableName)) - return err +func (s *Store) generateDedupeQueries(tableID, stagingTableID types.TableIdentifier, primaryKeys []string, topicConfig kafkalib.TopicConfig) []string { + var primaryKeysEscaped []string + for _, pk := range primaryKeys { + primaryKeysEscaped = append(primaryKeysEscaped, sql.EscapeNameIfNecessary(pk, s.ShouldUppercaseEscapedNames(), s.Label())) + } + + orderColsToIterate := primaryKeysEscaped + if topicConfig.IncludeArtieUpdatedAt { + orderColsToIterate = append(orderColsToIterate, sql.EscapeNameIfNecessary(constants.UpdateColumnMarker, s.ShouldUppercaseEscapedNames(), s.Label())) + } + + var orderByCols []string + for _, orderByCol := range orderColsToIterate { + orderByCols = append(orderByCols, fmt.Sprintf("%s ASC", orderByCol)) + } + + var parts []string + parts = append(parts, + fmt.Sprintf(`CREATE OR REPLACE TABLE %s OPTIONS (expiration_timestamp = TIMESTAMP("%s")) AS (SELECT * FROM %s QUALIFY ROW_NUMBER() OVER (PARTITION BY %s ORDER BY %s) = 2)`, + stagingTableID.FullyQualifiedName(), + typing.ExpiresDate(time.Now().UTC().Add(constants.TemporaryTableTTL)), + tableID.FullyQualifiedName(), + strings.Join(primaryKeysEscaped, ", "), + strings.Join(orderByCols, ", "), + ), + ) + + var whereClauses []string + for _, primaryKeyEscaped := range primaryKeysEscaped { + whereClauses = append(whereClauses, fmt.Sprintf("t1.%s = t2.%s", primaryKeyEscaped, primaryKeyEscaped)) + } + + // https://cloud.google.com/bigquery/docs/reference/standard-sql/dml-syntax#delete_with_subquery + parts = append(parts, + fmt.Sprintf("DELETE FROM %s t1 WHERE EXISTS (SELECT * FROM %s t2 WHERE %s)", + tableID.FullyQualifiedName(), + stagingTableID.FullyQualifiedName(), + strings.Join(whereClauses, " AND "), + ), + ) + + parts = append(parts, fmt.Sprintf("INSERT INTO %s SELECT * FROM %s", tableID.FullyQualifiedName(), stagingTableID.FullyQualifiedName())) + return parts +} + +func (s *Store) Dedupe(tableID types.TableIdentifier, primaryKeys []string, topicConfig kafkalib.TopicConfig) error { + stagingTableID := shared.TempTableID(tableID, strings.ToLower(stringutil.Random(5))) + + var txCommitted bool + tx, err := s.Begin() + if err != nil { + return fmt.Errorf("failed to start a tx: %w", err) + } + + defer func() { + if !txCommitted { + if err = tx.Rollback(); err != nil { + slog.Warn("Failed to rollback tx", slog.Any("err", err)) + } + } + + _ = ddl.DropTemporaryTable(s, stagingTableID.FullyQualifiedName(), false) + }() + + for _, part := range s.generateDedupeQueries(tableID, stagingTableID, primaryKeys, topicConfig) { + if _, err = tx.Exec(part); err != nil { + return fmt.Errorf("failed to execute tx, query: %q, err: %w", part, err) + } + } + + if err = tx.Commit(); err != nil { + return fmt.Errorf("failed to commit tx: %w", err) + } + + txCommitted = true + return nil } func LoadBigQuery(cfg config.Config, _store *db.Store) (*Store, error) { diff --git a/clients/bigquery/bigquery_dedupe_test.go b/clients/bigquery/bigquery_dedupe_test.go new file mode 100644 index 000000000..ca7d898bd --- /dev/null +++ b/clients/bigquery/bigquery_dedupe_test.go @@ -0,0 +1,90 @@ +package bigquery + +import ( + "fmt" + "strings" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/artie-labs/transfer/clients/shared" + "github.com/artie-labs/transfer/lib/config/constants" + "github.com/artie-labs/transfer/lib/kafkalib" + "github.com/artie-labs/transfer/lib/stringutil" + "github.com/artie-labs/transfer/lib/typing" +) + +func (b *BigQueryTestSuite) TestGenerateDedupeQueries() { + { + // Dedupe with one primary key + no `__artie_updated_at` flag. + tableID := NewTableIdentifier("project12", "public", "customers") + stagingTableID := shared.TempTableID(tableID, strings.ToLower(stringutil.Random(5))) + + parts := b.store.generateDedupeQueries(tableID, stagingTableID, []string{"id"}, kafkalib.TopicConfig{}) + assert.Len(b.T(), parts, 3) + assert.Equal( + b.T(), + fmt.Sprintf("CREATE OR REPLACE TABLE %s OPTIONS (expiration_timestamp = TIMESTAMP(%s)) AS (SELECT * FROM `project12`.`public`.`customers` QUALIFY ROW_NUMBER() OVER (PARTITION BY `id` ORDER BY `id` ASC) = 2)", + stagingTableID.FullyQualifiedName(), + fmt.Sprintf(`"%s"`, typing.ExpiresDate(time.Now().UTC().Add(constants.TemporaryTableTTL))), + ), + parts[0], + ) + assert.Equal(b.T(), fmt.Sprintf("DELETE FROM `project12`.`public`.`customers` t1 WHERE EXISTS (SELECT * FROM %s t2 WHERE t1.`id` = t2.`id`)", stagingTableID.FullyQualifiedName()), parts[1]) + assert.Equal(b.T(), fmt.Sprintf("INSERT INTO `project12`.`public`.`customers` SELECT * FROM %s", stagingTableID.FullyQualifiedName()), parts[2]) + } + { + // Dedupe with one primary key + `__artie_updated_at` flag. + tableID := NewTableIdentifier("project12", "public", "customers") + stagingTableID := shared.TempTableID(tableID, strings.ToLower(stringutil.Random(5))) + + parts := b.store.generateDedupeQueries(tableID, stagingTableID, []string{"id"}, kafkalib.TopicConfig{IncludeArtieUpdatedAt: true}) + assert.Len(b.T(), parts, 3) + assert.Equal( + b.T(), + fmt.Sprintf("CREATE OR REPLACE TABLE %s OPTIONS (expiration_timestamp = TIMESTAMP(%s)) AS (SELECT * FROM `project12`.`public`.`customers` QUALIFY ROW_NUMBER() OVER (PARTITION BY `id` ORDER BY `id` ASC, __artie_updated_at ASC) = 2)", + stagingTableID.FullyQualifiedName(), + fmt.Sprintf(`"%s"`, typing.ExpiresDate(time.Now().UTC().Add(constants.TemporaryTableTTL))), + ), + parts[0], + ) + assert.Equal(b.T(), fmt.Sprintf("DELETE FROM `project12`.`public`.`customers` t1 WHERE EXISTS (SELECT * FROM %s t2 WHERE t1.`id` = t2.`id`)", stagingTableID.FullyQualifiedName()), parts[1]) + assert.Equal(b.T(), fmt.Sprintf("INSERT INTO `project12`.`public`.`customers` SELECT * FROM %s", stagingTableID.FullyQualifiedName()), parts[2]) + } + { + // Dedupe with composite keys + no `__artie_updated_at` flag. + tableID := NewTableIdentifier("project123", "public", "user_settings") + stagingTableID := shared.TempTableID(tableID, strings.ToLower(stringutil.Random(5))) + + parts := b.store.generateDedupeQueries(tableID, stagingTableID, []string{"user_id", "settings"}, kafkalib.TopicConfig{}) + assert.Len(b.T(), parts, 3) + assert.Equal( + b.T(), + fmt.Sprintf("CREATE OR REPLACE TABLE %s OPTIONS (expiration_timestamp = TIMESTAMP(%s)) AS (SELECT * FROM `project123`.`public`.`user_settings` QUALIFY ROW_NUMBER() OVER (PARTITION BY `user_id`, `settings` ORDER BY `user_id` ASC, `settings` ASC) = 2)", + stagingTableID.FullyQualifiedName(), + fmt.Sprintf(`"%s"`, typing.ExpiresDate(time.Now().UTC().Add(constants.TemporaryTableTTL))), + ), + parts[0], + ) + assert.Equal(b.T(), fmt.Sprintf("DELETE FROM `project123`.`public`.`user_settings` t1 WHERE EXISTS (SELECT * FROM %s t2 WHERE t1.`user_id` = t2.`user_id` AND t1.`settings` = t2.`settings`)", stagingTableID.FullyQualifiedName()), parts[1]) + assert.Equal(b.T(), fmt.Sprintf("INSERT INTO `project123`.`public`.`user_settings` SELECT * FROM %s", stagingTableID.FullyQualifiedName()), parts[2]) + } + { + // Dedupe with composite keys + `__artie_updated_at` flag. + tableID := NewTableIdentifier("project123", "public", "user_settings") + stagingTableID := shared.TempTableID(tableID, strings.ToLower(stringutil.Random(5))) + + parts := b.store.generateDedupeQueries(tableID, stagingTableID, []string{"user_id", "settings"}, kafkalib.TopicConfig{IncludeArtieUpdatedAt: true}) + assert.Len(b.T(), parts, 3) + assert.Equal( + b.T(), + fmt.Sprintf("CREATE OR REPLACE TABLE %s OPTIONS (expiration_timestamp = TIMESTAMP(%s)) AS (SELECT * FROM `project123`.`public`.`user_settings` QUALIFY ROW_NUMBER() OVER (PARTITION BY `user_id`, `settings` ORDER BY `user_id` ASC, `settings` ASC, __artie_updated_at ASC) = 2)", + stagingTableID.FullyQualifiedName(), + fmt.Sprintf(`"%s"`, typing.ExpiresDate(time.Now().UTC().Add(constants.TemporaryTableTTL))), + ), + parts[0], + ) + assert.Equal(b.T(), fmt.Sprintf("DELETE FROM `project123`.`public`.`user_settings` t1 WHERE EXISTS (SELECT * FROM %s t2 WHERE t1.`user_id` = t2.`user_id` AND t1.`settings` = t2.`settings`)", stagingTableID.FullyQualifiedName()), parts[1]) + assert.Equal(b.T(), fmt.Sprintf("INSERT INTO `project123`.`public`.`user_settings` SELECT * FROM %s", stagingTableID.FullyQualifiedName()), parts[2]) + } +} diff --git a/clients/snowflake/snowflake.go b/clients/snowflake/snowflake.go index 1f3f2c6fe..047da9dc0 100644 --- a/clients/snowflake/snowflake.go +++ b/clients/snowflake/snowflake.go @@ -15,9 +15,8 @@ import ( "github.com/artie-labs/transfer/lib/kafkalib" "github.com/artie-labs/transfer/lib/optimization" "github.com/artie-labs/transfer/lib/ptr" + "github.com/artie-labs/transfer/lib/sql" "github.com/artie-labs/transfer/lib/stringutil" - "github.com/artie-labs/transfer/lib/typing" - "github.com/artie-labs/transfer/lib/typing/columns" ) const maxRetries = 10 @@ -131,13 +130,12 @@ func (s *Store) reestablishConnection() error { func (s *Store) generateDedupeQueries(tableID, stagingTableID types.TableIdentifier, primaryKeys []string, topicConfig kafkalib.TopicConfig) []string { var primaryKeysEscaped []string for _, pk := range primaryKeys { - pkCol := columns.NewColumn(pk, typing.Invalid) - primaryKeysEscaped = append(primaryKeysEscaped, pkCol.Name(s.ShouldUppercaseEscapedNames(), s.Label())) + primaryKeysEscaped = append(primaryKeysEscaped, sql.EscapeNameIfNecessary(pk, s.ShouldUppercaseEscapedNames(), s.Label())) } orderColsToIterate := primaryKeysEscaped if topicConfig.IncludeArtieUpdatedAt { - orderColsToIterate = append(orderColsToIterate, constants.UpdateColumnMarker) + orderColsToIterate = append(orderColsToIterate, sql.EscapeNameIfNecessary(constants.UpdateColumnMarker, s.ShouldUppercaseEscapedNames(), s.Label())) } var orderByCols []string