Skip to content

Commit

Permalink
[Snowflake] Address idle action item follow ups (#907)
Browse files Browse the repository at this point in the history
  • Loading branch information
Tang8330 authored Sep 17, 2024
1 parent 46eab86 commit 2440919
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 132 deletions.
11 changes: 0 additions & 11 deletions clients/snowflake/errors.go

This file was deleted.

34 changes: 0 additions & 34 deletions clients/snowflake/errors_test.go

This file was deleted.

52 changes: 17 additions & 35 deletions clients/snowflake/snowflake.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,8 @@ import (
"github.com/artie-labs/transfer/lib/typing"
)

const maxRetries = 10

type Store struct {
db.Store
testDB bool // Used for testing
configMap *types.DwhToTablesConfigMap
config config.Config
}
Expand Down Expand Up @@ -70,30 +67,6 @@ func (s *Store) GetConfigMap() *types.DwhToTablesConfigMap {
return s.configMap
}

func (s *Store) reestablishConnection() error {
if s.testDB {
// Don't actually re-establish for tests.
return nil
}

cfg, err := s.config.Snowflake.ToConfig()
if err != nil {
return fmt.Errorf("failed to get snowflake config: %w", err)
}

dsn, err := gosnowflake.DSN(cfg)
if err != nil {
return fmt.Errorf("failed to get Snowflake DSN: %w", err)
}

store, err := db.Open("snowflake", dsn)
if err != nil {
return err
}
s.Store = store
return nil
}

// Dedupe takes a table and will remove duplicates based on the primary key(s).
// These queries are inspired and modified from: https://stackoverflow.com/a/71515946
func (s *Store) Dedupe(tableID sql.TableIdentifier, primaryKeys []string, includeArtieUpdatedAt bool) error {
Expand All @@ -106,21 +79,30 @@ func LoadSnowflake(cfg config.Config, _store *db.Store) (*Store, error) {
if _store != nil {
// Used for tests.
return &Store{
testDB: true,
configMap: &types.DwhToTablesConfigMap{},
config: cfg,

Store: *_store,
Store: *_store,
}, nil
}

s := &Store{
configMap: &types.DwhToTablesConfigMap{},
config: cfg,
snowflakeCfg, err := cfg.Snowflake.ToConfig()
if err != nil {
return nil, fmt.Errorf("failed to get Snowflake config: %w", err)
}

dsn, err := gosnowflake.DSN(snowflakeCfg)
if err != nil {
return nil, fmt.Errorf("failed to get Snowflake DSN: %w", err)
}

if err := s.reestablishConnection(); err != nil {
store, err := db.Open("snowflake", dsn)
if err != nil {
return nil, err
}
return s, nil

return &Store{
configMap: &types.DwhToTablesConfigMap{},
config: cfg,
Store: store,
}, nil
}
9 changes: 2 additions & 7 deletions clients/snowflake/snowflake_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,13 +126,8 @@ func (s *SnowflakeTestSuite) TestExecuteMergeReestablishAuth() {

s.stageStore.configMap.AddTableToConfig(s.identifierFor(tableData), types.NewDwhTableConfig(&cols, nil, false, true))

s.fakeStageStore.ExecReturnsOnCall(0, nil, fmt.Errorf("390114: Authentication token has expired. The user must authenticate again."))
err := s.stageStore.Merge(tableData)
assert.NoError(s.T(), err, "transient errors like auth errors will be retried")

// 5 regular ones and then 1 additional one to re-establish auth and another one for dropping the temporary table
baseline := 5
assert.Equal(s.T(), baseline+2, s.fakeStageStore.ExecCallCount(), "called merge")
assert.NoError(s.T(), s.stageStore.Merge(tableData))
assert.Equal(s.T(), 5, s.fakeStageStore.ExecCallCount())
}

func (s *SnowflakeTestSuite) TestExecuteMerge() {
Expand Down
52 changes: 7 additions & 45 deletions clients/snowflake/writes.go
Original file line number Diff line number Diff line change
@@ -1,41 +1,20 @@
package snowflake

import (
"log/slog"

"github.com/artie-labs/transfer/clients/shared"
"github.com/artie-labs/transfer/lib/config/constants"
"github.com/artie-labs/transfer/lib/destination/types"
"github.com/artie-labs/transfer/lib/logger"
"github.com/artie-labs/transfer/lib/optimization"
"github.com/artie-labs/transfer/lib/sql"
"github.com/artie-labs/transfer/lib/typing"
"github.com/artie-labs/transfer/lib/typing/columns"
)

func (s *Store) Append(tableData *optimization.TableData, _ bool) error {
var err error
for i := 0; i < maxRetries; i++ {
if i > 0 {
// TODO: remove
if IsAuthExpiredError(err) {
slog.Warn("Authentication has expired, will reload the Snowflake store and retry appending", slog.Any("err", err))
if connErr := s.reestablishConnection(); connErr != nil {
// TODO: Remove this panic and return an error instead. Ensure the callers of [Append] handle this properly.
logger.Panic("Failed to reestablish connection", slog.Any("err", connErr))
}
} else {
break
}
}

// TODO: For history mode - in the future, we could also have a separate stage name for history mode so we can enable parallel processing.
err = shared.Append(s, tableData, types.AdditionalSettings{
AdditionalCopyClause: `FILE_FORMAT = (TYPE = 'csv' FIELD_DELIMITER= '\t' FIELD_OPTIONALLY_ENCLOSED_BY='"' NULL_IF='\\N' EMPTY_FIELD_AS_NULL=FALSE) PURGE = TRUE`,
})
}

return err
// TODO: For history mode - in the future, we could also have a separate stage name for history mode so we can enable parallel processing.
return shared.Append(s, tableData, types.AdditionalSettings{
AdditionalCopyClause: `FILE_FORMAT = (TYPE = 'csv' FIELD_DELIMITER= '\t' FIELD_OPTIONALLY_ENCLOSED_BY='"' NULL_IF='\\N' EMPTY_FIELD_AS_NULL=FALSE) PURGE = TRUE`,
})
}

func (s *Store) additionalEqualityStrings(tableData *optimization.TableData) []string {
Expand All @@ -47,24 +26,7 @@ func (s *Store) additionalEqualityStrings(tableData *optimization.TableData) []s
}

func (s *Store) Merge(tableData *optimization.TableData) error {
var err error
for i := 0; i < maxRetries; i++ {
if i > 0 {
// TODO: Remove
if IsAuthExpiredError(err) {
slog.Warn("Authentication has expired, will reload the Snowflake store and retry merging", slog.Any("err", err))
if connErr := s.reestablishConnection(); connErr != nil {
// TODO: Remove this panic and return an error instead. Ensure the callers of [Merge] handle this properly.
logger.Panic("Failed to reestablish connection", slog.Any("err", connErr))
}
} else {
break
}
}

err = shared.Merge(s, tableData, types.MergeOpts{
AdditionalEqualityStrings: s.additionalEqualityStrings(tableData),
})
}
return err
return shared.Merge(s, tableData, types.MergeOpts{
AdditionalEqualityStrings: s.additionalEqualityStrings(tableData),
})
}

0 comments on commit 2440919

Please sign in to comment.