Skip to content

Commit

Permalink
Thread context.Context for config (#75)
Browse files Browse the repository at this point in the history
  • Loading branch information
Tang8330 authored Apr 11, 2023
1 parent 54478c9 commit 953a9ab
Show file tree
Hide file tree
Showing 22 changed files with 249 additions and 194 deletions.
6 changes: 4 additions & 2 deletions clients/bigquery/bigquery.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ func LoadBigQuery(ctx context.Context, _store *db.Store) *Store {
}
}

if credPath := config.GetSettings().Config.BigQuery.PathToCredentials; credPath != "" {
settings := config.FromContext(ctx)

if credPath := settings.Config.BigQuery.PathToCredentials; credPath != "" {
// If the credPath is set, let's set it into the env var.
logger.FromContext(ctx).Debug("writing the path to BQ credentials to env var for google auth")
err := os.Setenv(GooglePathToCredentialsEnvKey, credPath)
Expand All @@ -53,7 +55,7 @@ func LoadBigQuery(ctx context.Context, _store *db.Store) *Store {

return &Store{
Store: db.Open(ctx, "bigquery", fmt.Sprintf("bigquery://%s/%s",
config.GetSettings().Config.BigQuery.ProjectID, config.GetSettings().Config.BigQuery.DefaultDataset)),
settings.Config.BigQuery.ProjectID, settings.Config.BigQuery.DefaultDataset)),
configMap: &types.DwhToTablesConfigMap{},
}
}
5 changes: 1 addition & 4 deletions clients/snowflake/ddl_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package snowflake

