Skip to content

Commit

Permalink
Merge branch 'master' into nv/always-uppercase-escaped-snowflake-names
Browse files Browse the repository at this point in the history
Signed-off-by: Nathan <[email protected]>
  • Loading branch information
nathan-artie authored May 2, 2024
2 parents 282335e + 4153191 commit 95dd0da
Show file tree
Hide file tree
Showing 50 changed files with 1,048 additions and 1,175 deletions.
104 changes: 90 additions & 14 deletions clients/bigquery/bigquery.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"fmt"
"log/slog"
"os"
"strings"
"time"

"cloud.google.com/go/bigquery"
_ "github.com/viant/bigquery"
Expand All @@ -19,6 +21,9 @@ import (
"github.com/artie-labs/transfer/lib/logger"
"github.com/artie-labs/transfer/lib/optimization"
"github.com/artie-labs/transfer/lib/ptr"
"github.com/artie-labs/transfer/lib/sql"
"github.com/artie-labs/transfer/lib/stringutil"
"github.com/artie-labs/transfer/lib/typing"
)

const (
Expand All @@ -39,14 +44,13 @@ type Store struct {
func (s *Store) PrepareTemporaryTable(tableData *optimization.TableData, tableConfig *types.DwhTableConfig, tempTableID types.TableIdentifier, _ types.AdditionalSettings, createTempTable bool) error {
if createTempTable {
tempAlterTableArgs := ddl.AlterTableArgs{
Dwh: s,
Tc: tableConfig,
TableID: tempTableID,
CreateTable: true,
TemporaryTable: true,
ColumnOp: constants.Add,
UppercaseEscNames: ptr.ToBool(s.ShouldUppercaseEscapedNames()),
Mode: tableData.Mode(),
Dwh: s,
Tc: tableConfig,
TableID: tempTableID,
CreateTable: true,
TemporaryTable: true,
ColumnOp: constants.Add,
Mode: tableData.Mode(),
}

if err := tempAlterTableArgs.AlterTable(tableData.ReadOnlyInMemoryCols().GetColumns()...); err != nil {
Expand Down Expand Up @@ -110,8 +114,8 @@ func (s *Store) Label() constants.DestinationKind {
return constants.BigQuery
}

func (s *Store) ShouldUppercaseEscapedNames() bool {
return false
func (s *Store) Dialect() sql.Dialect {
return sql.BigQueryDialect{}
}

func (s *Store) GetClient(ctx context.Context) *bigquery.Client {
Expand Down Expand Up @@ -143,10 +147,82 @@ func (s *Store) putTable(ctx context.Context, tableID types.TableIdentifier, row
return nil
}

func (s *Store) Dedupe(tableID types.TableIdentifier, _ []string, _ kafkalib.TopicConfig) error {
fqTableName := tableID.FullyQualifiedName()
_, err := s.Exec(fmt.Sprintf("CREATE OR REPLACE TABLE %s AS SELECT DISTINCT * FROM %s", fqTableName, fqTableName))
return err
func (s *Store) generateDedupeQueries(tableID, stagingTableID types.TableIdentifier, primaryKeys []string, topicConfig kafkalib.TopicConfig) []string {
var primaryKeysEscaped []string
for _, pk := range primaryKeys {
primaryKeysEscaped = append(primaryKeysEscaped, s.Dialect().QuoteIdentifier(pk))
}

orderColsToIterate := primaryKeysEscaped
if topicConfig.IncludeArtieUpdatedAt {
orderColsToIterate = append(orderColsToIterate, s.Dialect().QuoteIdentifier(constants.UpdateColumnMarker))
}

var orderByCols []string
for _, orderByCol := range orderColsToIterate {
orderByCols = append(orderByCols, fmt.Sprintf("%s ASC", orderByCol))
}

var parts []string
parts = append(parts,
fmt.Sprintf(`CREATE OR REPLACE TABLE %s OPTIONS (expiration_timestamp = TIMESTAMP("%s")) AS (SELECT * FROM %s QUALIFY ROW_NUMBER() OVER (PARTITION BY %s ORDER BY %s) = 2)`,
stagingTableID.FullyQualifiedName(),
typing.ExpiresDate(time.Now().UTC().Add(constants.TemporaryTableTTL)),
tableID.FullyQualifiedName(),
strings.Join(primaryKeysEscaped, ", "),
strings.Join(orderByCols, ", "),
),
)

var whereClauses []string
for _, primaryKeyEscaped := range primaryKeysEscaped {
whereClauses = append(whereClauses, fmt.Sprintf("t1.%s = t2.%s", primaryKeyEscaped, primaryKeyEscaped))
}

// https://cloud.google.com/bigquery/docs/reference/standard-sql/dml-syntax#delete_with_subquery
parts = append(parts,
fmt.Sprintf("DELETE FROM %s t1 WHERE EXISTS (SELECT * FROM %s t2 WHERE %s)",
tableID.FullyQualifiedName(),
stagingTableID.FullyQualifiedName(),
strings.Join(whereClauses, " AND "),
),
)

parts = append(parts, fmt.Sprintf("INSERT INTO %s SELECT * FROM %s", tableID.FullyQualifiedName(), stagingTableID.FullyQualifiedName()))
return parts
}

func (s *Store) Dedupe(tableID types.TableIdentifier, primaryKeys []string, topicConfig kafkalib.TopicConfig) error {
stagingTableID := shared.TempTableID(tableID, strings.ToLower(stringutil.Random(5)))

var txCommitted bool
tx, err := s.Begin()
if err != nil {
return fmt.Errorf("failed to start a tx: %w", err)
}

defer func() {
if !txCommitted {
if err = tx.Rollback(); err != nil {
slog.Warn("Failed to rollback tx", slog.Any("err", err))
}
}

_ = ddl.DropTemporaryTable(s, stagingTableID.FullyQualifiedName(), false)
}()

for _, part := range s.generateDedupeQueries(tableID, stagingTableID, primaryKeys, topicConfig) {
if _, err = tx.Exec(part); err != nil {
return fmt.Errorf("failed to execute tx, query: %q, err: %w", part, err)
}
}

if err = tx.Commit(); err != nil {
return fmt.Errorf("failed to commit tx: %w", err)
}

txCommitted = true
return nil
}

func LoadBigQuery(cfg config.Config, _store *db.Store) (*Store, error) {
Expand Down
90 changes: 90 additions & 0 deletions clients/bigquery/bigquery_dedupe_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
package bigquery

import (
"fmt"
"strings"
"time"

"github.com/stretchr/testify/assert"

"github.com/artie-labs/transfer/clients/shared"
"github.com/artie-labs/transfer/lib/config/constants"
"github.com/artie-labs/transfer/lib/kafkalib"
"github.com/artie-labs/transfer/lib/stringutil"
"github.com/artie-labs/transfer/lib/typing"
)

func (b *BigQueryTestSuite) TestGenerateDedupeQueries() {
{
// Dedupe with one primary key + no `__artie_updated_at` flag.
tableID := NewTableIdentifier("project12", "public", "customers")
stagingTableID := shared.TempTableID(tableID, strings.ToLower(stringutil.Random(5)))

parts := b.store.generateDedupeQueries(tableID, stagingTableID, []string{"id"}, kafkalib.TopicConfig{})
assert.Len(b.T(), parts, 3)
assert.Equal(
b.T(),
fmt.Sprintf("CREATE OR REPLACE TABLE %s OPTIONS (expiration_timestamp = TIMESTAMP(%s)) AS (SELECT * FROM `project12`.`public`.`customers` QUALIFY ROW_NUMBER() OVER (PARTITION BY `id` ORDER BY `id` ASC) = 2)",
stagingTableID.FullyQualifiedName(),
fmt.Sprintf(`"%s"`, typing.ExpiresDate(time.Now().UTC().Add(constants.TemporaryTableTTL))),
),
parts[0],
)
assert.Equal(b.T(), fmt.Sprintf("DELETE FROM `project12`.`public`.`customers` t1 WHERE EXISTS (SELECT * FROM %s t2 WHERE t1.`id` = t2.`id`)", stagingTableID.FullyQualifiedName()), parts[1])
assert.Equal(b.T(), fmt.Sprintf("INSERT INTO `project12`.`public`.`customers` SELECT * FROM %s", stagingTableID.FullyQualifiedName()), parts[2])
}
{
// Dedupe with one primary key + `__artie_updated_at` flag.
tableID := NewTableIdentifier("project12", "public", "customers")
stagingTableID := shared.TempTableID(tableID, strings.ToLower(stringutil.Random(5)))

parts := b.store.generateDedupeQueries(tableID, stagingTableID, []string{"id"}, kafkalib.TopicConfig{IncludeArtieUpdatedAt: true})
assert.Len(b.T(), parts, 3)
assert.Equal(
b.T(),
fmt.Sprintf("CREATE OR REPLACE TABLE %s OPTIONS (expiration_timestamp = TIMESTAMP(%s)) AS (SELECT * FROM `project12`.`public`.`customers` QUALIFY ROW_NUMBER() OVER (PARTITION BY `id` ORDER BY `id` ASC, `__artie_updated_at` ASC) = 2)",
stagingTableID.FullyQualifiedName(),
fmt.Sprintf(`"%s"`, typing.ExpiresDate(time.Now().UTC().Add(constants.TemporaryTableTTL))),
),
parts[0],
)
assert.Equal(b.T(), fmt.Sprintf("DELETE FROM `project12`.`public`.`customers` t1 WHERE EXISTS (SELECT * FROM %s t2 WHERE t1.`id` = t2.`id`)", stagingTableID.FullyQualifiedName()), parts[1])
assert.Equal(b.T(), fmt.Sprintf("INSERT INTO `project12`.`public`.`customers` SELECT * FROM %s", stagingTableID.FullyQualifiedName()), parts[2])
}
{
// Dedupe with composite keys + no `__artie_updated_at` flag.
tableID := NewTableIdentifier("project123", "public", "user_settings")
stagingTableID := shared.TempTableID(tableID, strings.ToLower(stringutil.Random(5)))

parts := b.store.generateDedupeQueries(tableID, stagingTableID, []string{"user_id", "settings"}, kafkalib.TopicConfig{})
assert.Len(b.T(), parts, 3)
assert.Equal(
b.T(),
fmt.Sprintf("CREATE OR REPLACE TABLE %s OPTIONS (expiration_timestamp = TIMESTAMP(%s)) AS (SELECT * FROM `project123`.`public`.`user_settings` QUALIFY ROW_NUMBER() OVER (PARTITION BY `user_id`, `settings` ORDER BY `user_id` ASC, `settings` ASC) = 2)",
stagingTableID.FullyQualifiedName(),
fmt.Sprintf(`"%s"`, typing.ExpiresDate(time.Now().UTC().Add(constants.TemporaryTableTTL))),
),
parts[0],
)
assert.Equal(b.T(), fmt.Sprintf("DELETE FROM `project123`.`public`.`user_settings` t1 WHERE EXISTS (SELECT * FROM %s t2 WHERE t1.`user_id` = t2.`user_id` AND t1.`settings` = t2.`settings`)", stagingTableID.FullyQualifiedName()), parts[1])
assert.Equal(b.T(), fmt.Sprintf("INSERT INTO `project123`.`public`.`user_settings` SELECT * FROM %s", stagingTableID.FullyQualifiedName()), parts[2])
}
{
// Dedupe with composite keys + `__artie_updated_at` flag.
tableID := NewTableIdentifier("project123", "public", "user_settings")
stagingTableID := shared.TempTableID(tableID, strings.ToLower(stringutil.Random(5)))

parts := b.store.generateDedupeQueries(tableID, stagingTableID, []string{"user_id", "settings"}, kafkalib.TopicConfig{IncludeArtieUpdatedAt: true})
assert.Len(b.T(), parts, 3)
assert.Equal(
b.T(),
fmt.Sprintf("CREATE OR REPLACE TABLE %s OPTIONS (expiration_timestamp = TIMESTAMP(%s)) AS (SELECT * FROM `project123`.`public`.`user_settings` QUALIFY ROW_NUMBER() OVER (PARTITION BY `user_id`, `settings` ORDER BY `user_id` ASC, `settings` ASC, `__artie_updated_at` ASC) = 2)",
stagingTableID.FullyQualifiedName(),
fmt.Sprintf(`"%s"`, typing.ExpiresDate(time.Now().UTC().Add(constants.TemporaryTableTTL))),
),
parts[0],
)
assert.Equal(b.T(), fmt.Sprintf("DELETE FROM `project123`.`public`.`user_settings` t1 WHERE EXISTS (SELECT * FROM %s t2 WHERE t1.`user_id` = t2.`user_id` AND t1.`settings` = t2.`settings`)", stagingTableID.FullyQualifiedName()), parts[1])
assert.Equal(b.T(), fmt.Sprintf("INSERT INTO `project123`.`public`.`user_settings` SELECT * FROM %s", stagingTableID.FullyQualifiedName()), parts[2])
}
}
9 changes: 8 additions & 1 deletion clients/bigquery/tableid.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@ import (
"fmt"

"github.com/artie-labs/transfer/lib/destination/types"
"github.com/artie-labs/transfer/lib/sql"
)

var dialect = sql.BigQueryDialect{}

type TableIdentifier struct {
projectID string
dataset string
Expand Down Expand Up @@ -39,5 +42,9 @@ func (ti TableIdentifier) WithTable(table string) types.TableIdentifier {
func (ti TableIdentifier) FullyQualifiedName() string {
// The fully qualified name for BigQuery is: project_id.dataset.tableName.
// We are escaping the project_id, dataset, and table because there could be special characters.
return fmt.Sprintf("`%s`.`%s`.`%s`", ti.projectID, ti.dataset, ti.table)
return fmt.Sprintf("%s.%s.%s",
dialect.QuoteIdentifier(ti.projectID),
dialect.QuoteIdentifier(ti.dataset),
dialect.QuoteIdentifier(ti.table),
)
}
16 changes: 7 additions & 9 deletions clients/mssql/staging.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,18 @@ import (
"github.com/artie-labs/transfer/lib/destination/ddl"
"github.com/artie-labs/transfer/lib/destination/types"
"github.com/artie-labs/transfer/lib/optimization"
"github.com/artie-labs/transfer/lib/ptr"
)

func (s *Store) PrepareTemporaryTable(tableData *optimization.TableData, tableConfig *types.DwhTableConfig, tempTableID types.TableIdentifier, _ types.AdditionalSettings, createTempTable bool) error {
if createTempTable {
tempAlterTableArgs := ddl.AlterTableArgs{
Dwh: s,
Tc: tableConfig,
TableID: tempTableID,
CreateTable: true,
TemporaryTable: true,
ColumnOp: constants.Add,
UppercaseEscNames: ptr.ToBool(s.ShouldUppercaseEscapedNames()),
Mode: tableData.Mode(),
Dwh: s,
Tc: tableConfig,
TableID: tempTableID,
CreateTable: true,
TemporaryTable: true,
ColumnOp: constants.Add,
Mode: tableData.Mode(),
}

if err := tempAlterTableArgs.AlterTable(tableData.ReadOnlyInMemoryCols().GetColumns()...); err != nil {
Expand Down
5 changes: 3 additions & 2 deletions clients/mssql/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/artie-labs/transfer/lib/kafkalib"
"github.com/artie-labs/transfer/lib/optimization"
"github.com/artie-labs/transfer/lib/ptr"
"github.com/artie-labs/transfer/lib/sql"
)

type Store struct {
Expand All @@ -34,8 +35,8 @@ func (s *Store) Label() constants.DestinationKind {
return constants.MSSQL
}

func (s *Store) ShouldUppercaseEscapedNames() bool {
return false
func (s *Store) Dialect() sql.Dialect {
return sql.MSSQLDialect{}
}

func (s *Store) Merge(tableData *optimization.TableData) error {
Expand Down
9 changes: 3 additions & 6 deletions clients/mssql/tableid.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@ package mssql
import (
"fmt"

"github.com/artie-labs/transfer/lib/config/constants"
"github.com/artie-labs/transfer/lib/destination/types"
"github.com/artie-labs/transfer/lib/sql"
)

var dialect = sql.MSSQLDialect{}

type TableIdentifier struct {
schema string
table string
Expand All @@ -30,9 +31,5 @@ func (ti TableIdentifier) WithTable(table string) types.TableIdentifier {
}

func (ti TableIdentifier) FullyQualifiedName() string {
return fmt.Sprintf(
"%s.%s",
ti.schema,
sql.EscapeName(ti.table, false, constants.MSSQL),
)
return fmt.Sprintf("%s.%s", ti.schema, dialect.QuoteIdentifier(ti.table))
}
5 changes: 3 additions & 2 deletions clients/redshift/redshift.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/artie-labs/transfer/lib/kafkalib"
"github.com/artie-labs/transfer/lib/optimization"
"github.com/artie-labs/transfer/lib/ptr"
"github.com/artie-labs/transfer/lib/sql"
"github.com/artie-labs/transfer/lib/stringutil"
)

Expand Down Expand Up @@ -45,8 +46,8 @@ func (s *Store) Label() constants.DestinationKind {
return constants.Redshift
}

func (s *Store) ShouldUppercaseEscapedNames() bool {
return false
func (s *Store) Dialect() sql.Dialect {
return sql.RedshiftDialect{}
}

func (s *Store) GetTableConfig(tableData *optimization.TableData) (*types.DwhTableConfig, error) {
Expand Down
2 changes: 1 addition & 1 deletion clients/redshift/redshift_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,5 @@ func TestTempTableName(t *testing.T) {
tableData := optimization.NewTableData(nil, config.Replication, nil, kafkalib.TopicConfig{Database: "db", Schema: "schema"}, "table")
tableID := (&Store{}).IdentifierFor(tableData.TopicConfig(), tableData.Name())
tempTableName := shared.TempTableID(tableID, "sUfFiX").FullyQualifiedName()
assert.Equal(t, `schema."table___artie_sUfFiX"`, trimTTL(tempTableName))
assert.Equal(t, `schema."table___artie_suffix"`, trimTTL(tempTableName))
}
16 changes: 7 additions & 9 deletions clients/redshift/staging.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,19 @@ import (
"github.com/artie-labs/transfer/lib/destination/ddl"
"github.com/artie-labs/transfer/lib/destination/types"
"github.com/artie-labs/transfer/lib/optimization"
"github.com/artie-labs/transfer/lib/ptr"
"github.com/artie-labs/transfer/lib/s3lib"
)

func (s *Store) PrepareTemporaryTable(tableData *optimization.TableData, tableConfig *types.DwhTableConfig, tempTableID types.TableIdentifier, _ types.AdditionalSettings, _ bool) error {
// Redshift always creates a temporary table.
tempAlterTableArgs := ddl.AlterTableArgs{
Dwh: s,
Tc: tableConfig,
TableID: tempTableID,
CreateTable: true,
TemporaryTable: true,
ColumnOp: constants.Add,
UppercaseEscNames: ptr.ToBool(s.ShouldUppercaseEscapedNames()),
Mode: tableData.Mode(),
Dwh: s,
Tc: tableConfig,
TableID: tempTableID,
CreateTable: true,
TemporaryTable: true,
ColumnOp: constants.Add,
Mode: tableData.Mode(),
}

if err := tempAlterTableArgs.AlterTable(tableData.ReadOnlyInMemoryCols().GetColumns()...); err != nil {
Expand Down
Loading

0 comments on commit 95dd0da

Please sign in to comment.