Skip to content

Commit

Permalink
Merge branch 'master' into nv/always-uppercase-escaped-snowflake-names
Browse files Browse the repository at this point in the history
  • Loading branch information
nathan-artie committed Apr 30, 2024
2 parents e01e7c5 + c9e9350 commit 282335e
Show file tree
Hide file tree
Showing 24 changed files with 301 additions and 234 deletions.
12 changes: 6 additions & 6 deletions clients/bigquery/merge_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,20 +41,20 @@ func (b *BigQueryTestSuite) TestBackfillColumn() {
{
name: "col that has default value that needs to be backfilled (boolean)",
col: needsBackfillCol,
backfillSQL: "UPDATE `db`.`public`.`tableName` SET foo = true WHERE foo IS NULL;",
commentSQL: "ALTER TABLE `db`.`public`.`tableName` ALTER COLUMN foo SET OPTIONS (description=`{\"backfilled\": true}`);",
backfillSQL: "UPDATE `db`.`public`.`tableName` SET `foo` = true WHERE `foo` IS NULL;",
commentSQL: "ALTER TABLE `db`.`public`.`tableName` ALTER COLUMN `foo` SET OPTIONS (description=`{\"backfilled\": true}`);",
},
{
name: "col that has default value that needs to be backfilled (string)",
col: needsBackfillColStr,
backfillSQL: "UPDATE `db`.`public`.`tableName` SET foo2 = 'hello there' WHERE foo2 IS NULL;",
commentSQL: "ALTER TABLE `db`.`public`.`tableName` ALTER COLUMN foo2 SET OPTIONS (description=`{\"backfilled\": true}`);",
backfillSQL: "UPDATE `db`.`public`.`tableName` SET `foo2` = 'hello there' WHERE `foo2` IS NULL;",
commentSQL: "ALTER TABLE `db`.`public`.`tableName` ALTER COLUMN `foo2` SET OPTIONS (description=`{\"backfilled\": true}`);",
},
{
name: "col that has default value that needs to be backfilled (number)",
col: needsBackfillColNum,
backfillSQL: "UPDATE `db`.`public`.`tableName` SET foo3 = 3.5 WHERE foo3 IS NULL;",
commentSQL: "ALTER TABLE `db`.`public`.`tableName` ALTER COLUMN foo3 SET OPTIONS (description=`{\"backfilled\": true}`);",
backfillSQL: "UPDATE `db`.`public`.`tableName` SET `foo3` = 3.5 WHERE `foo3` IS NULL;",
commentSQL: "ALTER TABLE `db`.`public`.`tableName` ALTER COLUMN `foo3` SET OPTIONS (description=`{\"backfilled\": true}`);",
},
}

Expand Down
14 changes: 7 additions & 7 deletions clients/shared/append.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package shared

