diff --git a/clients/snowflake/snowflake.go b/clients/snowflake/snowflake.go index 58403df1e..1f6c79087 100644 --- a/clients/snowflake/snowflake.go +++ b/clients/snowflake/snowflake.go @@ -58,8 +58,9 @@ func (s *Store) GetConfigMap() *types.DwhToTablesConfigMap { func (s *Store) Merge(ctx context.Context, tableData *optimization.TableData) error { err := s.mergeWithStages(ctx, tableData) if AuthenticationExpirationErr(err) { - logger.FromContext(ctx).WithError(err).Warn("authentication has expired, will reload the Snowflake store") + logger.FromContext(ctx).WithError(err).Warn("authentication has expired, will reload the Snowflake store and retry merging") s.ReestablishConnection(ctx) + return s.Merge(ctx, tableData) } return err diff --git a/clients/snowflake/snowflake_test.go b/clients/snowflake/snowflake_test.go index 3c04933a0..37874528d 100644 --- a/clients/snowflake/snowflake_test.go +++ b/clients/snowflake/snowflake_test.go @@ -114,14 +114,11 @@ func (s *SnowflakeTestSuite) TestExecuteMergeReestablishAuth() { s.fakeStageStore.ExecReturnsOnCall(0, nil, fmt.Errorf("390114: Authentication token has expired. The user must authenticate again.")) err := s.stageStore.Merge(s.ctx, tableData) - assert.True(s.T(), AuthenticationExpirationErr(err), err) - - s.fakeStageStore.ExecReturnsOnCall(1, nil, nil) - assert.Nil(s.T(), s.stageStore.Merge(s.ctx, tableData)) - s.fakeStageStore.ExecReturns(nil, nil) + assert.NoError(s.T(), err, "transient errors like auth errors will be retried") // 5 regular ones and then 1 additional one to re-establish auth. - assert.Equal(s.T(), s.fakeStageStore.ExecCallCount(), 6, "called merge") + baseline := 5 + assert.Equal(s.T(), s.fakeStageStore.ExecCallCount(), baseline+1, "called merge") } func (s *SnowflakeTestSuite) TestExecuteMerge() {