diff --git a/clients/snowflake/merge.go b/clients/snowflake/merge.go index abff0567c..dd9b50b94 100644 --- a/clients/snowflake/merge.go +++ b/clients/snowflake/merge.go @@ -13,7 +13,7 @@ import ( "github.com/artie-labs/transfer/lib/typing/ext" ) -func merge(tableData *optimization.TableData) (string, error) { +func getMergeStatement(tableData *optimization.TableData) (string, error) { var tableValues []string var cols []string var sflkCols []string diff --git a/clients/snowflake/merge_test.go b/clients/snowflake/merge_test.go index 1b7651606..cc0d82355 100644 --- a/clients/snowflake/merge_test.go +++ b/clients/snowflake/merge_test.go @@ -26,8 +26,8 @@ func (s *SnowflakeTestSuite) TestMergeNoDeleteFlag() { LatestCDCTs: time.Time{}, } - _, err := merge(tableData) - assert.Error(s.T(), err, "merge failed") + _, err := getMergeStatement(tableData) + assert.Error(s.T(), err, "getMergeStatement failed") } @@ -62,8 +62,8 @@ func (s *SnowflakeTestSuite) TestMerge() { LatestCDCTs: time.Time{}, } - mergeSQL, err := merge(tableData) - assert.NoError(s.T(), err, "merge failed") + mergeSQL, err := getMergeStatement(tableData) + assert.NoError(s.T(), err, "getMergeStatement failed") assert.Contains(s.T(), mergeSQL, "robin") assert.Contains(s.T(), mergeSQL, "false") assert.Contains(s.T(), mergeSQL, "1") @@ -114,8 +114,8 @@ func (s *SnowflakeTestSuite) TestMergeWithSingleQuote() { LatestCDCTs: time.Time{}, } - mergeSQL, err := merge(tableData) - assert.NoError(s.T(), err, "merge failed") + mergeSQL, err := getMergeStatement(tableData) + assert.NoError(s.T(), err, "getMergeStatement failed") assert.Contains(s.T(), mergeSQL, `I can\'t fail`) } @@ -147,7 +147,7 @@ func (s *SnowflakeTestSuite) TestMergeJson() { LatestCDCTs: time.Time{}, } - mergeSQL, err := merge(tableData) - assert.NoError(s.T(), err, "merge failed") + mergeSQL, err := getMergeStatement(tableData) + assert.NoError(s.T(), err, "getMergeStatement failed") assert.Contains(s.T(), mergeSQL, `"label": "2\\" pipe"`) -} \ No newline at end of file +} diff --git a/clients/snowflake/snowflake.go b/clients/snowflake/snowflake.go index e658331b3..b7d8cf24a 100644 --- a/clients/snowflake/snowflake.go +++ b/clients/snowflake/snowflake.go @@ -16,6 +16,7 @@ import ( type Store struct { db.Store + testDB bool // Used for testing configMap *types.DwhToTablesConfigMap } @@ -38,6 +39,16 @@ func (s *Store) GetConfigMap() *types.DwhToTablesConfigMap { } func (s *Store) Merge(ctx context.Context, tableData *optimization.TableData) error { + err := s.merge(ctx, tableData) + if AuthenticationExpirationErr(err) { + logger.FromContext(ctx).WithError(err).Warn("authentication has expired, will reload the Snowflake store") + s.ReestablishConnection(ctx) + } + + return err +} + +func (s *Store) merge(ctx context.Context, tableData *optimization.TableData) error { if tableData.Rows == 0 { // There's no rows. Let's skip. return nil @@ -88,23 +99,23 @@ func (s *Store) Merge(ctx context.Context, tableData *optimization.TableData) er } tableData.UpdateInMemoryColumns(tableConfig.Columns()) - query, err := merge(tableData) + query, err := getMergeStatement(tableData) if err != nil { - log.WithError(err).Warn("failed to generate the merge query") + log.WithError(err).Warn("failed to generate the getMergeStatement query") return err } log.WithField("query", query).Debug("executing...") _, err = s.Exec(query) - if AuthenticationExpirationErr(err) { - log.WithError(err).Warn("authentication has expired, will reload the Snowflake store") - s.ReestablishConnection(ctx) - } - return err } func (s *Store) ReestablishConnection(ctx context.Context) { + if s.testDB { + // Don't actually re-establish for tests. + return + } + cfg := &gosnowflake.Config{ Account: config.GetSettings().Config.Snowflake.AccountID, User: config.GetSettings().Config.Snowflake.Username, @@ -132,6 +143,7 @@ func LoadSnowflake(ctx context.Context, _store *db.Store) *Store { if _store != nil { // Used for tests. return &Store{ + testDB: true, Store: *_store, configMap: &types.DwhToTablesConfigMap{}, } diff --git a/clients/snowflake/snowflake_test.go b/clients/snowflake/snowflake_test.go index e44cf8e25..74eeaaa42 100644 --- a/clients/snowflake/snowflake_test.go +++ b/clients/snowflake/snowflake_test.go @@ -58,6 +58,52 @@ func (s *SnowflakeTestSuite) TestExecuteMergeNilEdgeCase() { assert.NoError(s.T(), err) } +func (s *SnowflakeTestSuite) TestExecuteMergeReestablishAuth() { + columns := map[string]typing.KindDetails{ + "id": typing.Integer, + "name": typing.String, + constants.DeleteColumnMarker: typing.Boolean, + // Add kindDetails to created_at + "created_at": typing.ParseValue(time.Now().Format(time.RFC3339Nano)), + } + + rowsData := make(map[string]map[string]interface{}) + + for i := 0; i < 5; i++ { + rowsData[fmt.Sprintf("pk-%d", i)] = map[string]interface{}{ + "id": fmt.Sprintf("pk-%d", i), + "created_at": time.Now().Format(time.RFC3339Nano), + "name": fmt.Sprintf("Robin-%d", i), + } + } + + topicConfig := kafkalib.TopicConfig{ + Database: "customer", + TableName: "orders", + Schema: "public", + } + + tableData := &optimization.TableData{ + InMemoryColumns: columns, + RowsData: rowsData, + TopicConfig: topicConfig, + PrimaryKey: "id", + Rows: 1, + } + + s.store.configMap.AddTableToConfig(topicConfig.ToFqName(constants.Snowflake), + types.NewDwhTableConfig(columns, nil, false, true)) + + s.fakeStore.ExecReturnsOnCall(0, nil, fmt.Errorf("390114: Authentication token has expired. The user must authenticate again.")) + err := s.store.Merge(context.Background(), tableData) + assert.True(s.T(), AuthenticationExpirationErr(err), err) + + s.fakeStore.ExecReturnsOnCall(1, nil, nil) + assert.Nil(s.T(), s.store.Merge(context.Background(), tableData)) + s.fakeStore.ExecReturns(nil, nil) + assert.Equal(s.T(), s.fakeStore.ExecCallCount(), 2, "called merge") +} + func (s *SnowflakeTestSuite) TestExecuteMerge() { columns := map[string]typing.KindDetails{ "id": typing.Integer,