import (
"log/slog"
"fmt"

"github.com/artie-labs/transfer/lib/config/constants"
"github.com/artie-labs/transfer/lib/destination"
Expand All @@ -20,7 +20,7 @@ func Append(dwh destination.DataWarehouse, tableData *optimization.TableData, op
tableID := dwh.IdentifierFor(tableData.TopicConfig(), tableData.Name())
tableConfig, err := dwh.GetTableConfig(tableData)
if err != nil {
return err
return fmt.Errorf("failed to get table config: %w", err)
}

// We don't care about srcKeysMissing because we don't drop columns when we append.
Expand All @@ -40,13 +40,13 @@ func Append(dwh destination.DataWarehouse, tableData *optimization.TableData, op
}

// Keys that exist in CDC stream, but not in DWH
err = createAlterTableArgs.AlterTable(targetKeysMissing...)
if err != nil {
slog.Warn("Failed to apply alter table", slog.Any("err", err))
return err
if err = createAlterTableArgs.AlterTable(targetKeysMissing...); err != nil {
return fmt.Errorf("failed to alter table: %w", err)
}

tableData.MergeColumnsFromDestination(tableConfig.Columns().GetColumns()...)
if err = tableData.MergeColumnsFromDestination(tableConfig.Columns().GetColumns()...); err != nil {
return fmt.Errorf("failed to merge columns from destination: %w", err)
}

additionalSettings := types.AdditionalSettings{
AdditionalCopyClause: opts.AdditionalCopyClause,
Expand Down
5 changes: 4 additions & 1 deletion clients/shared/merge.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,10 @@ func Merge(dwh destination.DataWarehouse, tableData *optimization.TableData, cfg
}

tableConfig.AuditColumnsToDelete(srcKeysMissing)
tableData.MergeColumnsFromDestination(tableConfig.Columns().GetColumns()...)
if err = tableData.MergeColumnsFromDestination(tableConfig.Columns().GetColumns()...); err != nil {
return fmt.Errorf("failed to merge columns from destination: %w", err)
}

temporaryTableID := TempTableID(dwh.IdentifierFor(tableData.TopicConfig(), tableData.Name()), tableData.TempTableSuffix())
temporaryTableName := temporaryTableID.FullyQualifiedName()
if err = dwh.PrepareTemporaryTable(tableData, tableConfig, temporaryTableID, types.AdditionalSettings{}, true); err != nil {
Expand Down
10 changes: 4 additions & 6 deletions clients/snowflake/snowflake.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ import (
"github.com/artie-labs/transfer/lib/kafkalib"
"github.com/artie-labs/transfer/lib/optimization"
"github.com/artie-labs/transfer/lib/ptr"
"github.com/artie-labs/transfer/lib/sql"
"github.com/artie-labs/transfer/lib/stringutil"
"github.com/artie-labs/transfer/lib/typing"
"github.com/artie-labs/transfer/lib/typing/columns"
Expand Down Expand Up @@ -146,10 +145,9 @@ func (s *Store) generateDedupeQueries(tableID, stagingTableID types.TableIdentif
orderByCols = append(orderByCols, fmt.Sprintf("%s ASC", pk))
}

temporaryTableName := sql.EscapeName(stagingTableID.Table(), s.ShouldUppercaseEscapedNames(), s.Label())
var parts []string
parts = append(parts, fmt.Sprintf("CREATE OR REPLACE TRANSIENT TABLE %s AS (SELECT * FROM %s QUALIFY ROW_NUMBER() OVER (PARTITION BY by %s ORDER BY %s) = 2)",
temporaryTableName,
parts = append(parts, fmt.Sprintf("CREATE OR REPLACE TRANSIENT TABLE %s AS (SELECT * FROM %s QUALIFY ROW_NUMBER() OVER (PARTITION BY %s ORDER BY %s) = 2)",
stagingTableID.FullyQualifiedName(),
tableID.FullyQualifiedName(),
strings.Join(primaryKeysEscaped, ", "),
strings.Join(orderByCols, ", "),
Expand All @@ -162,11 +160,11 @@ func (s *Store) generateDedupeQueries(tableID, stagingTableID types.TableIdentif

parts = append(parts, fmt.Sprintf("DELETE FROM %s t1 USING %s t2 WHERE %s",
tableID.FullyQualifiedName(),
temporaryTableName,
stagingTableID.FullyQualifiedName(),
strings.Join(whereClauses, " AND "),
))

parts = append(parts, fmt.Sprintf("INSERT INTO %s SELECT * FROM %s", tableID.FullyQualifiedName(), temporaryTableName))
parts = append(parts, fmt.Sprintf("INSERT INTO %s SELECT * FROM %s", tableID.FullyQualifiedName(), stagingTableID.FullyQualifiedName()))
return parts
}

Expand Down
28 changes: 12 additions & 16 deletions clients/snowflake/snowflake_dedupe_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,64 +15,60 @@ func (s *SnowflakeTestSuite) TestGenerateDedupeQueries() {
// Dedupe with one primary key + no `__artie_updated_at` flag.
tableID := NewTableIdentifier("db", "public", "customers")
stagingTableID := shared.TempTableID(tableID, strings.ToLower(stringutil.Random(5)))
stagingTableName := strings.ToUpper(stagingTableID.Table())

parts := s.stageStore.generateDedupeQueries(tableID, stagingTableID, []string{"id"}, kafkalib.TopicConfig{})
assert.Len(s.T(), parts, 3)
assert.Equal(
s.T(),
fmt.Sprintf(`CREATE OR REPLACE TRANSIENT TABLE "%s" AS (SELECT * FROM db.public."CUSTOMERS" QUALIFY ROW_NUMBER() OVER (PARTITION BY by id ORDER BY id ASC) = 2)`, stagingTableName),
fmt.Sprintf(`CREATE OR REPLACE TRANSIENT TABLE %s AS (SELECT * FROM db.public."CUSTOMERS" QUALIFY ROW_NUMBER() OVER (PARTITION BY id ORDER BY id ASC) = 2)`, stagingTableID.FullyQualifiedName()),
parts[0],
)
assert.Equal(s.T(), fmt.Sprintf(`DELETE FROM db.public."CUSTOMERS" t1 USING "%s" t2 WHERE t1.id = t2.id`, stagingTableName), parts[1])
assert.Equal(s.T(), fmt.Sprintf(`INSERT INTO db.public."CUSTOMERS" SELECT * FROM "%s"`, stagingTableName), parts[2])
assert.Equal(s.T(), fmt.Sprintf(`DELETE FROM db.public."CUSTOMERS" t1 USING %s t2 WHERE t1.id = t2.id`, stagingTableID.FullyQualifiedName()), parts[1])
assert.Equal(s.T(), fmt.Sprintf(`INSERT INTO db.public."CUSTOMERS" SELECT * FROM %s`, stagingTableID.FullyQualifiedName()), parts[2])
}
{
// Dedupe with one primary key + `__artie_updated_at` flag.
tableID := NewTableIdentifier("db", "public", "customers")
stagingTableID := shared.TempTableID(tableID, strings.ToLower(stringutil.Random(5)))
stagingTableName := strings.ToUpper(stagingTableID.Table())

parts := s.stageStore.generateDedupeQueries(tableID, stagingTableID, []string{"id"}, kafkalib.TopicConfig{IncludeArtieUpdatedAt: true})
assert.Len(s.T(), parts, 3)
assert.Equal(
s.T(),
fmt.Sprintf(`CREATE OR REPLACE TRANSIENT TABLE "%s" AS (SELECT * FROM db.public."CUSTOMERS" QUALIFY ROW_NUMBER() OVER (PARTITION BY by id ORDER BY id ASC, __artie_updated_at ASC) = 2)`, stagingTableName),
fmt.Sprintf(`CREATE OR REPLACE TRANSIENT TABLE %s AS (SELECT * FROM db.public."CUSTOMERS" QUALIFY ROW_NUMBER() OVER (PARTITION BY id ORDER BY id ASC, __artie_updated_at ASC) = 2)`, stagingTableID.FullyQualifiedName()),
parts[0],
)
assert.Equal(s.T(), fmt.Sprintf(`DELETE FROM db.public."CUSTOMERS" t1 USING "%s" t2 WHERE t1.id = t2.id`, stagingTableName), parts[1])
assert.Equal(s.T(), fmt.Sprintf(`INSERT INTO db.public."CUSTOMERS" SELECT * FROM "%s"`, stagingTableName), parts[2])
assert.Equal(s.T(), fmt.Sprintf(`DELETE FROM db.public."CUSTOMERS" t1 USING %s t2 WHERE t1.id = t2.id`, stagingTableID.FullyQualifiedName()), parts[1])
assert.Equal(s.T(), fmt.Sprintf(`INSERT INTO db.public."CUSTOMERS" SELECT * FROM %s`, stagingTableID.FullyQualifiedName()), parts[2])
}
{
// Dedupe with composite keys + no `__artie_updated_at` flag.
tableID := NewTableIdentifier("db", "public", "user_settings")
stagingTableID := shared.TempTableID(tableID, strings.ToLower(stringutil.Random(5)))
stagingTableName := strings.ToUpper(stagingTableID.Table())

parts := s.stageStore.generateDedupeQueries(tableID, stagingTableID, []string{"user_id", "settings"}, kafkalib.TopicConfig{})
assert.Len(s.T(), parts, 3)
assert.Equal(
s.T(),
fmt.Sprintf(`CREATE OR REPLACE TRANSIENT TABLE "%s" AS (SELECT * FROM db.public."USER_SETTINGS" QUALIFY ROW_NUMBER() OVER (PARTITION BY by user_id, settings ORDER BY user_id ASC, settings ASC) = 2)`, stagingTableName),
fmt.Sprintf(`CREATE OR REPLACE TRANSIENT TABLE %s AS (SELECT * FROM db.public."USER_SETTINGS" QUALIFY ROW_NUMBER() OVER (PARTITION BY user_id, settings ORDER BY user_id ASC, settings ASC) = 2)`, stagingTableID.FullyQualifiedName()),
parts[0],
)
assert.Equal(s.T(), fmt.Sprintf(`DELETE FROM db.public."USER_SETTINGS" t1 USING "%s" t2 WHERE t1.user_id = t2.user_id AND t1.settings = t2.settings`, stagingTableName), parts[1])
assert.Equal(s.T(), fmt.Sprintf(`INSERT INTO db.public."USER_SETTINGS" SELECT * FROM "%s"`, stagingTableName), parts[2])
assert.Equal(s.T(), fmt.Sprintf(`DELETE FROM db.public."USER_SETTINGS" t1 USING %s t2 WHERE t1.user_id = t2.user_id AND t1.settings = t2.settings`, stagingTableID.FullyQualifiedName()), parts[1])
assert.Equal(s.T(), fmt.Sprintf(`INSERT INTO db.public."USER_SETTINGS" SELECT * FROM %s`, stagingTableID.FullyQualifiedName()), parts[2])
}
{
// Dedupe with composite keys + `__artie_updated_at` flag.
tableID := NewTableIdentifier("db", "public", "user_settings")
stagingTableID := shared.TempTableID(tableID, strings.ToLower(stringutil.Random(5)))
stagingTableName := strings.ToUpper(stagingTableID.Table())

parts := s.stageStore.generateDedupeQueries(tableID, stagingTableID, []string{"user_id", "settings"}, kafkalib.TopicConfig{IncludeArtieUpdatedAt: true})
assert.Len(s.T(), parts, 3)
assert.Equal(
s.T(),
fmt.Sprintf(`CREATE OR REPLACE TRANSIENT TABLE "%s" AS (SELECT * FROM db.public."USER_SETTINGS" QUALIFY ROW_NUMBER() OVER (PARTITION BY by user_id, settings ORDER BY user_id ASC, settings ASC, __artie_updated_at ASC) = 2)`, stagingTableName),
fmt.Sprintf(`CREATE OR REPLACE TRANSIENT TABLE %s AS (SELECT * FROM db.public."USER_SETTINGS" QUALIFY ROW_NUMBER() OVER (PARTITION BY user_id, settings ORDER BY user_id ASC, settings ASC, __artie_updated_at ASC) = 2)`, stagingTableID.FullyQualifiedName()),
parts[0],
)
assert.Equal(s.T(), fmt.Sprintf(`DELETE FROM db.public."USER_SETTINGS" t1 USING "%s" t2 WHERE t1.user_id = t2.user_id AND t1.settings = t2.settings`, stagingTableName), parts[1])
assert.Equal(s.T(), fmt.Sprintf(`INSERT INTO db.public."USER_SETTINGS" SELECT * FROM "%s"`, stagingTableName), parts[2])
assert.Equal(s.T(), fmt.Sprintf(`DELETE FROM db.public."USER_SETTINGS" t1 USING %s t2 WHERE t1.user_id = t2.user_id AND t1.settings = t2.settings`, stagingTableID.FullyQualifiedName()), parts[1])
assert.Equal(s.T(), fmt.Sprintf(`INSERT INTO db.public."USER_SETTINGS" SELECT * FROM %s`, stagingTableID.FullyQualifiedName()), parts[2])
}
}
2 changes: 1 addition & 1 deletion lib/cdc/postgres/debezium_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ func (p *PostgresTestSuite) TestPostgresEvent() {
})
assert.NoError(p.T(), err)
assert.Equal(p.T(), float64(59), evtData["id"])
assert.Equal(p.T(), "2022-11-16T04:01:53+00:00", evtData[constants.DatabaseUpdatedColumnMarker])
assert.Equal(p.T(), "2022-11-16T04:01:53.308+00:00", evtData[constants.DatabaseUpdatedColumnMarker])

assert.Equal(p.T(), "Barings Participation Investors", evtData["item"])
assert.Equal(p.T(), map[string]any{"object": "foo"}, evtData["nested"])
Expand Down
2 changes: 1 addition & 1 deletion lib/destination/ddl/ddl_alter_delete_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ func (d *DDLTestSuite) TestAlterDelete_Complete() {
execQuery, _ := d.fakeBigQueryStore.ExecArgsForCall(0)
var found bool
for key := range allColsMap {
if execQuery == fmt.Sprintf("ALTER TABLE %s drop COLUMN %s", bqName, key) {
if execQuery == fmt.Sprintf("ALTER TABLE %s drop COLUMN `%s`", bqName, key) {
found = true
}
}
Expand Down
31 changes: 17 additions & 14 deletions lib/destination/ddl/ddl_create_table_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,27 +29,30 @@ func (d *DDLTestSuite) Test_CreateTable() {
d.snowflakeStagesStore.GetConfigMap().AddTableToConfig(snowflakeTableID, types.NewDwhTableConfig(&columns.Columns{}, nil, true, true))

type dwhToTableConfig struct {
_tableID types.TableIdentifier
_dwh destination.DataWarehouse
_tableConfig *types.DwhTableConfig
_fakeStore *mocks.FakeStore
_tableID types.TableIdentifier
_dwh destination.DataWarehouse
_tableConfig *types.DwhTableConfig
_fakeStore *mocks.FakeStore
_expectedQuery string
}

bigQueryTc := d.bigQueryStore.GetConfigMap().TableConfig(bqTableID)
snowflakeStagesTc := d.snowflakeStagesStore.GetConfigMap().TableConfig(snowflakeTableID)

for _, dwhTc := range []dwhToTableConfig{
{
_tableID: bqTableID,
_dwh: d.bigQueryStore,
_tableConfig: bigQueryTc,
_fakeStore: d.fakeBigQueryStore,
_tableID: bqTableID,
_dwh: d.bigQueryStore,
_tableConfig: bigQueryTc,
_fakeStore: d.fakeBigQueryStore,
_expectedQuery: fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (`name` string)", bqTableID.FullyQualifiedName()),
},
{
_tableID: snowflakeTableID,
_dwh: d.snowflakeStagesStore,
_tableConfig: snowflakeStagesTc,
_fakeStore: d.fakeSnowflakeStagesStore,
_tableID: snowflakeTableID,
_dwh: d.snowflakeStagesStore,
_tableConfig: snowflakeStagesTc,
_fakeStore: d.fakeSnowflakeStagesStore,
_expectedQuery: fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (name string)", snowflakeTableID.FullyQualifiedName()),
},
} {
alterTableArgs := ddl.AlterTableArgs{
Expand All @@ -66,8 +69,8 @@ func (d *DDLTestSuite) Test_CreateTable() {
assert.Equal(d.T(), 1, dwhTc._fakeStore.ExecCallCount())

query, _ := dwhTc._fakeStore.ExecArgsForCall(0)
assert.Equal(d.T(), query, fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (name string)", dwhTc._tableID.FullyQualifiedName()), query)
assert.Equal(d.T(), false, dwhTc._tableConfig.CreateTable())
assert.Equal(d.T(), dwhTc._expectedQuery, query)
assert.False(d.T(), dwhTc._tableConfig.CreateTable())
}
}

Expand Down
8 changes: 4 additions & 4 deletions lib/destination/ddl/ddl_temp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func (d *DDLTestSuite) TestCreateTemporaryTable_Errors() {
TemporaryTable: true,
ColumnOp: constants.Add,
CdcTime: time.Time{},
UppercaseEscNames: ptr.ToBool(false),
UppercaseEscNames: ptr.ToBool(true),
Mode: config.Replication,
}

Expand Down Expand Up @@ -81,7 +81,7 @@ func (d *DDLTestSuite) TestCreateTemporaryTable() {
TemporaryTable: true,
ColumnOp: constants.Add,
CdcTime: time.Time{},
UppercaseEscNames: ptr.ToBool(false),
UppercaseEscNames: ptr.ToBool(true),
Mode: config.Replication,
}

Expand All @@ -91,7 +91,7 @@ func (d *DDLTestSuite) TestCreateTemporaryTable() {

assert.Contains(d.T(),
query,
`CREATE TABLE IF NOT EXISTS db.schema."TEMPTABLENAME" (foo string,bar float,"start" string) STAGE_COPY_OPTIONS = ( PURGE = TRUE ) STAGE_FILE_FORMAT = ( TYPE = 'csv' FIELD_DELIMITER= '\t' FIELD_OPTIONALLY_ENCLOSED_BY='"' NULL_IF='\\N' EMPTY_FIELD_AS_NULL=FALSE)`,
`CREATE TABLE IF NOT EXISTS db.schema."TEMPTABLENAME" (foo string,bar float,"START" string) STAGE_COPY_OPTIONS = ( PURGE = TRUE ) STAGE_FILE_FORMAT = ( TYPE = 'csv' FIELD_DELIMITER= '\t' FIELD_OPTIONALLY_ENCLOSED_BY='"' NULL_IF='\\N' EMPTY_FIELD_AS_NULL=FALSE)`,
query)
}
{
Expand All @@ -115,6 +115,6 @@ func (d *DDLTestSuite) TestCreateTemporaryTable() {
assert.Equal(d.T(), 1, d.fakeBigQueryStore.ExecCallCount())
bqQuery, _ := d.fakeBigQueryStore.ExecArgsForCall(0)
// Cutting off the expiration_timestamp since it's time based.
assert.Contains(d.T(), bqQuery, "CREATE TABLE IF NOT EXISTS `db`.`schema`.`tempTableName` (foo string,bar float64,`select` string) OPTIONS (expiration_timestamp =")
assert.Contains(d.T(), bqQuery, "CREATE TABLE IF NOT EXISTS `db`.`schema`.`tempTableName` (`foo` string,`bar` float64,`select` string) OPTIONS (expiration_timestamp =")
}
}
4 changes: 2 additions & 2 deletions lib/destination/dml/merge_bigquery_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func TestMergeStatement_TempTable(t *testing.T) {
mergeSQL, err := mergeArg.GetStatement()
assert.NoError(t, err)

assert.Contains(t, mergeSQL, "MERGE INTO customers.orders c USING customers.orders_tmp AS cc ON c.order_id = cc.order_id", mergeSQL)
assert.Contains(t, mergeSQL, "MERGE INTO customers.orders c USING customers.orders_tmp AS cc ON c.`order_id` = cc.`order_id`", mergeSQL)
}

func TestMergeStatement_JSONKey(t *testing.T) {
Expand All @@ -50,5 +50,5 @@ func TestMergeStatement_JSONKey(t *testing.T) {

mergeSQL, err := mergeArg.GetStatement()
assert.NoError(t, err)
assert.Contains(t, mergeSQL, "MERGE INTO customers.orders c USING customers.orders_tmp AS cc ON TO_JSON_STRING(c.order_oid) = TO_JSON_STRING(cc.order_oid)", mergeSQL)
assert.Contains(t, mergeSQL, "MERGE INTO customers.orders c USING customers.orders_tmp AS cc ON TO_JSON_STRING(c.`order_oid`) = TO_JSON_STRING(cc.`order_oid`)", mergeSQL)
}
Loading

0 comments on commit 282335e

Please sign in to comment.