Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve Snowflake Dedupe. #500

Merged
merged 15 commits into from
Apr 24, 2024
2 changes: 1 addition & 1 deletion clients/bigquery/bigquery.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ func (s *Store) putTable(ctx context.Context, tableID types.TableIdentifier, row
return nil
}

func (s *Store) Dedupe(tableID types.TableIdentifier) error {
func (s *Store) Dedupe(tableID types.TableIdentifier, _ []string, _ kafkalib.TopicConfig) error {
fqTableName := tableID.FullyQualifiedName()
_, err := s.Exec(fmt.Sprintf("CREATE OR REPLACE TABLE %s AS SELECT DISTINCT * FROM %s", fqTableName, fqTableName))
return err
Expand Down
2 changes: 1 addition & 1 deletion clients/mssql/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func (s *Store) Sweep() error {
return shared.Sweep(s, tcs, queryFunc)
}

func (s *Store) Dedupe(_ types.TableIdentifier) error {
func (s *Store) Dedupe(_ types.TableIdentifier, _ []string, _ kafkalib.TopicConfig) error {
return nil // dedupe is not necessary for MS SQL
}

Expand Down
2 changes: 1 addition & 1 deletion clients/redshift/redshift.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ WHERE
return shared.Sweep(s, tcs, queryFunc)
}

func (s *Store) Dedupe(tableID types.TableIdentifier) error {
func (s *Store) Dedupe(tableID types.TableIdentifier, _ []string, _ kafkalib.TopicConfig) error {
fqTableName := tableID.FullyQualifiedName()
stagingTableName := shared.TempTableID(tableID, strings.ToLower(stringutil.Random(5))).FullyQualifiedName()

Expand Down
10 changes: 6 additions & 4 deletions clients/shared/table_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,12 @@ func TestGetTableCfgArgs_ShouldParseComment(t *testing.T) {

type MockDWH struct{}

func (MockDWH) Label() constants.DestinationKind { panic("not implemented") }
func (MockDWH) Merge(tableData *optimization.TableData) error { panic("not implemented") }
func (MockDWH) Append(tableData *optimization.TableData) error { panic("not implemented") }
func (MockDWH) Dedupe(tableID types.TableIdentifier) error { panic("not implemented") }
func (MockDWH) Label() constants.DestinationKind { panic("not implemented") }
func (MockDWH) Merge(tableData *optimization.TableData) error { panic("not implemented") }
func (MockDWH) Append(tableData *optimization.TableData) error { panic("not implemented") }
func (MockDWH) Dedupe(tableID types.TableIdentifier, primaryKeys []string, topicConfig kafkalib.TopicConfig) error {
panic("not implemented")
}
func (MockDWH) Exec(query string, args ...any) (sql.Result, error) { panic("not implemented") }
func (MockDWH) Query(query string, args ...any) (*sql.Rows, error) { panic("not implemented") }
func (MockDWH) Begin() (*sql.Tx, error) { panic("not implemented") }
Expand Down
76 changes: 73 additions & 3 deletions clients/snowflake/snowflake.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package snowflake

import (
"fmt"
"strings"

"github.com/snowflakedb/gosnowflake"

Expand All @@ -13,6 +14,9 @@ import (
"github.com/artie-labs/transfer/lib/kafkalib"
"github.com/artie-labs/transfer/lib/optimization"
"github.com/artie-labs/transfer/lib/ptr"
"github.com/artie-labs/transfer/lib/stringutil"
"github.com/artie-labs/transfer/lib/typing"
"github.com/artie-labs/transfer/lib/typing/columns"
)

const maxRetries = 10
Expand Down Expand Up @@ -119,10 +123,76 @@ func (s *Store) reestablishConnection() error {
return nil
}

func (s *Store) Dedupe(tableID types.TableIdentifier) error {
func (s *Store) generateDedupeQueries(tableID, stagingTableID types.TableIdentifier, primaryKeys []string, topicConfig kafkalib.TopicConfig) []string {
var primaryKeysEscaped []string
for _, pk := range primaryKeys {
pkCol := columns.NewColumn(pk, typing.Invalid)

primaryKeysEscaped = append(primaryKeysEscaped, pkCol.Name(s.ShouldUppercaseEscapedNames(), &columns.NameArgs{DestKind: s.Label()}))
}

orderColsToIterate := primaryKeysEscaped
if topicConfig.IncludeArtieUpdatedAt {
orderColsToIterate = append(orderColsToIterate, constants.UpdateColumnMarker)
}

var orderByCols []string
for _, pk := range orderColsToIterate {
orderByCols = append(orderByCols, fmt.Sprintf("%s ASC", pk))
}

fqTableName := tableID.FullyQualifiedName()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tableID.FullyQualifiedName() is so short now that it doesn't take arguments we could just inline it instead of assigning to fqTableName.

_, err := s.Exec(fmt.Sprintf("CREATE OR REPLACE TABLE %s AS SELECT DISTINCT * FROM %s", fqTableName, fqTableName))
return err
var parts []string
parts = append(parts, fmt.Sprintf("CREATE OR REPLACE TRANSIENT TABLE %s AS (SELECT * FROM %s QUALIFY ROW_NUMBER() OVER (PARTITION BY by %s ORDER BY %s) = 2)",
stagingTableID.Table(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use stagingTableID.FullyQualifiedName() or escape stagingTableID.Table() with sql.EscapeName

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a transient table, it doesn't belong to a schema or database.

fqTableName,
strings.Join(primaryKeysEscaped, ", "),
strings.Join(orderByCols, ", "),
))

var whereClauses []string
for _, primaryKeyEscaped := range primaryKeysEscaped {
whereClauses = append(whereClauses, fmt.Sprintf("t1.%s = t2.%s", primaryKeyEscaped, primaryKeyEscaped))
}

parts = append(parts, fmt.Sprintf("DELETE FROM %s t1 USING %s t2 WHERE %s",
fqTableName,
stagingTableID.Table(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto escape or use FullyQualifiedName().

strings.Join(whereClauses, ", "),
nathan-artie marked this conversation as resolved.
Show resolved Hide resolved
))

parts = append(parts, fmt.Sprintf("INSERT INTO %s SELECT * FROM %s", fqTableName, stagingTableID.Table()))
return parts
}

// Dedupe takes a table and will remove duplicates based on the primary key(s).
// These queries are inspired and modified from: https://stackoverflow.com/a/71515946
func (s *Store) Dedupe(tableID types.TableIdentifier, primaryKeys []string, topicConfig kafkalib.TopicConfig) error {
var txCommitted bool
tx, err := s.Begin()
if err != nil {
return fmt.Errorf("failed to start a tx: %w", err)
}

defer func() {
if !txCommitted {
tx.Rollback()
nathan-artie marked this conversation as resolved.
Show resolved Hide resolved
}
}()

stagingTableID := shared.TempTableID(tableID, strings.ToLower(stringutil.Random(5)))
for _, part := range s.generateDedupeQueries(tableID, stagingTableID, primaryKeys, topicConfig) {
nathan-artie marked this conversation as resolved.
Show resolved Hide resolved
if _, err = tx.Exec(part); err != nil {
return fmt.Errorf("failed to execute tx, query: %q, err: %w", part, err)
}
}

if err = tx.Commit(); err != nil {
return fmt.Errorf("failed to commit tx: %w", err)
}

txCommitted = true
return nil
}

func LoadSnowflake(cfg config.Config, _store *db.Store) (*Store, error) {
Expand Down
74 changes: 74 additions & 0 deletions clients/snowflake/snowflake_dedupe_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package snowflake

import (
"fmt"
"strings"

"github.com/artie-labs/transfer/clients/shared"
"github.com/artie-labs/transfer/lib/kafkalib"
"github.com/artie-labs/transfer/lib/stringutil"
"github.com/stretchr/testify/assert"
)

func (s *SnowflakeTestSuite) TestDedupe() {
nathan-artie marked this conversation as resolved.
Show resolved Hide resolved
{
// Dedupe with one primary key + no `__artie_updated_at` flag.
tableID := NewTableIdentifier("db", "public", "customers")
stagingTableID := shared.TempTableID(tableID, strings.ToLower(stringutil.Random(5)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be nice to have two methods TempTableID and TempTableIDWithSuffix.


parts := s.stageStore.generateDedupeQueries(tableID, stagingTableID, []string{"id"}, kafkalib.TopicConfig{})
assert.Len(s.T(), parts, 3)
assert.Equal(
s.T(),
fmt.Sprintf("CREATE OR REPLACE TRANSIENT TABLE %s AS (SELECT * FROM db.public.customers QUALIFY ROW_NUMBER() OVER (PARTITION BY by id ORDER BY id ASC) = 2)", stagingTableID.Table()),
nathan-artie marked this conversation as resolved.
Show resolved Hide resolved
parts[0],
)
assert.Equal(s.T(), fmt.Sprintf("DELETE FROM db.public.customers t1 USING %s t2 WHERE t1.id = t2.id", stagingTableID.Table()), parts[1])
assert.Equal(s.T(), fmt.Sprintf("INSERT INTO db.public.customers SELECT * FROM %s", stagingTableID.Table()), parts[2])
}
{
// Dedupe with one primary key + `__artie_updated_at` flag.
tableID := NewTableIdentifier("db", "public", "customers")
stagingTableID := shared.TempTableID(tableID, strings.ToLower(stringutil.Random(5)))

parts := s.stageStore.generateDedupeQueries(tableID, stagingTableID, []string{"id"}, kafkalib.TopicConfig{IncludeArtieUpdatedAt: true})
assert.Len(s.T(), parts, 3)
assert.Equal(
s.T(),
fmt.Sprintf("CREATE OR REPLACE TRANSIENT TABLE %s AS (SELECT * FROM db.public.customers QUALIFY ROW_NUMBER() OVER (PARTITION BY by id ORDER BY id ASC, __artie_updated_at ASC) = 2)", stagingTableID.Table()),
parts[0],
)
assert.Equal(s.T(), fmt.Sprintf("DELETE FROM db.public.customers t1 USING %s t2 WHERE t1.id = t2.id", stagingTableID.Table()), parts[1])
assert.Equal(s.T(), fmt.Sprintf("INSERT INTO db.public.customers SELECT * FROM %s", stagingTableID.Table()), parts[2])
}
{
// Dedupe with composite keys + no `__artie_updated_at` flag.
tableID := NewTableIdentifier("db", "public", "user_settings")
stagingTableID := shared.TempTableID(tableID, strings.ToLower(stringutil.Random(5)))

parts := s.stageStore.generateDedupeQueries(tableID, stagingTableID, []string{"user_id", "settings"}, kafkalib.TopicConfig{})
assert.Len(s.T(), parts, 3)
assert.Equal(
s.T(),
fmt.Sprintf("CREATE OR REPLACE TRANSIENT TABLE %s AS (SELECT * FROM db.public.user_settings QUALIFY ROW_NUMBER() OVER (PARTITION BY by user_id, settings ORDER BY user_id ASC, settings ASC) = 2)", stagingTableID.Table()),
parts[0],
)
assert.Equal(s.T(), fmt.Sprintf("DELETE FROM db.public.user_settings t1 USING %s t2 WHERE t1.user_id = t2.user_id, t1.settings = t2.settings", stagingTableID.Table()), parts[1])
assert.Equal(s.T(), fmt.Sprintf("INSERT INTO db.public.user_settings SELECT * FROM %s", stagingTableID.Table()), parts[2])
}
{
// Dedupe with composite keys + `__artie_updated_at` flag.
tableID := NewTableIdentifier("db", "public", "user_settings")
stagingTableID := shared.TempTableID(tableID, strings.ToLower(stringutil.Random(5)))

parts := s.stageStore.generateDedupeQueries(tableID, stagingTableID, []string{"user_id", "settings"}, kafkalib.TopicConfig{IncludeArtieUpdatedAt: true})
assert.Len(s.T(), parts, 3)
assert.Equal(
s.T(),
fmt.Sprintf("CREATE OR REPLACE TRANSIENT TABLE %s AS (SELECT * FROM db.public.user_settings QUALIFY ROW_NUMBER() OVER (PARTITION BY by user_id, settings ORDER BY user_id ASC, settings ASC, __artie_updated_at ASC) = 2)", stagingTableID.Table()),
parts[0],
)
assert.Equal(s.T(), fmt.Sprintf("DELETE FROM db.public.user_settings t1 USING %s t2 WHERE t1.user_id = t2.user_id, t1.settings = t2.settings", stagingTableID.Table()), parts[1])
assert.Equal(s.T(), fmt.Sprintf("INSERT INTO db.public.user_settings SELECT * FROM %s", stagingTableID.Table()), parts[2])
}
}
2 changes: 1 addition & 1 deletion lib/destination/dwh.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ type DataWarehouse interface {
Label() constants.DestinationKind
Merge(tableData *optimization.TableData) error
Append(tableData *optimization.TableData) error
Dedupe(tableID types.TableIdentifier) error
Dedupe(tableID types.TableIdentifier, primaryKeys []string, topicConfig kafkalib.TopicConfig) error
Exec(query string, args ...any) (sql.Result, error)
Query(query string, args ...any) (*sql.Rows, error)
Begin() (*sql.Tx, error)
Expand Down