diff --git a/clients/bigquery/merge_test.go b/clients/bigquery/merge_test.go index f0b3b1fb6..a8a9b88c9 100644 --- a/clients/bigquery/merge_test.go +++ b/clients/bigquery/merge_test.go @@ -41,20 +41,20 @@ func (b *BigQueryTestSuite) TestBackfillColumn() { { name: "col that has default value that needs to be backfilled (boolean)", col: needsBackfillCol, - backfillSQL: "UPDATE `db`.`public`.`tableName` SET foo = true WHERE foo IS NULL;", - commentSQL: "ALTER TABLE `db`.`public`.`tableName` ALTER COLUMN foo SET OPTIONS (description=`{\"backfilled\": true}`);", + backfillSQL: "UPDATE `db`.`public`.`tableName` SET `foo` = true WHERE `foo` IS NULL;", + commentSQL: "ALTER TABLE `db`.`public`.`tableName` ALTER COLUMN `foo` SET OPTIONS (description=`{\"backfilled\": true}`);", }, { name: "col that has default value that needs to be backfilled (string)", col: needsBackfillColStr, - backfillSQL: "UPDATE `db`.`public`.`tableName` SET foo2 = 'hello there' WHERE foo2 IS NULL;", - commentSQL: "ALTER TABLE `db`.`public`.`tableName` ALTER COLUMN foo2 SET OPTIONS (description=`{\"backfilled\": true}`);", + backfillSQL: "UPDATE `db`.`public`.`tableName` SET `foo2` = 'hello there' WHERE `foo2` IS NULL;", + commentSQL: "ALTER TABLE `db`.`public`.`tableName` ALTER COLUMN `foo2` SET OPTIONS (description=`{\"backfilled\": true}`);", }, { name: "col that has default value that needs to be backfilled (number)", col: needsBackfillColNum, - backfillSQL: "UPDATE `db`.`public`.`tableName` SET foo3 = 3.5 WHERE foo3 IS NULL;", - commentSQL: "ALTER TABLE `db`.`public`.`tableName` ALTER COLUMN foo3 SET OPTIONS (description=`{\"backfilled\": true}`);", + backfillSQL: "UPDATE `db`.`public`.`tableName` SET `foo3` = 3.5 WHERE `foo3` IS NULL;", + commentSQL: "ALTER TABLE `db`.`public`.`tableName` ALTER COLUMN `foo3` SET OPTIONS (description=`{\"backfilled\": true}`);", }, } diff --git a/clients/shared/append.go b/clients/shared/append.go index 21c755194..06f1e2172 100644 --- a/clients/shared/append.go +++ b/clients/shared/append.go @@ -1,7 +1,7 @@ package shared import ( - "log/slog" + "fmt" "github.com/artie-labs/transfer/lib/config/constants" "github.com/artie-labs/transfer/lib/destination" @@ -20,7 +20,7 @@ func Append(dwh destination.DataWarehouse, tableData *optimization.TableData, op tableID := dwh.IdentifierFor(tableData.TopicConfig(), tableData.Name()) tableConfig, err := dwh.GetTableConfig(tableData) if err != nil { - return err + return fmt.Errorf("failed to get table config: %w", err) } // We don't care about srcKeysMissing because we don't drop columns when we append. @@ -40,13 +40,13 @@ func Append(dwh destination.DataWarehouse, tableData *optimization.TableData, op } // Keys that exist in CDC stream, but not in DWH - err = createAlterTableArgs.AlterTable(targetKeysMissing...) - if err != nil { - slog.Warn("Failed to apply alter table", slog.Any("err", err)) - return err + if err = createAlterTableArgs.AlterTable(targetKeysMissing...); err != nil { + return fmt.Errorf("failed to alter table: %w", err) } - tableData.MergeColumnsFromDestination(tableConfig.Columns().GetColumns()...) + if err = tableData.MergeColumnsFromDestination(tableConfig.Columns().GetColumns()...); err != nil { + return fmt.Errorf("failed to merge columns from destination: %w", err) + } additionalSettings := types.AdditionalSettings{ AdditionalCopyClause: opts.AdditionalCopyClause, diff --git a/clients/shared/merge.go b/clients/shared/merge.go index b9247faa6..47ce3c6d3 100644 --- a/clients/shared/merge.go +++ b/clients/shared/merge.go @@ -69,7 +69,10 @@ func Merge(dwh destination.DataWarehouse, tableData *optimization.TableData, cfg } tableConfig.AuditColumnsToDelete(srcKeysMissing) - tableData.MergeColumnsFromDestination(tableConfig.Columns().GetColumns()...) + if err = tableData.MergeColumnsFromDestination(tableConfig.Columns().GetColumns()...); err != nil { + return fmt.Errorf("failed to merge columns from destination: %w", err) + } + temporaryTableID := TempTableID(dwh.IdentifierFor(tableData.TopicConfig(), tableData.Name()), tableData.TempTableSuffix()) temporaryTableName := temporaryTableID.FullyQualifiedName() if err = dwh.PrepareTemporaryTable(tableData, tableConfig, temporaryTableID, types.AdditionalSettings{}, true); err != nil { diff --git a/clients/snowflake/snowflake.go b/clients/snowflake/snowflake.go index ff85f71b4..e300be95a 100644 --- a/clients/snowflake/snowflake.go +++ b/clients/snowflake/snowflake.go @@ -15,7 +15,6 @@ import ( "github.com/artie-labs/transfer/lib/kafkalib" "github.com/artie-labs/transfer/lib/optimization" "github.com/artie-labs/transfer/lib/ptr" - "github.com/artie-labs/transfer/lib/sql" "github.com/artie-labs/transfer/lib/stringutil" "github.com/artie-labs/transfer/lib/typing" "github.com/artie-labs/transfer/lib/typing/columns" @@ -146,10 +145,9 @@ func (s *Store) generateDedupeQueries(tableID, stagingTableID types.TableIdentif orderByCols = append(orderByCols, fmt.Sprintf("%s ASC", pk)) } - temporaryTableName := sql.EscapeName(stagingTableID.Table(), s.ShouldUppercaseEscapedNames(), s.Label()) var parts []string - parts = append(parts, fmt.Sprintf("CREATE OR REPLACE TRANSIENT TABLE %s AS (SELECT * FROM %s QUALIFY ROW_NUMBER() OVER (PARTITION BY by %s ORDER BY %s) = 2)", - temporaryTableName, + parts = append(parts, fmt.Sprintf("CREATE OR REPLACE TRANSIENT TABLE %s AS (SELECT * FROM %s QUALIFY ROW_NUMBER() OVER (PARTITION BY %s ORDER BY %s) = 2)", + stagingTableID.FullyQualifiedName(), tableID.FullyQualifiedName(), strings.Join(primaryKeysEscaped, ", "), strings.Join(orderByCols, ", "), @@ -162,11 +160,11 @@ func (s *Store) generateDedupeQueries(tableID, stagingTableID types.TableIdentif parts = append(parts, fmt.Sprintf("DELETE FROM %s t1 USING %s t2 WHERE %s", tableID.FullyQualifiedName(), - temporaryTableName, + stagingTableID.FullyQualifiedName(), strings.Join(whereClauses, " AND "), )) - parts = append(parts, fmt.Sprintf("INSERT INTO %s SELECT * FROM %s", tableID.FullyQualifiedName(), temporaryTableName)) + parts = append(parts, fmt.Sprintf("INSERT INTO %s SELECT * FROM %s", tableID.FullyQualifiedName(), stagingTableID.FullyQualifiedName())) return parts } diff --git a/clients/snowflake/snowflake_dedupe_test.go b/clients/snowflake/snowflake_dedupe_test.go index 97130f928..b9b5aed83 100644 --- a/clients/snowflake/snowflake_dedupe_test.go +++ b/clients/snowflake/snowflake_dedupe_test.go @@ -15,64 +15,60 @@ func (s *SnowflakeTestSuite) TestGenerateDedupeQueries() { // Dedupe with one primary key + no `__artie_updated_at` flag. tableID := NewTableIdentifier("db", "public", "customers") stagingTableID := shared.TempTableID(tableID, strings.ToLower(stringutil.Random(5))) - stagingTableName := strings.ToUpper(stagingTableID.Table()) parts := s.stageStore.generateDedupeQueries(tableID, stagingTableID, []string{"id"}, kafkalib.TopicConfig{}) assert.Len(s.T(), parts, 3) assert.Equal( s.T(), - fmt.Sprintf(`CREATE OR REPLACE TRANSIENT TABLE "%s" AS (SELECT * FROM db.public."CUSTOMERS" QUALIFY ROW_NUMBER() OVER (PARTITION BY by id ORDER BY id ASC) = 2)`, stagingTableName), + fmt.Sprintf(`CREATE OR REPLACE TRANSIENT TABLE %s AS (SELECT * FROM db.public."CUSTOMERS" QUALIFY ROW_NUMBER() OVER (PARTITION BY id ORDER BY id ASC) = 2)`, stagingTableID.FullyQualifiedName()), parts[0], ) - assert.Equal(s.T(), fmt.Sprintf(`DELETE FROM db.public."CUSTOMERS" t1 USING "%s" t2 WHERE t1.id = t2.id`, stagingTableName), parts[1]) - assert.Equal(s.T(), fmt.Sprintf(`INSERT INTO db.public."CUSTOMERS" SELECT * FROM "%s"`, stagingTableName), parts[2]) + assert.Equal(s.T(), fmt.Sprintf(`DELETE FROM db.public."CUSTOMERS" t1 USING %s t2 WHERE t1.id = t2.id`, stagingTableID.FullyQualifiedName()), parts[1]) + assert.Equal(s.T(), fmt.Sprintf(`INSERT INTO db.public."CUSTOMERS" SELECT * FROM %s`, stagingTableID.FullyQualifiedName()), parts[2]) } { // Dedupe with one primary key + `__artie_updated_at` flag. tableID := NewTableIdentifier("db", "public", "customers") stagingTableID := shared.TempTableID(tableID, strings.ToLower(stringutil.Random(5))) - stagingTableName := strings.ToUpper(stagingTableID.Table()) parts := s.stageStore.generateDedupeQueries(tableID, stagingTableID, []string{"id"}, kafkalib.TopicConfig{IncludeArtieUpdatedAt: true}) assert.Len(s.T(), parts, 3) assert.Equal( s.T(), - fmt.Sprintf(`CREATE OR REPLACE TRANSIENT TABLE "%s" AS (SELECT * FROM db.public."CUSTOMERS" QUALIFY ROW_NUMBER() OVER (PARTITION BY by id ORDER BY id ASC, __artie_updated_at ASC) = 2)`, stagingTableName), + fmt.Sprintf(`CREATE OR REPLACE TRANSIENT TABLE %s AS (SELECT * FROM db.public."CUSTOMERS" QUALIFY ROW_NUMBER() OVER (PARTITION BY id ORDER BY id ASC, __artie_updated_at ASC) = 2)`, stagingTableID.FullyQualifiedName()), parts[0], ) - assert.Equal(s.T(), fmt.Sprintf(`DELETE FROM db.public."CUSTOMERS" t1 USING "%s" t2 WHERE t1.id = t2.id`, stagingTableName), parts[1]) - assert.Equal(s.T(), fmt.Sprintf(`INSERT INTO db.public."CUSTOMERS" SELECT * FROM "%s"`, stagingTableName), parts[2]) + assert.Equal(s.T(), fmt.Sprintf(`DELETE FROM db.public."CUSTOMERS" t1 USING %s t2 WHERE t1.id = t2.id`, stagingTableID.FullyQualifiedName()), parts[1]) + assert.Equal(s.T(), fmt.Sprintf(`INSERT INTO db.public."CUSTOMERS" SELECT * FROM %s`, stagingTableID.FullyQualifiedName()), parts[2]) } { // Dedupe with composite keys + no `__artie_updated_at` flag. tableID := NewTableIdentifier("db", "public", "user_settings") stagingTableID := shared.TempTableID(tableID, strings.ToLower(stringutil.Random(5))) - stagingTableName := strings.ToUpper(stagingTableID.Table()) parts := s.stageStore.generateDedupeQueries(tableID, stagingTableID, []string{"user_id", "settings"}, kafkalib.TopicConfig{}) assert.Len(s.T(), parts, 3) assert.Equal( s.T(), - fmt.Sprintf(`CREATE OR REPLACE TRANSIENT TABLE "%s" AS (SELECT * FROM db.public."USER_SETTINGS" QUALIFY ROW_NUMBER() OVER (PARTITION BY by user_id, settings ORDER BY user_id ASC, settings ASC) = 2)`, stagingTableName), + fmt.Sprintf(`CREATE OR REPLACE TRANSIENT TABLE %s AS (SELECT * FROM db.public."USER_SETTINGS" QUALIFY ROW_NUMBER() OVER (PARTITION BY user_id, settings ORDER BY user_id ASC, settings ASC) = 2)`, stagingTableID.FullyQualifiedName()), parts[0], ) - assert.Equal(s.T(), fmt.Sprintf(`DELETE FROM db.public."USER_SETTINGS" t1 USING "%s" t2 WHERE t1.user_id = t2.user_id AND t1.settings = t2.settings`, stagingTableName), parts[1]) - assert.Equal(s.T(), fmt.Sprintf(`INSERT INTO db.public."USER_SETTINGS" SELECT * FROM "%s"`, stagingTableName), parts[2]) + assert.Equal(s.T(), fmt.Sprintf(`DELETE FROM db.public."USER_SETTINGS" t1 USING %s t2 WHERE t1.user_id = t2.user_id AND t1.settings = t2.settings`, stagingTableID.FullyQualifiedName()), parts[1]) + assert.Equal(s.T(), fmt.Sprintf(`INSERT INTO db.public."USER_SETTINGS" SELECT * FROM %s`, stagingTableID.FullyQualifiedName()), parts[2]) } { // Dedupe with composite keys + `__artie_updated_at` flag. tableID := NewTableIdentifier("db", "public", "user_settings") stagingTableID := shared.TempTableID(tableID, strings.ToLower(stringutil.Random(5))) - stagingTableName := strings.ToUpper(stagingTableID.Table()) parts := s.stageStore.generateDedupeQueries(tableID, stagingTableID, []string{"user_id", "settings"}, kafkalib.TopicConfig{IncludeArtieUpdatedAt: true}) assert.Len(s.T(), parts, 3) assert.Equal( s.T(), - fmt.Sprintf(`CREATE OR REPLACE TRANSIENT TABLE "%s" AS (SELECT * FROM db.public."USER_SETTINGS" QUALIFY ROW_NUMBER() OVER (PARTITION BY by user_id, settings ORDER BY user_id ASC, settings ASC, __artie_updated_at ASC) = 2)`, stagingTableName), + fmt.Sprintf(`CREATE OR REPLACE TRANSIENT TABLE %s AS (SELECT * FROM db.public."USER_SETTINGS" QUALIFY ROW_NUMBER() OVER (PARTITION BY user_id, settings ORDER BY user_id ASC, settings ASC, __artie_updated_at ASC) = 2)`, stagingTableID.FullyQualifiedName()), parts[0], ) - assert.Equal(s.T(), fmt.Sprintf(`DELETE FROM db.public."USER_SETTINGS" t1 USING "%s" t2 WHERE t1.user_id = t2.user_id AND t1.settings = t2.settings`, stagingTableName), parts[1]) - assert.Equal(s.T(), fmt.Sprintf(`INSERT INTO db.public."USER_SETTINGS" SELECT * FROM "%s"`, stagingTableName), parts[2]) + assert.Equal(s.T(), fmt.Sprintf(`DELETE FROM db.public."USER_SETTINGS" t1 USING %s t2 WHERE t1.user_id = t2.user_id AND t1.settings = t2.settings`, stagingTableID.FullyQualifiedName()), parts[1]) + assert.Equal(s.T(), fmt.Sprintf(`INSERT INTO db.public."USER_SETTINGS" SELECT * FROM %s`, stagingTableID.FullyQualifiedName()), parts[2]) } } diff --git a/lib/cdc/postgres/debezium_test.go b/lib/cdc/postgres/debezium_test.go index 1169ed456..d81a55cd1 100644 --- a/lib/cdc/postgres/debezium_test.go +++ b/lib/cdc/postgres/debezium_test.go @@ -90,7 +90,7 @@ func (p *PostgresTestSuite) TestPostgresEvent() { }) assert.NoError(p.T(), err) assert.Equal(p.T(), float64(59), evtData["id"]) - assert.Equal(p.T(), "2022-11-16T04:01:53+00:00", evtData[constants.DatabaseUpdatedColumnMarker]) + assert.Equal(p.T(), "2022-11-16T04:01:53.308+00:00", evtData[constants.DatabaseUpdatedColumnMarker]) assert.Equal(p.T(), "Barings Participation Investors", evtData["item"]) assert.Equal(p.T(), map[string]any{"object": "foo"}, evtData["nested"]) diff --git a/lib/destination/ddl/ddl_alter_delete_test.go b/lib/destination/ddl/ddl_alter_delete_test.go index e3b47ff12..a50d3277a 100644 --- a/lib/destination/ddl/ddl_alter_delete_test.go +++ b/lib/destination/ddl/ddl_alter_delete_test.go @@ -351,7 +351,7 @@ func (d *DDLTestSuite) TestAlterDelete_Complete() { execQuery, _ := d.fakeBigQueryStore.ExecArgsForCall(0) var found bool for key := range allColsMap { - if execQuery == fmt.Sprintf("ALTER TABLE %s drop COLUMN %s", bqName, key) { + if execQuery == fmt.Sprintf("ALTER TABLE %s drop COLUMN `%s`", bqName, key) { found = true } } diff --git a/lib/destination/ddl/ddl_create_table_test.go b/lib/destination/ddl/ddl_create_table_test.go index a8e07dfea..44209ecfc 100644 --- a/lib/destination/ddl/ddl_create_table_test.go +++ b/lib/destination/ddl/ddl_create_table_test.go @@ -29,10 +29,11 @@ func (d *DDLTestSuite) Test_CreateTable() { d.snowflakeStagesStore.GetConfigMap().AddTableToConfig(snowflakeTableID, types.NewDwhTableConfig(&columns.Columns{}, nil, true, true)) type dwhToTableConfig struct { - _tableID types.TableIdentifier - _dwh destination.DataWarehouse - _tableConfig *types.DwhTableConfig - _fakeStore *mocks.FakeStore + _tableID types.TableIdentifier + _dwh destination.DataWarehouse + _tableConfig *types.DwhTableConfig + _fakeStore *mocks.FakeStore + _expectedQuery string } bigQueryTc := d.bigQueryStore.GetConfigMap().TableConfig(bqTableID) @@ -40,16 +41,18 @@ func (d *DDLTestSuite) Test_CreateTable() { for _, dwhTc := range []dwhToTableConfig{ { - _tableID: bqTableID, - _dwh: d.bigQueryStore, - _tableConfig: bigQueryTc, - _fakeStore: d.fakeBigQueryStore, + _tableID: bqTableID, + _dwh: d.bigQueryStore, + _tableConfig: bigQueryTc, + _fakeStore: d.fakeBigQueryStore, + _expectedQuery: fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (`name` string)", bqTableID.FullyQualifiedName()), }, { - _tableID: snowflakeTableID, - _dwh: d.snowflakeStagesStore, - _tableConfig: snowflakeStagesTc, - _fakeStore: d.fakeSnowflakeStagesStore, + _tableID: snowflakeTableID, + _dwh: d.snowflakeStagesStore, + _tableConfig: snowflakeStagesTc, + _fakeStore: d.fakeSnowflakeStagesStore, + _expectedQuery: fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (name string)", snowflakeTableID.FullyQualifiedName()), }, } { alterTableArgs := ddl.AlterTableArgs{ @@ -66,8 +69,8 @@ func (d *DDLTestSuite) Test_CreateTable() { assert.Equal(d.T(), 1, dwhTc._fakeStore.ExecCallCount()) query, _ := dwhTc._fakeStore.ExecArgsForCall(0) - assert.Equal(d.T(), query, fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (name string)", dwhTc._tableID.FullyQualifiedName()), query) - assert.Equal(d.T(), false, dwhTc._tableConfig.CreateTable()) + assert.Equal(d.T(), dwhTc._expectedQuery, query) + assert.False(d.T(), dwhTc._tableConfig.CreateTable()) } } diff --git a/lib/destination/ddl/ddl_temp_test.go b/lib/destination/ddl/ddl_temp_test.go index e04e30f08..35bde3e2e 100644 --- a/lib/destination/ddl/ddl_temp_test.go +++ b/lib/destination/ddl/ddl_temp_test.go @@ -46,7 +46,7 @@ func (d *DDLTestSuite) TestCreateTemporaryTable_Errors() { TemporaryTable: true, ColumnOp: constants.Add, CdcTime: time.Time{}, - UppercaseEscNames: ptr.ToBool(false), + UppercaseEscNames: ptr.ToBool(true), Mode: config.Replication, } @@ -81,7 +81,7 @@ func (d *DDLTestSuite) TestCreateTemporaryTable() { TemporaryTable: true, ColumnOp: constants.Add, CdcTime: time.Time{}, - UppercaseEscNames: ptr.ToBool(false), + UppercaseEscNames: ptr.ToBool(true), Mode: config.Replication, } @@ -91,7 +91,7 @@ func (d *DDLTestSuite) TestCreateTemporaryTable() { assert.Contains(d.T(), query, - `CREATE TABLE IF NOT EXISTS db.schema."TEMPTABLENAME" (foo string,bar float,"start" string) STAGE_COPY_OPTIONS = ( PURGE = TRUE ) STAGE_FILE_FORMAT = ( TYPE = 'csv' FIELD_DELIMITER= '\t' FIELD_OPTIONALLY_ENCLOSED_BY='"' NULL_IF='\\N' EMPTY_FIELD_AS_NULL=FALSE)`, + `CREATE TABLE IF NOT EXISTS db.schema."TEMPTABLENAME" (foo string,bar float,"START" string) STAGE_COPY_OPTIONS = ( PURGE = TRUE ) STAGE_FILE_FORMAT = ( TYPE = 'csv' FIELD_DELIMITER= '\t' FIELD_OPTIONALLY_ENCLOSED_BY='"' NULL_IF='\\N' EMPTY_FIELD_AS_NULL=FALSE)`, query) } { @@ -115,6 +115,6 @@ func (d *DDLTestSuite) TestCreateTemporaryTable() { assert.Equal(d.T(), 1, d.fakeBigQueryStore.ExecCallCount()) bqQuery, _ := d.fakeBigQueryStore.ExecArgsForCall(0) // Cutting off the expiration_timestamp since it's time based. - assert.Contains(d.T(), bqQuery, "CREATE TABLE IF NOT EXISTS `db`.`schema`.`tempTableName` (foo string,bar float64,`select` string) OPTIONS (expiration_timestamp =") + assert.Contains(d.T(), bqQuery, "CREATE TABLE IF NOT EXISTS `db`.`schema`.`tempTableName` (`foo` string,`bar` float64,`select` string) OPTIONS (expiration_timestamp =") } } diff --git a/lib/destination/dml/merge_bigquery_test.go b/lib/destination/dml/merge_bigquery_test.go index 6fb44d488..24cd40398 100644 --- a/lib/destination/dml/merge_bigquery_test.go +++ b/lib/destination/dml/merge_bigquery_test.go @@ -29,7 +29,7 @@ func TestMergeStatement_TempTable(t *testing.T) { mergeSQL, err := mergeArg.GetStatement() assert.NoError(t, err) - assert.Contains(t, mergeSQL, "MERGE INTO customers.orders c USING customers.orders_tmp AS cc ON c.order_id = cc.order_id", mergeSQL) + assert.Contains(t, mergeSQL, "MERGE INTO customers.orders c USING customers.orders_tmp AS cc ON c.`order_id` = cc.`order_id`", mergeSQL) } func TestMergeStatement_JSONKey(t *testing.T) { @@ -50,5 +50,5 @@ func TestMergeStatement_JSONKey(t *testing.T) { mergeSQL, err := mergeArg.GetStatement() assert.NoError(t, err) - assert.Contains(t, mergeSQL, "MERGE INTO customers.orders c USING customers.orders_tmp AS cc ON TO_JSON_STRING(c.order_oid) = TO_JSON_STRING(cc.order_oid)", mergeSQL) + assert.Contains(t, mergeSQL, "MERGE INTO customers.orders c USING customers.orders_tmp AS cc ON TO_JSON_STRING(c.`order_oid`) = TO_JSON_STRING(cc.`order_oid`)", mergeSQL) } diff --git a/lib/destination/dml/merge_parts_test.go b/lib/destination/dml/merge_parts_test.go index 23f07bb94..f412b6400 100644 --- a/lib/destination/dml/merge_parts_test.go +++ b/lib/destination/dml/merge_parts_test.go @@ -31,7 +31,7 @@ type result struct { // getBasicColumnsForTest - will return you all the columns within `result` that are needed for tests. // * In here, we'll return if compositeKey=false - id (pk), email, first_name, last_name, created_at, toast_text (TOAST-able) // * Else if compositeKey=true - id(pk), email (pk), first_name, last_name, created_at, toast_text (TOAST-able) -func getBasicColumnsForTest(compositeKey bool, uppercaseEscNames bool) result { +func getBasicColumnsForTest(compositeKey bool) result { idCol := columns.NewColumn("id", typing.Float) emailCol := columns.NewColumn("email", typing.String) textToastCol := columns.NewColumn("toast_text", typing.String) @@ -47,10 +47,10 @@ func getBasicColumnsForTest(compositeKey bool, uppercaseEscNames bool) result { cols.AddColumn(columns.NewColumn(constants.DeleteColumnMarker, typing.Boolean)) var pks []columns.Wrapper - pks = append(pks, columns.NewWrapper(idCol, uppercaseEscNames, constants.Redshift)) + pks = append(pks, columns.NewWrapper(idCol, false, constants.Redshift)) if compositeKey { - pks = append(pks, columns.NewWrapper(emailCol, uppercaseEscNames, constants.Redshift)) + pks = append(pks, columns.NewWrapper(emailCol, false, constants.Redshift)) } return result{ @@ -65,7 +65,7 @@ func TestMergeStatementParts_SkipDelete(t *testing.T) { // 2. There are 3 SQL queries (INSERT, UPDATE and DELETE) fqTableName := "public.tableName" tempTableName := "public.tableName__temp" - res := getBasicColumnsForTest(false, false) + res := getBasicColumnsForTest(false) mergeArg := &MergeArgument{ TableID: MockTableIdentifier{fqTableName}, SubQuery: tempTableName, @@ -92,7 +92,7 @@ func TestMergeStatementParts_SkipDelete(t *testing.T) { func TestMergeStatementPartsSoftDelete(t *testing.T) { fqTableName := "public.tableName" tempTableName := "public.tableName__temp" - res := getBasicColumnsForTest(false, false) + res := getBasicColumnsForTest(false) mergeArg := &MergeArgument{ TableID: MockTableIdentifier{fqTableName}, SubQuery: tempTableName, @@ -132,7 +132,7 @@ func TestMergeStatementPartsSoftDelete(t *testing.T) { func TestMergeStatementPartsSoftDeleteComposite(t *testing.T) { fqTableName := "public.tableName" tempTableName := "public.tableName__temp" - res := getBasicColumnsForTest(true, false) + res := getBasicColumnsForTest(true) mergeArg := &MergeArgument{ TableID: MockTableIdentifier{fqTableName}, SubQuery: tempTableName, @@ -175,7 +175,7 @@ func TestMergeStatementParts(t *testing.T) { // 2. There are 3 SQL queries (INSERT, UPDATE and DELETE) fqTableName := "public.tableName" tempTableName := "public.tableName__temp" - res := getBasicColumnsForTest(false, false) + res := getBasicColumnsForTest(false) mergeArg := &MergeArgument{ TableID: MockTableIdentifier{fqTableName}, SubQuery: tempTableName, @@ -233,7 +233,7 @@ func TestMergeStatementParts(t *testing.T) { func TestMergeStatementPartsCompositeKey(t *testing.T) { fqTableName := "public.tableName" tempTableName := "public.tableName__temp" - res := getBasicColumnsForTest(true, false) + res := getBasicColumnsForTest(true) mergeArg := &MergeArgument{ TableID: MockTableIdentifier{fqTableName}, SubQuery: tempTableName, diff --git a/lib/destination/dml/merge_test.go b/lib/destination/dml/merge_test.go index 962902109..7124f6f12 100644 --- a/lib/destination/dml/merge_test.go +++ b/lib/destination/dml/merge_test.go @@ -110,11 +110,11 @@ func TestMergeStatement(t *testing.T) { TableID: MockTableIdentifier{fqTable}, SubQuery: subQuery, IdempotentKey: "", - PrimaryKeys: []columns.Wrapper{columns.NewWrapper(columns.NewColumn("id", typing.Invalid), false, constants.Snowflake)}, + PrimaryKeys: []columns.Wrapper{columns.NewWrapper(columns.NewColumn("id", typing.Invalid), true, constants.Snowflake)}, Columns: &_cols, DestKind: constants.Snowflake, SoftDelete: false, - UppercaseEscNames: ptr.ToBool(false), + UppercaseEscNames: ptr.ToBool(true), } mergeSQL, err := mergeArg.GetStatement() @@ -125,10 +125,10 @@ func TestMergeStatement(t *testing.T) { assert.Contains(t, mergeSQL, "AS cc ON c.id = cc.id", mergeSQL) // Check setting for update - assert.Contains(t, mergeSQL, `SET id=cc.id,bar=cc.bar,updated_at=cc.updated_at,"start"=cc."start"`, mergeSQL) + assert.Contains(t, mergeSQL, `SET id=cc.id,bar=cc.bar,updated_at=cc.updated_at,"START"=cc."START"`, mergeSQL) // Check for INSERT - assert.Contains(t, mergeSQL, `id,bar,updated_at,"start"`, mergeSQL) - assert.Contains(t, mergeSQL, `cc.id,cc.bar,cc.updated_at,cc."start"`, mergeSQL) + assert.Contains(t, mergeSQL, `id,bar,updated_at,"START"`, mergeSQL) + assert.Contains(t, mergeSQL, `cc.id,cc.bar,cc.updated_at,cc."START"`, mergeSQL) } func TestMergeStatementIdempotentKey(t *testing.T) { @@ -250,13 +250,13 @@ func TestMergeStatementEscapePrimaryKeys(t *testing.T) { SubQuery: subQuery, IdempotentKey: "", PrimaryKeys: []columns.Wrapper{ - columns.NewWrapper(columns.NewColumn("id", typing.Invalid), false, constants.Snowflake), - columns.NewWrapper(columns.NewColumn("group", typing.Invalid), false, constants.Snowflake), + columns.NewWrapper(columns.NewColumn("id", typing.Invalid), true, constants.Snowflake), + columns.NewWrapper(columns.NewColumn("group", typing.Invalid), true, constants.Snowflake), }, Columns: &_cols, DestKind: constants.Snowflake, SoftDelete: false, - UppercaseEscNames: ptr.ToBool(false), + UppercaseEscNames: ptr.ToBool(true), } mergeSQL, err := mergeArg.GetStatement() @@ -264,10 +264,10 @@ func TestMergeStatementEscapePrimaryKeys(t *testing.T) { assert.Contains(t, mergeSQL, fmt.Sprintf("MERGE INTO %s", fqTable), mergeSQL) assert.NotContains(t, mergeSQL, fmt.Sprintf("cc.%s >= c.%s", "updated_at", "updated_at"), fmt.Sprintf("Idempotency key: %s", mergeSQL)) // Check primary keys clause - assert.Contains(t, mergeSQL, `AS cc ON c.id = cc.id and c."group" = cc."group"`, mergeSQL) + assert.Contains(t, mergeSQL, `AS cc ON c.id = cc.id and c."GROUP" = cc."GROUP"`, mergeSQL) // Check setting for update - assert.Contains(t, mergeSQL, `SET id=cc.id,"group"=cc."group",updated_at=cc.updated_at,"start"=cc."start"`, mergeSQL) + assert.Contains(t, mergeSQL, `SET id=cc.id,"GROUP"=cc."GROUP",updated_at=cc.updated_at,"START"=cc."START"`, mergeSQL) // Check for INSERT - assert.Contains(t, mergeSQL, `id,"group",updated_at,"start"`, mergeSQL) - assert.Contains(t, mergeSQL, `cc.id,cc."group",cc.updated_at,cc."start"`, mergeSQL) + assert.Contains(t, mergeSQL, `id,"GROUP",updated_at,"START"`, mergeSQL) + assert.Contains(t, mergeSQL, `cc.id,cc."GROUP",cc.updated_at,cc."START"`, mergeSQL) } diff --git a/lib/optimization/event_update_test.go b/lib/optimization/event_update_test.go index af4a2e1eb..2d5fa5e78 100644 --- a/lib/optimization/event_update_test.go +++ b/lib/optimization/event_update_test.go @@ -11,139 +11,160 @@ import ( "github.com/stretchr/testify/assert" ) +const strCol = "string" + func TestTableData_UpdateInMemoryColumnsFromDestination(t *testing.T) { - const strCol = "string" + { + tableDataCols := &columns.Columns{} + tableData := &TableData{ + inMemoryColumns: tableDataCols, + } - tableDataCols := &columns.Columns{} - tableDataCols.AddColumn(columns.NewColumn("name", typing.String)) - tableDataCols.AddColumn(columns.NewColumn("bool_backfill", typing.Boolean)) - tableDataCols.AddColumn(columns.NewColumn("prev_invalid", typing.Invalid)) - tableDataCols.AddColumn(columns.NewColumn("numeric_test", typing.EDecimal)) + tableData.AddInMemoryCol(columns.NewColumn("foo", typing.String)) + invalidCol := columns.NewColumn("foo", typing.Invalid) + assert.ErrorContains(t, tableData.MergeColumnsFromDestination(invalidCol), `column "foo" is invalid`) + } + { + tableDataCols := &columns.Columns{} + tableData := &TableData{ + inMemoryColumns: tableDataCols, + } - // Casting these as STRING so tableColumn via this f(x) will set it correctly. - tableDataCols.AddColumn(columns.NewColumn("ext_date", typing.String)) - tableDataCols.AddColumn(columns.NewColumn("ext_time", typing.String)) - tableDataCols.AddColumn(columns.NewColumn("ext_datetime", typing.String)) - tableDataCols.AddColumn(columns.NewColumn("ext_dec", typing.String)) + tableDataCols.AddColumn(columns.NewColumn("name", typing.String)) + tableDataCols.AddColumn(columns.NewColumn("bool_backfill", typing.Boolean)) + tableDataCols.AddColumn(columns.NewColumn("prev_invalid", typing.Invalid)) + tableDataCols.AddColumn(columns.NewColumn("numeric_test", typing.EDecimal)) - extDecimalType := typing.EDecimal - extDecimalType.ExtendedDecimalDetails = decimal.NewDecimal(ptr.ToInt(22), 2, nil) - tableDataCols.AddColumn(columns.NewColumn("ext_dec_filled", extDecimalType)) + // Casting these as STRING so tableColumn via this f(x) will set it correctly. + tableDataCols.AddColumn(columns.NewColumn("ext_date", typing.String)) + tableDataCols.AddColumn(columns.NewColumn("ext_time", typing.String)) + tableDataCols.AddColumn(columns.NewColumn("ext_datetime", typing.String)) + tableDataCols.AddColumn(columns.NewColumn("ext_dec", typing.String)) - tableDataCols.AddColumn(columns.NewColumn(strCol, typing.String)) + extDecimalType := typing.EDecimal + extDecimalType.ExtendedDecimalDetails = decimal.NewDecimal(ptr.ToInt(22), 2, nil) + tableDataCols.AddColumn(columns.NewColumn("ext_dec_filled", extDecimalType)) - tableData := &TableData{ - inMemoryColumns: tableDataCols, - } + tableDataCols.AddColumn(columns.NewColumn(strCol, typing.String)) - nonExistentTableCols := []string{"dusty", "the", "mini", "aussie"} - var nonExistentCols []columns.Column - for _, nonExistentTableCol := range nonExistentTableCols { - nonExistentCols = append(nonExistentCols, columns.NewColumn(nonExistentTableCol, typing.String)) - } + nonExistentTableCols := []string{"dusty", "the", "mini", "aussie"} + var nonExistentCols []columns.Column + for _, nonExistentTableCol := range nonExistentTableCols { + nonExistentCols = append(nonExistentCols, columns.NewColumn(nonExistentTableCol, typing.String)) + } - // Testing to make sure we don't copy over non-existent columns - tableData.MergeColumnsFromDestination(nonExistentCols...) - for _, nonExistentTableCol := range nonExistentTableCols { - _, isOk := tableData.inMemoryColumns.GetColumn(nonExistentTableCol) - assert.False(t, isOk, nonExistentTableCol) - } + // Testing to make sure we don't copy over non-existent columns + assert.NoError(t, tableData.MergeColumnsFromDestination(nonExistentCols...)) + for _, nonExistentTableCol := range nonExistentTableCols { + _, isOk := tableData.inMemoryColumns.GetColumn(nonExistentTableCol) + assert.False(t, isOk, nonExistentTableCol) + } - // Making sure it's still numeric - tableData.MergeColumnsFromDestination(columns.NewColumn("numeric_test", typing.Integer)) - numericCol, isOk := tableData.inMemoryColumns.GetColumn("numeric_test") - assert.True(t, isOk) - assert.Equal(t, typing.EDecimal.Kind, numericCol.KindDetails.Kind, "numeric_test") - - // Testing to make sure we're copying the kindDetails over. - tableData.MergeColumnsFromDestination(columns.NewColumn("prev_invalid", typing.String)) - prevInvalidCol, isOk := tableData.inMemoryColumns.GetColumn("prev_invalid") - assert.True(t, isOk) - assert.Equal(t, typing.String, prevInvalidCol.KindDetails) - - // Testing backfill - for _, inMemoryCol := range tableData.inMemoryColumns.GetColumns() { - assert.False(t, inMemoryCol.Backfilled(), inMemoryCol.RawName()) - } - backfilledCol := columns.NewColumn("bool_backfill", typing.Boolean) - backfilledCol.SetBackfilled(true) - tableData.MergeColumnsFromDestination(backfilledCol) - for _, inMemoryCol := range tableData.inMemoryColumns.GetColumns() { - if inMemoryCol.RawName() == backfilledCol.RawName() { - assert.True(t, inMemoryCol.Backfilled(), inMemoryCol.RawName()) - } else { + // Making sure it's still numeric + assert.NoError(t, tableData.MergeColumnsFromDestination(columns.NewColumn("numeric_test", typing.Integer))) + numericCol, isOk := tableData.inMemoryColumns.GetColumn("numeric_test") + assert.True(t, isOk) + assert.Equal(t, typing.EDecimal.Kind, numericCol.KindDetails.Kind, "numeric_test") + + // Testing to make sure we're copying the kindDetails over. + assert.NoError(t, tableData.MergeColumnsFromDestination(columns.NewColumn("prev_invalid", typing.String))) + prevInvalidCol, isOk := tableData.inMemoryColumns.GetColumn("prev_invalid") + assert.True(t, isOk) + assert.Equal(t, typing.String, prevInvalidCol.KindDetails) + + // Testing backfill + for _, inMemoryCol := range tableData.inMemoryColumns.GetColumns() { assert.False(t, inMemoryCol.Backfilled(), inMemoryCol.RawName()) } - } + backfilledCol := columns.NewColumn("bool_backfill", typing.Boolean) + backfilledCol.SetBackfilled(true) + assert.NoError(t, tableData.MergeColumnsFromDestination(backfilledCol)) + for _, inMemoryCol := range tableData.inMemoryColumns.GetColumns() { + if inMemoryCol.RawName() == backfilledCol.RawName() { + assert.True(t, inMemoryCol.Backfilled(), inMemoryCol.RawName()) + } else { + assert.False(t, inMemoryCol.Backfilled(), inMemoryCol.RawName()) + } + } + + // Testing extTimeDetails + for _, extTimeDetailsCol := range []string{"ext_date", "ext_time", "ext_datetime"} { + col, isOk := tableData.inMemoryColumns.GetColumn(extTimeDetailsCol) + assert.True(t, isOk, extTimeDetailsCol) + assert.Equal(t, typing.String, col.KindDetails, extTimeDetailsCol) + assert.Nil(t, col.KindDetails.ExtendedTimeDetails, extTimeDetailsCol) + } - // Testing extTimeDetails - for _, extTimeDetailsCol := range []string{"ext_date", "ext_time", "ext_datetime"} { - col, isOk := tableData.inMemoryColumns.GetColumn(extTimeDetailsCol) - assert.True(t, isOk, extTimeDetailsCol) - assert.Equal(t, typing.String, col.KindDetails, extTimeDetailsCol) - assert.Nil(t, col.KindDetails.ExtendedTimeDetails, extTimeDetailsCol) + assert.NoError(t, tableData.MergeColumnsFromDestination(columns.NewColumn("ext_time", typing.NewKindDetailsFromTemplate(typing.ETime, ext.TimeKindType)))) + assert.NoError(t, tableData.MergeColumnsFromDestination(columns.NewColumn("ext_date", typing.NewKindDetailsFromTemplate(typing.ETime, ext.DateKindType)))) + assert.NoError(t, tableData.MergeColumnsFromDestination(columns.NewColumn("ext_datetime", typing.NewKindDetailsFromTemplate(typing.ETime, ext.DateTimeKindType)))) + + dateCol, isOk := tableData.inMemoryColumns.GetColumn("ext_date") + assert.True(t, isOk) + assert.NotNil(t, dateCol.KindDetails.ExtendedTimeDetails) + assert.Equal(t, ext.DateKindType, dateCol.KindDetails.ExtendedTimeDetails.Type) + + timeCol, isOk := tableData.inMemoryColumns.GetColumn("ext_time") + assert.True(t, isOk) + assert.NotNil(t, timeCol.KindDetails.ExtendedTimeDetails) + assert.Equal(t, ext.TimeKindType, timeCol.KindDetails.ExtendedTimeDetails.Type) + + dateTimeCol, isOk := tableData.inMemoryColumns.GetColumn("ext_datetime") + assert.True(t, isOk) + assert.NotNil(t, dateTimeCol.KindDetails.ExtendedTimeDetails) + assert.Equal(t, ext.DateTimeKindType, dateTimeCol.KindDetails.ExtendedTimeDetails.Type) + + // Testing extDecimalDetails + // Confirm that before you update, it's invalid. + extDecCol, isOk := tableData.inMemoryColumns.GetColumn("ext_dec") + assert.True(t, isOk) + assert.Equal(t, typing.String, extDecCol.KindDetails) + + extDecimal := typing.EDecimal + extDecimal.ExtendedDecimalDetails = decimal.NewDecimal(ptr.ToInt(30), 2, nil) + assert.NoError(t, tableData.MergeColumnsFromDestination(columns.NewColumn("ext_dec", extDecimal))) + // Now it should be ext decimal type + extDecCol, isOk = tableData.inMemoryColumns.GetColumn("ext_dec") + assert.True(t, isOk) + assert.Equal(t, typing.EDecimal.Kind, extDecCol.KindDetails.Kind) + // Check precision and scale too. + assert.Equal(t, 30, *extDecCol.KindDetails.ExtendedDecimalDetails.Precision()) + assert.Equal(t, 2, extDecCol.KindDetails.ExtendedDecimalDetails.Scale()) + + // Testing ext_dec_filled since it's already filled out + extDecColFilled, isOk := tableData.inMemoryColumns.GetColumn("ext_dec_filled") + assert.True(t, isOk) + assert.Equal(t, typing.EDecimal.Kind, extDecColFilled.KindDetails.Kind) + // Check precision and scale too. + assert.Equal(t, 22, *extDecColFilled.KindDetails.ExtendedDecimalDetails.Precision()) + assert.Equal(t, 2, extDecColFilled.KindDetails.ExtendedDecimalDetails.Scale()) + + assert.NoError(t, tableData.MergeColumnsFromDestination(columns.NewColumn("ext_dec_filled", extDecimal))) + extDecColFilled, isOk = tableData.inMemoryColumns.GetColumn("ext_dec_filled") + assert.True(t, isOk) + assert.Equal(t, typing.EDecimal.Kind, extDecColFilled.KindDetails.Kind) + // Check precision and scale too. + assert.Equal(t, 22, *extDecColFilled.KindDetails.ExtendedDecimalDetails.Precision()) + assert.Equal(t, 2, extDecColFilled.KindDetails.ExtendedDecimalDetails.Scale()) } + { + tableDataCols := &columns.Columns{} + tableData := &TableData{ + inMemoryColumns: tableDataCols, + } + + tableDataCols.AddColumn(columns.NewColumn(strCol, typing.String)) + + // Testing string precision + stringKindWithPrecision := typing.KindDetails{ + Kind: typing.String.Kind, + OptionalStringPrecision: ptr.ToInt(123), + } - tableData.MergeColumnsFromDestination(columns.NewColumn("ext_time", typing.NewKindDetailsFromTemplate(typing.ETime, ext.TimeKindType))) - tableData.MergeColumnsFromDestination(columns.NewColumn("ext_date", typing.NewKindDetailsFromTemplate(typing.ETime, ext.DateKindType))) - tableData.MergeColumnsFromDestination(columns.NewColumn("ext_datetime", typing.NewKindDetailsFromTemplate(typing.ETime, ext.DateTimeKindType))) - - dateCol, isOk := tableData.inMemoryColumns.GetColumn("ext_date") - assert.True(t, isOk) - assert.NotNil(t, dateCol.KindDetails.ExtendedTimeDetails) - assert.Equal(t, ext.DateKindType, dateCol.KindDetails.ExtendedTimeDetails.Type) - - timeCol, isOk := tableData.inMemoryColumns.GetColumn("ext_time") - assert.True(t, isOk) - assert.NotNil(t, timeCol.KindDetails.ExtendedTimeDetails) - assert.Equal(t, ext.TimeKindType, timeCol.KindDetails.ExtendedTimeDetails.Type) - - dateTimeCol, isOk := tableData.inMemoryColumns.GetColumn("ext_datetime") - assert.True(t, isOk) - assert.NotNil(t, dateTimeCol.KindDetails.ExtendedTimeDetails) - assert.Equal(t, ext.DateTimeKindType, dateTimeCol.KindDetails.ExtendedTimeDetails.Type) - - // Testing extDecimalDetails - // Confirm that before you update, it's invalid. - extDecCol, isOk := tableData.inMemoryColumns.GetColumn("ext_dec") - assert.True(t, isOk) - assert.Equal(t, typing.String, extDecCol.KindDetails) - - extDecimal := typing.EDecimal - extDecimal.ExtendedDecimalDetails = decimal.NewDecimal(ptr.ToInt(30), 2, nil) - tableData.MergeColumnsFromDestination(columns.NewColumn("ext_dec", extDecimal)) - // Now it should be ext decimal type - extDecCol, isOk = tableData.inMemoryColumns.GetColumn("ext_dec") - assert.True(t, isOk) - assert.Equal(t, typing.EDecimal.Kind, extDecCol.KindDetails.Kind) - // Check precision and scale too. - assert.Equal(t, 30, *extDecCol.KindDetails.ExtendedDecimalDetails.Precision()) - assert.Equal(t, 2, extDecCol.KindDetails.ExtendedDecimalDetails.Scale()) - - // Testing ext_dec_filled since it's already filled out - extDecColFilled, isOk := tableData.inMemoryColumns.GetColumn("ext_dec_filled") - assert.True(t, isOk) - assert.Equal(t, typing.EDecimal.Kind, extDecColFilled.KindDetails.Kind) - // Check precision and scale too. - assert.Equal(t, 22, *extDecColFilled.KindDetails.ExtendedDecimalDetails.Precision()) - assert.Equal(t, 2, extDecColFilled.KindDetails.ExtendedDecimalDetails.Scale()) - - tableData.MergeColumnsFromDestination(columns.NewColumn("ext_dec_filled", extDecimal)) - extDecColFilled, isOk = tableData.inMemoryColumns.GetColumn("ext_dec_filled") - assert.True(t, isOk) - assert.Equal(t, typing.EDecimal.Kind, extDecColFilled.KindDetails.Kind) - // Check precision and scale too. - assert.Equal(t, 22, *extDecColFilled.KindDetails.ExtendedDecimalDetails.Precision()) - assert.Equal(t, 2, extDecColFilled.KindDetails.ExtendedDecimalDetails.Scale()) - - // Testing string precision - stringKindWithPrecision := typing.KindDetails{ - Kind: typing.String.Kind, - OptionalStringPrecision: ptr.ToInt(123), + assert.NoError(t, tableData.MergeColumnsFromDestination(columns.NewColumn(strCol, stringKindWithPrecision))) + foundStrCol, isOk := tableData.inMemoryColumns.GetColumn(strCol) + assert.True(t, isOk) + assert.Equal(t, typing.String.Kind, foundStrCol.KindDetails.Kind) + assert.Equal(t, 123, *foundStrCol.KindDetails.OptionalStringPrecision) } - tableData.MergeColumnsFromDestination(columns.NewColumn(strCol, stringKindWithPrecision)) - foundStrCol, isOk := tableData.inMemoryColumns.GetColumn(strCol) - assert.True(t, isOk) - assert.Equal(t, typing.String.Kind, foundStrCol.KindDetails.Kind) - assert.Equal(t, 123, *foundStrCol.KindDetails.OptionalStringPrecision) } diff --git a/lib/optimization/table_data.go b/lib/optimization/table_data.go index 89b3eb264..e47839528 100644 --- a/lib/optimization/table_data.go +++ b/lib/optimization/table_data.go @@ -248,9 +248,9 @@ func (t *TableData) ShouldFlush(cfg config.Config) (bool, string) { // Prior to merging, we will need to treat `tableConfig` as the source-of-truth and whenever there's discrepancies // We will prioritize using the values coming from (2) TableConfig. We also cannot simply do a replacement, as we have in-memory columns // That carry metadata for Artie Transfer. They are prefixed with __artie. -func (t *TableData) MergeColumnsFromDestination(destCols ...columns.Column) { +func (t *TableData) MergeColumnsFromDestination(destCols ...columns.Column) error { if t == nil || len(destCols) == 0 { - return + return nil } for _, inMemoryCol := range t.inMemoryColumns.GetColumns() { @@ -258,6 +258,10 @@ func (t *TableData) MergeColumnsFromDestination(destCols ...columns.Column) { var found bool for _, destCol := range destCols { if destCol.RawName() == strings.ToLower(inMemoryCol.RawName()) { + if destCol.KindDetails.Kind == typing.Invalid.Kind { + return fmt.Errorf("column %q is invalid", destCol.RawName()) + } + foundColumn = destCol found = true break @@ -297,4 +301,6 @@ func (t *TableData) MergeColumnsFromDestination(destCols ...columns.Column) { t.inMemoryColumns.UpdateColumn(inMemoryCol) } } + + return nil } diff --git a/lib/sql/escape.go b/lib/sql/escape.go index 180f303dc..fbc111fda 100644 --- a/lib/sql/escape.go +++ b/lib/sql/escape.go @@ -24,8 +24,11 @@ func NeedsEscaping(name string, destKind constants.DestinationKind) bool { var reservedKeywords []string if destKind == constants.Redshift { reservedKeywords = constants.RedshiftReservedKeywords - } else if destKind == constants.MSSQL { - return !strings.HasPrefix(name, constants.ArtiePrefix) + } else if destKind == constants.MSSQL || destKind == constants.BigQuery { + // TODO: Escape names that start with [constants.ArtiePrefix]. + if !strings.HasPrefix(name, constants.ArtiePrefix) { + return true + } } else { reservedKeywords = constants.ReservedKeywords } diff --git a/lib/sql/escape_test.go b/lib/sql/escape_test.go index ed44be702..3a9a45592 100644 --- a/lib/sql/escape_test.go +++ b/lib/sql/escape_test.go @@ -7,6 +7,31 @@ import ( "github.com/stretchr/testify/assert" ) +func TestNeedsEscaping(t *testing.T) { + // BigQuery: + assert.True(t, NeedsEscaping("select", constants.BigQuery)) // name that is reserved + assert.True(t, NeedsEscaping("foo", constants.BigQuery)) // name that is not reserved + assert.False(t, NeedsEscaping("__artie_foo", constants.BigQuery)) // Artie prefix + assert.True(t, NeedsEscaping("__artie_foo:bar", constants.MSSQL)) // Artie prefix + symbol + + // MS SQL: + assert.True(t, NeedsEscaping("select", constants.MSSQL)) // name that is reserved + assert.True(t, NeedsEscaping("foo", constants.MSSQL)) // name that is not reserved + assert.False(t, NeedsEscaping("__artie_foo", constants.MSSQL)) // Artie prefix + assert.True(t, NeedsEscaping("__artie_foo:bar", constants.MSSQL)) // Artie prefix + symbol + + // Redshift: + assert.True(t, NeedsEscaping("select", constants.Redshift)) // name that is reserved + assert.True(t, NeedsEscaping("truncatecolumns", constants.Redshift)) // name that is reserved for Redshift + assert.False(t, NeedsEscaping("foo", constants.Redshift)) // name that is not reserved + assert.False(t, NeedsEscaping("__artie_foo", constants.Redshift)) // Artie prefix + + // Snowflake: + assert.True(t, NeedsEscaping("select", constants.Snowflake)) // name that is reserved + assert.False(t, NeedsEscaping("foo", constants.Snowflake)) // name that is not reserved + assert.False(t, NeedsEscaping("__artie_foo", constants.Snowflake)) // Artie prefix +} + func TestEscapeNameIfNecessary(t *testing.T) { type _testCase struct { name string @@ -56,8 +81,8 @@ func TestEscapeNameIfNecessary(t *testing.T) { name: "bigquery, #2", destKind: constants.BigQuery, nameToEscape: "hello", - expectedName: "hello", - expectedNameWhenUpperCfg: "hello", + expectedName: "`hello`", + expectedNameWhenUpperCfg: "`HELLO`", }, { name: "redshift, #1 (delta)", diff --git a/lib/typing/bigquery_test.go b/lib/typing/bigquery_test.go index 64337909e..11c126d20 100644 --- a/lib/typing/bigquery_test.go +++ b/lib/typing/bigquery_test.go @@ -1,6 +1,7 @@ package typing import ( + "fmt" "testing" "time" @@ -52,7 +53,12 @@ func TestBigQueryTypeToKind(t *testing.T) { for bqCol, expectedKind := range bqColToExpectedKind { kd, err := DwhTypeToKind(constants.BigQuery, bqCol, "") - assert.NoError(t, err) + if expectedKind.Kind == Invalid.Kind { + assert.ErrorContains(t, err, fmt.Sprintf("unable to map type: %q to dwh type", bqCol)) + } else { + assert.NoError(t, err) + } + assert.Equal(t, expectedKind.Kind, kd.Kind, bqCol) } } diff --git a/lib/typing/columns/columns_test.go b/lib/typing/columns/columns_test.go index 8a2ef5b91..7a130a166 100644 --- a/lib/typing/columns/columns_test.go +++ b/lib/typing/columns/columns_test.go @@ -145,20 +145,20 @@ func TestColumn_Name(t *testing.T) { { colName: "start", expectedName: "start", - expectedNameEsc: `"start"`, // since this is a reserved word. + expectedNameEsc: `"START"`, // since this is a reserved word. expectedNameEscBq: "`start`", // BQ escapes via backticks. }, { colName: "foo", expectedName: "foo", expectedNameEsc: "foo", - expectedNameEscBq: "foo", + expectedNameEscBq: "`foo`", }, { colName: "bar", expectedName: "bar", expectedNameEsc: "bar", - expectedNameEscBq: "bar", + expectedNameEscBq: "`bar`", }, } @@ -169,7 +169,7 @@ func TestColumn_Name(t *testing.T) { assert.Equal(t, testCase.expectedName, col.RawName(), testCase.colName) - assert.Equal(t, testCase.expectedNameEsc, col.Name(false, constants.Snowflake), testCase.colName) + assert.Equal(t, testCase.expectedNameEsc, col.Name(true, constants.Snowflake), testCase.colName) assert.Equal(t, testCase.expectedNameEscBq, col.Name(false, constants.BigQuery), testCase.colName) } } @@ -265,14 +265,14 @@ func TestColumns_GetEscapedColumnsToUpdate(t *testing.T) { { name: "happy path", cols: happyPathCols, - expectedColsEsc: []string{"hi", "bye", `"start"`}, - expectedColsEscBq: []string{"hi", "bye", "`start`"}, + expectedColsEsc: []string{"hi", "bye", `"START"`}, + expectedColsEscBq: []string{"`hi`", "`bye`", "`start`"}, }, { name: "happy path + extra col", cols: extraCols, - expectedColsEsc: []string{"hi", "bye", `"start"`}, - expectedColsEscBq: []string{"hi", "bye", "`start`"}, + expectedColsEsc: []string{"hi", "bye", `"START"`}, + expectedColsEscBq: []string{"`hi`", "`bye`", "`start`"}, }, } @@ -281,7 +281,7 @@ func TestColumns_GetEscapedColumnsToUpdate(t *testing.T) { columns: testCase.cols, } - assert.Equal(t, testCase.expectedColsEsc, cols.GetEscapedColumnsToUpdate(false, constants.Snowflake), testCase.name) + assert.Equal(t, testCase.expectedColsEsc, cols.GetEscapedColumnsToUpdate(true, constants.Snowflake), testCase.name) assert.Equal(t, testCase.expectedColsEscBq, cols.GetEscapedColumnsToUpdate(false, constants.BigQuery), testCase.name) } } @@ -498,13 +498,13 @@ func TestColumnsUpdateQuery(t *testing.T) { name: "struct, string and toast string (bigquery)", columns: lastCaseColTypes, destKind: constants.BigQuery, - expectedString: `a1= CASE WHEN COALESCE(TO_JSON_STRING(cc.a1) != '{"key":"__debezium_unavailable_value"}', true) THEN cc.a1 ELSE c.a1 END,b2= CASE WHEN COALESCE(cc.b2 != '__debezium_unavailable_value', true) THEN cc.b2 ELSE c.b2 END,c3=cc.c3`, + expectedString: "`a1`= CASE WHEN COALESCE(TO_JSON_STRING(cc.`a1`) != '{\"key\":\"__debezium_unavailable_value\"}', true) THEN cc.`a1` ELSE c.`a1` END,`b2`= CASE WHEN COALESCE(cc.`b2` != '__debezium_unavailable_value', true) THEN cc.`b2` ELSE c.`b2` END,`c3`=cc.`c3`", }, { name: "struct, string and toast string (bigquery) w/ reserved keywords", columns: lastCaseEscapeTypes, destKind: constants.BigQuery, - expectedString: fmt.Sprintf(`a1= CASE WHEN COALESCE(TO_JSON_STRING(cc.a1) != '%s', true) THEN cc.a1 ELSE c.a1 END,b2= CASE WHEN COALESCE(cc.b2 != '__debezium_unavailable_value', true) THEN cc.b2 ELSE c.b2 END,c3=cc.c3,%s,%s`, + expectedString: fmt.Sprintf("`a1`= CASE WHEN COALESCE(TO_JSON_STRING(cc.`a1`) != '%s', true) THEN cc.`a1` ELSE c.`a1` END,`b2`= CASE WHEN COALESCE(cc.`b2` != '__debezium_unavailable_value', true) THEN cc.`b2` ELSE c.`b2` END,`c3`=cc.`c3`,%s,%s", key, fmt.Sprintf("`start`= CASE WHEN COALESCE(TO_JSON_STRING(cc.`start`) != '%s', true) THEN cc.`start` ELSE c.`start` END", key), "`select`=cc.`select`"), skipDeleteCol: true, }, @@ -512,7 +512,7 @@ func TestColumnsUpdateQuery(t *testing.T) { name: "struct, string and toast string (bigquery) w/ reserved keywords", columns: lastCaseEscapeTypes, destKind: constants.BigQuery, - expectedString: fmt.Sprintf(`a1= CASE WHEN COALESCE(TO_JSON_STRING(cc.a1) != '%s', true) THEN cc.a1 ELSE c.a1 END,b2= CASE WHEN COALESCE(cc.b2 != '__debezium_unavailable_value', true) THEN cc.b2 ELSE c.b2 END,c3=cc.c3,%s,%s`, + expectedString: fmt.Sprintf("`a1`= CASE WHEN COALESCE(TO_JSON_STRING(cc.`a1`) != '%s', true) THEN cc.`a1` ELSE c.`a1` END,`b2`= CASE WHEN COALESCE(cc.`b2` != '__debezium_unavailable_value', true) THEN cc.`b2` ELSE c.`b2` END,`c3`=cc.`c3`,%s,%s", key, fmt.Sprintf("`start`= CASE WHEN COALESCE(TO_JSON_STRING(cc.`start`) != '%s', true) THEN cc.`start` ELSE c.`start` END", key), "`select`=cc.`select`,__artie_delete=cc.__artie_delete"), skipDeleteCol: false, }, diff --git a/lib/typing/columns/default_test.go b/lib/typing/columns/default_test.go index 05ac93442..8c277b4df 100644 --- a/lib/typing/columns/default_test.go +++ b/lib/typing/columns/default_test.go @@ -127,7 +127,7 @@ func TestColumn_DefaultValue(t *testing.T) { args: &DefaultValueArgs{ Escape: true, }, - expectedValue: "'03:19:24'", + expectedValue: "'03:19:24.942'", }, { name: "datetime", @@ -138,7 +138,7 @@ func TestColumn_DefaultValue(t *testing.T) { args: &DefaultValueArgs{ Escape: true, }, - expectedValue: "'2022-09-06T03:19:24Z'", + expectedValue: "'2022-09-06T03:19:24.942Z'", }, } diff --git a/lib/typing/columns/wrapper_test.go b/lib/typing/columns/wrapper_test.go index 7a624ab18..083eb7235 100644 --- a/lib/typing/columns/wrapper_test.go +++ b/lib/typing/columns/wrapper_test.go @@ -23,25 +23,25 @@ func TestWrapper_Complete(t *testing.T) { name: "happy", expectedRawName: "happy", expectedEscapedName: "happy", - expectedEscapedNameBQ: "happy", + expectedEscapedNameBQ: "`happy`", }, { name: "user_id", expectedRawName: "user_id", expectedEscapedName: "user_id", - expectedEscapedNameBQ: "user_id", + expectedEscapedNameBQ: "`user_id`", }, { name: "group", expectedRawName: "group", - expectedEscapedName: `"group"`, + expectedEscapedName: `"GROUP"`, expectedEscapedNameBQ: "`group`", }, } for _, testCase := range testCases { // Snowflake escape - w := NewWrapper(NewColumn(testCase.name, typing.Invalid), false, constants.Snowflake) + w := NewWrapper(NewColumn(testCase.name, typing.Invalid), true, constants.Snowflake) assert.Equal(t, testCase.expectedEscapedName, w.EscapedName(), testCase.name) assert.Equal(t, testCase.expectedRawName, w.RawName(), testCase.name) diff --git a/lib/typing/ext/variables.go b/lib/typing/ext/variables.go index b2f7ecb97..2f44d353f 100644 --- a/lib/typing/ext/variables.go +++ b/lib/typing/ext/variables.go @@ -4,7 +4,7 @@ import "time" const ( BigQueryDateTimeFormat = "2006-01-02 15:04:05.999999" - ISO8601 = "2006-01-02T15:04:05-07:00" + ISO8601 = "2006-01-02T15:04:05.999-07:00" PostgresDateFormat = "2006-01-02" PostgresTimeFormat = "15:04:05.999999-07" // microsecond precision AdditionalTimeFormat = "15:04:05.999999Z07" diff --git a/lib/typing/redshift.go b/lib/typing/redshift.go index 18c96d794..91f5f5daf 100644 --- a/lib/typing/redshift.go +++ b/lib/typing/redshift.go @@ -9,6 +9,7 @@ import ( ) func redshiftTypeToKind(rawType string, stringPrecision string) KindDetails { + // TODO: Check if there are any missing Redshift data types. if strings.HasPrefix(rawType, "numeric") { return ParseNumeric(defaultPrefix, rawType) } @@ -29,7 +30,7 @@ func redshiftTypeToKind(rawType string, stringPrecision string) KindDetails { switch rawType { case "super": return Struct - case "integer", "bigint": + case "smallint", "integer", "bigint": return Integer case "double precision": return Float diff --git a/lib/typing/snowflake_test.go b/lib/typing/snowflake_test.go index 571dcdbd1..7a5eb1f2b 100644 --- a/lib/typing/snowflake_test.go +++ b/lib/typing/snowflake_test.go @@ -41,7 +41,7 @@ func TestSnowflakeTypeToKindFloats(t *testing.T) { { // Invalid because precision nor scale is included. kd, err := DwhTypeToKind(constants.Snowflake, "NUMERIC", "") - assert.NoError(t, err) + assert.ErrorContains(t, err, `unable to map type: "numeric" to dwh type`) assert.Equal(t, Invalid, kd) } { @@ -105,12 +105,12 @@ func TestSnowflakeTypeToKindComplex(t *testing.T) { func TestSnowflakeTypeToKindErrors(t *testing.T) { { kd, err := DwhTypeToKind(constants.Snowflake, "", "") - assert.NoError(t, err) + assert.ErrorContains(t, err, `unable to map type: "" to dwh type`) assert.Equal(t, Invalid, kd) } { kd, err := DwhTypeToKind(constants.Snowflake, "abc123", "") - assert.NoError(t, err) + assert.ErrorContains(t, err, `unable to map type: "abc123" to dwh type`) assert.Equal(t, Invalid, kd) } } diff --git a/lib/typing/typing.go b/lib/typing/typing.go index 16b635b0d..ac3b1f72c 100644 --- a/lib/typing/typing.go +++ b/lib/typing/typing.go @@ -194,17 +194,22 @@ func KindToDWHType(kd KindDetails, dwh constants.DestinationKind, isPk bool) str func DwhTypeToKind(dwh constants.DestinationKind, dwhType, stringPrecision string) (KindDetails, error) { dwhType = strings.ToLower(dwhType) - + var kd KindDetails switch dwh { case constants.Snowflake: - return snowflakeTypeToKind(dwhType), nil + kd = snowflakeTypeToKind(dwhType) case constants.BigQuery: - return bigQueryTypeToKind(dwhType), nil + kd = bigQueryTypeToKind(dwhType) case constants.Redshift: - return redshiftTypeToKind(dwhType, stringPrecision), nil + kd = redshiftTypeToKind(dwhType, stringPrecision) case constants.MSSQL: - return mssqlTypeToKind(dwhType, stringPrecision), nil + kd = mssqlTypeToKind(dwhType, stringPrecision) + default: + return Invalid, fmt.Errorf("unexpected dwh kind, label: %v", dwh) } - return Invalid, fmt.Errorf("unexpected dwh kind, label: %v", dwh) + if kd.Kind == Invalid.Kind { + return Invalid, fmt.Errorf("unable to map type: %q to dwh type", dwhType) + } + return kd, nil }