Skip to content

Commit

Permalink
Improve Snowflake Dedupe. (#500)
Browse files Browse the repository at this point in the history
  • Loading branch information
Tang8330 authored Apr 24, 2024
1 parent 97f1e00 commit 799e4c0
Show file tree
Hide file tree
Showing 7 changed files with 161 additions and 12 deletions.
2 changes: 1 addition & 1 deletion clients/bigquery/bigquery.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ func (s *Store) putTable(ctx context.Context, tableID types.TableIdentifier, row
return nil
}

func (s *Store) Dedupe(tableID types.TableIdentifier) error {
func (s *Store) Dedupe(tableID types.TableIdentifier, _ []string, _ kafkalib.TopicConfig) error {
fqTableName := tableID.FullyQualifiedName()
_, err := s.Exec(fmt.Sprintf("CREATE OR REPLACE TABLE %s AS SELECT DISTINCT * FROM %s", fqTableName, fqTableName))
return err
Expand Down
2 changes: 1 addition & 1 deletion clients/mssql/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func (s *Store) Sweep() error {
return shared.Sweep(s, tcs, queryFunc)
}

func (s *Store) Dedupe(_ types.TableIdentifier) error {
func (s *Store) Dedupe(_ types.TableIdentifier, _ []string, _ kafkalib.TopicConfig) error {
return nil // dedupe is not necessary for MS SQL
}

Expand Down
2 changes: 1 addition & 1 deletion clients/redshift/redshift.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ WHERE
return shared.Sweep(s, tcs, queryFunc)
}

func (s *Store) Dedupe(tableID types.TableIdentifier) error {
func (s *Store) Dedupe(tableID types.TableIdentifier, _ []string, _ kafkalib.TopicConfig) error {
fqTableName := tableID.FullyQualifiedName()
stagingTableName := shared.TempTableID(tableID, strings.ToLower(stringutil.Random(5))).FullyQualifiedName()

Expand Down
10 changes: 6 additions & 4 deletions clients/shared/table_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,12 @@ func TestGetTableCfgArgs_ShouldParseComment(t *testing.T) {

type MockDWH struct{}

func (MockDWH) Label() constants.DestinationKind { panic("not implemented") }
func (MockDWH) Merge(tableData *optimization.TableData) error { panic("not implemented") }
func (MockDWH) Append(tableData *optimization.TableData) error { panic("not implemented") }
func (MockDWH) Dedupe(tableID types.TableIdentifier) error { panic("not implemented") }
func (MockDWH) Label() constants.DestinationKind { panic("not implemented") }
func (MockDWH) Merge(tableData *optimization.TableData) error { panic("not implemented") }
func (MockDWH) Append(tableData *optimization.TableData) error { panic("not implemented") }
func (MockDWH) Dedupe(tableID types.TableIdentifier, primaryKeys []string, topicConfig kafkalib.TopicConfig) error {
panic("not implemented")
}
func (MockDWH) Exec(query string, args ...any) (sql.Result, error) { panic("not implemented") }
func (MockDWH) Query(query string, args ...any) (*sql.Rows, error) { panic("not implemented") }
func (MockDWH) Begin() (*sql.Tx, error) { panic("not implemented") }
Expand Down
81 changes: 77 additions & 4 deletions clients/snowflake/snowflake.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package snowflake

import (
"fmt"
"log/slog"
"strings"

"github.com/snowflakedb/gosnowflake"

Expand All @@ -13,6 +15,10 @@ import (
"github.com/artie-labs/transfer/lib/kafkalib"
"github.com/artie-labs/transfer/lib/optimization"
"github.com/artie-labs/transfer/lib/ptr"
"github.com/artie-labs/transfer/lib/sql"
"github.com/artie-labs/transfer/lib/stringutil"
"github.com/artie-labs/transfer/lib/typing"
"github.com/artie-labs/transfer/lib/typing/columns"
)

const maxRetries = 10
Expand Down Expand Up @@ -119,10 +125,77 @@ func (s *Store) reestablishConnection() error {
return nil
}

func (s *Store) Dedupe(tableID types.TableIdentifier) error {
fqTableName := tableID.FullyQualifiedName()
_, err := s.Exec(fmt.Sprintf("CREATE OR REPLACE TABLE %s AS SELECT DISTINCT * FROM %s", fqTableName, fqTableName))
return err
func (s *Store) generateDedupeQueries(tableID, stagingTableID types.TableIdentifier, primaryKeys []string, topicConfig kafkalib.TopicConfig) []string {
var primaryKeysEscaped []string
for _, pk := range primaryKeys {
pkCol := columns.NewColumn(pk, typing.Invalid)
primaryKeysEscaped = append(primaryKeysEscaped, pkCol.Name(s.ShouldUppercaseEscapedNames(), &columns.NameArgs{DestKind: s.Label()}))
}

orderColsToIterate := primaryKeysEscaped
if topicConfig.IncludeArtieUpdatedAt {
orderColsToIterate = append(orderColsToIterate, constants.UpdateColumnMarker)
}

var orderByCols []string
for _, pk := range orderColsToIterate {
orderByCols = append(orderByCols, fmt.Sprintf("%s ASC", pk))
}

temporaryTableName := sql.EscapeName(stagingTableID.Table(), s.ShouldUppercaseEscapedNames(), s.Label())
var parts []string
parts = append(parts, fmt.Sprintf("CREATE OR REPLACE TRANSIENT TABLE %s AS (SELECT * FROM %s QUALIFY ROW_NUMBER() OVER (PARTITION BY by %s ORDER BY %s) = 2)",
temporaryTableName,
tableID.FullyQualifiedName(),
strings.Join(primaryKeysEscaped, ", "),
strings.Join(orderByCols, ", "),
))

var whereClauses []string
for _, primaryKeyEscaped := range primaryKeysEscaped {
whereClauses = append(whereClauses, fmt.Sprintf("t1.%s = t2.%s", primaryKeyEscaped, primaryKeyEscaped))
}

parts = append(parts, fmt.Sprintf("DELETE FROM %s t1 USING %s t2 WHERE %s",
tableID.FullyQualifiedName(),
temporaryTableName,
strings.Join(whereClauses, " AND "),
))

parts = append(parts, fmt.Sprintf("INSERT INTO %s SELECT * FROM %s", tableID.FullyQualifiedName(), temporaryTableName))
return parts
}

// Dedupe takes a table and will remove duplicates based on the primary key(s).
// These queries are inspired and modified from: https://stackoverflow.com/a/71515946
func (s *Store) Dedupe(tableID types.TableIdentifier, primaryKeys []string, topicConfig kafkalib.TopicConfig) error {
var txCommitted bool
tx, err := s.Begin()
if err != nil {
return fmt.Errorf("failed to start a tx: %w", err)
}

defer func() {
if !txCommitted {
if err = tx.Rollback(); err != nil {
slog.Warn("Failed to rollback tx", slog.Any("err", err))
}
}
}()

stagingTableID := shared.TempTableID(tableID, strings.ToLower(stringutil.Random(5)))
for _, part := range s.generateDedupeQueries(tableID, stagingTableID, primaryKeys, topicConfig) {
if _, err = tx.Exec(part); err != nil {
return fmt.Errorf("failed to execute tx, query: %q, err: %w", part, err)
}
}

if err = tx.Commit(); err != nil {
return fmt.Errorf("failed to commit tx: %w", err)
}

txCommitted = true
return nil
}

func LoadSnowflake(cfg config.Config, _store *db.Store) (*Store, error) {
Expand Down
74 changes: 74 additions & 0 deletions clients/snowflake/snowflake_dedupe_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package snowflake

import (
"fmt"
"strings"

"github.com/artie-labs/transfer/clients/shared"
"github.com/artie-labs/transfer/lib/kafkalib"
"github.com/artie-labs/transfer/lib/stringutil"
"github.com/stretchr/testify/assert"
)

func (s *SnowflakeTestSuite) TestGenerateDedupeQueries() {
{
// Dedupe with one primary key + no `__artie_updated_at` flag.
tableID := NewTableIdentifier("db", "public", "customers")
stagingTableID := shared.TempTableID(tableID, strings.ToLower(stringutil.Random(5)))

parts := s.stageStore.generateDedupeQueries(tableID, stagingTableID, []string{"id"}, kafkalib.TopicConfig{})
assert.Len(s.T(), parts, 3)
assert.Equal(
s.T(),
fmt.Sprintf(`CREATE OR REPLACE TRANSIENT TABLE "%s" AS (SELECT * FROM db.public."CUSTOMERS" QUALIFY ROW_NUMBER() OVER (PARTITION BY by id ORDER BY id ASC) = 2)`, stagingTableID.Table()),
parts[0],
)
assert.Equal(s.T(), fmt.Sprintf(`DELETE FROM db.public."CUSTOMERS" t1 USING "%s" t2 WHERE t1.id = t2.id`, stagingTableID.Table()), parts[1])
assert.Equal(s.T(), fmt.Sprintf(`INSERT INTO db.public."CUSTOMERS" SELECT * FROM "%s"`, stagingTableID.Table()), parts[2])
}
{
// Dedupe with one primary key + `__artie_updated_at` flag.
tableID := NewTableIdentifier("db", "public", "customers")
stagingTableID := shared.TempTableID(tableID, strings.ToLower(stringutil.Random(5)))

parts := s.stageStore.generateDedupeQueries(tableID, stagingTableID, []string{"id"}, kafkalib.TopicConfig{IncludeArtieUpdatedAt: true})
assert.Len(s.T(), parts, 3)
assert.Equal(
s.T(),
fmt.Sprintf(`CREATE OR REPLACE TRANSIENT TABLE "%s" AS (SELECT * FROM db.public."CUSTOMERS" QUALIFY ROW_NUMBER() OVER (PARTITION BY by id ORDER BY id ASC, __artie_updated_at ASC) = 2)`, stagingTableID.Table()),
parts[0],
)
assert.Equal(s.T(), fmt.Sprintf(`DELETE FROM db.public."CUSTOMERS" t1 USING "%s" t2 WHERE t1.id = t2.id`, stagingTableID.Table()), parts[1])
assert.Equal(s.T(), fmt.Sprintf(`INSERT INTO db.public."CUSTOMERS" SELECT * FROM "%s"`, stagingTableID.Table()), parts[2])
}
{
// Dedupe with composite keys + no `__artie_updated_at` flag.
tableID := NewTableIdentifier("db", "public", "user_settings")
stagingTableID := shared.TempTableID(tableID, strings.ToLower(stringutil.Random(5)))

parts := s.stageStore.generateDedupeQueries(tableID, stagingTableID, []string{"user_id", "settings"}, kafkalib.TopicConfig{})
assert.Len(s.T(), parts, 3)
assert.Equal(
s.T(),
fmt.Sprintf(`CREATE OR REPLACE TRANSIENT TABLE "%s" AS (SELECT * FROM db.public."USER_SETTINGS" QUALIFY ROW_NUMBER() OVER (PARTITION BY by user_id, settings ORDER BY user_id ASC, settings ASC) = 2)`, stagingTableID.Table()),
parts[0],
)
assert.Equal(s.T(), fmt.Sprintf(`DELETE FROM db.public."USER_SETTINGS" t1 USING "%s" t2 WHERE t1.user_id = t2.user_id AND t1.settings = t2.settings`, stagingTableID.Table()), parts[1])
assert.Equal(s.T(), fmt.Sprintf(`INSERT INTO db.public."USER_SETTINGS" SELECT * FROM "%s"`, stagingTableID.Table()), parts[2])
}
{
// Dedupe with composite keys + `__artie_updated_at` flag.
tableID := NewTableIdentifier("db", "public", "user_settings")
stagingTableID := shared.TempTableID(tableID, strings.ToLower(stringutil.Random(5)))

parts := s.stageStore.generateDedupeQueries(tableID, stagingTableID, []string{"user_id", "settings"}, kafkalib.TopicConfig{IncludeArtieUpdatedAt: true})
assert.Len(s.T(), parts, 3)
assert.Equal(
s.T(),
fmt.Sprintf(`CREATE OR REPLACE TRANSIENT TABLE "%s" AS (SELECT * FROM db.public."USER_SETTINGS" QUALIFY ROW_NUMBER() OVER (PARTITION BY by user_id, settings ORDER BY user_id ASC, settings ASC, __artie_updated_at ASC) = 2)`, stagingTableID.Table()),
parts[0],
)
assert.Equal(s.T(), fmt.Sprintf(`DELETE FROM db.public."USER_SETTINGS" t1 USING "%s" t2 WHERE t1.user_id = t2.user_id AND t1.settings = t2.settings`, stagingTableID.Table()), parts[1])
assert.Equal(s.T(), fmt.Sprintf(`INSERT INTO db.public."USER_SETTINGS" SELECT * FROM "%s"`, stagingTableID.Table()), parts[2])
}
}
2 changes: 1 addition & 1 deletion lib/destination/dwh.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ type DataWarehouse interface {
Label() constants.DestinationKind
Merge(tableData *optimization.TableData) error
Append(tableData *optimization.TableData) error
Dedupe(tableID types.TableIdentifier) error
Dedupe(tableID types.TableIdentifier, primaryKeys []string, topicConfig kafkalib.TopicConfig) error
Exec(query string, args ...any) (sql.Result, error)
Query(query string, args ...any) (*sql.Rows, error)
Begin() (*sql.Tx, error)
Expand Down

0 comments on commit 799e4c0

Please sign in to comment.