Skip to content

Commit

Permalink
[Snowflake] Reestablish idle sessions (#57)
Browse files Browse the repository at this point in the history
  • Loading branch information
Tang8330 authored Mar 2, 2023
1 parent dc7ba0f commit a3c5ee5
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 15 deletions.
8 changes: 8 additions & 0 deletions clients/snowflake/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,11 @@ func TableDoesNotExistErr(err error) bool {

return strings.Contains(err.Error(), "does not exist or not authorized")
}

func AuthenticationExpirationErr(err error) bool {
if err == nil {
return false
}

return strings.Contains(err.Error(), "Authentication token has expired")
}
4 changes: 4 additions & 0 deletions clients/snowflake/errors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,7 @@ func TestTableDoesNotExistErr(t *testing.T) {
assert.Equal(t, TableDoesNotExistErr(err), expectation, err)
}
}

func TestAuthenticationExpirationErr(t *testing.T) {
assert.Equal(t, true, AuthenticationExpirationErr(fmt.Errorf("390114: Authentication token has expired. The user must authenticate again.")))
}
41 changes: 26 additions & 15 deletions clients/snowflake/snowflake.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,31 @@ func (s *Store) Merge(ctx context.Context, tableData *optimization.TableData) er

log.WithField("query", query).Debug("executing...")
_, err = s.Exec(query)
if AuthenticationExpirationErr(err) {
log.WithError(err).Warn("authentication has expired, will reload the Snowflake store")
s.ReestablishConnection(ctx)
}

return err
}

func (s *Store) ReestablishConnection(ctx context.Context) {
dsn, err := gosnowflake.DSN(&gosnowflake.Config{
Account: config.GetSettings().Config.Snowflake.AccountID,
User: config.GetSettings().Config.Snowflake.Username,
Password: config.GetSettings().Config.Snowflake.Password,
Warehouse: config.GetSettings().Config.Snowflake.Warehouse,
Region: config.GetSettings().Config.Snowflake.Region,
})

if err != nil {
logger.FromContext(ctx).Fatalf("failed to get snowflake dsn, err: %v", err)
}

s.Store = db.Open(ctx, "snowflake", dsn)
return
}

func LoadSnowflake(ctx context.Context, _store *db.Store) *Store {
if _store != nil {
// Used for tests.
Expand All @@ -108,21 +130,10 @@ func LoadSnowflake(ctx context.Context, _store *db.Store) *Store {
}
}

dsn, err := gosnowflake.DSN(&gosnowflake.Config{
Account: config.GetSettings().Config.Snowflake.AccountID,
User: config.GetSettings().Config.Snowflake.Username,
Password: config.GetSettings().Config.Snowflake.Password,
Warehouse: config.GetSettings().Config.Snowflake.Warehouse,
Region: config.GetSettings().Config.Snowflake.Region,
KeepSessionAlive: true,
})

if err != nil {
logger.FromContext(ctx).Fatalf("failed to get snowflake dsn, err: %v", err)
}

return &Store{
Store: db.Open(ctx, "snowflake", dsn),
s := &Store{
configMap: &types.DwhToTablesConfigMap{},
}

s.ReestablishConnection(ctx)
return s
}

0 comments on commit a3c5ee5

Please sign in to comment.