From 6fb875d565e52756c69a4c7a4e0fc24cdbcb8b58 Mon Sep 17 00:00:00 2001 From: Robin Tang Date: Tue, 11 Jul 2023 12:20:04 -0700 Subject: [PATCH] Supporting tables named after keywords (#152) --- clients/bigquery/bigquery.go | 9 +- clients/bigquery/merge.go | 38 +++--- clients/redshift/merge.go | 27 ++-- clients/redshift/staging.go | 6 +- clients/snowflake/ddl_test.go | 12 +- clients/snowflake/snowflake_test.go | 20 +-- clients/snowflake/staging.go | 34 ++--- clients/snowflake/staging_test.go | 2 +- clients/snowflake/util.go | 5 +- clients/snowflake/util_test.go | 6 +- clients/utils/table_config.go | 2 +- clients/utils/utils.go | 6 +- lib/cdc/mysql/debezium_suite_test.go | 9 +- lib/cdc/mysql/debezium_test.go | 4 +- lib/cdc/util/relational_event_test.go | 24 ++-- lib/config/config.go | 7 ++ lib/dwh/ddl/ddl.go | 10 +- lib/dwh/ddl/ddl_alter_delete_test.go | 6 +- lib/dwh/ddl/ddl_bq_test.go | 20 +-- lib/dwh/ddl/ddl_sflk_test.go | 25 ++-- lib/dwh/ddl/ddl_suite_test.go | 1 + lib/dwh/dml/merge.go | 19 +-- lib/dwh/dml/merge_bigquery_test.go | 29 ++--- lib/dwh/dml/merge_parts_test.go | 130 ++++++++++---------- lib/dwh/dml/merge_suite_test.go | 23 ++++ lib/dwh/dml/merge_test.go | 88 ++++++------- lib/dwh/dml/merge_valid_test.go | 12 +- lib/dwh/types/table_config.go | 8 +- lib/dwh/types/table_config_test.go | 38 +++--- lib/dwh/types/types_suite_test.go | 24 ++++ lib/dwh/types/types_test.go | 9 +- lib/optimization/event.go | 31 +++-- lib/optimization/event_test.go | 48 ++++---- lib/optimization/event_update_test.go | 54 ++++---- lib/optimization/optimization_suite_test.go | 24 ++++ lib/sql/escape.go | 40 ++++++ lib/sql/escape_test.go | 102 +++++++++++++++ lib/sql/sql_suite_test.go | 24 ++++ lib/typing/columns/columns.go | 43 ++----- lib/typing/columns/columns_suite_test.go | 25 ++++ lib/typing/columns/columns_test.go | 110 ++++++++--------- lib/typing/columns/diff.go | 13 +- lib/typing/columns/diff_test.go | 58 +++++---- lib/typing/columns/wrapper.go | 10 +- lib/typing/columns/wrapper_test.go | 28 ++--- models/event/event_save_test.go | 10 +- 46 files changed, 777 insertions(+), 496 deletions(-) create mode 100644 lib/dwh/dml/merge_suite_test.go create mode 100644 lib/dwh/types/types_suite_test.go create mode 100644 lib/optimization/optimization_suite_test.go create mode 100644 lib/sql/escape.go create mode 100644 lib/sql/escape_test.go create mode 100644 lib/sql/sql_suite_test.go create mode 100644 lib/typing/columns/columns_suite_test.go diff --git a/clients/bigquery/bigquery.go b/clients/bigquery/bigquery.go index ecd93adac..05d19074f 100644 --- a/clients/bigquery/bigquery.go +++ b/clients/bigquery/bigquery.go @@ -33,10 +33,11 @@ type Store struct { func (s *Store) getTableConfig(ctx context.Context, tableData *optimization.TableData) (*types.DwhTableConfig, error) { return utils.GetTableConfig(ctx, utils.GetTableCfgArgs{ - Dwh: s, - FqName: tableData.ToFqName(ctx, constants.BigQuery), - ConfigMap: s.configMap, - Query: fmt.Sprintf("SELECT column_name, data_type, description FROM `%s.INFORMATION_SCHEMA.COLUMN_FIELD_PATHS` WHERE table_name='%s';", tableData.TopicConfig.Database, tableData.Name()), + Dwh: s, + FqName: tableData.ToFqName(ctx, constants.BigQuery, true), + ConfigMap: s.configMap, + Query: fmt.Sprintf("SELECT column_name, data_type, description FROM `%s.INFORMATION_SCHEMA.COLUMN_FIELD_PATHS` WHERE table_name='%s';", + tableData.TopicConfig.Database, tableData.Name(ctx, nil)), ColumnNameLabel: describeNameCol, ColumnTypeLabel: describeTypeCol, ColumnDescLabel: describeCommentCol, diff --git a/clients/bigquery/merge.go b/clients/bigquery/merge.go index b5bfaa1a7..507b743bd 100644 --- a/clients/bigquery/merge.go +++ b/clients/bigquery/merge.go @@ -6,6 +6,8 @@ import ( "strings" "time" + "github.com/artie-labs/transfer/lib/sql" + "github.com/artie-labs/transfer/lib/ptr" "github.com/artie-labs/transfer/lib/jitter" @@ -36,11 +38,11 @@ func (r *Row) Save() (map[string]bigquery.Value, string, error) { return r.data, bigquery.NoDedupeID, nil } -func merge(tableData *optimization.TableData) ([]*Row, error) { +func merge(ctx context.Context, tableData *optimization.TableData) ([]*Row, error) { var rows []*Row for _, value := range tableData.RowsData() { data := make(map[string]bigquery.Value) - for _, col := range tableData.ReadOnlyInMemoryCols().GetColumnsToUpdate(nil) { + for _, col := range tableData.ReadOnlyInMemoryCols().GetColumnsToUpdate(ctx, nil) { colKind, _ := tableData.ReadOnlyInMemoryCols().GetColumn(col) colVal, err := CastColVal(value[col], colKind) if err != nil { @@ -76,13 +78,13 @@ func (s *Store) backfillColumn(ctx context.Context, column columns.Column, fqTab } fqTableName = strings.ToLower(fqTableName) - escapedCol := column.Name(&columns.NameArgs{Escape: true, DestKind: s.Label()}) + escapedCol := column.Name(ctx, &sql.NameArgs{Escape: true, DestKind: s.Label()}) query := fmt.Sprintf(`UPDATE %s SET %s = %v WHERE %s IS NULL;`, // UPDATE table SET col = default_val WHERE col IS NULL fqTableName, escapedCol, defaultVal, escapedCol) logger.FromContext(ctx).WithFields(map[string]interface{}{ - "colName": column.Name(nil), + "colName": column.Name(ctx, nil), "query": query, "table": fqTableName, }).Info("backfilling column") @@ -113,12 +115,12 @@ func (s *Store) Merge(ctx context.Context, tableData *optimization.TableData) er log := logger.FromContext(ctx) // Check if all the columns exist in BigQuery - srcKeysMissing, targetKeysMissing := columns.Diff(tableData.ReadOnlyInMemoryCols(), + srcKeysMissing, targetKeysMissing := columns.Diff(ctx, tableData.ReadOnlyInMemoryCols(), tableConfig.Columns(), tableData.TopicConfig.SoftDelete, tableData.TopicConfig.IncludeArtieUpdatedAt) createAlterTableArgs := ddl.AlterTableArgs{ Dwh: s, Tc: tableConfig, - FqTableName: tableData.ToFqName(ctx, s.Label()), + FqTableName: tableData.ToFqName(ctx, s.Label(), true), CreateTable: tableConfig.CreateTable(), ColumnOp: constants.Add, CdcTime: tableData.LatestCDCTs, @@ -137,7 +139,7 @@ func (s *Store) Merge(ctx context.Context, tableData *optimization.TableData) er deleteAlterTableArgs := ddl.AlterTableArgs{ Dwh: s, Tc: tableConfig, - FqTableName: tableData.ToFqName(ctx, s.Label()), + FqTableName: tableData.ToFqName(ctx, s.Label(), true), CreateTable: false, ColumnOp: constants.Delete, ContainOtherOperations: tableData.ContainOtherOperations(), @@ -155,7 +157,7 @@ func (s *Store) Merge(ctx context.Context, tableData *optimization.TableData) er for colToDelete := range tableConfig.ReadOnlyColumnsToDelete() { var found bool for _, col := range srcKeysMissing { - if found = col.Name(nil) == colToDelete; found { + if found = col.Name(ctx, nil) == colToDelete; found { // Found it. break } @@ -168,13 +170,13 @@ func (s *Store) Merge(ctx context.Context, tableData *optimization.TableData) er } // Infer the right data types from BigQuery before temp table creation. - tableData.UpdateInMemoryColumnsFromDestination(tableConfig.Columns().GetColumns()...) + tableData.UpdateInMemoryColumnsFromDestination(ctx, tableConfig.Columns().GetColumns()...) // Start temporary table creation tempAlterTableArgs := ddl.AlterTableArgs{ Dwh: s, Tc: tableConfig, - FqTableName: fmt.Sprintf("%s_%s", tableData.ToFqName(ctx, s.Label()), tableData.TempTableSuffix()), + FqTableName: fmt.Sprintf("%s_%s", tableData.ToFqName(ctx, s.Label(), false), tableData.TempTableSuffix()), CreateTable: true, TemporaryTable: true, ColumnOp: constants.Add, @@ -193,9 +195,9 @@ func (s *Store) Merge(ctx context.Context, tableData *optimization.TableData) er var attempts int for { - err = s.backfillColumn(ctx, col, tableData.ToFqName(ctx, s.Label())) + err = s.backfillColumn(ctx, col, tableData.ToFqName(ctx, s.Label(), true)) if err == nil { - tableConfig.Columns().UpsertColumn(col.Name(nil), columns.UpsertColumnArg{ + tableConfig.Columns().UpsertColumn(col.Name(ctx, nil), columns.UpsertColumnArg{ Backfilled: ptr.ToBool(true), }) break @@ -208,30 +210,30 @@ func (s *Store) Merge(ctx context.Context, tableData *optimization.TableData) er } else { defaultVal, _ := col.DefaultValue(nil) return fmt.Errorf("failed to backfill col: %v, default value: %v, err: %v", - col.Name(nil), defaultVal, err) + col.Name(ctx, nil), defaultVal, err) } } } // Perform actual merge now - rows, err := merge(tableData) + rows, err := merge(ctx, tableData) if err != nil { log.WithError(err).Warn("failed to generate the merge query") return err } - tableName := fmt.Sprintf("%s_%s", tableData.Name(), tableData.TempTableSuffix()) + tableName := fmt.Sprintf("%s_%s", tableData.Name(ctx, nil), tableData.TempTableSuffix()) err = s.PutTable(ctx, tableData.TopicConfig.Database, tableName, rows) if err != nil { return fmt.Errorf("failed to insert into temp table: %s, error: %v", tableName, err) } - mergeQuery, err := dml.MergeStatement(&dml.MergeArgument{ - FqTableName: tableData.ToFqName(ctx, constants.BigQuery), + mergeQuery, err := dml.MergeStatement(ctx, &dml.MergeArgument{ + FqTableName: tableData.ToFqName(ctx, constants.BigQuery, true), SubQuery: tempAlterTableArgs.FqTableName, IdempotentKey: tableData.TopicConfig.IdempotentKey, - PrimaryKeys: tableData.PrimaryKeys(&columns.NameArgs{ + PrimaryKeys: tableData.PrimaryKeys(ctx, &sql.NameArgs{ Escape: true, DestKind: s.Label(), }), diff --git a/clients/redshift/merge.go b/clients/redshift/merge.go index 0e881719d..26af45837 100644 --- a/clients/redshift/merge.go +++ b/clients/redshift/merge.go @@ -4,6 +4,8 @@ import ( "context" "fmt" + "github.com/artie-labs/transfer/lib/sql" + "github.com/artie-labs/transfer/clients/utils" "github.com/artie-labs/transfer/lib/config/constants" "github.com/artie-labs/transfer/lib/dwh/ddl" @@ -21,7 +23,10 @@ func (s *Store) Merge(ctx context.Context, tableData *optimization.TableData) er } tableConfig, err := s.getTableConfig(ctx, getTableConfigArgs{ - Table: tableData.Name(), + Table: tableData.Name(ctx, &sql.NameArgs{ + Escape: true, + DestKind: s.Label(), + }), Schema: tableData.TopicConfig.Schema, DropDeletedColumns: tableData.TopicConfig.DropDeletedColumns, }) @@ -30,9 +35,9 @@ func (s *Store) Merge(ctx context.Context, tableData *optimization.TableData) er } log := logger.FromContext(ctx) - fqName := tableData.ToFqName(ctx, s.Label()) + fqName := tableData.ToFqName(ctx, s.Label(), true) // Check if all the columns exist in Redshift - srcKeysMissing, targetKeysMissing := columns.Diff(tableData.ReadOnlyInMemoryCols(), tableConfig.Columns(), + srcKeysMissing, targetKeysMissing := columns.Diff(ctx, tableData.ReadOnlyInMemoryCols(), tableConfig.Columns(), tableData.TopicConfig.SoftDelete, tableData.TopicConfig.IncludeArtieUpdatedAt) createAlterTableArgs := ddl.AlterTableArgs{ Dwh: s, @@ -74,7 +79,7 @@ func (s *Store) Merge(ctx context.Context, tableData *optimization.TableData) er for colToDelete := range tableConfig.ReadOnlyColumnsToDelete() { var found bool for _, col := range srcKeysMissing { - if found = col.Name(nil) == colToDelete; found { + if found = col.Name(ctx, nil) == colToDelete; found { // Found it. break } @@ -86,10 +91,10 @@ func (s *Store) Merge(ctx context.Context, tableData *optimization.TableData) er } } - tableData.UpdateInMemoryColumnsFromDestination(tableConfig.Columns().GetColumns()...) + tableData.UpdateInMemoryColumnsFromDestination(ctx, tableConfig.Columns().GetColumns()...) // Temporary tables cannot specify schemas, so we just prefix it instead. - temporaryTableName := fmt.Sprintf("%s_%s", tableData.ToFqName(ctx, s.Label()), tableData.TempTableSuffix()) + temporaryTableName := fmt.Sprintf("%s_%s", tableData.ToFqName(ctx, s.Label(), false), tableData.TempTableSuffix()) if err = s.prepareTempTable(ctx, tableData, tableConfig, temporaryTableName); err != nil { return err } @@ -100,24 +105,24 @@ func (s *Store) Merge(ctx context.Context, tableData *optimization.TableData) er continue } - err = utils.BackfillColumn(ctx, s, col, tableData.ToFqName(ctx, s.Label())) + err = utils.BackfillColumn(ctx, s, col, tableData.ToFqName(ctx, s.Label(), true)) if err != nil { defaultVal, _ := col.DefaultValue(nil) return fmt.Errorf("failed to backfill col: %v, default value: %v, error: %v", - col.Name(nil), defaultVal, err) + col.Name(ctx, nil), defaultVal, err) } - tableConfig.Columns().UpsertColumn(col.Name(nil), columns.UpsertColumnArg{ + tableConfig.Columns().UpsertColumn(col.Name(ctx, nil), columns.UpsertColumnArg{ Backfilled: ptr.ToBool(true), }) } // Prepare merge statement - mergeParts, err := dml.MergeStatementParts(&dml.MergeArgument{ + mergeParts, err := dml.MergeStatementParts(ctx, &dml.MergeArgument{ FqTableName: fqName, SubQuery: temporaryTableName, IdempotentKey: tableData.TopicConfig.IdempotentKey, - PrimaryKeys: tableData.PrimaryKeys(&columns.NameArgs{ + PrimaryKeys: tableData.PrimaryKeys(ctx, &sql.NameArgs{ Escape: true, DestKind: s.Label(), }), diff --git a/clients/redshift/staging.go b/clients/redshift/staging.go index f911998bc..6a15c4cf9 100644 --- a/clients/redshift/staging.go +++ b/clients/redshift/staging.go @@ -38,7 +38,7 @@ func (s *Store) prepareTempTable(ctx context.Context, tableData *optimization.Ta return fmt.Errorf("failed to add comment to table, tableName: %v, err: %v", tempTableName, err) } - fp, err := s.loadTemporaryTable(tableData, tempTableName) + fp, err := s.loadTemporaryTable(ctx, tableData, tempTableName) if err != nil { return fmt.Errorf("failed to load temporary table, err: %v", err) } @@ -72,7 +72,7 @@ func (s *Store) prepareTempTable(ctx context.Context, tableData *optimization.Ta // loadTemporaryTable will write the data into /tmp/newTableName.csv // This way, another function can call this and then invoke a Snowflake PUT. // Returns the file path and potential error -func (s *Store) loadTemporaryTable(tableData *optimization.TableData, newTableName string) (string, error) { +func (s *Store) loadTemporaryTable(ctx context.Context, tableData *optimization.TableData, newTableName string) (string, error) { filePath := fmt.Sprintf("/tmp/%s.csv", newTableName) file, err := os.Create(filePath) if err != nil { @@ -84,7 +84,7 @@ func (s *Store) loadTemporaryTable(tableData *optimization.TableData, newTableNa writer.Comma = '\t' for _, value := range tableData.RowsData() { var row []string - for _, col := range tableData.ReadOnlyInMemoryCols().GetColumnsToUpdate(nil) { + for _, col := range tableData.ReadOnlyInMemoryCols().GetColumnsToUpdate(ctx, nil) { colKind, _ := tableData.ReadOnlyInMemoryCols().GetColumn(col) colVal := value[col] // Check diff --git a/clients/snowflake/ddl_test.go b/clients/snowflake/ddl_test.go index 3653f39c6..a88cbd241 100644 --- a/clients/snowflake/ddl_test.go +++ b/clients/snowflake/ddl_test.go @@ -36,12 +36,12 @@ func (s *SnowflakeTestSuite) TestMutateColumnsWithMemoryCacheDeletions() { nameCol := columns.NewColumn("name", typing.String) tc := s.stageStore.configMap.TableConfig(fqName) - val := tc.ShouldDeleteColumn(s.ctx, nameCol.Name(nil), time.Now().Add(-1*6*time.Hour), true) + val := tc.ShouldDeleteColumn(s.ctx, nameCol.Name(s.ctx, nil), time.Now().Add(-1*6*time.Hour), true) assert.False(s.T(), val, "should not try to delete this column") assert.Equal(s.T(), len(s.stageStore.configMap.TableConfig(fqName).ReadOnlyColumnsToDelete()), 1) // Now let's try to add this column back, it should delete it from the cache. - tc.MutateInMemoryColumns(false, constants.Add, nameCol) + tc.MutateInMemoryColumns(s.ctx, false, constants.Add, nameCol) assert.Equal(s.T(), len(s.stageStore.configMap.TableConfig(fqName).ReadOnlyColumnsToDelete()), 0) } @@ -63,23 +63,23 @@ func (s *SnowflakeTestSuite) TestShouldDeleteColumn() { nameCol := columns.NewColumn("name", typing.String) // Let's try to delete name. - allowed := s.stageStore.configMap.TableConfig(fqName).ShouldDeleteColumn(s.ctx, nameCol.Name(nil), + allowed := s.stageStore.configMap.TableConfig(fqName).ShouldDeleteColumn(s.ctx, nameCol.Name(s.ctx, nil), time.Now().Add(-1*(6*time.Hour)), true) assert.Equal(s.T(), allowed, false, "should not be allowed to delete") // Process tried to delete, but it's lagged. - allowed = s.stageStore.configMap.TableConfig(fqName).ShouldDeleteColumn(s.ctx, nameCol.Name(nil), + allowed = s.stageStore.configMap.TableConfig(fqName).ShouldDeleteColumn(s.ctx, nameCol.Name(s.ctx, nil), time.Now().Add(-1*(6*time.Hour)), true) assert.Equal(s.T(), allowed, false, "should not be allowed to delete") // Process now caught up, and is asking if we can delete, should still be no. - allowed = s.stageStore.configMap.TableConfig(fqName).ShouldDeleteColumn(s.ctx, nameCol.Name(nil), time.Now(), true) + allowed = s.stageStore.configMap.TableConfig(fqName).ShouldDeleteColumn(s.ctx, nameCol.Name(s.ctx, nil), time.Now(), true) assert.Equal(s.T(), allowed, false, "should not be allowed to delete still") // Process is finally ahead, has permission to delete now. - allowed = s.stageStore.configMap.TableConfig(fqName).ShouldDeleteColumn(s.ctx, nameCol.Name(nil), + allowed = s.stageStore.configMap.TableConfig(fqName).ShouldDeleteColumn(s.ctx, nameCol.Name(s.ctx, nil), time.Now().Add(2*constants.DeletionConfidencePadding), true) assert.Equal(s.T(), allowed, true, "should now be allowed to delete") diff --git a/clients/snowflake/snowflake_test.go b/clients/snowflake/snowflake_test.go index d5cab02e5..db684699b 100644 --- a/clients/snowflake/snowflake_test.go +++ b/clients/snowflake/snowflake_test.go @@ -47,7 +47,7 @@ func (s *SnowflakeTestSuite) TestExecuteMergeNilEdgeCase() { } tableData := optimization.NewTableData(&cols, []string{"id"}, topicConfig, "foo") - assert.Equal(s.T(), topicConfig.TableName, tableData.Name(), "override is working") + assert.Equal(s.T(), topicConfig.TableName, tableData.Name(s.ctx, nil), "override is working") for pk, row := range rowsData { tableData.InsertRow(pk, row, false) @@ -64,7 +64,7 @@ func (s *SnowflakeTestSuite) TestExecuteMergeNilEdgeCase() { anotherCols.AddColumn(columns.NewColumn(colName, kindDetails)) } - s.stageStore.configMap.AddTableToConfig(tableData.ToFqName(s.ctx, constants.Snowflake), + s.stageStore.configMap.AddTableToConfig(tableData.ToFqName(s.ctx, constants.Snowflake, true), types.NewDwhTableConfig(&anotherCols, nil, false, true)) err := s.stageStore.Merge(s.ctx, tableData) @@ -109,7 +109,7 @@ func (s *SnowflakeTestSuite) TestExecuteMergeReestablishAuth() { tableData.InsertRow(pk, row, false) } - s.stageStore.configMap.AddTableToConfig(tableData.ToFqName(s.ctx, constants.Snowflake), + s.stageStore.configMap.AddTableToConfig(tableData.ToFqName(s.ctx, constants.Snowflake, true), types.NewDwhTableConfig(&cols, nil, false, true)) s.fakeStageStore.ExecReturnsOnCall(0, nil, fmt.Errorf("390114: Authentication token has expired. The user must authenticate again.")) @@ -161,7 +161,7 @@ func (s *SnowflakeTestSuite) TestExecuteMerge() { var idx int for _, destKind := range []constants.DestinationKind{constants.Snowflake, constants.SnowflakeStages} { - fqName := tableData.ToFqName(s.ctx, destKind) + fqName := tableData.ToFqName(s.ctx, destKind, true) s.stageStore.configMap.AddTableToConfig(fqName, types.NewDwhTableConfig(&cols, nil, false, true)) err := s.stageStore.Merge(s.ctx, tableData) assert.Nil(s.T(), err) @@ -244,7 +244,7 @@ func (s *SnowflakeTestSuite) TestExecuteMergeDeletionFlagRemoval() { sflkCols.AddColumn(columns.NewColumn("new", typing.String)) config := types.NewDwhTableConfig(&sflkCols, nil, false, true) - s.stageStore.configMap.AddTableToConfig(tableData.ToFqName(s.ctx, constants.Snowflake), config) + s.stageStore.configMap.AddTableToConfig(tableData.ToFqName(s.ctx, constants.Snowflake, true), config) err := s.stageStore.Merge(s.ctx, tableData) assert.Nil(s.T(), err) @@ -252,10 +252,10 @@ func (s *SnowflakeTestSuite) TestExecuteMergeDeletionFlagRemoval() { assert.Equal(s.T(), s.fakeStageStore.ExecCallCount(), 5, "called merge") // Check the temp deletion table now. - assert.Equal(s.T(), len(s.stageStore.configMap.TableConfig(tableData.ToFqName(s.ctx, constants.Snowflake)).ReadOnlyColumnsToDelete()), 1, - s.stageStore.configMap.TableConfig(tableData.ToFqName(s.ctx, constants.Snowflake)).ReadOnlyColumnsToDelete()) + assert.Equal(s.T(), len(s.stageStore.configMap.TableConfig(tableData.ToFqName(s.ctx, constants.Snowflake, true)).ReadOnlyColumnsToDelete()), 1, + s.stageStore.configMap.TableConfig(tableData.ToFqName(s.ctx, constants.Snowflake, true)).ReadOnlyColumnsToDelete()) - _, isOk := s.stageStore.configMap.TableConfig(tableData.ToFqName(s.ctx, constants.Snowflake)).ReadOnlyColumnsToDelete()["new"] + _, isOk := s.stageStore.configMap.TableConfig(tableData.ToFqName(s.ctx, constants.Snowflake, true)).ReadOnlyColumnsToDelete()["new"] assert.True(s.T(), isOk) // Now try to execute merge where 1 of the rows have the column now @@ -276,8 +276,8 @@ func (s *SnowflakeTestSuite) TestExecuteMergeDeletionFlagRemoval() { assert.Equal(s.T(), s.fakeStageStore.ExecCallCount(), 10, "called merge again") // Caught up now, so columns should be 0. - assert.Equal(s.T(), len(s.stageStore.configMap.TableConfig(tableData.ToFqName(s.ctx, constants.Snowflake)).ReadOnlyColumnsToDelete()), 0, - s.stageStore.configMap.TableConfig(tableData.ToFqName(s.ctx, constants.Snowflake)).ReadOnlyColumnsToDelete()) + assert.Equal(s.T(), len(s.stageStore.configMap.TableConfig(tableData.ToFqName(s.ctx, constants.Snowflake, true)).ReadOnlyColumnsToDelete()), 0, + s.stageStore.configMap.TableConfig(tableData.ToFqName(s.ctx, constants.Snowflake, true)).ReadOnlyColumnsToDelete()) } func (s *SnowflakeTestSuite) TestExecuteMergeExitEarly() { diff --git a/clients/snowflake/staging.go b/clients/snowflake/staging.go index aaf2b0b06..23b35d63c 100644 --- a/clients/snowflake/staging.go +++ b/clients/snowflake/staging.go @@ -7,6 +7,8 @@ import ( "os" "strings" + "github.com/artie-labs/transfer/lib/sql" + "github.com/artie-labs/transfer/clients/utils" "github.com/artie-labs/transfer/lib/ptr" @@ -41,7 +43,7 @@ func (s *Store) prepareTempTable(ctx context.Context, tableData *optimization.Ta return fmt.Errorf("failed to create temp table, error: %v", err) } - fp, err := s.loadTemporaryTable(tableData, tempTableName) + fp, err := s.loadTemporaryTable(ctx, tableData, tempTableName) if err != nil { return fmt.Errorf("failed to load temporary table, err: %v", err) } @@ -52,12 +54,12 @@ func (s *Store) prepareTempTable(ctx context.Context, tableData *optimization.Ta _, err = s.Exec(fmt.Sprintf("COPY INTO %s (%s) FROM (SELECT %s FROM @%s)", // Copy into temporary tables (column ...) - tempTableName, strings.Join(tableData.ReadOnlyInMemoryCols().GetColumnsToUpdate(&columns.NameArgs{ + tempTableName, strings.Join(tableData.ReadOnlyInMemoryCols().GetColumnsToUpdate(ctx, &sql.NameArgs{ Escape: true, DestKind: s.Label(), }), ","), // Escaped columns, TABLE NAME - escapeColumns(tableData.ReadOnlyInMemoryCols(), ","), addPrefixToTableName(tempTableName, "%"))) + escapeColumns(ctx, tableData.ReadOnlyInMemoryCols(), ","), addPrefixToTableName(tempTableName, "%"))) if err != nil { return fmt.Errorf("failed to load staging file into temporary table, err: %v", err) @@ -73,7 +75,7 @@ func (s *Store) prepareTempTable(ctx context.Context, tableData *optimization.Ta // loadTemporaryTable will write the data into /tmp/newTableName.csv // This way, another function can call this and then invoke a Snowflake PUT. // Returns the file path and potential error -func (s *Store) loadTemporaryTable(tableData *optimization.TableData, newTableName string) (string, error) { +func (s *Store) loadTemporaryTable(ctx context.Context, tableData *optimization.TableData, newTableName string) (string, error) { filePath := fmt.Sprintf("/tmp/%s.csv", newTableName) file, err := os.Create(filePath) if err != nil { @@ -85,7 +87,7 @@ func (s *Store) loadTemporaryTable(tableData *optimization.TableData, newTableNa writer.Comma = '\t' for _, value := range tableData.RowsData() { var row []string - for _, col := range tableData.ReadOnlyInMemoryCols().GetColumnsToUpdate(nil) { + for _, col := range tableData.ReadOnlyInMemoryCols().GetColumnsToUpdate(ctx, nil) { colKind, _ := tableData.ReadOnlyInMemoryCols().GetColumn(col) colVal := value[col] // Check @@ -113,7 +115,7 @@ func (s *Store) mergeWithStages(ctx context.Context, tableData *optimization.Tab return nil } - fqName := tableData.ToFqName(ctx, constants.Snowflake) + fqName := tableData.ToFqName(ctx, constants.Snowflake, true) tableConfig, err := s.getTableConfig(ctx, fqName, tableData.TopicConfig.DropDeletedColumns) if err != nil { return err @@ -121,7 +123,7 @@ func (s *Store) mergeWithStages(ctx context.Context, tableData *optimization.Tab log := logger.FromContext(ctx) // Check if all the columns exist in Snowflake - srcKeysMissing, targetKeysMissing := columns.Diff(tableData.ReadOnlyInMemoryCols(), tableConfig.Columns(), + srcKeysMissing, targetKeysMissing := columns.Diff(ctx, tableData.ReadOnlyInMemoryCols(), tableConfig.Columns(), tableData.TopicConfig.SoftDelete, tableData.TopicConfig.IncludeArtieUpdatedAt) createAlterTableArgs := ddl.AlterTableArgs{ Dwh: s, @@ -163,7 +165,7 @@ func (s *Store) mergeWithStages(ctx context.Context, tableData *optimization.Tab for colToDelete := range tableConfig.ReadOnlyColumnsToDelete() { var found bool for _, col := range srcKeysMissing { - if found = col.Name(nil) == colToDelete; found { + if found = col.Name(ctx, nil) == colToDelete; found { // Found it. break } @@ -175,8 +177,8 @@ func (s *Store) mergeWithStages(ctx context.Context, tableData *optimization.Tab } } - tableData.UpdateInMemoryColumnsFromDestination(tableConfig.Columns().GetColumns()...) - temporaryTableName := fmt.Sprintf("%s_%s", tableData.ToFqName(ctx, s.Label()), tableData.TempTableSuffix()) + tableData.UpdateInMemoryColumnsFromDestination(ctx, tableConfig.Columns().GetColumns()...) + temporaryTableName := fmt.Sprintf("%s_%s", tableData.ToFqName(ctx, s.Label(), false), tableData.TempTableSuffix()) if err = s.prepareTempTable(ctx, tableData, tableConfig, temporaryTableName); err != nil { return err } @@ -187,24 +189,24 @@ func (s *Store) mergeWithStages(ctx context.Context, tableData *optimization.Tab continue } - err = utils.BackfillColumn(ctx, s, col, tableData.ToFqName(ctx, s.Label())) + err = utils.BackfillColumn(ctx, s, col, tableData.ToFqName(ctx, s.Label(), true)) if err != nil { defaultVal, _ := col.DefaultValue(nil) return fmt.Errorf("failed to backfill col: %v, default value: %v, error: %v", - col.Name(nil), defaultVal, err) + col.Name(ctx, nil), defaultVal, err) } - tableConfig.Columns().UpsertColumn(col.Name(nil), columns.UpsertColumnArg{ + tableConfig.Columns().UpsertColumn(col.Name(ctx, nil), columns.UpsertColumnArg{ Backfilled: ptr.ToBool(true), }) } // Prepare merge statement - mergeQuery, err := dml.MergeStatement(&dml.MergeArgument{ - FqTableName: tableData.ToFqName(ctx, constants.Snowflake), + mergeQuery, err := dml.MergeStatement(ctx, &dml.MergeArgument{ + FqTableName: tableData.ToFqName(ctx, constants.Snowflake, true), SubQuery: temporaryTableName, IdempotentKey: tableData.TopicConfig.IdempotentKey, - PrimaryKeys: tableData.PrimaryKeys(&columns.NameArgs{ + PrimaryKeys: tableData.PrimaryKeys(ctx, &sql.NameArgs{ Escape: true, DestKind: s.Label(), }), diff --git a/clients/snowflake/staging_test.go b/clients/snowflake/staging_test.go index 75d5e4a75..5d9bc6f0e 100644 --- a/clients/snowflake/staging_test.go +++ b/clients/snowflake/staging_test.go @@ -126,7 +126,7 @@ func (s *SnowflakeTestSuite) TestPrepareTempTable() { func (s *SnowflakeTestSuite) TestLoadTemporaryTable() { tempTableName, tableData := generateTableData(100) - fp, err := s.stageStore.loadTemporaryTable(tableData, tempTableName) + fp, err := s.stageStore.loadTemporaryTable(s.ctx, tableData, tempTableName) assert.NoError(s.T(), err) // Read the CSV and confirm. csvfile, err := os.Open(fp) diff --git a/clients/snowflake/util.go b/clients/snowflake/util.go index 65aa137fd..09ace9cc3 100644 --- a/clients/snowflake/util.go +++ b/clients/snowflake/util.go @@ -1,6 +1,7 @@ package snowflake import ( + "context" "fmt" "strings" @@ -24,9 +25,9 @@ func addPrefixToTableName(fqTableName string, prefix string) string { // escapeColumns will take the columns that are passed in, escape them and return them in the ordered received. // It'll return like this: $1, $2, $3 -func escapeColumns(columns *columns.Columns, delimiter string) string { +func escapeColumns(ctx context.Context, columns *columns.Columns, delimiter string) string { var escapedCols []string - for index, col := range columns.GetColumnsToUpdate(nil) { + for index, col := range columns.GetColumnsToUpdate(ctx, nil) { colKind, _ := columns.GetColumn(col) escapedCol := fmt.Sprintf("$%d", index+1) switch colKind.KindDetails { diff --git a/clients/snowflake/util_test.go b/clients/snowflake/util_test.go index 05e008083..baafb1d8a 100644 --- a/clients/snowflake/util_test.go +++ b/clients/snowflake/util_test.go @@ -46,7 +46,7 @@ func TestAddPrefixToTableName(t *testing.T) { } } -func TestEscapeColumns(t *testing.T) { +func (s *SnowflakeTestSuite) TestEscapeColumns() { type _testCase struct { name string cols *columns.Columns @@ -87,7 +87,7 @@ func TestEscapeColumns(t *testing.T) { } for _, testCase := range testCases { - actualString := escapeColumns(testCase.cols, ",") - assert.Equal(t, testCase.expectedString, actualString, testCase.name) + actualString := escapeColumns(s.ctx, testCase.cols, ",") + assert.Equal(s.T(), testCase.expectedString, actualString, testCase.name) } } diff --git a/clients/utils/table_config.go b/clients/utils/table_config.go index be3f5b9c5..ef74126b9 100644 --- a/clients/utils/table_config.go +++ b/clients/utils/table_config.go @@ -64,7 +64,7 @@ func GetTableConfig(ctx context.Context, args GetTableCfgArgs) (*types.DwhTableC tableMissing = true err = nil } else { - return nil, fmt.Errorf("failed to query %v, err: %v", args.Dwh.Label(), err) + return nil, fmt.Errorf("failed to query %v, err: %v, query: %v", args.Dwh.Label(), err, args.Query) } default: return nil, fmt.Errorf("failed to query %v, err: %v", args.Dwh.Label(), err) diff --git a/clients/utils/utils.go b/clients/utils/utils.go index cf9a6a5b9..68ba7d01a 100644 --- a/clients/utils/utils.go +++ b/clients/utils/utils.go @@ -5,6 +5,8 @@ import ( "fmt" "strings" + "github.com/artie-labs/transfer/lib/sql" + "github.com/artie-labs/transfer/lib/config/constants" "github.com/artie-labs/transfer/lib/dwh" @@ -31,13 +33,13 @@ func BackfillColumn(ctx context.Context, dwh dwh.DataWarehouse, column columns.C return fmt.Errorf("failed to escape default value, err: %v", err) } - escapedCol := column.Name(&columns.NameArgs{Escape: true, DestKind: dwh.Label()}) + escapedCol := column.Name(ctx, &sql.NameArgs{Escape: true, DestKind: dwh.Label()}) query := fmt.Sprintf(`UPDATE %s SET %s = %v WHERE %s IS NULL;`, // UPDATE table SET col = default_val WHERE col IS NULL fqTableName, escapedCol, defaultVal, escapedCol, ) logger.FromContext(ctx).WithFields(map[string]interface{}{ - "colName": column.Name(nil), + "colName": column.Name(ctx, nil), "query": query, "table": fqTableName, }).Info("backfilling column") diff --git a/lib/cdc/mysql/debezium_suite_test.go b/lib/cdc/mysql/debezium_suite_test.go index eae390d89..fce539ce3 100644 --- a/lib/cdc/mysql/debezium_suite_test.go +++ b/lib/cdc/mysql/debezium_suite_test.go @@ -1,18 +1,25 @@ package mysql import ( - "github.com/stretchr/testify/suite" + "context" "testing" + + "github.com/artie-labs/transfer/lib/config" + + "github.com/stretchr/testify/suite" ) type MySQLTestSuite struct { suite.Suite *Debezium + ctx context.Context } func (m *MySQLTestSuite) SetupTest() { var debezium Debezium m.Debezium = &debezium + m.ctx = context.Background() + m.ctx = config.InjectSettingsIntoContext(m.ctx, &config.Settings{Config: &config.Config{}}) } func TestPostgresTestSuite(t *testing.T) { diff --git a/lib/cdc/mysql/debezium_test.go b/lib/cdc/mysql/debezium_test.go index e36e44f97..c4e3a56f7 100644 --- a/lib/cdc/mysql/debezium_test.go +++ b/lib/cdc/mysql/debezium_test.go @@ -319,7 +319,7 @@ func (m *MySQLTestSuite) TestGetEventFromBytes() { col, isOk := cols.GetColumn("abcdef") assert.True(m.T(), isOk) - assert.Equal(m.T(), "abcdef", col.Name(nil)) + assert.Equal(m.T(), "abcdef", col.Name(m.ctx, nil)) for key := range evtData { if strings.Contains(key, constants.ArtiePrefix) { continue @@ -327,6 +327,6 @@ func (m *MySQLTestSuite) TestGetEventFromBytes() { col, isOk := cols.GetColumn(strings.ToLower(key)) assert.Equal(m.T(), true, isOk, key) - assert.Equal(m.T(), typing.Invalid, col.KindDetails, fmt.Sprintf("colName: %v, evtData key: %v", col.Name(nil), key)) + assert.Equal(m.T(), typing.Invalid, col.KindDetails, fmt.Sprintf("colName: %v, evtData key: %v", col.Name(m.ctx, nil), key)) } } diff --git a/lib/cdc/util/relational_event_test.go b/lib/cdc/util/relational_event_test.go index 3f6982eca..c4733ef10 100644 --- a/lib/cdc/util/relational_event_test.go +++ b/lib/cdc/util/relational_event_test.go @@ -12,7 +12,7 @@ import ( "github.com/stretchr/testify/assert" ) -func TestSource_GetOptionalSchema(t *testing.T) { +func (u *UtilTestSuite) TestSource_GetOptionalSchema() { ctx := context.Background() var schemaEventPayload SchemaEventPayload err := json.Unmarshal([]byte(`{ @@ -56,34 +56,34 @@ func TestSource_GetOptionalSchema(t *testing.T) { "payload": {} }`), &schemaEventPayload) - assert.NoError(t, err) + assert.NoError(u.T(), err) optionalSchema := schemaEventPayload.GetOptionalSchema(ctx) value, isOk := optionalSchema["last_modified"] - assert.True(t, isOk) - assert.Equal(t, value, typing.String) + assert.True(u.T(), isOk) + assert.Equal(u.T(), value, typing.String) cols := schemaEventPayload.GetColumns(ctx) - assert.Equal(t, 6, len(cols.GetColumns())) + assert.Equal(u.T(), 6, len(cols.GetColumns())) col, isOk := cols.GetColumn("boolean_column") - assert.True(t, isOk) + assert.True(u.T(), isOk) defaultVal, err := col.DefaultValue(nil) - assert.NoError(t, err) - assert.Equal(t, false, defaultVal) + assert.NoError(u.T(), err) + assert.Equal(u.T(), false, defaultVal) for _, _col := range cols.GetColumns() { // All the other columns do not have a default value. - if _col.Name(nil) != "boolean_column" { + if _col.Name(u.ctx, nil) != "boolean_column" { defaultVal, err = _col.DefaultValue(nil) - assert.NoError(t, err) - assert.Nil(t, defaultVal, _col.Name(nil)) + assert.NoError(u.T(), err) + assert.Nil(u.T(), defaultVal, _col.Name(u.ctx, nil)) } } // OptionalColumn does not pick up custom data types. _, isOk = optionalSchema["zoned_timestamp_column"] - assert.False(t, isOk) + assert.False(u.T(), isOk) } func TestSource_GetExecutionTime(t *testing.T) { diff --git a/lib/config/config.go b/lib/config/config.go index 1b1996a39..a44466e38 100644 --- a/lib/config/config.go +++ b/lib/config/config.go @@ -81,6 +81,10 @@ type Redshift struct { CredentialsClause string `yaml:"credentialsClause"` } +type SharedDestinationConfig struct { + UppercaseEscapedNames bool `yaml:"uppercaseEscapedNames"` +} + type Snowflake struct { AccountID string `yaml:"account"` Username string `yaml:"username"` @@ -128,6 +132,9 @@ type Config struct { Pubsub *Pubsub Kafka *Kafka + // Shared destination configuration + SharedDestinationConfig SharedDestinationConfig `yaml:"sharedDestinationConfig"` + // Supported destinations BigQuery *BigQuery `yaml:"bigquery"` Snowflake *Snowflake `yaml:"snowflake"` diff --git a/lib/dwh/ddl/ddl.go b/lib/dwh/ddl/ddl.go index 151a8ef52..cfe7aad9c 100644 --- a/lib/dwh/ddl/ddl.go +++ b/lib/dwh/ddl/ddl.go @@ -6,6 +6,8 @@ import ( "strings" "time" + "github.com/artie-labs/transfer/lib/sql" + "github.com/artie-labs/transfer/lib/typing/columns" "github.com/artie-labs/transfer/lib/config/constants" @@ -82,7 +84,7 @@ func AlterTable(ctx context.Context, args AlterTableArgs, cols ...columns.Column } if args.ColumnOp == constants.Delete { - if !args.Tc.ShouldDeleteColumn(ctx, col.Name(nil), args.CdcTime, args.ContainOtherOperations) { + if !args.Tc.ShouldDeleteColumn(ctx, col.Name(ctx, nil), args.CdcTime, args.ContainOtherOperations) { continue } } @@ -90,12 +92,12 @@ func AlterTable(ctx context.Context, args AlterTableArgs, cols ...columns.Column mutateCol = append(mutateCol, col) switch args.ColumnOp { case constants.Add: - colSQLParts = append(colSQLParts, fmt.Sprintf(`%s %s`, col.Name(&columns.NameArgs{ + colSQLParts = append(colSQLParts, fmt.Sprintf(`%s %s`, col.Name(ctx, &sql.NameArgs{ Escape: true, DestKind: args.Dwh.Label(), }), typing.KindToDWHType(col.KindDetails, args.Dwh.Label()))) case constants.Delete: - colSQLParts = append(colSQLParts, fmt.Sprintf(`%s`, col.Name(&columns.NameArgs{ + colSQLParts = append(colSQLParts, fmt.Sprintf(`%s`, col.Name(ctx, &sql.NameArgs{ Escape: true, DestKind: args.Dwh.Label(), }))) @@ -152,7 +154,7 @@ func AlterTable(ctx context.Context, args AlterTableArgs, cols ...columns.Column if err == nil { // createTable = false since it all successfully updated. - args.Tc.MutateInMemoryColumns(false, args.ColumnOp, mutateCol...) + args.Tc.MutateInMemoryColumns(ctx, false, args.ColumnOp, mutateCol...) } return nil diff --git a/lib/dwh/ddl/ddl_alter_delete_test.go b/lib/dwh/ddl/ddl_alter_delete_test.go index fed6e6667..25a74e858 100644 --- a/lib/dwh/ddl/ddl_alter_delete_test.go +++ b/lib/dwh/ddl/ddl_alter_delete_test.go @@ -29,9 +29,9 @@ func (d *DDLTestSuite) TestAlterDelete_Complete() { }, "tableName") originalColumnLength := len(cols.GetColumns()) - bqName := td.ToFqName(d.bqCtx, constants.BigQuery) - redshiftName := td.ToFqName(d.ctx, constants.Redshift) - snowflakeName := td.ToFqName(d.ctx, constants.Snowflake) + bqName := td.ToFqName(d.bqCtx, constants.BigQuery, true) + redshiftName := td.ToFqName(d.ctx, constants.Redshift, true) + snowflakeName := td.ToFqName(d.ctx, constants.Snowflake, true) // Testing 3 scenarios here // 1. DropDeletedColumns = false, ContainOtherOperations = true, don't delete ever. diff --git a/lib/dwh/ddl/ddl_bq_test.go b/lib/dwh/ddl/ddl_bq_test.go index 7c55dc3ae..347d312e1 100644 --- a/lib/dwh/ddl/ddl_bq_test.go +++ b/lib/dwh/ddl/ddl_bq_test.go @@ -6,6 +6,8 @@ import ( "fmt" "time" + artieSQL "github.com/artie-labs/transfer/lib/sql" + "github.com/artie-labs/transfer/lib/typing/columns" "github.com/stretchr/testify/assert" @@ -40,8 +42,7 @@ func (d *DDLTestSuite) TestAlterTableDropColumnsBigQuery() { cols.AddColumn(columns.NewColumn(colName, kindDetails)) } - fqName := td.ToFqName(d.bqCtx, constants.BigQuery) - + fqName := td.ToFqName(d.bqCtx, constants.BigQuery, true) originalColumnLength := len(cols.GetColumns()) d.bigQueryStore.GetConfigMap().AddTableToConfig(fqName, types.NewDwhTableConfig(&cols, nil, false, true)) tc := d.bigQueryStore.GetConfigMap().TableConfig(fqName) @@ -84,7 +85,7 @@ func (d *DDLTestSuite) TestAlterTableDropColumnsBigQuery() { err := ddl.AlterTable(d.bqCtx, alterTableArgs, column) query, _ := d.fakeBigQueryStore.ExecArgsForCall(callIdx) - assert.Equal(d.T(), fmt.Sprintf("ALTER TABLE %s drop COLUMN %s", fqName, column.Name(&columns.NameArgs{ + assert.Equal(d.T(), fmt.Sprintf("ALTER TABLE %s drop COLUMN %s", fqName, column.Name(d.ctx, &artieSQL.NameArgs{ Escape: true, DestKind: d.bigQueryStore.Label(), })), query) @@ -144,7 +145,7 @@ func (d *DDLTestSuite) TestAlterTableAddColumns() { err := ddl.AlterTable(d.bqCtx, alterTableArgs, col) assert.NoError(d.T(), err) query, _ := d.fakeBigQueryStore.ExecArgsForCall(callIdx) - assert.Equal(d.T(), fmt.Sprintf("ALTER TABLE %s %s COLUMN %s %s", fqName, constants.Add, col.Name(&columns.NameArgs{ + assert.Equal(d.T(), fmt.Sprintf("ALTER TABLE %s %s COLUMN %s %s", fqName, constants.Add, col.Name(d.ctx, &artieSQL.NameArgs{ Escape: true, DestKind: d.bigQueryStore.Label(), }), @@ -156,10 +157,10 @@ func (d *DDLTestSuite) TestAlterTableAddColumns() { assert.Equal(d.T(), newColsLen+existingColsLen, len(d.bigQueryStore.GetConfigMap().TableConfig(fqName).Columns().GetColumns()), d.bigQueryStore.GetConfigMap().TableConfig(fqName).Columns()) // Check by iterating over the columns for _, column := range d.bigQueryStore.GetConfigMap().TableConfig(fqName).Columns().GetColumns() { - existingCol, isOk := existingCols.GetColumn(column.Name(nil)) + existingCol, isOk := existingCols.GetColumn(column.Name(d.ctx, nil)) if !isOk { // Check new cols? - existingCol.KindDetails, isOk = newCols[column.Name(nil)] + existingCol.KindDetails, isOk = newCols[column.Name(d.ctx, nil)] } assert.True(d.T(), isOk) @@ -204,7 +205,7 @@ func (d *DDLTestSuite) TestAlterTableAddColumnsSomeAlreadyExist() { err := ddl.AlterTable(d.bqCtx, alterTableArgs, column) assert.NoError(d.T(), err) query, _ := d.fakeBigQueryStore.ExecArgsForCall(callIdx) - assert.Equal(d.T(), fmt.Sprintf("ALTER TABLE %s %s COLUMN %s %s", fqName, constants.Add, column.Name(&columns.NameArgs{ + assert.Equal(d.T(), fmt.Sprintf("ALTER TABLE %s %s COLUMN %s %s", fqName, constants.Add, column.Name(d.ctx, &artieSQL.NameArgs{ Escape: true, DestKind: d.bigQueryStore.Label(), }), @@ -216,7 +217,7 @@ func (d *DDLTestSuite) TestAlterTableAddColumnsSomeAlreadyExist() { assert.Equal(d.T(), existingColsLen, len(d.bigQueryStore.GetConfigMap().TableConfig(fqName).Columns().GetColumns()), d.bigQueryStore.GetConfigMap().TableConfig(fqName).Columns()) // Check by iterating over the columns for _, column := range d.bigQueryStore.GetConfigMap().TableConfig(fqName).Columns().GetColumns() { - existingCol, isOk := existingCols.GetColumn(column.Name(nil)) + existingCol, isOk := existingCols.GetColumn(column.Name(d.ctx, nil)) assert.True(d.T(), isOk) assert.Equal(d.T(), column.KindDetails, existingCol.KindDetails) } @@ -241,8 +242,7 @@ func (d *DDLTestSuite) TestAlterTableDropColumnsBigQuerySafety() { cols.AddColumn(columns.NewColumn(colName, kindDetails)) } - fqName := td.ToFqName(d.bqCtx, constants.BigQuery) - + fqName := td.ToFqName(d.bqCtx, constants.BigQuery, true) originalColumnLength := len(columnNameToKindDetailsMap) d.bigQueryStore.GetConfigMap().AddTableToConfig(fqName, types.NewDwhTableConfig(&cols, nil, false, false)) tc := d.bigQueryStore.GetConfigMap().TableConfig(fqName) diff --git a/lib/dwh/ddl/ddl_sflk_test.go b/lib/dwh/ddl/ddl_sflk_test.go index eb76b8212..bf956b07f 100644 --- a/lib/dwh/ddl/ddl_sflk_test.go +++ b/lib/dwh/ddl/ddl_sflk_test.go @@ -5,6 +5,8 @@ import ( "fmt" "time" + "github.com/artie-labs/transfer/lib/sql" + "github.com/artie-labs/transfer/lib/typing/columns" "github.com/stretchr/testify/assert" @@ -40,7 +42,7 @@ func (d *DDLTestSuite) TestAlterComplexObjects() { for i := 0; i < len(cols); i++ { execQuery, _ := d.fakeSnowflakeStagesStore.ExecArgsForCall(i) - assert.Equal(d.T(), fmt.Sprintf("ALTER TABLE %s add COLUMN %s %s", fqTable, cols[i].Name(&columns.NameArgs{ + assert.Equal(d.T(), fmt.Sprintf("ALTER TABLE %s add COLUMN %s %s", fqTable, cols[i].Name(d.ctx, &sql.NameArgs{ Escape: true, DestKind: d.snowflakeStagesStore.Label(), }), @@ -110,15 +112,15 @@ func (d *DDLTestSuite) TestAlterTableAdd() { for _, column := range tableConfig.Columns().GetColumns() { var found bool for _, expCol := range cols { - if found = column.Name(nil) == expCol.Name(nil); found { - assert.Equal(d.T(), column.KindDetails, expCol.KindDetails, fmt.Sprintf("wrong col kind, col: %s", column.Name(nil))) + if found = column.Name(d.ctx, nil) == expCol.Name(d.ctx, nil); found { + assert.Equal(d.T(), column.KindDetails, expCol.KindDetails, fmt.Sprintf("wrong col kind, col: %s", column.Name(d.ctx, nil))) break } } assert.True(d.T(), found, fmt.Sprintf("Col not found: %s, actual list: %v, expected list: %v", - column.Name(nil), tableConfig.Columns(), cols)) + column.Name(d.ctx, nil), tableConfig.Columns(), cols)) } } @@ -151,7 +153,7 @@ func (d *DDLTestSuite) TestAlterTableDeleteDryRun() { for col := range tableConfig.ReadOnlyColumnsToDelete() { var found bool for _, expCol := range cols { - if found = col == expCol.Name(nil); found { + if found = col == expCol.Name(d.ctx, nil); found { break } } @@ -162,7 +164,7 @@ func (d *DDLTestSuite) TestAlterTableDeleteDryRun() { } for i := 0; i < len(cols); i++ { - colToActuallyDelete := cols[i].Name(nil) + colToActuallyDelete := cols[i].Name(d.ctx, nil) // Now let's check the timestamp assert.True(d.T(), tableConfig.ReadOnlyColumnsToDelete()[colToActuallyDelete].After(time.Now())) // Now let's actually try to dial the time back, and it should actually try to delete. @@ -173,10 +175,11 @@ func (d *DDLTestSuite) TestAlterTableDeleteDryRun() { assert.Equal(d.T(), i+1, d.fakeSnowflakeStagesStore.ExecCallCount(), "tried to delete one column") execArg, _ := d.fakeSnowflakeStagesStore.ExecArgsForCall(i) - assert.Equal(d.T(), execArg, fmt.Sprintf("ALTER TABLE %s %s COLUMN %s", fqTable, constants.Delete, cols[i].Name(&columns.NameArgs{ - Escape: true, - DestKind: d.snowflakeStagesStore.Label(), - }))) + assert.Equal(d.T(), execArg, fmt.Sprintf("ALTER TABLE %s %s COLUMN %s", fqTable, constants.Delete, cols[i].Name(d.ctx, + &sql.NameArgs{ + Escape: true, + DestKind: d.snowflakeStagesStore.Label(), + }))) } } @@ -217,7 +220,7 @@ func (d *DDLTestSuite) TestAlterTableDelete() { for col := range tableConfig.ReadOnlyColumnsToDelete() { var found bool for _, expCol := range cols { - if found = col == expCol.Name(nil); found { + if found = col == expCol.Name(d.ctx, nil); found { break } } diff --git a/lib/dwh/ddl/ddl_suite_test.go b/lib/dwh/ddl/ddl_suite_test.go index 0bcf4207f..8bcd5326e 100644 --- a/lib/dwh/ddl/ddl_suite_test.go +++ b/lib/dwh/ddl/ddl_suite_test.go @@ -33,6 +33,7 @@ type DDLTestSuite struct { func (d *DDLTestSuite) SetupTest() { ctx := config.InjectSettingsIntoContext(context.Background(), &config.Settings{ VerboseLogging: true, + Config: &config.Config{}, }) bqCtx := config.InjectSettingsIntoContext(context.Background(), &config.Settings{ diff --git a/lib/dwh/dml/merge.go b/lib/dwh/dml/merge.go index 284eff230..b4899ff89 100644 --- a/lib/dwh/dml/merge.go +++ b/lib/dwh/dml/merge.go @@ -1,10 +1,13 @@ package dml import ( + "context" "errors" "fmt" "strings" + "github.com/artie-labs/transfer/lib/sql" + "github.com/artie-labs/transfer/lib/stringutil" "github.com/artie-labs/transfer/lib/typing/columns" @@ -56,7 +59,7 @@ func (m *MergeArgument) Valid() error { return nil } -func MergeStatementParts(m *MergeArgument) ([]string, error) { +func MergeStatementParts(ctx context.Context, m *MergeArgument) ([]string, error) { if err := m.Valid(); err != nil { return nil, err } @@ -84,7 +87,7 @@ func MergeStatementParts(m *MergeArgument) ([]string, error) { equalitySQLParts = append(equalitySQLParts, equalitySQL) } - cols := m.ColumnsToTypes.GetColumnsToUpdate(&columns.NameArgs{ + cols := m.ColumnsToTypes.GetColumnsToUpdate(ctx, &sql.NameArgs{ Escape: true, DestKind: m.DestKind, }) @@ -108,7 +111,7 @@ func MergeStatementParts(m *MergeArgument) ([]string, error) { // UPDATE fmt.Sprintf(`UPDATE %s as c SET %s FROM %s as cc WHERE %s%s;`, // UPDATE table set col1 = cc. col1 - m.FqTableName, columns.ColumnsUpdateQuery(cols, m.ColumnsToTypes, m.DestKind), + m.FqTableName, columns.ColumnsUpdateQuery(ctx, cols, m.ColumnsToTypes, m.DestKind), // FROM table (temp) WHERE join on PK(s) m.SubQuery, strings.Join(equalitySQLParts, " and "), idempotentClause, ), @@ -152,7 +155,7 @@ func MergeStatementParts(m *MergeArgument) ([]string, error) { // UPDATE fmt.Sprintf(`UPDATE %s as c SET %s FROM %s as cc WHERE %s%s AND COALESCE(cc.%s, false) = false;`, // UPDATE table set col1 = cc. col1 - m.FqTableName, columns.ColumnsUpdateQuery(cols, m.ColumnsToTypes, m.DestKind), + m.FqTableName, columns.ColumnsUpdateQuery(ctx, cols, m.ColumnsToTypes, m.DestKind), // FROM staging WHERE join on PK(s) m.SubQuery, strings.Join(equalitySQLParts, " and "), idempotentClause, constants.DeleteColumnMarker, ), @@ -170,7 +173,7 @@ func MergeStatementParts(m *MergeArgument) ([]string, error) { }, nil } -func MergeStatement(m *MergeArgument) (string, error) { +func MergeStatement(ctx context.Context, m *MergeArgument) (string, error) { if err := m.Valid(); err != nil { return "", err } @@ -209,7 +212,7 @@ func MergeStatement(m *MergeArgument) (string, error) { subQuery = m.SubQuery } - cols := m.ColumnsToTypes.GetColumnsToUpdate(&columns.NameArgs{ + cols := m.ColumnsToTypes.GetColumnsToUpdate(ctx, &sql.NameArgs{ Escape: true, DestKind: m.DestKind, }) @@ -229,7 +232,7 @@ func MergeStatement(m *MergeArgument) (string, error) { ); `, m.FqTableName, subQuery, strings.Join(equalitySQLParts, " and "), // Update + Soft Deletion - idempotentClause, columns.ColumnsUpdateQuery(cols, m.ColumnsToTypes, m.DestKind), + idempotentClause, columns.ColumnsUpdateQuery(ctx, cols, m.ColumnsToTypes, m.DestKind), // Insert constants.DeleteColumnMarker, strings.Join(cols, ","), array.StringsJoinAddPrefix(array.StringsJoinAddPrefixArgs{ @@ -270,7 +273,7 @@ func MergeStatement(m *MergeArgument) (string, error) { // Delete constants.DeleteColumnMarker, // Update - constants.DeleteColumnMarker, idempotentClause, columns.ColumnsUpdateQuery(cols, m.ColumnsToTypes, m.DestKind), + constants.DeleteColumnMarker, idempotentClause, columns.ColumnsUpdateQuery(ctx, cols, m.ColumnsToTypes, m.DestKind), // Insert constants.DeleteColumnMarker, strings.Join(cols, ","), array.StringsJoinAddPrefix(array.StringsJoinAddPrefixArgs{ diff --git a/lib/dwh/dml/merge_bigquery_test.go b/lib/dwh/dml/merge_bigquery_test.go index 3f364fa09..63d31350a 100644 --- a/lib/dwh/dml/merge_bigquery_test.go +++ b/lib/dwh/dml/merge_bigquery_test.go @@ -1,18 +1,13 @@ package dml import ( - "testing" - - "github.com/artie-labs/transfer/lib/typing/columns" - "github.com/artie-labs/transfer/lib/config/constants" - - "github.com/stretchr/testify/assert" - "github.com/artie-labs/transfer/lib/typing" + "github.com/artie-labs/transfer/lib/typing/columns" + "github.com/stretchr/testify/assert" ) -func TestMergeStatement_TempTable(t *testing.T) { +func (m *MergeTestSuite) TestMergeStatement_TempTable() { var cols columns.Columns cols.AddColumn(columns.NewColumn("order_id", typing.Integer)) cols.AddColumn(columns.NewColumn("name", typing.String)) @@ -21,19 +16,19 @@ func TestMergeStatement_TempTable(t *testing.T) { mergeArg := &MergeArgument{ FqTableName: "customers.orders", SubQuery: "customers.orders_tmp", - PrimaryKeys: []columns.Wrapper{columns.NewWrapper(columns.NewColumn("order_id", typing.Invalid), nil)}, + PrimaryKeys: []columns.Wrapper{columns.NewWrapper(m.ctx, columns.NewColumn("order_id", typing.Invalid), nil)}, ColumnsToTypes: cols, DestKind: constants.BigQuery, SoftDelete: false, } - mergeSQL, err := MergeStatement(mergeArg) - assert.NoError(t, err) + mergeSQL, err := MergeStatement(m.ctx, mergeArg) + assert.NoError(m.T(), err) - assert.Contains(t, mergeSQL, "MERGE INTO customers.orders c using customers.orders_tmp as cc on c.order_id = cc.order_id", mergeSQL) + assert.Contains(m.T(), mergeSQL, "MERGE INTO customers.orders c using customers.orders_tmp as cc on c.order_id = cc.order_id", mergeSQL) } -func TestMergeStatement_JSONKey(t *testing.T) { +func (m *MergeTestSuite) TestMergeStatement_JSONKey() { var cols columns.Columns cols.AddColumn(columns.NewColumn("order_oid", typing.Struct)) cols.AddColumn(columns.NewColumn("name", typing.String)) @@ -42,13 +37,13 @@ func TestMergeStatement_JSONKey(t *testing.T) { mergeArg := &MergeArgument{ FqTableName: "customers.orders", SubQuery: "customers.orders_tmp", - PrimaryKeys: []columns.Wrapper{columns.NewWrapper(columns.NewColumn("order_oid", typing.Invalid), nil)}, + PrimaryKeys: []columns.Wrapper{columns.NewWrapper(m.ctx, columns.NewColumn("order_oid", typing.Invalid), nil)}, ColumnsToTypes: cols, DestKind: constants.BigQuery, SoftDelete: false, } - mergeSQL, err := MergeStatement(mergeArg) - assert.NoError(t, err) - assert.Contains(t, mergeSQL, "MERGE INTO customers.orders c using customers.orders_tmp as cc on TO_JSON_STRING(c.order_oid) = TO_JSON_STRING(cc.order_oid)", mergeSQL) + mergeSQL, err := MergeStatement(m.ctx, mergeArg) + assert.NoError(m.T(), err) + assert.Contains(m.T(), mergeSQL, "MERGE INTO customers.orders c using customers.orders_tmp as cc on TO_JSON_STRING(c.order_oid) = TO_JSON_STRING(cc.order_oid)", mergeSQL) } diff --git a/lib/dwh/dml/merge_parts_test.go b/lib/dwh/dml/merge_parts_test.go index 0bd45fea3..0c40ec684 100644 --- a/lib/dwh/dml/merge_parts_test.go +++ b/lib/dwh/dml/merge_parts_test.go @@ -1,9 +1,10 @@ package dml import ( - "testing" + "context" "github.com/artie-labs/transfer/lib/config/constants" + "github.com/artie-labs/transfer/lib/sql" "github.com/artie-labs/transfer/lib/typing" @@ -11,15 +12,15 @@ import ( "github.com/stretchr/testify/assert" ) -func TestMergeStatementPartsValidation(t *testing.T) { +func (m *MergeTestSuite) TestMergeStatementPartsValidation() { for _, arg := range []*MergeArgument{ {DestKind: constants.Snowflake}, {DestKind: constants.SnowflakeStages}, {DestKind: constants.BigQuery}, } { - parts, err := MergeStatementParts(arg) - assert.Error(t, err) - assert.Nil(t, parts) + parts, err := MergeStatementParts(m.ctx, arg) + assert.Error(m.T(), err) + assert.Nil(m.T(), parts) } } @@ -31,7 +32,7 @@ type result struct { // getBasicColumnsForTest - will return you all the columns within `result` that are needed for tests. // * In here, we'll return if compositeKey=false - id (pk), email, first_name, last_name, created_at, toast_text (TOAST-able) // * Else if compositeKey=true - id(pk), email (pk), first_name, last_name, created_at, toast_text (TOAST-able) -func getBasicColumnsForTest(compositeKey bool) result { +func getBasicColumnsForTest(ctx context.Context, compositeKey bool) result { idCol := columns.NewColumn("id", typing.Float) emailCol := columns.NewColumn("email", typing.String) textToastCol := columns.NewColumn("toast_text", typing.String) @@ -47,13 +48,13 @@ func getBasicColumnsForTest(compositeKey bool) result { cols.AddColumn(columns.NewColumn(constants.DeleteColumnMarker, typing.Boolean)) var pks []columns.Wrapper - pks = append(pks, columns.NewWrapper(idCol, &columns.NameArgs{ + pks = append(pks, columns.NewWrapper(ctx, idCol, &sql.NameArgs{ Escape: true, DestKind: constants.Redshift, })) if compositeKey { - pks = append(pks, columns.NewWrapper(emailCol, &columns.NameArgs{ + pks = append(pks, columns.NewWrapper(ctx, emailCol, &sql.NameArgs{ Escape: true, DestKind: constants.Redshift, })) @@ -65,11 +66,11 @@ func getBasicColumnsForTest(compositeKey bool) result { } } -func TestMergeStatementPartsSoftDelete(t *testing.T) { +func (m *MergeTestSuite) TestMergeStatementPartsSoftDelete() { fqTableName := "public.tableName" tempTableName := "public.tableName__temp" - res := getBasicColumnsForTest(false) - m := &MergeArgument{ + res := getBasicColumnsForTest(m.ctx, false) + mergeArg := &MergeArgument{ FqTableName: fqTableName, SubQuery: tempTableName, PrimaryKeys: res.PrimaryKeys, @@ -77,35 +78,34 @@ func TestMergeStatementPartsSoftDelete(t *testing.T) { DestKind: constants.Redshift, SoftDelete: true, } + parts, err := MergeStatementParts(m.ctx, mergeArg) + assert.NoError(m.T(), err) + assert.Equal(m.T(), 2, len(parts)) - parts, err := MergeStatementParts(m) - assert.NoError(t, err) - assert.Equal(t, 2, len(parts)) - - assert.Equal(t, + assert.Equal(m.T(), `INSERT INTO public.tableName (id,email,first_name,last_name,created_at,toast_text,__artie_delete) SELECT cc.id,cc.email,cc.first_name,cc.last_name,cc.created_at,cc.toast_text,cc.__artie_delete FROM public.tableName__temp as cc LEFT JOIN public.tableName as c on c.id = cc.id WHERE c.id IS NULL;`, parts[0]) - assert.Equal(t, + assert.Equal(m.T(), `UPDATE public.tableName as c SET id=cc.id,email=cc.email,first_name=cc.first_name,last_name=cc.last_name,created_at=cc.created_at,toast_text= CASE WHEN cc.toast_text != '__debezium_unavailable_value' THEN cc.toast_text ELSE c.toast_text END,__artie_delete=cc.__artie_delete FROM public.tableName__temp as cc WHERE c.id = cc.id;`, parts[1]) - m.IdempotentKey = "created_at" - parts, err = MergeStatementParts(m) + mergeArg.IdempotentKey = "created_at" + parts, err = MergeStatementParts(m.ctx, mergeArg) // Parts[0] for insertion should be identical - assert.Equal(t, + assert.Equal(m.T(), `INSERT INTO public.tableName (id,email,first_name,last_name,created_at,toast_text,__artie_delete) SELECT cc.id,cc.email,cc.first_name,cc.last_name,cc.created_at,cc.toast_text,cc.__artie_delete FROM public.tableName__temp as cc LEFT JOIN public.tableName as c on c.id = cc.id WHERE c.id IS NULL;`, parts[0]) // Parts[1] where we're doing UPDATES will have idempotency key. - assert.Equal(t, + assert.Equal(m.T(), `UPDATE public.tableName as c SET id=cc.id,email=cc.email,first_name=cc.first_name,last_name=cc.last_name,created_at=cc.created_at,toast_text= CASE WHEN cc.toast_text != '__debezium_unavailable_value' THEN cc.toast_text ELSE c.toast_text END,__artie_delete=cc.__artie_delete FROM public.tableName__temp as cc WHERE c.id = cc.id AND cc.created_at >= c.created_at;`, parts[1]) } -func TestMergeStatementPartsSoftDeleteComposite(t *testing.T) { +func (m *MergeTestSuite) TestMergeStatementPartsSoftDeleteComposite() { fqTableName := "public.tableName" tempTableName := "public.tableName__temp" - res := getBasicColumnsForTest(true) - m := &MergeArgument{ + res := getBasicColumnsForTest(m.ctx, true) + mergeArg := &MergeArgument{ FqTableName: fqTableName, SubQuery: tempTableName, PrimaryKeys: res.PrimaryKeys, @@ -114,37 +114,37 @@ func TestMergeStatementPartsSoftDeleteComposite(t *testing.T) { SoftDelete: true, } - parts, err := MergeStatementParts(m) - assert.NoError(t, err) - assert.Equal(t, 2, len(parts)) + parts, err := MergeStatementParts(m.ctx, mergeArg) + assert.NoError(m.T(), err) + assert.Equal(m.T(), 2, len(parts)) - assert.Equal(t, + assert.Equal(m.T(), `INSERT INTO public.tableName (id,email,first_name,last_name,created_at,toast_text,__artie_delete) SELECT cc.id,cc.email,cc.first_name,cc.last_name,cc.created_at,cc.toast_text,cc.__artie_delete FROM public.tableName__temp as cc LEFT JOIN public.tableName as c on c.id = cc.id and c.email = cc.email WHERE c.id IS NULL;`, parts[0]) - assert.Equal(t, + assert.Equal(m.T(), `UPDATE public.tableName as c SET id=cc.id,email=cc.email,first_name=cc.first_name,last_name=cc.last_name,created_at=cc.created_at,toast_text= CASE WHEN cc.toast_text != '__debezium_unavailable_value' THEN cc.toast_text ELSE c.toast_text END,__artie_delete=cc.__artie_delete FROM public.tableName__temp as cc WHERE c.id = cc.id and c.email = cc.email;`, parts[1]) - m.IdempotentKey = "created_at" - parts, err = MergeStatementParts(m) + mergeArg.IdempotentKey = "created_at" + parts, err = MergeStatementParts(m.ctx, mergeArg) // Parts[0] for insertion should be identical - assert.Equal(t, + assert.Equal(m.T(), `INSERT INTO public.tableName (id,email,first_name,last_name,created_at,toast_text,__artie_delete) SELECT cc.id,cc.email,cc.first_name,cc.last_name,cc.created_at,cc.toast_text,cc.__artie_delete FROM public.tableName__temp as cc LEFT JOIN public.tableName as c on c.id = cc.id and c.email = cc.email WHERE c.id IS NULL;`, parts[0]) // Parts[1] where we're doing UPDATES will have idempotency key. - assert.Equal(t, + assert.Equal(m.T(), `UPDATE public.tableName as c SET id=cc.id,email=cc.email,first_name=cc.first_name,last_name=cc.last_name,created_at=cc.created_at,toast_text= CASE WHEN cc.toast_text != '__debezium_unavailable_value' THEN cc.toast_text ELSE c.toast_text END,__artie_delete=cc.__artie_delete FROM public.tableName__temp as cc WHERE c.id = cc.id and c.email = cc.email AND cc.created_at >= c.created_at;`, parts[1]) } -func TestMergeStatementParts(t *testing.T) { +func (m *MergeTestSuite) TestMergeStatementParts() { // Biggest difference with this test are: // 1. We are not saving `__artie_deleted` column // 2. There are 3 SQL queries (INSERT, UPDATE and DELETE) fqTableName := "public.tableName" tempTableName := "public.tableName__temp" - res := getBasicColumnsForTest(false) - m := &MergeArgument{ + res := getBasicColumnsForTest(m.ctx, false) + mergeArg := &MergeArgument{ FqTableName: fqTableName, SubQuery: tempTableName, PrimaryKeys: res.PrimaryKeys, @@ -152,23 +152,23 @@ func TestMergeStatementParts(t *testing.T) { DestKind: constants.Redshift, } - parts, err := MergeStatementParts(m) - assert.NoError(t, err) - assert.Equal(t, 3, len(parts)) + parts, err := MergeStatementParts(m.ctx, mergeArg) + assert.NoError(m.T(), err) + assert.Equal(m.T(), 3, len(parts)) - assert.Equal(t, + assert.Equal(m.T(), `INSERT INTO public.tableName (id,email,first_name,last_name,created_at,toast_text) SELECT cc.id,cc.email,cc.first_name,cc.last_name,cc.created_at,cc.toast_text FROM public.tableName__temp as cc LEFT JOIN public.tableName as c on c.id = cc.id WHERE c.id IS NULL;`, parts[0]) - assert.Equal(t, + assert.Equal(m.T(), `UPDATE public.tableName as c SET id=cc.id,email=cc.email,first_name=cc.first_name,last_name=cc.last_name,created_at=cc.created_at,toast_text= CASE WHEN cc.toast_text != '__debezium_unavailable_value' THEN cc.toast_text ELSE c.toast_text END FROM public.tableName__temp as cc WHERE c.id = cc.id AND COALESCE(cc.__artie_delete, false) = false;`, parts[1]) - assert.Equal(t, + assert.Equal(m.T(), `DELETE FROM public.tableName WHERE (id) IN (SELECT cc.id FROM public.tableName__temp as cc WHERE cc.__artie_delete = true);`, parts[2]) - m = &MergeArgument{ + mergeArg = &MergeArgument{ FqTableName: fqTableName, SubQuery: tempTableName, PrimaryKeys: res.PrimaryKeys, @@ -177,28 +177,28 @@ func TestMergeStatementParts(t *testing.T) { IdempotentKey: "created_at", } - parts, err = MergeStatementParts(m) - assert.NoError(t, err) - assert.Equal(t, 3, len(parts)) + parts, err = MergeStatementParts(m.ctx, mergeArg) + assert.NoError(m.T(), err) + assert.Equal(m.T(), 3, len(parts)) - assert.Equal(t, + assert.Equal(m.T(), `INSERT INTO public.tableName (id,email,first_name,last_name,created_at,toast_text) SELECT cc.id,cc.email,cc.first_name,cc.last_name,cc.created_at,cc.toast_text FROM public.tableName__temp as cc LEFT JOIN public.tableName as c on c.id = cc.id WHERE c.id IS NULL;`, parts[0]) - assert.Equal(t, + assert.Equal(m.T(), `UPDATE public.tableName as c SET id=cc.id,email=cc.email,first_name=cc.first_name,last_name=cc.last_name,created_at=cc.created_at,toast_text= CASE WHEN cc.toast_text != '__debezium_unavailable_value' THEN cc.toast_text ELSE c.toast_text END FROM public.tableName__temp as cc WHERE c.id = cc.id AND cc.created_at >= c.created_at AND COALESCE(cc.__artie_delete, false) = false;`, parts[1]) - assert.Equal(t, + assert.Equal(m.T(), `DELETE FROM public.tableName WHERE (id) IN (SELECT cc.id FROM public.tableName__temp as cc WHERE cc.__artie_delete = true);`, parts[2]) } -func TestMergeStatementPartsCompositeKey(t *testing.T) { +func (m *MergeTestSuite) TestMergeStatementPartsCompositeKey() { fqTableName := "public.tableName" tempTableName := "public.tableName__temp" - res := getBasicColumnsForTest(true) - m := &MergeArgument{ + res := getBasicColumnsForTest(m.ctx, true) + mergeArg := &MergeArgument{ FqTableName: fqTableName, SubQuery: tempTableName, PrimaryKeys: res.PrimaryKeys, @@ -206,23 +206,23 @@ func TestMergeStatementPartsCompositeKey(t *testing.T) { DestKind: constants.Redshift, } - parts, err := MergeStatementParts(m) - assert.NoError(t, err) - assert.Equal(t, 3, len(parts)) + parts, err := MergeStatementParts(m.ctx, mergeArg) + assert.NoError(m.T(), err) + assert.Equal(m.T(), 3, len(parts)) - assert.Equal(t, + assert.Equal(m.T(), `INSERT INTO public.tableName (id,email,first_name,last_name,created_at,toast_text) SELECT cc.id,cc.email,cc.first_name,cc.last_name,cc.created_at,cc.toast_text FROM public.tableName__temp as cc LEFT JOIN public.tableName as c on c.id = cc.id and c.email = cc.email WHERE c.id IS NULL;`, parts[0]) - assert.Equal(t, + assert.Equal(m.T(), `UPDATE public.tableName as c SET id=cc.id,email=cc.email,first_name=cc.first_name,last_name=cc.last_name,created_at=cc.created_at,toast_text= CASE WHEN cc.toast_text != '__debezium_unavailable_value' THEN cc.toast_text ELSE c.toast_text END FROM public.tableName__temp as cc WHERE c.id = cc.id and c.email = cc.email AND COALESCE(cc.__artie_delete, false) = false;`, parts[1]) - assert.Equal(t, + assert.Equal(m.T(), `DELETE FROM public.tableName WHERE (id,email) IN (SELECT cc.id,cc.email FROM public.tableName__temp as cc WHERE cc.__artie_delete = true);`, parts[2]) - m = &MergeArgument{ + mergeArg = &MergeArgument{ FqTableName: fqTableName, SubQuery: tempTableName, PrimaryKeys: res.PrimaryKeys, @@ -231,19 +231,19 @@ func TestMergeStatementPartsCompositeKey(t *testing.T) { IdempotentKey: "created_at", } - parts, err = MergeStatementParts(m) - assert.NoError(t, err) - assert.Equal(t, 3, len(parts)) + parts, err = MergeStatementParts(m.ctx, mergeArg) + assert.NoError(m.T(), err) + assert.Equal(m.T(), 3, len(parts)) - assert.Equal(t, + assert.Equal(m.T(), `INSERT INTO public.tableName (id,email,first_name,last_name,created_at,toast_text) SELECT cc.id,cc.email,cc.first_name,cc.last_name,cc.created_at,cc.toast_text FROM public.tableName__temp as cc LEFT JOIN public.tableName as c on c.id = cc.id and c.email = cc.email WHERE c.id IS NULL;`, parts[0]) - assert.Equal(t, + assert.Equal(m.T(), `UPDATE public.tableName as c SET id=cc.id,email=cc.email,first_name=cc.first_name,last_name=cc.last_name,created_at=cc.created_at,toast_text= CASE WHEN cc.toast_text != '__debezium_unavailable_value' THEN cc.toast_text ELSE c.toast_text END FROM public.tableName__temp as cc WHERE c.id = cc.id and c.email = cc.email AND cc.created_at >= c.created_at AND COALESCE(cc.__artie_delete, false) = false;`, parts[1]) - assert.Equal(t, + assert.Equal(m.T(), `DELETE FROM public.tableName WHERE (id,email) IN (SELECT cc.id,cc.email FROM public.tableName__temp as cc WHERE cc.__artie_delete = true);`, parts[2]) } diff --git a/lib/dwh/dml/merge_suite_test.go b/lib/dwh/dml/merge_suite_test.go new file mode 100644 index 000000000..f0c2b9887 --- /dev/null +++ b/lib/dwh/dml/merge_suite_test.go @@ -0,0 +1,23 @@ +package dml + +import ( + "context" + "testing" + + "github.com/artie-labs/transfer/lib/config" + "github.com/stretchr/testify/suite" +) + +type MergeTestSuite struct { + suite.Suite + ctx context.Context +} + +func (m *MergeTestSuite) SetupTest() { + m.ctx = context.Background() + m.ctx = config.InjectSettingsIntoContext(m.ctx, &config.Settings{Config: &config.Config{}}) +} + +func TestMergeTestSuite(t *testing.T) { + suite.Run(t, new(MergeTestSuite)) +} diff --git a/lib/dwh/dml/merge_test.go b/lib/dwh/dml/merge_test.go index bbf000c16..3e4639afe 100644 --- a/lib/dwh/dml/merge_test.go +++ b/lib/dwh/dml/merge_test.go @@ -3,9 +3,10 @@ package dml import ( "fmt" "strings" - "testing" "time" + "github.com/artie-labs/transfer/lib/sql" + "github.com/artie-labs/transfer/lib/typing/columns" "github.com/artie-labs/transfer/lib/config/constants" @@ -13,7 +14,7 @@ import ( "github.com/stretchr/testify/assert" ) -func TestMergeStatementSoftDelete(t *testing.T) { +func (m *MergeTestSuite) TestMergeStatementSoftDelete() { // No idempotent key fqTable := "database.schema.table" cols := []string{ @@ -38,26 +39,26 @@ func TestMergeStatementSoftDelete(t *testing.T) { _cols.AddColumn(columns.NewColumn(constants.DeleteColumnMarker, typing.Boolean)) for _, idempotentKey := range []string{"", "updated_at"} { - mergeSQL, err := MergeStatement(&MergeArgument{ + mergeSQL, err := MergeStatement(m.ctx, &MergeArgument{ FqTableName: fqTable, SubQuery: subQuery, IdempotentKey: idempotentKey, - PrimaryKeys: []columns.Wrapper{columns.NewWrapper(columns.NewColumn("id", typing.Invalid), nil)}, + PrimaryKeys: []columns.Wrapper{columns.NewWrapper(m.ctx, columns.NewColumn("id", typing.Invalid), nil)}, ColumnsToTypes: _cols, DestKind: constants.Snowflake, SoftDelete: true, }) - assert.NoError(t, err) - assert.True(t, strings.Contains(mergeSQL, fmt.Sprintf("MERGE INTO %s", fqTable)), mergeSQL) + assert.NoError(m.T(), err) + assert.True(m.T(), strings.Contains(mergeSQL, fmt.Sprintf("MERGE INTO %s", fqTable)), mergeSQL) // Soft deletion flag being passed. - assert.True(t, strings.Contains(mergeSQL, fmt.Sprintf("%s=cc.%s", constants.DeleteColumnMarker, constants.DeleteColumnMarker)), mergeSQL) + assert.True(m.T(), strings.Contains(mergeSQL, fmt.Sprintf("%s=cc.%s", constants.DeleteColumnMarker, constants.DeleteColumnMarker)), mergeSQL) - assert.Equal(t, len(idempotentKey) > 0, strings.Contains(mergeSQL, fmt.Sprintf("cc.%s >= c.%s", "updated_at", "updated_at"))) + assert.Equal(m.T(), len(idempotentKey) > 0, strings.Contains(mergeSQL, fmt.Sprintf("cc.%s >= c.%s", "updated_at", "updated_at"))) } } -func TestMergeStatement(t *testing.T) { +func (m *MergeTestSuite) TestMergeStatement() { // No idempotent key fqTable := "database.schema.table" colToTypes := map[string]typing.KindDetails{ @@ -84,29 +85,29 @@ func TestMergeStatement(t *testing.T) { // select cc.foo, cc.bar from (values (12, 34), (44, 55)) as cc(foo, bar); subQuery := fmt.Sprintf("SELECT %s from (values %s) as %s(%s)", strings.Join(cols, ","), strings.Join(tableValues, ","), "_tbl", strings.Join(cols, ",")) - mergeSQL, err := MergeStatement(&MergeArgument{ + mergeSQL, err := MergeStatement(m.ctx, &MergeArgument{ FqTableName: fqTable, SubQuery: subQuery, IdempotentKey: "", - PrimaryKeys: []columns.Wrapper{columns.NewWrapper(columns.NewColumn("id", typing.Invalid), nil)}, + PrimaryKeys: []columns.Wrapper{columns.NewWrapper(m.ctx, columns.NewColumn("id", typing.Invalid), nil)}, ColumnsToTypes: _cols, DestKind: constants.Snowflake, SoftDelete: false, }) - assert.NoError(t, err) - assert.True(t, strings.Contains(mergeSQL, fmt.Sprintf("MERGE INTO %s", fqTable)), mergeSQL) - assert.False(t, strings.Contains(mergeSQL, fmt.Sprintf("cc.%s >= c.%s", "updated_at", "updated_at")), fmt.Sprintf("Idempotency key: %s", mergeSQL)) + assert.NoError(m.T(), err) + assert.True(m.T(), strings.Contains(mergeSQL, fmt.Sprintf("MERGE INTO %s", fqTable)), mergeSQL) + assert.False(m.T(), strings.Contains(mergeSQL, fmt.Sprintf("cc.%s >= c.%s", "updated_at", "updated_at")), fmt.Sprintf("Idempotency key: %s", mergeSQL)) // Check primary keys clause - assert.True(t, strings.Contains(mergeSQL, "as cc on c.id = cc.id"), mergeSQL) + assert.True(m.T(), strings.Contains(mergeSQL, "as cc on c.id = cc.id"), mergeSQL) // Check setting for update - assert.True(t, strings.Contains(mergeSQL, `SET id=cc.id,bar=cc.bar,updated_at=cc.updated_at,"start"=cc."start"`), mergeSQL) + assert.True(m.T(), strings.Contains(mergeSQL, `SET id=cc.id,bar=cc.bar,updated_at=cc.updated_at,"start"=cc."start"`), mergeSQL) // Check for INSERT - assert.True(t, strings.Contains(mergeSQL, `id,bar,updated_at,"start"`), mergeSQL) - assert.True(t, strings.Contains(mergeSQL, `cc.id,cc.bar,cc.updated_at,cc."start"`), mergeSQL) + assert.True(m.T(), strings.Contains(mergeSQL, `id,bar,updated_at,"start"`), mergeSQL) + assert.True(m.T(), strings.Contains(mergeSQL, `cc.id,cc.bar,cc.updated_at,cc."start"`), mergeSQL) } -func TestMergeStatementIdempotentKey(t *testing.T) { +func (m *MergeTestSuite) TestMergeStatementIdempotentKey() { fqTable := "database.schema.table" cols := []string{ "id", @@ -129,21 +130,21 @@ func TestMergeStatementIdempotentKey(t *testing.T) { _cols.AddColumn(columns.NewColumn("id", typing.String)) _cols.AddColumn(columns.NewColumn(constants.DeleteColumnMarker, typing.Boolean)) - mergeSQL, err := MergeStatement(&MergeArgument{ + mergeSQL, err := MergeStatement(m.ctx, &MergeArgument{ FqTableName: fqTable, SubQuery: subQuery, IdempotentKey: "updated_at", - PrimaryKeys: []columns.Wrapper{columns.NewWrapper(columns.NewColumn("id", typing.Invalid), nil)}, + PrimaryKeys: []columns.Wrapper{columns.NewWrapper(m.ctx, columns.NewColumn("id", typing.Invalid), nil)}, ColumnsToTypes: _cols, DestKind: constants.Snowflake, SoftDelete: false, }) - assert.NoError(t, err) - assert.True(t, strings.Contains(mergeSQL, fmt.Sprintf("MERGE INTO %s", fqTable)), mergeSQL) - assert.True(t, strings.Contains(mergeSQL, fmt.Sprintf("cc.%s >= c.%s", "updated_at", "updated_at")), fmt.Sprintf("Idempotency key: %s", mergeSQL)) + assert.NoError(m.T(), err) + assert.True(m.T(), strings.Contains(mergeSQL, fmt.Sprintf("MERGE INTO %s", fqTable)), mergeSQL) + assert.True(m.T(), strings.Contains(mergeSQL, fmt.Sprintf("cc.%s >= c.%s", "updated_at", "updated_at")), fmt.Sprintf("Idempotency key: %s", mergeSQL)) } -func TestMergeStatementCompositeKey(t *testing.T) { +func (m *MergeTestSuite) TestMergeStatementCompositeKey() { fqTable := "database.schema.table" cols := []string{ "id", @@ -168,24 +169,23 @@ func TestMergeStatementCompositeKey(t *testing.T) { _cols.AddColumn(columns.NewColumn("another_id", typing.String)) _cols.AddColumn(columns.NewColumn(constants.DeleteColumnMarker, typing.Boolean)) - mergeSQL, err := MergeStatement(&MergeArgument{ + mergeSQL, err := MergeStatement(m.ctx, &MergeArgument{ FqTableName: fqTable, SubQuery: subQuery, IdempotentKey: "updated_at", - PrimaryKeys: []columns.Wrapper{columns.NewWrapper(columns.NewColumn("id", typing.Invalid), nil), - columns.NewWrapper(columns.NewColumn("another_id", typing.Invalid), nil)}, + PrimaryKeys: []columns.Wrapper{columns.NewWrapper(m.ctx, columns.NewColumn("id", typing.Invalid), nil), + columns.NewWrapper(m.ctx, columns.NewColumn("another_id", typing.Invalid), nil)}, ColumnsToTypes: _cols, DestKind: constants.Snowflake, SoftDelete: false, }) - assert.NoError(t, err) - assert.True(t, strings.Contains(mergeSQL, fmt.Sprintf("MERGE INTO %s", fqTable)), mergeSQL) - assert.True(t, strings.Contains(mergeSQL, fmt.Sprintf("cc.%s >= c.%s", "updated_at", "updated_at")), fmt.Sprintf("Idempotency key: %s", mergeSQL)) - - assert.True(t, strings.Contains(mergeSQL, fmt.Sprintf("cc on c.id = cc.id and c.another_id = cc.another_id"))) + assert.NoError(m.T(), err) + assert.True(m.T(), strings.Contains(mergeSQL, fmt.Sprintf("MERGE INTO %s", fqTable)), mergeSQL) + assert.True(m.T(), strings.Contains(mergeSQL, fmt.Sprintf("cc.%s >= c.%s", "updated_at", "updated_at")), fmt.Sprintf("Idempotency key: %s", mergeSQL)) + assert.True(m.T(), strings.Contains(mergeSQL, fmt.Sprintf("cc on c.id = cc.id and c.another_id = cc.another_id"))) } -func TestMergeStatementEscapePrimaryKeys(t *testing.T) { +func (m *MergeTestSuite) TestMergeStatementEscapePrimaryKeys() { // No idempotent key fqTable := "database.schema.table" colToTypes := map[string]typing.KindDetails{ @@ -212,16 +212,16 @@ func TestMergeStatementEscapePrimaryKeys(t *testing.T) { // select cc.foo, cc.bar from (values (12, 34), (44, 55)) as cc(foo, bar); subQuery := fmt.Sprintf("SELECT %s from (values %s) as %s(%s)", strings.Join(cols, ","), strings.Join(tableValues, ","), "_tbl", strings.Join(cols, ",")) - mergeSQL, err := MergeStatement(&MergeArgument{ + mergeSQL, err := MergeStatement(m.ctx, &MergeArgument{ FqTableName: fqTable, SubQuery: subQuery, IdempotentKey: "", PrimaryKeys: []columns.Wrapper{ - columns.NewWrapper(columns.NewColumn("id", typing.Invalid), &columns.NameArgs{ + columns.NewWrapper(m.ctx, columns.NewColumn("id", typing.Invalid), &sql.NameArgs{ Escape: true, DestKind: constants.Snowflake, }), - columns.NewWrapper(columns.NewColumn("group", typing.Invalid), &columns.NameArgs{ + columns.NewWrapper(m.ctx, columns.NewColumn("group", typing.Invalid), &sql.NameArgs{ Escape: true, DestKind: constants.Snowflake, }), @@ -230,15 +230,15 @@ func TestMergeStatementEscapePrimaryKeys(t *testing.T) { DestKind: constants.Snowflake, SoftDelete: false, }) - assert.NoError(t, err) - assert.True(t, strings.Contains(mergeSQL, fmt.Sprintf("MERGE INTO %s", fqTable)), mergeSQL) - assert.False(t, strings.Contains(mergeSQL, fmt.Sprintf("cc.%s >= c.%s", "updated_at", "updated_at")), fmt.Sprintf("Idempotency key: %s", mergeSQL)) + assert.NoError(m.T(), err) + assert.True(m.T(), strings.Contains(mergeSQL, fmt.Sprintf("MERGE INTO %s", fqTable)), mergeSQL) + assert.False(m.T(), strings.Contains(mergeSQL, fmt.Sprintf("cc.%s >= c.%s", "updated_at", "updated_at")), fmt.Sprintf("Idempotency key: %s", mergeSQL)) // Check primary keys clause - assert.True(t, strings.Contains(mergeSQL, `as cc on c.id = cc.id and c."group" = cc."group"`), mergeSQL) + assert.True(m.T(), strings.Contains(mergeSQL, `as cc on c.id = cc.id and c."group" = cc."group"`), mergeSQL) // Check setting for update - assert.True(t, strings.Contains(mergeSQL, `SET id=cc.id,"group"=cc."group",updated_at=cc.updated_at,"start"=cc."start"`), mergeSQL) + assert.True(m.T(), strings.Contains(mergeSQL, `SET id=cc.id,"group"=cc."group",updated_at=cc.updated_at,"start"=cc."start"`), mergeSQL) // Check for INSERT - assert.True(t, strings.Contains(mergeSQL, `id,"group",updated_at,"start"`), mergeSQL) - assert.True(t, strings.Contains(mergeSQL, `cc.id,cc."group",cc.updated_at,cc."start"`), mergeSQL) + assert.True(m.T(), strings.Contains(mergeSQL, `id,"group",updated_at,"start"`), mergeSQL) + assert.True(m.T(), strings.Contains(mergeSQL, `cc.id,cc."group",cc.updated_at,cc."start"`), mergeSQL) } diff --git a/lib/dwh/dml/merge_valid_test.go b/lib/dwh/dml/merge_valid_test.go index 4fcc630c2..f97088f46 100644 --- a/lib/dwh/dml/merge_valid_test.go +++ b/lib/dwh/dml/merge_valid_test.go @@ -1,15 +1,13 @@ package dml import ( - "testing" - "github.com/artie-labs/transfer/lib/typing" "github.com/artie-labs/transfer/lib/typing/columns" "github.com/stretchr/testify/assert" ) -func TestMergeArgument_Valid(t *testing.T) { +func (m *MergeTestSuite) TestMergeArgument_Valid() { type _testCase struct { name string mergeArg *MergeArgument @@ -18,7 +16,7 @@ func TestMergeArgument_Valid(t *testing.T) { } primaryKeys := []columns.Wrapper{ - columns.NewWrapper(columns.NewColumn("id", typing.Integer), nil), + columns.NewWrapper(m.ctx, columns.NewColumn("id", typing.Integer), nil), } var cols columns.Columns @@ -89,10 +87,10 @@ func TestMergeArgument_Valid(t *testing.T) { for _, testCase := range testCases { actualErr := testCase.mergeArg.Valid() if testCase.expectedError { - assert.Error(t, actualErr, testCase.name) - assert.Equal(t, testCase.expectErrorMessage, actualErr.Error(), testCase.name) + assert.Error(m.T(), actualErr, testCase.name) + assert.Equal(m.T(), testCase.expectErrorMessage, actualErr.Error(), testCase.name) } else { - assert.NoError(t, actualErr, testCase.name) + assert.NoError(m.T(), actualErr, testCase.name) } } } diff --git a/lib/dwh/types/table_config.go b/lib/dwh/types/table_config.go index 1d929c199..53ee6c7c1 100644 --- a/lib/dwh/types/table_config.go +++ b/lib/dwh/types/table_config.go @@ -56,7 +56,7 @@ func (d *DwhTableConfig) Columns() *columns.Columns { return d.columns } -func (d *DwhTableConfig) MutateInMemoryColumns(createTable bool, columnOp constants.ColumnOperation, cols ...columns.Column) { +func (d *DwhTableConfig) MutateInMemoryColumns(ctx context.Context, createTable bool, columnOp constants.ColumnOperation, cols ...columns.Column) { d.Lock() defer d.Unlock() switch columnOp { @@ -64,15 +64,15 @@ func (d *DwhTableConfig) MutateInMemoryColumns(createTable bool, columnOp consta for _, col := range cols { d.columns.AddColumn(col) // Delete from the permissions table, if exists. - delete(d.columnsToDelete, col.Name(nil)) + delete(d.columnsToDelete, col.Name(ctx, nil)) } d.createTable = createTable case constants.Delete: for _, col := range cols { // Delete from the permissions and in-memory table - d.columns.DeleteColumn(col.Name(nil)) - delete(d.columnsToDelete, col.Name(nil)) + d.columns.DeleteColumn(col.Name(ctx, nil)) + delete(d.columnsToDelete, col.Name(ctx, nil)) } } } diff --git a/lib/dwh/types/table_config_test.go b/lib/dwh/types/table_config_test.go index ab1bd569e..7c588fefc 100644 --- a/lib/dwh/types/table_config_test.go +++ b/lib/dwh/types/table_config_test.go @@ -18,7 +18,7 @@ import ( "github.com/stretchr/testify/assert" ) -func TestDwhTableConfig_ShouldDeleteColumn(t *testing.T) { +func (t *TypesTestSuite) TestDwhTableConfig_ShouldDeleteColumn() { ctx := config.InjectSettingsIntoContext(context.Background(), &config.Settings{ VerboseLogging: false, }) @@ -28,27 +28,27 @@ func TestDwhTableConfig_ShouldDeleteColumn(t *testing.T) { dwhTableConfig := NewDwhTableConfig(&columns.Columns{}, nil, false, false) for i := 0; i < 100; i++ { results := dwhTableConfig.ShouldDeleteColumn(ctx, "hello", time.Now().UTC(), true) - assert.False(t, results) - assert.Equal(t, len(dwhTableConfig.ReadOnlyColumnsToDelete()), 0) + assert.False(t.T(), results) + assert.Equal(t.T(), len(dwhTableConfig.ReadOnlyColumnsToDelete()), 0) } // 2. DropDeletedColumns = true and ContainsOtherOperations = false, so don't delete dwhTableConfig = NewDwhTableConfig(&columns.Columns{}, nil, false, true) for i := 0; i < 100; i++ { results := dwhTableConfig.ShouldDeleteColumn(ctx, "hello", time.Now().UTC(), false) - assert.False(t, results) - assert.Equal(t, len(dwhTableConfig.ReadOnlyColumnsToDelete()), 0) + assert.False(t.T(), results) + assert.Equal(t.T(), len(dwhTableConfig.ReadOnlyColumnsToDelete()), 0) } // 3. DropDeletedColumns = true and ContainsOtherOperations = true, now check CDC time to delete. dwhTableConfig = NewDwhTableConfig(&columns.Columns{}, nil, false, true) for i := 0; i < 100; i++ { results := dwhTableConfig.ShouldDeleteColumn(ctx, "hello", time.Now().UTC(), true) - assert.False(t, results) - assert.Equal(t, len(dwhTableConfig.ReadOnlyColumnsToDelete()), 1) + assert.False(t.T(), results) + assert.Equal(t.T(), len(dwhTableConfig.ReadOnlyColumnsToDelete()), 1) } - assert.True(t, dwhTableConfig.ShouldDeleteColumn(ctx, "hello", time.Now().UTC().Add(2*constants.DeletionConfidencePadding), true)) + assert.True(t.T(), dwhTableConfig.ShouldDeleteColumn(ctx, "hello", time.Now().UTC().Add(2*constants.DeletionConfidencePadding), true)) } // TestDwhTableConfig_ColumnsConcurrency this file is meant to test the concurrency methods of .Columns() @@ -82,19 +82,19 @@ func TestDwhTableConfig_ColumnsConcurrency(t *testing.T) { wg.Wait() } -func TestDwhTableConfig_MutateInMemoryColumns(t *testing.T) { +func (t *TypesTestSuite) TestDwhTableConfig_MutateInMemoryColumns() { tc := NewDwhTableConfig(&columns.Columns{}, nil, false, false) for _, col := range []string{"a", "b", "c", "d", "e"} { - tc.MutateInMemoryColumns(false, constants.Add, columns.NewColumn(col, typing.String)) + tc.MutateInMemoryColumns(t.ctx, false, constants.Add, columns.NewColumn(col, typing.String)) } - assert.Equal(t, 5, len(tc.columns.GetColumns())) + assert.Equal(t.T(), 5, len(tc.columns.GetColumns())) var wg sync.WaitGroup for _, addCol := range []string{"aa", "bb", "cc", "dd", "ee", "ff"} { wg.Add(1) go func(colName string) { defer wg.Done() - tc.MutateInMemoryColumns(false, constants.Add, columns.NewColumn(colName, typing.String)) + tc.MutateInMemoryColumns(t.ctx, false, constants.Add, columns.NewColumn(colName, typing.String)) }(addCol) } @@ -102,15 +102,15 @@ func TestDwhTableConfig_MutateInMemoryColumns(t *testing.T) { wg.Add(1) go func(colName string) { defer wg.Done() - tc.MutateInMemoryColumns(false, constants.Delete, columns.NewColumn(colName, typing.Invalid)) + tc.MutateInMemoryColumns(t.ctx, false, constants.Delete, columns.NewColumn(colName, typing.Invalid)) }(removeCol) } wg.Wait() - assert.Equal(t, 6, len(tc.columns.GetColumns())) + assert.Equal(t.T(), 6, len(tc.columns.GetColumns())) } -func TestDwhTableConfig_ReadOnlyColumnsToDelete(t *testing.T) { +func (t *TypesTestSuite) TestDwhTableConfig_ReadOnlyColumnsToDelete() { colsToDelete := make(map[string]time.Time) for _, colToDelete := range []string{"a", "b", "c", "d"} { colsToDelete[colToDelete] = time.Now() @@ -124,13 +124,13 @@ func TestDwhTableConfig_ReadOnlyColumnsToDelete(t *testing.T) { time.Sleep(time.Duration(jitter.JitterMs(50, 1)) * time.Millisecond) defer wg.Done() actualColsToDelete := tc.ReadOnlyColumnsToDelete() - assert.Equal(t, colsToDelete, actualColsToDelete) + assert.Equal(t.T(), colsToDelete, actualColsToDelete) }() } wg.Wait() } -func TestDwhTableConfig_ClearColumnsToDeleteByColName(t *testing.T) { +func (t *TypesTestSuite) TestDwhTableConfig_ClearColumnsToDeleteByColName() { colsToDelete := make(map[string]time.Time) for _, colToDelete := range []string{"a", "b", "c", "d"} { colsToDelete[colToDelete] = time.Now() @@ -138,7 +138,7 @@ func TestDwhTableConfig_ClearColumnsToDeleteByColName(t *testing.T) { tc := NewDwhTableConfig(nil, colsToDelete, false, false) var wg sync.WaitGroup - assert.Equal(t, 4, len(tc.columnsToDelete)) + assert.Equal(t.T(), 4, len(tc.columnsToDelete)) for _, colToDelete := range []string{"a", "b", "c", "d"} { for i := 0; i < 100; i++ { wg.Add(1) @@ -151,5 +151,5 @@ func TestDwhTableConfig_ClearColumnsToDeleteByColName(t *testing.T) { } wg.Wait() - assert.Equal(t, 0, len(tc.columnsToDelete)) + assert.Equal(t.T(), 0, len(tc.columnsToDelete)) } diff --git a/lib/dwh/types/types_suite_test.go b/lib/dwh/types/types_suite_test.go new file mode 100644 index 000000000..22269fcb0 --- /dev/null +++ b/lib/dwh/types/types_suite_test.go @@ -0,0 +1,24 @@ +package types + +import ( + "context" + "testing" + + "github.com/artie-labs/transfer/lib/config" + + "github.com/stretchr/testify/suite" +) + +type TypesTestSuite struct { + suite.Suite + ctx context.Context +} + +func (t *TypesTestSuite) SetupTest() { + t.ctx = context.Background() + t.ctx = config.InjectSettingsIntoContext(t.ctx, &config.Settings{Config: &config.Config{}}) +} + +func TestTypesTestSuite(t *testing.T) { + suite.Run(t, new(TypesTestSuite)) +} diff --git a/lib/dwh/types/types_test.go b/lib/dwh/types/types_test.go index 2a4112ec4..eaa3559b1 100644 --- a/lib/dwh/types/types_test.go +++ b/lib/dwh/types/types_test.go @@ -2,7 +2,6 @@ package types import ( "sync" - "testing" "time" "github.com/artie-labs/transfer/lib/typing/columns" @@ -31,17 +30,17 @@ func generateDwhTableCfg() *DwhTableConfig { } } -func TestDwhToTablesConfigMap_TableConfigBasic(t *testing.T) { +func (t *TypesTestSuite) TestDwhToTablesConfigMap_TableConfigBasic() { dwh := &DwhToTablesConfigMap{} dwhTableConfig := generateDwhTableCfg() fqName := "database.schema.tableName" dwh.AddTableToConfig(fqName, dwhTableConfig) - assert.Equal(t, *dwhTableConfig, *dwh.TableConfig(fqName)) + assert.Equal(t.T(), *dwhTableConfig, *dwh.TableConfig(fqName)) } // TestDwhToTablesConfigMap_Concurrency - has a bunch of concurrent go-routines that are rapidly adding and reading from the tableConfig. -func TestDwhToTablesConfigMap_Concurrency(t *testing.T) { +func (t *TypesTestSuite) TestDwhToTablesConfigMap_Concurrency() { dwh := &DwhToTablesConfigMap{} fqName := "db.schema.table" dwhTableCfg := generateDwhTableCfg() @@ -63,7 +62,7 @@ func TestDwhToTablesConfigMap_Concurrency(t *testing.T) { defer wg.Done() for i := 0; i < 1000; i++ { time.Sleep(time.Duration(jitter.JitterMs(5, 1)) * time.Millisecond) - assert.Equal(t, *dwhTableCfg, *dwh.TableConfig(fqName)) + assert.Equal(t.T(), *dwhTableCfg, *dwh.TableConfig(fqName)) } }() diff --git a/lib/optimization/event.go b/lib/optimization/event.go index f31a4adfd..37a8c48e8 100644 --- a/lib/optimization/event.go +++ b/lib/optimization/event.go @@ -6,6 +6,8 @@ import ( "strings" "time" + "github.com/artie-labs/transfer/lib/sql" + "github.com/artie-labs/transfer/lib/typing/columns" "github.com/artie-labs/transfer/lib/stringutil" @@ -49,18 +51,18 @@ func (t *TableData) ContainOtherOperations() bool { return t.containOtherOperations } -func (t *TableData) PrimaryKeys(args *columns.NameArgs) []columns.Wrapper { +func (t *TableData) PrimaryKeys(ctx context.Context, args *sql.NameArgs) []columns.Wrapper { var primaryKeysEscaped []columns.Wrapper for _, pk := range t.primaryKeys { col := columns.NewColumn(pk, typing.Invalid) - primaryKeysEscaped = append(primaryKeysEscaped, columns.NewWrapper(col, args)) + primaryKeysEscaped = append(primaryKeysEscaped, columns.NewWrapper(ctx, col, args)) } return primaryKeysEscaped } -func (t *TableData) Name() string { - return t.name +func (t *TableData) Name(ctx context.Context, args *sql.NameArgs) string { + return sql.EscapeName(ctx, t.name, args) } func (t *TableData) SetInMemoryColumns(columns *columns.Columns) { @@ -142,17 +144,26 @@ func (t *TableData) RowsData() map[string]map[string]interface{} { return _rowsData } -func (t *TableData) ToFqName(ctx context.Context, kind constants.DestinationKind) string { +func (t *TableData) ToFqName(ctx context.Context, kind constants.DestinationKind, escape bool) string { switch kind { case constants.Redshift: // Redshift is Postgres compatible, so when establishing a connection, we'll specify a database. // Thus, we only need to specify schema and table name here. - return fmt.Sprintf("%s.%s", t.TopicConfig.Schema, t.Name()) + return fmt.Sprintf("%s.%s", t.TopicConfig.Schema, t.Name(ctx, &sql.NameArgs{ + Escape: escape, + DestKind: kind, + })) case constants.BigQuery: // The fully qualified name for BigQuery is: project_id.dataset.tableName. - return fmt.Sprintf("%s.%s.%s", config.FromContext(ctx).Config.BigQuery.ProjectID, t.TopicConfig.Database, t.Name()) + return fmt.Sprintf("%s.%s.%s", config.FromContext(ctx).Config.BigQuery.ProjectID, t.TopicConfig.Database, t.Name(ctx, &sql.NameArgs{ + Escape: escape, + DestKind: kind, + })) default: - return fmt.Sprintf("%s.%s.%s", t.TopicConfig.Database, t.TopicConfig.Schema, t.Name()) + return fmt.Sprintf("%s.%s.%s", t.TopicConfig.Database, t.TopicConfig.Schema, t.Name(ctx, &sql.NameArgs{ + Escape: escape, + DestKind: kind, + })) } } @@ -179,7 +190,7 @@ func (t *TableData) ShouldFlush(ctx context.Context) bool { // Prior to merging, we will need to treat `tableConfig` as the source-of-truth and whenever there's discrepancies // We will prioritize using the values coming from (2) TableConfig. We also cannot simply do a replacement, as we have in-memory columns // That carry metadata for Artie Transfer. They are prefixed with __artie. -func (t *TableData) UpdateInMemoryColumnsFromDestination(cols ...columns.Column) { +func (t *TableData) UpdateInMemoryColumnsFromDestination(ctx context.Context, cols ...columns.Column) { if t == nil { return } @@ -188,7 +199,7 @@ func (t *TableData) UpdateInMemoryColumnsFromDestination(cols ...columns.Column) var foundColumn columns.Column var found bool for _, col := range cols { - if col.Name(nil) == strings.ToLower(inMemoryCol.Name(nil)) { + if col.Name(ctx, nil) == strings.ToLower(inMemoryCol.Name(ctx, nil)) { foundColumn = col found = true break diff --git a/lib/optimization/event_test.go b/lib/optimization/event_test.go index 5931999e5..b9f4fe98d 100644 --- a/lib/optimization/event_test.go +++ b/lib/optimization/event_test.go @@ -17,7 +17,7 @@ import ( "github.com/stretchr/testify/assert" ) -func TestNewTableData_TableName(t *testing.T) { +func (o *OptimizationTestSuite) TestNewTableData_TableName() { type _testCase struct { name string tableName string @@ -71,16 +71,16 @@ func TestNewTableData_TableName(t *testing.T) { TableName: testCase.overrideName, Schema: testCase.schema, }, testCase.tableName) - assert.Equal(t, testCase.expectedName, td.Name(), testCase.name) - assert.Equal(t, testCase.expectedName, td.name, testCase.name) - assert.Equal(t, testCase.expectedSnowflakeFqName, td.ToFqName(ctx, constants.SnowflakeStages)) - assert.Equal(t, testCase.expectedSnowflakeFqName, td.ToFqName(ctx, constants.Snowflake)) - assert.Equal(t, testCase.expectedBigQueryFqName, td.ToFqName(ctx, constants.BigQuery)) - assert.Equal(t, testCase.expectedBigQueryFqName, td.ToFqName(ctx, constants.BigQuery)) + assert.Equal(o.T(), testCase.expectedName, td.Name(o.ctx, nil), testCase.name) + assert.Equal(o.T(), testCase.expectedName, td.name, testCase.name) + assert.Equal(o.T(), testCase.expectedSnowflakeFqName, td.ToFqName(ctx, constants.SnowflakeStages, true)) + assert.Equal(o.T(), testCase.expectedSnowflakeFqName, td.ToFqName(ctx, constants.Snowflake, true)) + assert.Equal(o.T(), testCase.expectedBigQueryFqName, td.ToFqName(ctx, constants.BigQuery, true)) + assert.Equal(o.T(), testCase.expectedBigQueryFqName, td.ToFqName(ctx, constants.BigQuery, true)) } } -func TestTableData_ReadOnlyInMemoryCols(t *testing.T) { +func (o *OptimizationTestSuite) TestTableData_ReadOnlyInMemoryCols() { // Making sure the columns are actually read only. var cols columns.Columns cols.AddColumn(columns.NewColumn("name", typing.String)) @@ -91,13 +91,13 @@ func TestTableData_ReadOnlyInMemoryCols(t *testing.T) { // Check if last_name actually exists. _, isOk := td.ReadOnlyInMemoryCols().GetColumn("last_name") - assert.False(t, isOk) + assert.False(o.T(), isOk) // Check length is 1. - assert.Equal(t, 1, len(td.ReadOnlyInMemoryCols().GetColumns())) + assert.Equal(o.T(), 1, len(td.ReadOnlyInMemoryCols().GetColumns())) } -func TestTableData_UpdateInMemoryColumns(t *testing.T) { +func (o *OptimizationTestSuite) TestTableData_UpdateInMemoryColumns() { var _cols columns.Columns for colName, colKind := range map[string]typing.KindDetails{ "FOO": typing.String, @@ -113,10 +113,10 @@ func TestTableData_UpdateInMemoryColumns(t *testing.T) { } extCol, isOk := tableData.ReadOnlyInMemoryCols().GetColumn("do_not_change_format") - assert.True(t, isOk) + assert.True(o.T(), isOk) extCol.KindDetails.ExtendedTimeDetails.Format = time.RFC3339Nano - tableData.inMemoryColumns.UpdateColumn(columns.NewColumn(extCol.Name(nil), extCol.KindDetails)) + tableData.inMemoryColumns.UpdateColumn(columns.NewColumn(extCol.Name(o.ctx, nil), extCol.KindDetails)) for name, colKindDetails := range map[string]typing.KindDetails{ "foo": typing.String, @@ -124,30 +124,30 @@ func TestTableData_UpdateInMemoryColumns(t *testing.T) { "bar": typing.Boolean, "do_not_change_format": typing.NewKindDetailsFromTemplate(typing.ETime, ext.DateTimeKindType), } { - tableData.UpdateInMemoryColumnsFromDestination(columns.NewColumn(name, colKindDetails)) + tableData.UpdateInMemoryColumnsFromDestination(o.ctx, columns.NewColumn(name, colKindDetails)) } // It's saved back in the original format. _, isOk = tableData.ReadOnlyInMemoryCols().GetColumn("foo") - assert.False(t, isOk) + assert.False(o.T(), isOk) _, isOk = tableData.ReadOnlyInMemoryCols().GetColumn("FOO") - assert.True(t, isOk) + assert.True(o.T(), isOk) col, isOk := tableData.ReadOnlyInMemoryCols().GetColumn("CHANGE_me") - assert.True(t, isOk) - assert.Equal(t, ext.DateTime.Type, col.KindDetails.ExtendedTimeDetails.Type) + assert.True(o.T(), isOk) + assert.Equal(o.T(), ext.DateTime.Type, col.KindDetails.ExtendedTimeDetails.Type) // It went from invalid to boolean. col, isOk = tableData.ReadOnlyInMemoryCols().GetColumn("bar") - assert.True(t, isOk) - assert.Equal(t, typing.Boolean, col.KindDetails) + assert.True(o.T(), isOk) + assert.Equal(o.T(), typing.Boolean, col.KindDetails) col, isOk = tableData.ReadOnlyInMemoryCols().GetColumn("do_not_change_format") - assert.True(t, isOk) - assert.Equal(t, col.KindDetails.Kind, typing.ETime.Kind) - assert.Equal(t, col.KindDetails.ExtendedTimeDetails.Type, ext.DateTimeKindType, "correctly mapped type") - assert.Equal(t, col.KindDetails.ExtendedTimeDetails.Format, time.RFC3339Nano, "format has been preserved") + assert.True(o.T(), isOk) + assert.Equal(o.T(), col.KindDetails.Kind, typing.ETime.Kind) + assert.Equal(o.T(), col.KindDetails.ExtendedTimeDetails.Type, ext.DateTimeKindType, "correctly mapped type") + assert.Equal(o.T(), col.KindDetails.ExtendedTimeDetails.Format, time.RFC3339Nano, "format has been preserved") } func TestTableData_ShouldFlushRowLength(t *testing.T) { diff --git a/lib/optimization/event_update_test.go b/lib/optimization/event_update_test.go index 4b7fa7d2a..f4236fb2a 100644 --- a/lib/optimization/event_update_test.go +++ b/lib/optimization/event_update_test.go @@ -1,15 +1,13 @@ package optimization import ( - "testing" - "github.com/artie-labs/transfer/lib/typing" "github.com/artie-labs/transfer/lib/typing/columns" "github.com/artie-labs/transfer/lib/typing/ext" "github.com/stretchr/testify/assert" ) -func TestTableData_UpdateInMemoryColumnsFromDestination(t *testing.T) { +func (o *OptimizationTestSuite) TestTableData_UpdateInMemoryColumnsFromDestination() { tableDataCols := &columns.Columns{} tableDataCols.AddColumn(columns.NewColumn("name", typing.String)) tableDataCols.AddColumn(columns.NewColumn("bool_backfill", typing.Boolean)) @@ -30,57 +28,57 @@ func TestTableData_UpdateInMemoryColumnsFromDestination(t *testing.T) { } // Testing to make sure we don't copy over non-existent columns - tableData.UpdateInMemoryColumnsFromDestination(nonExistentCols...) + tableData.UpdateInMemoryColumnsFromDestination(o.ctx, nonExistentCols...) for _, nonExistentTableCol := range nonExistentTableCols { _, isOk := tableData.inMemoryColumns.GetColumn(nonExistentTableCol) - assert.False(t, isOk, nonExistentTableCol) + assert.False(o.T(), isOk, nonExistentTableCol) } // Testing to make sure we're copying the kindDetails over. - tableData.UpdateInMemoryColumnsFromDestination(columns.NewColumn("prev_invalid", typing.String)) + tableData.UpdateInMemoryColumnsFromDestination(o.ctx, columns.NewColumn("prev_invalid", typing.String)) prevInvalidCol, isOk := tableData.inMemoryColumns.GetColumn("prev_invalid") - assert.True(t, isOk) - assert.Equal(t, typing.String, prevInvalidCol.KindDetails) + assert.True(o.T(), isOk) + assert.Equal(o.T(), typing.String, prevInvalidCol.KindDetails) // Testing backfill for _, inMemoryCol := range tableData.inMemoryColumns.GetColumns() { - assert.False(t, inMemoryCol.Backfilled(), inMemoryCol.Name(nil)) + assert.False(o.T(), inMemoryCol.Backfilled(), inMemoryCol.Name(o.ctx, nil)) } backfilledCol := columns.NewColumn("bool_backfill", typing.Boolean) backfilledCol.SetBackfilled(true) - tableData.UpdateInMemoryColumnsFromDestination(backfilledCol) + tableData.UpdateInMemoryColumnsFromDestination(o.ctx, backfilledCol) for _, inMemoryCol := range tableData.inMemoryColumns.GetColumns() { - if inMemoryCol.Name(nil) == backfilledCol.Name(nil) { - assert.True(t, inMemoryCol.Backfilled(), inMemoryCol.Name(nil)) + if inMemoryCol.Name(o.ctx, nil) == backfilledCol.Name(o.ctx, nil) { + assert.True(o.T(), inMemoryCol.Backfilled(), inMemoryCol.Name(o.ctx, nil)) } else { - assert.False(t, inMemoryCol.Backfilled(), inMemoryCol.Name(nil)) + assert.False(o.T(), inMemoryCol.Backfilled(), inMemoryCol.Name(o.ctx, nil)) } } // Testing extTimeDetails for _, extTimeDetailsCol := range []string{"ext_date", "ext_time", "ext_datetime"} { col, isOk := tableData.inMemoryColumns.GetColumn(extTimeDetailsCol) - assert.True(t, isOk, extTimeDetailsCol) - assert.Equal(t, typing.String, col.KindDetails, extTimeDetailsCol) - assert.Nil(t, col.KindDetails.ExtendedTimeDetails, extTimeDetailsCol) + assert.True(o.T(), isOk, extTimeDetailsCol) + assert.Equal(o.T(), typing.String, col.KindDetails, extTimeDetailsCol) + assert.Nil(o.T(), col.KindDetails.ExtendedTimeDetails, extTimeDetailsCol) } - tableData.UpdateInMemoryColumnsFromDestination(columns.NewColumn("ext_date", typing.NewKindDetailsFromTemplate(typing.ETime, ext.DateKindType))) - tableData.UpdateInMemoryColumnsFromDestination(columns.NewColumn("ext_time", typing.NewKindDetailsFromTemplate(typing.ETime, ext.TimeKindType))) - tableData.UpdateInMemoryColumnsFromDestination(columns.NewColumn("ext_datetime", typing.NewKindDetailsFromTemplate(typing.ETime, ext.DateTimeKindType))) + tableData.UpdateInMemoryColumnsFromDestination(o.ctx, columns.NewColumn("ext_date", typing.NewKindDetailsFromTemplate(typing.ETime, ext.DateKindType))) + tableData.UpdateInMemoryColumnsFromDestination(o.ctx, columns.NewColumn("ext_time", typing.NewKindDetailsFromTemplate(typing.ETime, ext.TimeKindType))) + tableData.UpdateInMemoryColumnsFromDestination(o.ctx, columns.NewColumn("ext_datetime", typing.NewKindDetailsFromTemplate(typing.ETime, ext.DateTimeKindType))) dateCol, isOk := tableData.inMemoryColumns.GetColumn("ext_date") - assert.True(t, isOk) - assert.NotNil(t, dateCol.KindDetails.ExtendedTimeDetails) - assert.Equal(t, ext.DateKindType, dateCol.KindDetails.ExtendedTimeDetails.Type) + assert.True(o.T(), isOk) + assert.NotNil(o.T(), dateCol.KindDetails.ExtendedTimeDetails) + assert.Equal(o.T(), ext.DateKindType, dateCol.KindDetails.ExtendedTimeDetails.Type) timeCol, isOk := tableData.inMemoryColumns.GetColumn("ext_time") - assert.True(t, isOk) - assert.NotNil(t, timeCol.KindDetails.ExtendedTimeDetails) - assert.Equal(t, ext.TimeKindType, timeCol.KindDetails.ExtendedTimeDetails.Type) + assert.True(o.T(), isOk) + assert.NotNil(o.T(), timeCol.KindDetails.ExtendedTimeDetails) + assert.Equal(o.T(), ext.TimeKindType, timeCol.KindDetails.ExtendedTimeDetails.Type) dateTimeCol, isOk := tableData.inMemoryColumns.GetColumn("ext_datetime") - assert.True(t, isOk) - assert.NotNil(t, dateTimeCol.KindDetails.ExtendedTimeDetails) - assert.Equal(t, ext.DateTimeKindType, dateTimeCol.KindDetails.ExtendedTimeDetails.Type) + assert.True(o.T(), isOk) + assert.NotNil(o.T(), dateTimeCol.KindDetails.ExtendedTimeDetails) + assert.Equal(o.T(), ext.DateTimeKindType, dateTimeCol.KindDetails.ExtendedTimeDetails.Type) } diff --git a/lib/optimization/optimization_suite_test.go b/lib/optimization/optimization_suite_test.go new file mode 100644 index 000000000..cb4e7268d --- /dev/null +++ b/lib/optimization/optimization_suite_test.go @@ -0,0 +1,24 @@ +package optimization + +import ( + "context" + "testing" + + "github.com/artie-labs/transfer/lib/config" + + "github.com/stretchr/testify/suite" +) + +type OptimizationTestSuite struct { + suite.Suite + ctx context.Context +} + +func (o *OptimizationTestSuite) SetupTest() { + o.ctx = context.Background() + o.ctx = config.InjectSettingsIntoContext(o.ctx, &config.Settings{Config: &config.Config{}}) +} + +func TestOptimizationTestSuite(t *testing.T) { + suite.Run(t, new(OptimizationTestSuite)) +} diff --git a/lib/sql/escape.go b/lib/sql/escape.go new file mode 100644 index 000000000..2b71748dc --- /dev/null +++ b/lib/sql/escape.go @@ -0,0 +1,40 @@ +package sql + +import ( + "context" + "fmt" + "strings" + + "github.com/artie-labs/transfer/lib/config" + + "github.com/artie-labs/transfer/lib/array" + "github.com/artie-labs/transfer/lib/config/constants" +) + +type NameArgs struct { + Escape bool + DestKind constants.DestinationKind +} + +func EscapeName(ctx context.Context, name string, args *NameArgs) string { + var escape bool + if args != nil { + escape = args.Escape + } + + if escape && array.StringContains(constants.ReservedKeywords, name) { + if config.FromContext(ctx).Config.SharedDestinationConfig.UppercaseEscapedNames { + name = strings.ToUpper(name) + } + + if args != nil && args.DestKind == constants.BigQuery { + // BigQuery needs backticks to escape. + return fmt.Sprintf("`%s`", name) + } else { + // Snowflake uses quotes. + return fmt.Sprintf(`"%s"`, name) + } + } + + return name +} diff --git a/lib/sql/escape_test.go b/lib/sql/escape_test.go new file mode 100644 index 000000000..416dfad08 --- /dev/null +++ b/lib/sql/escape_test.go @@ -0,0 +1,102 @@ +package sql + +import ( + "github.com/artie-labs/transfer/lib/config" + "github.com/artie-labs/transfer/lib/config/constants" + "github.com/stretchr/testify/assert" +) + +func (s *SqlTestSuite) TestEscapeName() { + type _testCase struct { + name string + nameToEscape string + args *NameArgs + expectedName string + expectedNameWhenUpperCfg string + } + + testCases := []_testCase{ + { + name: "args = nil", + nameToEscape: "order", + expectedName: "order", + expectedNameWhenUpperCfg: "order", + }, + { + name: "escape = false", + args: &NameArgs{}, + nameToEscape: "order", + expectedName: "order", + expectedNameWhenUpperCfg: "order", + }, + { + name: "escape = true, snowflake", + args: &NameArgs{ + Escape: true, + DestKind: constants.Snowflake, + }, + nameToEscape: "order", + expectedName: `"order"`, + expectedNameWhenUpperCfg: `"ORDER"`, + }, + { + name: "escape = true, snowflake #2", + args: &NameArgs{ + Escape: true, + DestKind: constants.Snowflake, + }, + nameToEscape: "hello", + expectedName: `hello`, + expectedNameWhenUpperCfg: "hello", + }, + { + name: "escape = true, redshift", + args: &NameArgs{ + Escape: true, + DestKind: constants.Redshift, + }, + nameToEscape: "order", + expectedName: `"order"`, + expectedNameWhenUpperCfg: `"ORDER"`, + }, + { + name: "escape = true, redshift #2", + args: &NameArgs{ + Escape: true, + DestKind: constants.Redshift, + }, + nameToEscape: "hello", + expectedName: `hello`, + expectedNameWhenUpperCfg: "hello", + }, + { + name: "escape = true, bigquery", + args: &NameArgs{ + Escape: true, + DestKind: constants.BigQuery, + }, + nameToEscape: "order", + expectedName: "`order`", + expectedNameWhenUpperCfg: "`ORDER`", + }, + { + name: "escape = true, bigquery, #2", + args: &NameArgs{ + Escape: true, + DestKind: constants.BigQuery, + }, + nameToEscape: "hello", + expectedName: "hello", + expectedNameWhenUpperCfg: "hello", + }, + } + + for _, testCase := range testCases { + actualName := EscapeName(s.ctx, testCase.nameToEscape, testCase.args) + assert.Equal(s.T(), testCase.expectedName, actualName, testCase.name) + + upperCtx := config.InjectSettingsIntoContext(s.ctx, &config.Settings{Config: &config.Config{SharedDestinationConfig: config.SharedDestinationConfig{UppercaseEscapedNames: true}}}) + actualUpperName := EscapeName(upperCtx, testCase.nameToEscape, testCase.args) + assert.Equal(s.T(), testCase.expectedNameWhenUpperCfg, actualUpperName, testCase.name) + } +} diff --git a/lib/sql/sql_suite_test.go b/lib/sql/sql_suite_test.go new file mode 100644 index 000000000..90d782bf8 --- /dev/null +++ b/lib/sql/sql_suite_test.go @@ -0,0 +1,24 @@ +package sql + +import ( + "context" + "testing" + + "github.com/artie-labs/transfer/lib/config" + + "github.com/stretchr/testify/suite" +) + +type SqlTestSuite struct { + suite.Suite + ctx context.Context +} + +func (s *SqlTestSuite) SetupTest() { + s.ctx = context.Background() + s.ctx = config.InjectSettingsIntoContext(s.ctx, &config.Settings{Config: &config.Config{}}) +} + +func TestSqlTestSuite(t *testing.T) { + suite.Run(t, new(SqlTestSuite)) +} diff --git a/lib/typing/columns/columns.go b/lib/typing/columns/columns.go index af16e8838..3674d8cdc 100644 --- a/lib/typing/columns/columns.go +++ b/lib/typing/columns/columns.go @@ -1,16 +1,15 @@ package columns import ( + "context" "fmt" "strings" "sync" + "github.com/artie-labs/transfer/lib/config/constants" + "github.com/artie-labs/transfer/lib/sql" "github.com/artie-labs/transfer/lib/stringutil" - "github.com/artie-labs/transfer/lib/typing" - - "github.com/artie-labs/transfer/lib/array" - "github.com/artie-labs/transfer/lib/config/constants" ) // EscapeName - will lowercase columns and escape spaces. @@ -88,31 +87,11 @@ func (c *Column) ShouldBackfill() bool { return c.defaultValue != nil && c.backfilled == false } -type NameArgs struct { - Escape bool - DestKind constants.DestinationKind -} - // Name will give you c.name // However, if you pass in escape, we will escape if the column name is part of the reserved words from destinations. // If so, it'll change from `start` => `"start"` as suggested by Snowflake. -func (c *Column) Name(args *NameArgs) string { - var escape bool - if args != nil { - escape = args.Escape - } - - if escape && array.StringContains(constants.ReservedKeywords, c.name) { - if args != nil && args.DestKind == constants.BigQuery { - // BigQuery needs backticks to escape. - return fmt.Sprintf("`%s`", c.name) - } else { - // Snowflake uses quotes. - return fmt.Sprintf(`"%s"`, c.name) - } - } - - return c.name +func (c *Column) Name(ctx context.Context, args *sql.NameArgs) string { + return sql.EscapeName(ctx, c.name, args) } type Columns struct { @@ -120,9 +99,9 @@ type Columns struct { sync.RWMutex } -func (c *Columns) EscapeName(args *NameArgs) { +func (c *Columns) EscapeName(ctx context.Context, args *sql.NameArgs) { for idx := range c.columns { - c.columns[idx].name = c.columns[idx].Name(args) + c.columns[idx].name = c.columns[idx].Name(ctx, args) } return @@ -210,7 +189,7 @@ func (c *Columns) GetColumn(name string) (Column, bool) { // GetColumnsToUpdate will filter all the `Invalid` columns so that we do not update it. // It also has an option to escape the returned columns or not. This is used mostly for the SQL MERGE queries. -func (c *Columns) GetColumnsToUpdate(args *NameArgs) []string { +func (c *Columns) GetColumnsToUpdate(ctx context.Context, args *sql.NameArgs) []string { if c == nil { return []string{} } @@ -224,7 +203,7 @@ func (c *Columns) GetColumnsToUpdate(args *NameArgs) []string { continue } - cols = append(cols, col.Name(args)) + cols = append(cols, col.Name(ctx, args)) } return cols @@ -276,8 +255,8 @@ func (c *Columns) DeleteColumn(name string) { // columnsToTypes - given that list, provide the types (separate list because this list may contain invalid columns // bigQueryTypeCasting - We'll need to escape the column comparison if the column's a struct. // It then returns a list of strings like: cc.first_name=c.first_name,cc.last_name=c.last_name,cc.email=c.email -func ColumnsUpdateQuery(columns []string, columnsToTypes Columns, destKind constants.DestinationKind) string { - columnsToTypes.EscapeName(&NameArgs{ +func ColumnsUpdateQuery(ctx context.Context, columns []string, columnsToTypes Columns, destKind constants.DestinationKind) string { + columnsToTypes.EscapeName(ctx, &sql.NameArgs{ Escape: true, DestKind: destKind, }) diff --git a/lib/typing/columns/columns_suite_test.go b/lib/typing/columns/columns_suite_test.go new file mode 100644 index 000000000..acb7877d8 --- /dev/null +++ b/lib/typing/columns/columns_suite_test.go @@ -0,0 +1,25 @@ +package columns + +import ( + "context" + "testing" + + "github.com/artie-labs/transfer/lib/config" + "github.com/stretchr/testify/suite" +) + +type ColumnsTestSuite struct { + suite.Suite + ctx context.Context +} + +func (c *ColumnsTestSuite) SetupTest() { + c.ctx = config.InjectSettingsIntoContext(context.Background(), &config.Settings{ + VerboseLogging: false, + Config: &config.Config{}, + }) +} + +func TestColumnsTestSuite(t *testing.T) { + suite.Run(t, new(ColumnsTestSuite)) +} diff --git a/lib/typing/columns/columns_test.go b/lib/typing/columns/columns_test.go index 5ef519ee1..0f9c4a90e 100644 --- a/lib/typing/columns/columns_test.go +++ b/lib/typing/columns/columns_test.go @@ -2,18 +2,16 @@ package columns import ( "fmt" - "testing" - "github.com/artie-labs/transfer/lib/ptr" - - "github.com/artie-labs/transfer/lib/typing" + "github.com/artie-labs/transfer/lib/sql" "github.com/artie-labs/transfer/lib/config/constants" - + "github.com/artie-labs/transfer/lib/ptr" + "github.com/artie-labs/transfer/lib/typing" "github.com/stretchr/testify/assert" ) -func TestEscapeName(t *testing.T) { +func (c *ColumnsTestSuite) TestEscapeName() { type _testCase struct { name string expectedName string @@ -36,11 +34,11 @@ func TestEscapeName(t *testing.T) { for _, testCase := range testCases { actualName := EscapeName(testCase.name) - assert.Equal(t, testCase.expectedName, actualName, testCase.name) + assert.Equal(c.T(), testCase.expectedName, actualName, testCase.name) } } -func TestColumn_ShouldSkip(t *testing.T) { +func (c *ColumnsTestSuite) TestColumn_ShouldSkip() { type _testCase struct { name string col *Column @@ -68,11 +66,11 @@ func TestColumn_ShouldSkip(t *testing.T) { } for _, testCase := range testCases { - assert.Equal(t, testCase.expectedResult, testCase.col.ShouldSkip(), testCase.name) + assert.Equal(c.T(), testCase.expectedResult, testCase.col.ShouldSkip(), testCase.name) } } -func TestColumn_ShouldBackfill(t *testing.T) { +func (c *ColumnsTestSuite) TestColumn_ShouldBackfill() { type _testCase struct { name string column *Column @@ -130,11 +128,11 @@ func TestColumn_ShouldBackfill(t *testing.T) { } for _, testCase := range testCases { - assert.Equal(t, testCase.expectShouldBackfill, testCase.column.ShouldBackfill(), testCase.name) + assert.Equal(c.T(), testCase.expectShouldBackfill, testCase.column.ShouldBackfill(), testCase.name) } } -func TestUnescapeColumnName(t *testing.T) { +func (c *ColumnsTestSuite) TestUnescapeColumnName() { type _testCase struct { escapedName string expectedBigQueryName string @@ -159,14 +157,14 @@ func TestUnescapeColumnName(t *testing.T) { } for _, testCase := range testCases { - assert.Equal(t, testCase.expectedBigQueryName, UnescapeColumnName(testCase.escapedName, constants.BigQuery)) - assert.Equal(t, testCase.expectedSnowflakeName, UnescapeColumnName(testCase.escapedName, constants.Snowflake)) - assert.Equal(t, testCase.expectedSnowflakeName, UnescapeColumnName(testCase.escapedName, constants.SnowflakeStages)) - assert.Equal(t, testCase.expectedOtherName, UnescapeColumnName(testCase.escapedName, "")) + assert.Equal(c.T(), testCase.expectedBigQueryName, UnescapeColumnName(testCase.escapedName, constants.BigQuery)) + assert.Equal(c.T(), testCase.expectedSnowflakeName, UnescapeColumnName(testCase.escapedName, constants.Snowflake)) + assert.Equal(c.T(), testCase.expectedSnowflakeName, UnescapeColumnName(testCase.escapedName, constants.SnowflakeStages)) + assert.Equal(c.T(), testCase.expectedOtherName, UnescapeColumnName(testCase.escapedName, "")) } } -func TestColumn_Name(t *testing.T) { +func (c *ColumnsTestSuite) TestColumn_Name() { type _testCase struct { colName string expectedName string @@ -198,27 +196,27 @@ func TestColumn_Name(t *testing.T) { } for _, testCase := range testCases { - c := &Column{ + col := &Column{ name: testCase.colName, } - assert.Equal(t, testCase.expectedName, c.Name(nil), testCase.colName) - assert.Equal(t, testCase.expectedName, c.Name(&NameArgs{ + assert.Equal(c.T(), testCase.expectedName, col.Name(c.ctx, nil), testCase.colName) + assert.Equal(c.T(), testCase.expectedName, col.Name(c.ctx, &sql.NameArgs{ Escape: false, }), testCase.colName) - assert.Equal(t, testCase.expectedNameEsc, c.Name(&NameArgs{ + assert.Equal(c.T(), testCase.expectedNameEsc, col.Name(c.ctx, &sql.NameArgs{ Escape: true, DestKind: constants.Snowflake, }), testCase.colName) - assert.Equal(t, testCase.expectedNameEscBq, c.Name(&NameArgs{ + assert.Equal(c.T(), testCase.expectedNameEscBq, col.Name(c.ctx, &sql.NameArgs{ Escape: true, DestKind: constants.BigQuery, }), testCase.colName) } } -func TestColumns_GetColumnsToUpdate(t *testing.T) { +func (c *ColumnsTestSuite) TestColumns_GetColumnsToUpdate() { type _testCase struct { name string cols []Column @@ -274,24 +272,24 @@ func TestColumns_GetColumnsToUpdate(t *testing.T) { columns: testCase.cols, } - assert.Equal(t, testCase.expectedCols, cols.GetColumnsToUpdate(nil), testCase.name) - assert.Equal(t, testCase.expectedCols, cols.GetColumnsToUpdate(&NameArgs{ + assert.Equal(c.T(), testCase.expectedCols, cols.GetColumnsToUpdate(c.ctx, nil), testCase.name) + assert.Equal(c.T(), testCase.expectedCols, cols.GetColumnsToUpdate(c.ctx, &sql.NameArgs{ Escape: false, }), testCase.name) - assert.Equal(t, testCase.expectedColsEsc, cols.GetColumnsToUpdate(&NameArgs{ + assert.Equal(c.T(), testCase.expectedColsEsc, cols.GetColumnsToUpdate(c.ctx, &sql.NameArgs{ Escape: true, DestKind: constants.Snowflake, }), testCase.name) - assert.Equal(t, testCase.expectedColsEscBq, cols.GetColumnsToUpdate(&NameArgs{ + assert.Equal(c.T(), testCase.expectedColsEscBq, cols.GetColumnsToUpdate(c.ctx, &sql.NameArgs{ Escape: true, DestKind: constants.BigQuery, }), testCase.name) } } -func TestColumns_UpsertColumns(t *testing.T) { +func (c *ColumnsTestSuite) TestColumns_UpsertColumns() { keys := []string{"a", "b", "c", "d", "e"} var cols Columns for _, key := range keys { @@ -303,7 +301,7 @@ func TestColumns_UpsertColumns(t *testing.T) { // Now inspect prior to change. for _, col := range cols.GetColumns() { - assert.False(t, col.ToastColumn) + assert.False(c.T(), col.ToastColumn) } // Now selectively update only a, b @@ -314,43 +312,43 @@ func TestColumns_UpsertColumns(t *testing.T) { // Now inspect. col, _ := cols.GetColumn(key) - assert.True(t, col.ToastColumn) + assert.True(c.T(), col.ToastColumn) } cols.UpsertColumn("zzz", UpsertColumnArg{}) zzzCol, _ := cols.GetColumn("zzz") - assert.False(t, zzzCol.ToastColumn) - assert.False(t, zzzCol.primaryKey) - assert.Equal(t, zzzCol.KindDetails, typing.Invalid) + assert.False(c.T(), zzzCol.ToastColumn) + assert.False(c.T(), zzzCol.primaryKey) + assert.Equal(c.T(), zzzCol.KindDetails, typing.Invalid) cols.UpsertColumn("aaa", UpsertColumnArg{ ToastCol: ptr.ToBool(true), PrimaryKey: ptr.ToBool(true), }) aaaCol, _ := cols.GetColumn("aaa") - assert.True(t, aaaCol.ToastColumn) - assert.True(t, aaaCol.primaryKey) - assert.Equal(t, aaaCol.KindDetails, typing.Invalid) + assert.True(c.T(), aaaCol.ToastColumn) + assert.True(c.T(), aaaCol.primaryKey) + assert.Equal(c.T(), aaaCol.KindDetails, typing.Invalid) length := len(cols.columns) for i := 0; i < 500; i++ { cols.UpsertColumn("", UpsertColumnArg{}) } - assert.Equal(t, length, len(cols.columns)) + assert.Equal(c.T(), length, len(cols.columns)) } -func TestColumns_Add_Duplicate(t *testing.T) { +func (c *ColumnsTestSuite) TestColumns_Add_Duplicate() { var cols Columns duplicateColumns := []Column{{name: "foo"}, {name: "foo"}, {name: "foo"}, {name: "foo"}, {name: "foo"}, {name: "foo"}} for _, duplicateColumn := range duplicateColumns { cols.AddColumn(duplicateColumn) } - assert.Equal(t, len(cols.GetColumns()), 1, "AddColumn() de-duplicates") + assert.Equal(c.T(), len(cols.GetColumns()), 1, "AddColumn() de-duplicates") } -func TestColumns_Mutation(t *testing.T) { +func (c *ColumnsTestSuite) TestColumns_Mutation() { var cols Columns colsToAdd := []Column{{name: "foo", KindDetails: typing.String, defaultValue: "bar"}, {name: "bar", KindDetails: typing.Struct}} // Insert @@ -358,14 +356,14 @@ func TestColumns_Mutation(t *testing.T) { cols.AddColumn(colToAdd) } - assert.Equal(t, len(cols.GetColumns()), 2) + assert.Equal(c.T(), len(cols.GetColumns()), 2) fooCol, isOk := cols.GetColumn("foo") - assert.True(t, isOk) - assert.Equal(t, typing.String, fooCol.KindDetails) + assert.True(c.T(), isOk) + assert.Equal(c.T(), typing.String, fooCol.KindDetails) barCol, isOk := cols.GetColumn("bar") - assert.True(t, isOk) - assert.Equal(t, typing.Struct, barCol.KindDetails) + assert.True(c.T(), isOk) + assert.Equal(c.T(), typing.Struct, barCol.KindDetails) // Update cols.UpdateColumn(Column{ @@ -380,23 +378,23 @@ func TestColumns_Mutation(t *testing.T) { }) fooCol, isOk = cols.GetColumn("foo") - assert.True(t, isOk) - assert.Equal(t, typing.Integer, fooCol.KindDetails) - assert.Equal(t, nil, fooCol.defaultValue) + assert.True(c.T(), isOk) + assert.Equal(c.T(), typing.Integer, fooCol.KindDetails) + assert.Equal(c.T(), nil, fooCol.defaultValue) barCol, isOk = cols.GetColumn("bar") - assert.True(t, isOk) - assert.Equal(t, typing.Boolean, barCol.KindDetails) - assert.Equal(t, "123", barCol.defaultValue) + assert.True(c.T(), isOk) + assert.Equal(c.T(), typing.Boolean, barCol.KindDetails) + assert.Equal(c.T(), "123", barCol.defaultValue) // Delete cols.DeleteColumn("foo") - assert.Equal(t, len(cols.GetColumns()), 1) + assert.Equal(c.T(), len(cols.GetColumns()), 1) cols.DeleteColumn("bar") - assert.Equal(t, len(cols.GetColumns()), 0) + assert.Equal(c.T(), len(cols.GetColumns()), 0) } -func TestColumnsUpdateQuery(t *testing.T) { +func (c *ColumnsTestSuite) TestColumnsUpdateQuery() { type testCase struct { name string columns []string @@ -522,7 +520,7 @@ func TestColumnsUpdateQuery(t *testing.T) { } for _, _testCase := range testCases { - actualQuery := ColumnsUpdateQuery(_testCase.columns, _testCase.columnsToTypes, _testCase.destKind) - assert.Equal(t, _testCase.expectedString, actualQuery, _testCase.name) + actualQuery := ColumnsUpdateQuery(c.ctx, _testCase.columns, _testCase.columnsToTypes, _testCase.destKind) + assert.Equal(c.T(), _testCase.expectedString, actualQuery, _testCase.name) } } diff --git a/lib/typing/columns/diff.go b/lib/typing/columns/diff.go index da90aae50..06f8699a8 100644 --- a/lib/typing/columns/diff.go +++ b/lib/typing/columns/diff.go @@ -1,6 +1,7 @@ package columns import ( + "context" "strings" "github.com/artie-labs/transfer/lib/config/constants" @@ -27,12 +28,12 @@ func shouldSkipColumn(colName string, softDelete bool, includeArtieUpdatedAt boo // Diff - when given 2 maps, a source and target // It will provide a diff in the form of 2 variables -func Diff(columnsInSource *Columns, columnsInDestination *Columns, softDelete bool, includeArtieUpdatedAt bool) ([]Column, []Column) { +func Diff(ctx context.Context, columnsInSource *Columns, columnsInDestination *Columns, softDelete bool, includeArtieUpdatedAt bool) ([]Column, []Column) { src := CloneColumns(columnsInSource) targ := CloneColumns(columnsInDestination) var colsToDelete []Column for _, col := range src.GetColumns() { - _, isOk := targ.GetColumn(col.Name(nil)) + _, isOk := targ.GetColumn(col.Name(ctx, nil)) if isOk { colsToDelete = append(colsToDelete, col) @@ -41,13 +42,13 @@ func Diff(columnsInSource *Columns, columnsInDestination *Columns, softDelete bo // We cannot delete inside a for-loop that is iterating over src.GetColumns() because we are messing up the array order. for _, colToDelete := range colsToDelete { - src.DeleteColumn(colToDelete.Name(nil)) - targ.DeleteColumn(colToDelete.Name(nil)) + src.DeleteColumn(colToDelete.Name(ctx, nil)) + targ.DeleteColumn(colToDelete.Name(ctx, nil)) } var targetColumnsMissing Columns for _, col := range src.GetColumns() { - if shouldSkipColumn(col.Name(nil), softDelete, includeArtieUpdatedAt) { + if shouldSkipColumn(col.Name(ctx, nil), softDelete, includeArtieUpdatedAt) { continue } @@ -56,7 +57,7 @@ func Diff(columnsInSource *Columns, columnsInDestination *Columns, softDelete bo var sourceColumnsMissing Columns for _, col := range targ.GetColumns() { - if shouldSkipColumn(col.Name(nil), softDelete, includeArtieUpdatedAt) { + if shouldSkipColumn(col.Name(ctx, nil), softDelete, includeArtieUpdatedAt) { continue } diff --git a/lib/typing/columns/diff_test.go b/lib/typing/columns/diff_test.go index 62499fd06..468242861 100644 --- a/lib/typing/columns/diff_test.go +++ b/lib/typing/columns/diff_test.go @@ -1,8 +1,6 @@ package columns import ( - "testing" - "github.com/artie-labs/transfer/lib/config/constants" "github.com/artie-labs/transfer/lib/typing" @@ -11,7 +9,7 @@ import ( "github.com/stretchr/testify/assert" ) -func TestShouldSkipColumn(t *testing.T) { +func (c *ColumnsTestSuite) TestShouldSkipColumn() { type _testCase struct { name string colName string @@ -55,11 +53,11 @@ func TestShouldSkipColumn(t *testing.T) { for _, testCase := range testCases { actualResult := shouldSkipColumn(testCase.colName, testCase.softDelete, testCase.includeArtieUpdatedAt) - assert.Equal(t, testCase.expectedResult, actualResult, testCase.name) + assert.Equal(c.T(), testCase.expectedResult, actualResult, testCase.name) } } -func TestDiff_VariousNils(t *testing.T) { +func (c *ColumnsTestSuite) TestDiff_VariousNils() { type _testCase struct { name string sourceCols *Columns @@ -111,22 +109,22 @@ func TestDiff_VariousNils(t *testing.T) { } for _, testCase := range testCases { - actualSrcKeysMissing, actualTargKeysMissing := Diff(testCase.sourceCols, testCase.targCols, false, false) - assert.Equal(t, testCase.expectedSrcKeyLength, len(actualSrcKeysMissing), testCase.name) - assert.Equal(t, testCase.expectedTargKeyLength, len(actualTargKeysMissing), testCase.name) + actualSrcKeysMissing, actualTargKeysMissing := Diff(c.ctx, testCase.sourceCols, testCase.targCols, false, false) + assert.Equal(c.T(), testCase.expectedSrcKeyLength, len(actualSrcKeysMissing), testCase.name) + assert.Equal(c.T(), testCase.expectedTargKeyLength, len(actualTargKeysMissing), testCase.name) } } -func TestDiffBasic(t *testing.T) { +func (c *ColumnsTestSuite) TestDiffBasic() { var source Columns source.AddColumn(NewColumn("a", typing.Integer)) - srcKeyMissing, targKeyMissing := Diff(&source, &source, false, false) - assert.Equal(t, len(srcKeyMissing), 0) - assert.Equal(t, len(targKeyMissing), 0) + srcKeyMissing, targKeyMissing := Diff(c.ctx, &source, &source, false, false) + assert.Equal(c.T(), len(srcKeyMissing), 0) + assert.Equal(c.T(), len(targKeyMissing), 0) } -func TestDiffDelta1(t *testing.T) { +func (c *ColumnsTestSuite) TestDiffDelta1() { var sourceCols Columns var targCols Columns for colName, kindDetails := range map[string]typing.KindDetails{ @@ -145,12 +143,12 @@ func TestDiffDelta1(t *testing.T) { targCols.AddColumn(NewColumn(colName, kindDetails)) } - srcKeyMissing, targKeyMissing := Diff(&sourceCols, &targCols, false, false) - assert.Equal(t, len(srcKeyMissing), 2, srcKeyMissing) // Missing aa, cc - assert.Equal(t, len(targKeyMissing), 2, targKeyMissing) // Missing aa, cc + srcKeyMissing, targKeyMissing := Diff(c.ctx, &sourceCols, &targCols, false, false) + assert.Equal(c.T(), len(srcKeyMissing), 2, srcKeyMissing) // Missing aa, cc + assert.Equal(c.T(), len(targKeyMissing), 2, targKeyMissing) // Missing aa, cc } -func TestDiffDelta2(t *testing.T) { +func (c *ColumnsTestSuite) TestDiffDelta2() { var sourceCols Columns var targetCols Columns @@ -177,12 +175,12 @@ func TestDiffDelta2(t *testing.T) { targetCols.AddColumn(NewColumn(colName, kindDetails)) } - srcKeyMissing, targKeyMissing := Diff(&sourceCols, &targetCols, false, false) - assert.Equal(t, len(srcKeyMissing), 1, srcKeyMissing) // Missing dd - assert.Equal(t, len(targKeyMissing), 3, targKeyMissing) // Missing a, c, d + srcKeyMissing, targKeyMissing := Diff(c.ctx, &sourceCols, &targetCols, false, false) + assert.Equal(c.T(), len(srcKeyMissing), 1, srcKeyMissing) // Missing dd + assert.Equal(c.T(), len(targKeyMissing), 3, targKeyMissing) // Missing a, c, d } -func TestDiffDeterministic(t *testing.T) { +func (c *ColumnsTestSuite) TestDiffDeterministic() { retMap := map[string]bool{} var sourceCols Columns @@ -192,35 +190,35 @@ func TestDiffDeterministic(t *testing.T) { sourceCols.AddColumn(NewColumn("name", typing.String)) for i := 0; i < 500; i++ { - keysMissing, targetKeysMissing := Diff(&sourceCols, &targCols, false, false) - assert.Equal(t, 0, len(keysMissing), keysMissing) + keysMissing, targetKeysMissing := Diff(c.ctx, &sourceCols, &targCols, false, false) + assert.Equal(c.T(), 0, len(keysMissing), keysMissing) var key string for _, targetKeyMissing := range targetKeysMissing { - key += targetKeyMissing.Name(nil) + key += targetKeyMissing.Name(c.ctx, nil) } retMap[key] = false } - assert.Equal(t, 1, len(retMap), retMap) + assert.Equal(c.T(), 1, len(retMap), retMap) } -func TestCopyColMap(t *testing.T) { +func (c *ColumnsTestSuite) TestCopyColMap() { var cols Columns cols.AddColumn(NewColumn("hello", typing.String)) cols.AddColumn(NewColumn("created_at", typing.NewKindDetailsFromTemplate(typing.ETime, ext.DateTimeKindType))) cols.AddColumn(NewColumn("updated_at", typing.NewKindDetailsFromTemplate(typing.ETime, ext.DateTimeKindType))) copiedCols := CloneColumns(&cols) - assert.Equal(t, *copiedCols, cols) + assert.Equal(c.T(), *copiedCols, cols) //Delete a row from copiedCols copiedCols.columns = append(copiedCols.columns[1:]) - assert.NotEqual(t, *copiedCols, cols) + assert.NotEqual(c.T(), *copiedCols, cols) } -func TestCloneColumns(t *testing.T) { +func (c *ColumnsTestSuite) TestCloneColumns() { type _testCase struct { name string cols *Columns @@ -263,6 +261,6 @@ func TestCloneColumns(t *testing.T) { for _, testCase := range testCases { actualCols := CloneColumns(testCase.cols) - assert.Equal(t, *testCase.expectedCols, *actualCols, testCase.name) + assert.Equal(c.T(), *testCase.expectedCols, *actualCols, testCase.name) } } diff --git a/lib/typing/columns/wrapper.go b/lib/typing/columns/wrapper.go index 6bdaeb652..9c939442f 100644 --- a/lib/typing/columns/wrapper.go +++ b/lib/typing/columns/wrapper.go @@ -1,14 +1,20 @@ package columns +import ( + "context" + + "github.com/artie-labs/transfer/lib/sql" +) + type Wrapper struct { name string escapedName string } -func NewWrapper(col Column, args *NameArgs) Wrapper { +func NewWrapper(ctx context.Context, col Column, args *sql.NameArgs) Wrapper { return Wrapper{ name: col.name, - escapedName: col.Name(args), + escapedName: col.Name(ctx, args), } } diff --git a/lib/typing/columns/wrapper_test.go b/lib/typing/columns/wrapper_test.go index 3cb6babfc..4a0b49e43 100644 --- a/lib/typing/columns/wrapper_test.go +++ b/lib/typing/columns/wrapper_test.go @@ -1,7 +1,7 @@ package columns import ( - "testing" + "github.com/artie-labs/transfer/lib/sql" "github.com/stretchr/testify/assert" @@ -10,7 +10,7 @@ import ( "github.com/artie-labs/transfer/lib/typing" ) -func TestWrapper_Complete(t *testing.T) { +func (c *ColumnsTestSuite) TestWrapper_Complete() { type _testCase struct { name string expectedRawName string @@ -41,38 +41,38 @@ func TestWrapper_Complete(t *testing.T) { for _, testCase := range testCases { // Snowflake escape - w := NewWrapper(NewColumn(testCase.name, typing.Invalid), &NameArgs{ + w := NewWrapper(c.ctx, NewColumn(testCase.name, typing.Invalid), &sql.NameArgs{ Escape: true, DestKind: constants.Snowflake, }) - assert.Equal(t, testCase.expectedEscapedName, w.EscapedName(), testCase.name) - assert.Equal(t, testCase.expectedRawName, w.RawName(), testCase.name) + assert.Equal(c.T(), testCase.expectedEscapedName, w.EscapedName(), testCase.name) + assert.Equal(c.T(), testCase.expectedRawName, w.RawName(), testCase.name) // BigQuery escape - w = NewWrapper(NewColumn(testCase.name, typing.Invalid), &NameArgs{ + w = NewWrapper(c.ctx, NewColumn(testCase.name, typing.Invalid), &sql.NameArgs{ Escape: true, DestKind: constants.BigQuery, }) - assert.Equal(t, testCase.expectedEscapedNameBQ, w.EscapedName(), testCase.name) - assert.Equal(t, testCase.expectedRawName, w.RawName(), testCase.name) + assert.Equal(c.T(), testCase.expectedEscapedNameBQ, w.EscapedName(), testCase.name) + assert.Equal(c.T(), testCase.expectedRawName, w.RawName(), testCase.name) for _, destKind := range []constants.DestinationKind{constants.Snowflake, constants.SnowflakeStages, constants.BigQuery} { - w = NewWrapper(NewColumn(testCase.name, typing.Invalid), &NameArgs{ + w = NewWrapper(c.ctx, NewColumn(testCase.name, typing.Invalid), &sql.NameArgs{ Escape: false, DestKind: destKind, }) - assert.Equal(t, testCase.expectedRawName, w.EscapedName(), testCase.name) - assert.Equal(t, testCase.expectedRawName, w.RawName(), testCase.name) + assert.Equal(c.T(), testCase.expectedRawName, w.EscapedName(), testCase.name) + assert.Equal(c.T(), testCase.expectedRawName, w.RawName(), testCase.name) } // Same if nil - w = NewWrapper(NewColumn(testCase.name, typing.Invalid), nil) + w = NewWrapper(c.ctx, NewColumn(testCase.name, typing.Invalid), nil) - assert.Equal(t, testCase.expectedRawName, w.EscapedName(), testCase.name) - assert.Equal(t, testCase.expectedRawName, w.RawName(), testCase.name) + assert.Equal(c.T(), testCase.expectedRawName, w.EscapedName(), testCase.name) + assert.Equal(c.T(), testCase.expectedRawName, w.RawName(), testCase.name) } } diff --git a/models/event/event_save_test.go b/models/event/event_save_test.go index e53082859..f0bbfddb6 100644 --- a/models/event/event_save_test.go +++ b/models/event/event_save_test.go @@ -49,7 +49,7 @@ func (e *EventsTestSuite) TestSaveEvent() { // Check the in-memory DB columns. var found int for _, col := range optimization.ReadOnlyInMemoryCols().GetColumns() { - if col.Name(nil) == expectedLowerCol || col.Name(nil) == anotherLowerCol { + if col.Name(e.ctx, nil) == expectedLowerCol || col.Name(e.ctx, nil) == anotherLowerCol { found += 1 } @@ -178,16 +178,16 @@ func (e *EventsTestSuite) TestEvent_SaveColumnsNoData() { td := models.GetMemoryDB(e.ctx).GetOrCreateTableData("non_existent") var prevKey string for _, col := range td.ReadOnlyInMemoryCols().GetColumns() { - if col.Name(nil) == constants.DeleteColumnMarker { + if col.Name(e.ctx, nil) == constants.DeleteColumnMarker { continue } if prevKey == "" { - prevKey = col.Name(nil) + prevKey = col.Name(e.ctx, nil) continue } - currentKeyParsed, err := strconv.Atoi(col.Name(nil)) + currentKeyParsed, err := strconv.Atoi(col.Name(e.ctx, nil)) assert.NoError(e.T(), err) prevKeyParsed, err := strconv.Atoi(prevKey) @@ -201,7 +201,7 @@ func (e *EventsTestSuite) TestEvent_SaveColumnsNoData() { evt.Columns.AddColumn(columns.NewColumn("foo", typing.Invalid)) var index int for idx, col := range evt.Columns.GetColumns() { - if col.Name(nil) == "foo" { + if col.Name(e.ctx, nil) == "foo" { index = idx } }