Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve Snowflake Dedupe. #500

Merged
merged 15 commits into from
Apr 24, 2024
2 changes: 1 addition & 1 deletion clients/bigquery/bigquery.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ func (s *Store) putTable(ctx context.Context, tableID types.TableIdentifier, row
return nil
}

func (s *Store) Dedupe(tableID types.TableIdentifier) error {
func (s *Store) Dedupe(tableID types.TableIdentifier, _ []string, _ kafkalib.TopicConfig) error {
fqTableName := tableID.FullyQualifiedName()
_, err := s.Exec(fmt.Sprintf("CREATE OR REPLACE TABLE %s AS SELECT DISTINCT * FROM %s", fqTableName, fqTableName))
return err
Expand Down
2 changes: 1 addition & 1 deletion clients/mssql/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func (s *Store) Sweep() error {
return shared.Sweep(s, tcs, queryFunc)
}

func (s *Store) Dedupe(_ types.TableIdentifier) error {
func (s *Store) Dedupe(_ types.TableIdentifier, _ []string, _ kafkalib.TopicConfig) error {
return nil // dedupe is not necessary for MS SQL
}

Expand Down
2 changes: 1 addition & 1 deletion clients/redshift/redshift.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ WHERE
return shared.Sweep(s, tcs, queryFunc)
}

func (s *Store) Dedupe(tableID types.TableIdentifier) error {
func (s *Store) Dedupe(tableID types.TableIdentifier, _ []string, _ kafkalib.TopicConfig) error {
fqTableName := tableID.FullyQualifiedName()
stagingTableName := shared.TempTableID(tableID, strings.ToLower(stringutil.Random(5))).FullyQualifiedName()

Expand Down
10 changes: 6 additions & 4 deletions clients/shared/table_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,12 @@ func TestGetTableCfgArgs_ShouldParseComment(t *testing.T) {

type MockDWH struct{}

func (MockDWH) Label() constants.DestinationKind { panic("not implemented") }
func (MockDWH) Merge(tableData *optimization.TableData) error { panic("not implemented") }
func (MockDWH) Append(tableData *optimization.TableData) error { panic("not implemented") }
func (MockDWH) Dedupe(tableID types.TableIdentifier) error { panic("not implemented") }
func (MockDWH) Label() constants.DestinationKind { panic("not implemented") }
func (MockDWH) Merge(tableData *optimization.TableData) error { panic("not implemented") }
func (MockDWH) Append(tableData *optimization.TableData) error { panic("not implemented") }
func (MockDWH) Dedupe(tableID types.TableIdentifier, primaryKeys []string, topicConfig kafkalib.TopicConfig) error {
panic("not implemented")
}
func (MockDWH) Exec(query string, args ...any) (sql.Result, error) { panic("not implemented") }
func (MockDWH) Query(query string, args ...any) (*sql.Rows, error) { panic("not implemented") }
func (MockDWH) Begin() (*sql.Tx, error) { panic("not implemented") }
Expand Down
81 changes: 77 additions & 4 deletions clients/snowflake/snowflake.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package snowflake

import (
"fmt"
"log/slog"
"strings"

"github.com/snowflakedb/gosnowflake"

Expand All @@ -13,6 +15,10 @@ import (
"github.com/artie-labs/transfer/lib/kafkalib"
"github.com/artie-labs/transfer/lib/optimization"
"github.com/artie-labs/transfer/lib/ptr"
"github.com/artie-labs/transfer/lib/sql"
"github.com/artie-labs/transfer/lib/stringutil"
"github.com/artie-labs/transfer/lib/typing"
"github.com/artie-labs/transfer/lib/typing/columns"
)

const maxRetries = 10
Expand Down Expand Up @@ -119,10 +125,77 @@ func (s *Store) reestablishConnection() error {
return nil
}

func (s *Store) Dedupe(tableID types.TableIdentifier) error {
fqTableName := tableID.FullyQualifiedName()
_, err := s.Exec(fmt.Sprintf("CREATE OR REPLACE TABLE %s AS SELECT DISTINCT * FROM %s", fqTableName, fqTableName))
return err
func (s *Store) generateDedupeQueries(tableID, stagingTableID types.TableIdentifier, primaryKeys []string, topicConfig kafkalib.TopicConfig) []string {
var primaryKeysEscaped []string
for _, pk := range primaryKeys {
pkCol := columns.NewColumn(pk, typing.Invalid)
primaryKeysEscaped = append(primaryKeysEscaped, pkCol.Name(s.ShouldUppercaseEscapedNames(), &columns.NameArgs{DestKind: s.Label()}))
}

orderColsToIterate := primaryKeysEscaped
if topicConfig.IncludeArtieUpdatedAt {
orderColsToIterate = append(orderColsToIterate, constants.UpdateColumnMarker)
}

var orderByCols []string
for _, pk := range orderColsToIterate {
orderByCols = append(orderByCols, fmt.Sprintf("%s ASC", pk))
}

temporaryTableName := sql.EscapeName(stagingTableID.Table(), s.ShouldUppercaseEscapedNames(), s.Label())
var parts []string
parts = append(parts, fmt.Sprintf("CREATE OR REPLACE TRANSIENT TABLE %s AS (SELECT * FROM %s QUALIFY ROW_NUMBER() OVER (PARTITION BY by %s ORDER BY %s) = 2)",
temporaryTableName,
tableID.FullyQualifiedName(),
strings.Join(primaryKeysEscaped, ", "),
strings.Join(orderByCols, ", "),
))

var whereClauses []string
for _, primaryKeyEscaped := range primaryKeysEscaped {
whereClauses = append(whereClauses, fmt.Sprintf("t1.%s = t2.%s", primaryKeyEscaped, primaryKeyEscaped))
}

parts = append(parts, fmt.Sprintf("DELETE FROM %s t1 USING %s t2 WHERE %s",
tableID.FullyQualifiedName(),
temporaryTableName,
strings.Join(whereClauses, " AND "),
))

parts = append(parts, fmt.Sprintf("INSERT INTO %s SELECT * FROM %s", tableID.FullyQualifiedName(), temporaryTableName))
return parts
}

// Dedupe takes a table and will remove duplicates based on the primary key(s).
// These queries are inspired and modified from: https://stackoverflow.com/a/71515946
func (s *Store) Dedupe(tableID types.TableIdentifier, primaryKeys []string, topicConfig kafkalib.TopicConfig) error {
var txCommitted bool
tx, err := s.Begin()
if err != nil {
return fmt.Errorf("failed to start a tx: %w", err)
}

defer func() {
if !txCommitted {
if err = tx.Rollback(); err != nil {
slog.Warn("Failed to rollback tx", slog.Any("err", err))
}
}
}()

stagingTableID := shared.TempTableID(tableID, strings.ToLower(stringutil.Random(5)))
for _, part := range s.generateDedupeQueries(tableID, stagingTableID, primaryKeys, topicConfig) {
nathan-artie marked this conversation as resolved.
Show resolved Hide resolved
if _, err = tx.Exec(part); err != nil {
return fmt.Errorf("failed to execute tx, query: %q, err: %w", part, err)
}
}

if err = tx.Commit(); err != nil {
return fmt.Errorf("failed to commit tx: %w", err)
}

txCommitted = true
return nil
}

func LoadSnowflake(cfg config.Config, _store *db.Store) (*Store, error) {
Expand Down
74 changes: 74 additions & 0 deletions clients/snowflake/snowflake_dedupe_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package snowflake

import (
"fmt"
"strings"

"github.com/artie-labs/transfer/clients/shared"
"github.com/artie-labs/transfer/lib/kafkalib"
"github.com/artie-labs/transfer/lib/stringutil"
"github.com/stretchr/testify/assert"
)

func (s *SnowflakeTestSuite) TestGenerateDedupeQueries() {
{
// Dedupe with one primary key + no `__artie_updated_at` flag.
tableID := NewTableIdentifier("db", "public", "customers")
stagingTableID := shared.TempTableID(tableID, strings.ToLower(stringutil.Random(5)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be nice to have two methods TempTableID and TempTableIDWithSuffix.


parts := s.stageStore.generateDedupeQueries(tableID, stagingTableID, []string{"id"}, kafkalib.TopicConfig{})
assert.Len(s.T(), parts, 3)
assert.Equal(
s.T(),
fmt.Sprintf(`CREATE OR REPLACE TRANSIENT TABLE "%s" AS (SELECT * FROM db.public."CUSTOMERS" QUALIFY ROW_NUMBER() OVER (PARTITION BY by id ORDER BY id ASC) = 2)`, stagingTableID.Table()),
parts[0],
)
assert.Equal(s.T(), fmt.Sprintf(`DELETE FROM db.public."CUSTOMERS" t1 USING "%s" t2 WHERE t1.id = t2.id`, stagingTableID.Table()), parts[1])
assert.Equal(s.T(), fmt.Sprintf(`INSERT INTO db.public."CUSTOMERS" SELECT * FROM "%s"`, stagingTableID.Table()), parts[2])
}
{
// Dedupe with one primary key + `__artie_updated_at` flag.
tableID := NewTableIdentifier("db", "public", "customers")
stagingTableID := shared.TempTableID(tableID, strings.ToLower(stringutil.Random(5)))

parts := s.stageStore.generateDedupeQueries(tableID, stagingTableID, []string{"id"}, kafkalib.TopicConfig{IncludeArtieUpdatedAt: true})
assert.Len(s.T(), parts, 3)
assert.Equal(
s.T(),
fmt.Sprintf(`CREATE OR REPLACE TRANSIENT TABLE "%s" AS (SELECT * FROM db.public."CUSTOMERS" QUALIFY ROW_NUMBER() OVER (PARTITION BY by id ORDER BY id ASC, __artie_updated_at ASC) = 2)`, stagingTableID.Table()),
parts[0],
)
assert.Equal(s.T(), fmt.Sprintf(`DELETE FROM db.public."CUSTOMERS" t1 USING "%s" t2 WHERE t1.id = t2.id`, stagingTableID.Table()), parts[1])
assert.Equal(s.T(), fmt.Sprintf(`INSERT INTO db.public."CUSTOMERS" SELECT * FROM "%s"`, stagingTableID.Table()), parts[2])
}
{
// Dedupe with composite keys + no `__artie_updated_at` flag.
tableID := NewTableIdentifier("db", "public", "user_settings")
stagingTableID := shared.TempTableID(tableID, strings.ToLower(stringutil.Random(5)))

parts := s.stageStore.generateDedupeQueries(tableID, stagingTableID, []string{"user_id", "settings"}, kafkalib.TopicConfig{})
assert.Len(s.T(), parts, 3)
assert.Equal(
s.T(),
fmt.Sprintf(`CREATE OR REPLACE TRANSIENT TABLE "%s" AS (SELECT * FROM db.public."USER_SETTINGS" QUALIFY ROW_NUMBER() OVER (PARTITION BY by user_id, settings ORDER BY user_id ASC, settings ASC) = 2)`, stagingTableID.Table()),
parts[0],
)
assert.Equal(s.T(), fmt.Sprintf(`DELETE FROM db.public."USER_SETTINGS" t1 USING "%s" t2 WHERE t1.user_id = t2.user_id AND t1.settings = t2.settings`, stagingTableID.Table()), parts[1])
assert.Equal(s.T(), fmt.Sprintf(`INSERT INTO db.public."USER_SETTINGS" SELECT * FROM "%s"`, stagingTableID.Table()), parts[2])
}
{
// Dedupe with composite keys + `__artie_updated_at` flag.
tableID := NewTableIdentifier("db", "public", "user_settings")
stagingTableID := shared.TempTableID(tableID, strings.ToLower(stringutil.Random(5)))

parts := s.stageStore.generateDedupeQueries(tableID, stagingTableID, []string{"user_id", "settings"}, kafkalib.TopicConfig{IncludeArtieUpdatedAt: true})
assert.Len(s.T(), parts, 3)
assert.Equal(
s.T(),
fmt.Sprintf(`CREATE OR REPLACE TRANSIENT TABLE "%s" AS (SELECT * FROM db.public."USER_SETTINGS" QUALIFY ROW_NUMBER() OVER (PARTITION BY by user_id, settings ORDER BY user_id ASC, settings ASC, __artie_updated_at ASC) = 2)`, stagingTableID.Table()),
parts[0],
)
assert.Equal(s.T(), fmt.Sprintf(`DELETE FROM db.public."USER_SETTINGS" t1 USING "%s" t2 WHERE t1.user_id = t2.user_id AND t1.settings = t2.settings`, stagingTableID.Table()), parts[1])
assert.Equal(s.T(), fmt.Sprintf(`INSERT INTO db.public."USER_SETTINGS" SELECT * FROM "%s"`, stagingTableID.Table()), parts[2])
}
}
2 changes: 1 addition & 1 deletion lib/destination/dwh.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ type DataWarehouse interface {
Label() constants.DestinationKind
Merge(tableData *optimization.TableData) error
Append(tableData *optimization.TableData) error
Dedupe(tableID types.TableIdentifier) error
Dedupe(tableID types.TableIdentifier, primaryKeys []string, topicConfig kafkalib.TopicConfig) error
Exec(query string, args ...any) (sql.Result, error)
Query(query string, args ...any) (*sql.Rows, error)
Begin() (*sql.Tx, error)
Expand Down