Skip to content

Commit

Permalink
[Snowflake] Fixing all inactive sessions (#65)
Browse files Browse the repository at this point in the history
  • Loading branch information
Tang8330 authored Mar 9, 2023
1 parent af06b77 commit 3fa3256
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 17 deletions.
2 changes: 1 addition & 1 deletion clients/snowflake/merge.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (
"github.com/artie-labs/transfer/lib/typing/ext"
)

func merge(tableData *optimization.TableData) (string, error) {
func getMergeStatement(tableData *optimization.TableData) (string, error) {
var tableValues []string
var cols []string
var sflkCols []string
Expand Down
18 changes: 9 additions & 9 deletions clients/snowflake/merge_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ func (s *SnowflakeTestSuite) TestMergeNoDeleteFlag() {
LatestCDCTs: time.Time{},
}

_, err := merge(tableData)
assert.Error(s.T(), err, "merge failed")
_, err := getMergeStatement(tableData)
assert.Error(s.T(), err, "getMergeStatement failed")

}

Expand Down Expand Up @@ -62,8 +62,8 @@ func (s *SnowflakeTestSuite) TestMerge() {
LatestCDCTs: time.Time{},
}

mergeSQL, err := merge(tableData)
assert.NoError(s.T(), err, "merge failed")
mergeSQL, err := getMergeStatement(tableData)
assert.NoError(s.T(), err, "getMergeStatement failed")
assert.Contains(s.T(), mergeSQL, "robin")
assert.Contains(s.T(), mergeSQL, "false")
assert.Contains(s.T(), mergeSQL, "1")
Expand Down Expand Up @@ -114,8 +114,8 @@ func (s *SnowflakeTestSuite) TestMergeWithSingleQuote() {
LatestCDCTs: time.Time{},
}

mergeSQL, err := merge(tableData)
assert.NoError(s.T(), err, "merge failed")
mergeSQL, err := getMergeStatement(tableData)
assert.NoError(s.T(), err, "getMergeStatement failed")
assert.Contains(s.T(), mergeSQL, `I can\'t fail`)
}

Expand Down Expand Up @@ -147,7 +147,7 @@ func (s *SnowflakeTestSuite) TestMergeJson() {
LatestCDCTs: time.Time{},
}

mergeSQL, err := merge(tableData)
assert.NoError(s.T(), err, "merge failed")
mergeSQL, err := getMergeStatement(tableData)
assert.NoError(s.T(), err, "getMergeStatement failed")
assert.Contains(s.T(), mergeSQL, `"label": "2\\" pipe"`)
}
}
26 changes: 19 additions & 7 deletions clients/snowflake/snowflake.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (

type Store struct {
db.Store
testDB bool // Used for testing
configMap *types.DwhToTablesConfigMap
}

Expand All @@ -38,6 +39,16 @@ func (s *Store) GetConfigMap() *types.DwhToTablesConfigMap {
}

func (s *Store) Merge(ctx context.Context, tableData *optimization.TableData) error {
err := s.merge(ctx, tableData)
if AuthenticationExpirationErr(err) {
logger.FromContext(ctx).WithError(err).Warn("authentication has expired, will reload the Snowflake store")
s.ReestablishConnection(ctx)
}

return err
}

func (s *Store) merge(ctx context.Context, tableData *optimization.TableData) error {
if tableData.Rows == 0 {
// There's no rows. Let's skip.
return nil
Expand Down Expand Up @@ -88,23 +99,23 @@ func (s *Store) Merge(ctx context.Context, tableData *optimization.TableData) er
}

tableData.UpdateInMemoryColumns(tableConfig.Columns())
query, err := merge(tableData)
query, err := getMergeStatement(tableData)
if err != nil {
log.WithError(err).Warn("failed to generate the merge query")
log.WithError(err).Warn("failed to generate the getMergeStatement query")
return err
}

log.WithField("query", query).Debug("executing...")
_, err = s.Exec(query)
if AuthenticationExpirationErr(err) {
log.WithError(err).Warn("authentication has expired, will reload the Snowflake store")
s.ReestablishConnection(ctx)
}

return err
}

func (s *Store) ReestablishConnection(ctx context.Context) {
if s.testDB {
// Don't actually re-establish for tests.
return
}

cfg := &gosnowflake.Config{
Account: config.GetSettings().Config.Snowflake.AccountID,
User: config.GetSettings().Config.Snowflake.Username,
Expand Down Expand Up @@ -132,6 +143,7 @@ func LoadSnowflake(ctx context.Context, _store *db.Store) *Store {
if _store != nil {
// Used for tests.
return &Store{
testDB: true,
Store: *_store,
configMap: &types.DwhToTablesConfigMap{},
}
Expand Down
46 changes: 46 additions & 0 deletions clients/snowflake/snowflake_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,52 @@ func (s *SnowflakeTestSuite) TestExecuteMergeNilEdgeCase() {
assert.NoError(s.T(), err)
}

func (s *SnowflakeTestSuite) TestExecuteMergeReestablishAuth() {
columns := map[string]typing.KindDetails{
"id": typing.Integer,
"name": typing.String,
constants.DeleteColumnMarker: typing.Boolean,
// Add kindDetails to created_at
"created_at": typing.ParseValue(time.Now().Format(time.RFC3339Nano)),
}

rowsData := make(map[string]map[string]interface{})

for i := 0; i < 5; i++ {
rowsData[fmt.Sprintf("pk-%d", i)] = map[string]interface{}{
"id": fmt.Sprintf("pk-%d", i),
"created_at": time.Now().Format(time.RFC3339Nano),
"name": fmt.Sprintf("Robin-%d", i),
}
}

topicConfig := kafkalib.TopicConfig{
Database: "customer",
TableName: "orders",
Schema: "public",
}

tableData := &optimization.TableData{
InMemoryColumns: columns,
RowsData: rowsData,
TopicConfig: topicConfig,
PrimaryKey: "id",
Rows: 1,
}

s.store.configMap.AddTableToConfig(topicConfig.ToFqName(constants.Snowflake),
types.NewDwhTableConfig(columns, nil, false, true))

s.fakeStore.ExecReturnsOnCall(0, nil, fmt.Errorf("390114: Authentication token has expired. The user must authenticate again."))
err := s.store.Merge(context.Background(), tableData)
assert.True(s.T(), AuthenticationExpirationErr(err), err)

s.fakeStore.ExecReturnsOnCall(1, nil, nil)
assert.Nil(s.T(), s.store.Merge(context.Background(), tableData))
s.fakeStore.ExecReturns(nil, nil)
assert.Equal(s.T(), s.fakeStore.ExecCallCount(), 2, "called merge")
}

func (s *SnowflakeTestSuite) TestExecuteMerge() {
columns := map[string]typing.KindDetails{
"id": typing.Integer,
Expand Down

0 comments on commit 3fa3256

Please sign in to comment.