import (
"context"
"fmt"
"github.com/artie-labs/transfer/lib/config/constants"
"github.com/artie-labs/transfer/lib/dwh/types"
Expand Down Expand Up @@ -108,11 +107,9 @@ func (s *SnowflakeTestSuite) TestManipulateShouldDeleteColumn() {
func (s *SnowflakeTestSuite) TestGetTableConfig() {
// If the table does not exist, snowflakeTableConfig should say so.
fqName := "customers.public.orders22"
ctx := context.Background()

s.fakeStore.QueryReturns(nil, fmt.Errorf("Table '%s' does not exist or not authorized", fqName))

tableConfig, err := s.store.getTableConfig(ctx, fqName, false)
tableConfig, err := s.store.getTableConfig(s.ctx, fqName, false)
assert.NotNil(s.T(), tableConfig, "config is nil")
assert.NoError(s.T(), err)

Expand Down
16 changes: 9 additions & 7 deletions clients/snowflake/snowflake.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,17 +116,19 @@ func (s *Store) ReestablishConnection(ctx context.Context) {
return
}

settings := config.FromContext(ctx)

cfg := &gosnowflake.Config{
Account: config.GetSettings().Config.Snowflake.AccountID,
User: config.GetSettings().Config.Snowflake.Username,
Password: config.GetSettings().Config.Snowflake.Password,
Warehouse: config.GetSettings().Config.Snowflake.Warehouse,
Region: config.GetSettings().Config.Snowflake.Region,
Account: settings.Config.Snowflake.AccountID,
User: settings.Config.Snowflake.Username,
Password: settings.Config.Snowflake.Password,
Warehouse: settings.Config.Snowflake.Warehouse,
Region: settings.Config.Snowflake.Region,
}

if config.GetSettings().Config.Snowflake.Host != "" {
if settings.Config.Snowflake.Host != "" {
// If the host is specified
cfg.Host = config.GetSettings().Config.Snowflake.Host
cfg.Host = settings.Config.Snowflake.Host
cfg.Region = ""
}

Expand Down
9 changes: 7 additions & 2 deletions clients/snowflake/snowflake_suite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package snowflake

import (
"context"
"github.com/artie-labs/transfer/lib/config"
"testing"

"github.com/stretchr/testify/suite"
Expand All @@ -14,14 +15,18 @@ type SnowflakeTestSuite struct {
suite.Suite
fakeStore *mocks.FakeStore
store *Store
ctx context.Context
}

func (s *SnowflakeTestSuite) SetupTest() {
ctx := context.Background()
s.ctx = config.InjectSettingsIntoContext(context.Background(), &config.Settings{
VerboseLogging: false,
})

s.fakeStore = &mocks.FakeStore{}
store := db.Store(s.fakeStore)
s.store = LoadSnowflake(ctx, &store)
s.store = LoadSnowflake(s.ctx, &store)

}

func TestSnowflakeTestSuite(t *testing.T) {
Expand Down
15 changes: 7 additions & 8 deletions clients/snowflake/snowflake_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package snowflake

import (
"context"
"fmt"
"time"

Expand Down Expand Up @@ -53,7 +52,7 @@ func (s *SnowflakeTestSuite) TestExecuteMergeNilEdgeCase() {
constants.DeleteColumnMarker: typing.Boolean,
}, nil, false, true))

err := s.store.Merge(context.Background(), tableData)
err := s.store.Merge(s.ctx, tableData)
assert.Equal(s.T(), tableData.InMemoryColumns["first_name"], typing.String)
assert.NoError(s.T(), err)
}
Expand Down Expand Up @@ -95,11 +94,11 @@ func (s *SnowflakeTestSuite) TestExecuteMergeReestablishAuth() {
types.NewDwhTableConfig(columns, nil, false, true))

s.fakeStore.ExecReturnsOnCall(0, nil, fmt.Errorf("390114: Authentication token has expired. The user must authenticate again."))
err := s.store.Merge(context.Background(), tableData)
err := s.store.Merge(s.ctx, tableData)
assert.True(s.T(), AuthenticationExpirationErr(err), err)

s.fakeStore.ExecReturnsOnCall(1, nil, nil)
assert.Nil(s.T(), s.store.Merge(context.Background(), tableData))
assert.Nil(s.T(), s.store.Merge(s.ctx, tableData))
s.fakeStore.ExecReturns(nil, nil)
assert.Equal(s.T(), s.fakeStore.ExecCallCount(), 2, "called merge")
}
Expand Down Expand Up @@ -139,7 +138,7 @@ func (s *SnowflakeTestSuite) TestExecuteMerge() {

s.store.configMap.AddTableToConfig(topicConfig.ToFqName(constants.Snowflake),
types.NewDwhTableConfig(columns, nil, false, true))
err := s.store.Merge(context.Background(), tableData)
err := s.store.Merge(s.ctx, tableData)
assert.Nil(s.T(), err)
s.fakeStore.ExecReturns(nil, nil)
assert.Equal(s.T(), s.fakeStore.ExecCallCount(), 1, "called merge")
Expand Down Expand Up @@ -193,7 +192,7 @@ func (s *SnowflakeTestSuite) TestExecuteMergeDeletionFlagRemoval() {
config := types.NewDwhTableConfig(sflkColumns, nil, false, true)
s.store.configMap.AddTableToConfig(topicConfig.ToFqName(constants.Snowflake), config)

err := s.store.Merge(context.Background(), tableData)
err := s.store.Merge(s.ctx, tableData)
assert.Nil(s.T(), err)
s.fakeStore.ExecReturns(nil, nil)
assert.Equal(s.T(), s.fakeStore.ExecCallCount(), 1, "called merge")
Expand All @@ -215,7 +214,7 @@ func (s *SnowflakeTestSuite) TestExecuteMergeDeletionFlagRemoval() {
break
}

err = s.store.Merge(context.Background(), tableData)
err = s.store.Merge(s.ctx, tableData)
assert.NoError(s.T(), err)
s.fakeStore.ExecReturns(nil, nil)
assert.Equal(s.T(), s.fakeStore.ExecCallCount(), 2, "called merge again")
Expand All @@ -226,7 +225,7 @@ func (s *SnowflakeTestSuite) TestExecuteMergeDeletionFlagRemoval() {
}

func (s *SnowflakeTestSuite) TestExecuteMergeExitEarly() {
err := s.store.Merge(context.Background(), &optimization.TableData{
err := s.store.Merge(s.ctx, &optimization.TableData{
InMemoryColumns: nil,
RowsData: nil,
PrimaryKey: "",
Expand Down
5 changes: 5 additions & 0 deletions lib/cdc/format/format_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package format

import (
"context"
"github.com/artie-labs/transfer/lib/config"
"os"
"os/exec"
"testing"
Expand All @@ -13,6 +14,10 @@ import (

func TestGetFormatParser(t *testing.T) {
ctx := context.Background()
ctx = config.InjectSettingsIntoContext(ctx, &config.Settings{
VerboseLogging: true,
})

validFormats := []string{constants.DBZPostgresAltFormat, constants.DBZPostgresFormat, constants.DBZMongoFormat}
for _, validFormat := range validFormats {
assert.NotNil(t, GetFormatParser(ctx, validFormat))
Expand Down
67 changes: 67 additions & 0 deletions lib/config/context.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package config

import (
"context"
"github.com/jessevdk/go-flags"
"log"
)

const settingsKey = "_settings"

type Settings struct {
Config *Config
VerboseLogging bool
}

// InjectSettingsIntoContext is used for tests ONLY
func InjectSettingsIntoContext(ctx context.Context, settings *Settings) context.Context {
return context.WithValue(ctx, settingsKey, settings)
}

func FromContext(ctx context.Context) *Settings {
settingsVal := ctx.Value(settingsKey)
if settingsVal == nil {
log.Fatalf("failed to grab settings from context")
}

settings, isOk := settingsVal.(*Settings)
if !isOk {
log.Fatalf("settings in context is not of *config.Settings type")
}

return settings
}

// InitializeCfgIntoContext will take the flags and then parse
// loadConfig is optional for testing purposes.
func InitializeCfgIntoContext(ctx context.Context, args []string, loadConfig bool) context.Context {
var opts struct {
ConfigFilePath string `short:"c" long:"config" description:"path to the config file"`
Verbose bool `short:"v" long:"verbose" description:"debug logging" optional:"true"`
}

_, err := flags.ParseArgs(&opts, args)
if err != nil {
log.Fatalf("failed to parse args, err: %v", err)
}

var config *Config
if loadConfig {
config, err = readFileToConfig(opts.ConfigFilePath)
if err != nil {
log.Fatalf("failed to parse config file. Please check your config, err: %v", err)
}

err = config.Validate()
if err != nil {
log.Fatalf("Failed to validate config, err: %v", err)
}
}

settings := &Settings{
Config: config,
VerboseLogging: opts.Verbose,
}

return context.WithValue(ctx, settingsKey, settings)
}
19 changes: 19 additions & 0 deletions lib/config/context_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package config

import (
"context"
"github.com/stretchr/testify/assert"
"testing"
)

func TestParseArgs(t *testing.T) {
ctx := InitializeCfgIntoContext(context.Background(), []string{}, false)
settings := FromContext(ctx)

assert.Equal(t, settings.VerboseLogging, false)
assert.Nil(t, settings.Config)

ctx = InitializeCfgIntoContext(context.Background(), []string{"-v"}, false)
settings = FromContext(ctx)
assert.Equal(t, settings.VerboseLogging, true)
}
54 changes: 0 additions & 54 deletions lib/config/flags.go

This file was deleted.

17 changes: 0 additions & 17 deletions lib/config/flags_test.go

This file was deleted.

12 changes: 7 additions & 5 deletions lib/dwh/utils/load.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@ import (
"github.com/artie-labs/transfer/lib/mocks"
)

func DataWarehouse(ctx context.Context) dwh.DataWarehouse {
switch config.GetSettings().Config.Output {
func DataWarehouse(ctx context.Context, store *db.Store) dwh.DataWarehouse {
settings := config.FromContext(ctx)

switch settings.Config.Output {
case "test":
// TODO - In the future, we can create a fake store that follows the MERGE syntax for SQL standard.
// Also, the fake library not only needs to support MERGE, but needs to be able to make it easy for us to return
Expand All @@ -25,13 +27,13 @@ func DataWarehouse(ctx context.Context) dwh.DataWarehouse {
})
return snowflake.LoadSnowflake(ctx, &store)
case "snowflake":
return snowflake.LoadSnowflake(ctx, nil)
return snowflake.LoadSnowflake(ctx, store)
case "bigquery":
return bigquery.LoadBigQuery(ctx, nil)
return bigquery.LoadBigQuery(ctx, store)
}

logger.FromContext(ctx).WithFields(map[string]interface{}{
"source": config.GetSettings().Config.Output,
"source": settings.Config.Output,
}).Fatal("No valid output sources specified.")

return nil
Expand Down
Loading

0 comments on commit 953a9ab

Please sign in to comment.