diff --git a/clients/bigquery/bigquery.go b/clients/bigquery/bigquery.go index 7389e4f7e..ed1f70a12 100644 --- a/clients/bigquery/bigquery.go +++ b/clients/bigquery/bigquery.go @@ -42,7 +42,9 @@ func LoadBigQuery(ctx context.Context, _store *db.Store) *Store { } } - if credPath := config.GetSettings().Config.BigQuery.PathToCredentials; credPath != "" { + settings := config.FromContext(ctx) + + if credPath := settings.Config.BigQuery.PathToCredentials; credPath != "" { // If the credPath is set, let's set it into the env var. logger.FromContext(ctx).Debug("writing the path to BQ credentials to env var for google auth") err := os.Setenv(GooglePathToCredentialsEnvKey, credPath) @@ -53,7 +55,7 @@ func LoadBigQuery(ctx context.Context, _store *db.Store) *Store { return &Store{ Store: db.Open(ctx, "bigquery", fmt.Sprintf("bigquery://%s/%s", - config.GetSettings().Config.BigQuery.ProjectID, config.GetSettings().Config.BigQuery.DefaultDataset)), + settings.Config.BigQuery.ProjectID, settings.Config.BigQuery.DefaultDataset)), configMap: &types.DwhToTablesConfigMap{}, } } diff --git a/clients/snowflake/ddl_test.go b/clients/snowflake/ddl_test.go index b4a4038a5..a6d6b3efc 100644 --- a/clients/snowflake/ddl_test.go +++ b/clients/snowflake/ddl_test.go @@ -1,7 +1,6 @@ package snowflake import ( - "context" "fmt" "github.com/artie-labs/transfer/lib/config/constants" "github.com/artie-labs/transfer/lib/dwh/types" @@ -108,11 +107,9 @@ func (s *SnowflakeTestSuite) TestManipulateShouldDeleteColumn() { func (s *SnowflakeTestSuite) TestGetTableConfig() { // If the table does not exist, snowflakeTableConfig should say so. fqName := "customers.public.orders22" - ctx := context.Background() - s.fakeStore.QueryReturns(nil, fmt.Errorf("Table '%s' does not exist or not authorized", fqName)) - tableConfig, err := s.store.getTableConfig(ctx, fqName, false) + tableConfig, err := s.store.getTableConfig(s.ctx, fqName, false) assert.NotNil(s.T(), tableConfig, "config is nil") assert.NoError(s.T(), err) diff --git a/clients/snowflake/snowflake.go b/clients/snowflake/snowflake.go index b7d8cf24a..a38845db1 100644 --- a/clients/snowflake/snowflake.go +++ b/clients/snowflake/snowflake.go @@ -116,17 +116,19 @@ func (s *Store) ReestablishConnection(ctx context.Context) { return } + settings := config.FromContext(ctx) + cfg := &gosnowflake.Config{ - Account: config.GetSettings().Config.Snowflake.AccountID, - User: config.GetSettings().Config.Snowflake.Username, - Password: config.GetSettings().Config.Snowflake.Password, - Warehouse: config.GetSettings().Config.Snowflake.Warehouse, - Region: config.GetSettings().Config.Snowflake.Region, + Account: settings.Config.Snowflake.AccountID, + User: settings.Config.Snowflake.Username, + Password: settings.Config.Snowflake.Password, + Warehouse: settings.Config.Snowflake.Warehouse, + Region: settings.Config.Snowflake.Region, } - if config.GetSettings().Config.Snowflake.Host != "" { + if settings.Config.Snowflake.Host != "" { // If the host is specified - cfg.Host = config.GetSettings().Config.Snowflake.Host + cfg.Host = settings.Config.Snowflake.Host cfg.Region = "" } diff --git a/clients/snowflake/snowflake_suite_test.go b/clients/snowflake/snowflake_suite_test.go index 3f468a462..fba5f9dac 100644 --- a/clients/snowflake/snowflake_suite_test.go +++ b/clients/snowflake/snowflake_suite_test.go @@ -2,6 +2,7 @@ package snowflake import ( "context" + "github.com/artie-labs/transfer/lib/config" "testing" "github.com/stretchr/testify/suite" @@ -14,14 +15,18 @@ type SnowflakeTestSuite struct { suite.Suite fakeStore *mocks.FakeStore store *Store + ctx context.Context } func (s *SnowflakeTestSuite) SetupTest() { - ctx := context.Background() + s.ctx = config.InjectSettingsIntoContext(context.Background(), &config.Settings{ + VerboseLogging: false, + }) s.fakeStore = &mocks.FakeStore{} store := db.Store(s.fakeStore) - s.store = LoadSnowflake(ctx, &store) + s.store = LoadSnowflake(s.ctx, &store) + } func TestSnowflakeTestSuite(t *testing.T) { diff --git a/clients/snowflake/snowflake_test.go b/clients/snowflake/snowflake_test.go index 74eeaaa42..a9933b1db 100644 --- a/clients/snowflake/snowflake_test.go +++ b/clients/snowflake/snowflake_test.go @@ -1,7 +1,6 @@ package snowflake import ( - "context" "fmt" "time" @@ -53,7 +52,7 @@ func (s *SnowflakeTestSuite) TestExecuteMergeNilEdgeCase() { constants.DeleteColumnMarker: typing.Boolean, }, nil, false, true)) - err := s.store.Merge(context.Background(), tableData) + err := s.store.Merge(s.ctx, tableData) assert.Equal(s.T(), tableData.InMemoryColumns["first_name"], typing.String) assert.NoError(s.T(), err) } @@ -95,11 +94,11 @@ func (s *SnowflakeTestSuite) TestExecuteMergeReestablishAuth() { types.NewDwhTableConfig(columns, nil, false, true)) s.fakeStore.ExecReturnsOnCall(0, nil, fmt.Errorf("390114: Authentication token has expired. The user must authenticate again.")) - err := s.store.Merge(context.Background(), tableData) + err := s.store.Merge(s.ctx, tableData) assert.True(s.T(), AuthenticationExpirationErr(err), err) s.fakeStore.ExecReturnsOnCall(1, nil, nil) - assert.Nil(s.T(), s.store.Merge(context.Background(), tableData)) + assert.Nil(s.T(), s.store.Merge(s.ctx, tableData)) s.fakeStore.ExecReturns(nil, nil) assert.Equal(s.T(), s.fakeStore.ExecCallCount(), 2, "called merge") } @@ -139,7 +138,7 @@ func (s *SnowflakeTestSuite) TestExecuteMerge() { s.store.configMap.AddTableToConfig(topicConfig.ToFqName(constants.Snowflake), types.NewDwhTableConfig(columns, nil, false, true)) - err := s.store.Merge(context.Background(), tableData) + err := s.store.Merge(s.ctx, tableData) assert.Nil(s.T(), err) s.fakeStore.ExecReturns(nil, nil) assert.Equal(s.T(), s.fakeStore.ExecCallCount(), 1, "called merge") @@ -193,7 +192,7 @@ func (s *SnowflakeTestSuite) TestExecuteMergeDeletionFlagRemoval() { config := types.NewDwhTableConfig(sflkColumns, nil, false, true) s.store.configMap.AddTableToConfig(topicConfig.ToFqName(constants.Snowflake), config) - err := s.store.Merge(context.Background(), tableData) + err := s.store.Merge(s.ctx, tableData) assert.Nil(s.T(), err) s.fakeStore.ExecReturns(nil, nil) assert.Equal(s.T(), s.fakeStore.ExecCallCount(), 1, "called merge") @@ -215,7 +214,7 @@ func (s *SnowflakeTestSuite) TestExecuteMergeDeletionFlagRemoval() { break } - err = s.store.Merge(context.Background(), tableData) + err = s.store.Merge(s.ctx, tableData) assert.NoError(s.T(), err) s.fakeStore.ExecReturns(nil, nil) assert.Equal(s.T(), s.fakeStore.ExecCallCount(), 2, "called merge again") @@ -226,7 +225,7 @@ func (s *SnowflakeTestSuite) TestExecuteMergeDeletionFlagRemoval() { } func (s *SnowflakeTestSuite) TestExecuteMergeExitEarly() { - err := s.store.Merge(context.Background(), &optimization.TableData{ + err := s.store.Merge(s.ctx, &optimization.TableData{ InMemoryColumns: nil, RowsData: nil, PrimaryKey: "", diff --git a/lib/cdc/format/format_test.go b/lib/cdc/format/format_test.go index a23ef8ab3..7e8acdae5 100644 --- a/lib/cdc/format/format_test.go +++ b/lib/cdc/format/format_test.go @@ -2,6 +2,7 @@ package format import ( "context" + "github.com/artie-labs/transfer/lib/config" "os" "os/exec" "testing" @@ -13,6 +14,10 @@ import ( func TestGetFormatParser(t *testing.T) { ctx := context.Background() + ctx = config.InjectSettingsIntoContext(ctx, &config.Settings{ + VerboseLogging: true, + }) + validFormats := []string{constants.DBZPostgresAltFormat, constants.DBZPostgresFormat, constants.DBZMongoFormat} for _, validFormat := range validFormats { assert.NotNil(t, GetFormatParser(ctx, validFormat)) diff --git a/lib/config/context.go b/lib/config/context.go new file mode 100644 index 000000000..19f90e003 --- /dev/null +++ b/lib/config/context.go @@ -0,0 +1,67 @@ +package config + +import ( + "context" + "github.com/jessevdk/go-flags" + "log" +) + +const settingsKey = "_settings" + +type Settings struct { + Config *Config + VerboseLogging bool +} + +// InjectSettingsIntoContext is used for tests ONLY +func InjectSettingsIntoContext(ctx context.Context, settings *Settings) context.Context { + return context.WithValue(ctx, settingsKey, settings) +} + +func FromContext(ctx context.Context) *Settings { + settingsVal := ctx.Value(settingsKey) + if settingsVal == nil { + log.Fatalf("failed to grab settings from context") + } + + settings, isOk := settingsVal.(*Settings) + if !isOk { + log.Fatalf("settings in context is not of *config.Settings type") + } + + return settings +} + +// InitializeCfgIntoContext will take the flags and then parse +// loadConfig is optional for testing purposes. +func InitializeCfgIntoContext(ctx context.Context, args []string, loadConfig bool) context.Context { + var opts struct { + ConfigFilePath string `short:"c" long:"config" description:"path to the config file"` + Verbose bool `short:"v" long:"verbose" description:"debug logging" optional:"true"` + } + + _, err := flags.ParseArgs(&opts, args) + if err != nil { + log.Fatalf("failed to parse args, err: %v", err) + } + + var config *Config + if loadConfig { + config, err = readFileToConfig(opts.ConfigFilePath) + if err != nil { + log.Fatalf("failed to parse config file. Please check your config, err: %v", err) + } + + err = config.Validate() + if err != nil { + log.Fatalf("Failed to validate config, err: %v", err) + } + } + + settings := &Settings{ + Config: config, + VerboseLogging: opts.Verbose, + } + + return context.WithValue(ctx, settingsKey, settings) +} diff --git a/lib/config/context_test.go b/lib/config/context_test.go new file mode 100644 index 000000000..3d296f58d --- /dev/null +++ b/lib/config/context_test.go @@ -0,0 +1,19 @@ +package config + +import ( + "context" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestParseArgs(t *testing.T) { + ctx := InitializeCfgIntoContext(context.Background(), []string{}, false) + settings := FromContext(ctx) + + assert.Equal(t, settings.VerboseLogging, false) + assert.Nil(t, settings.Config) + + ctx = InitializeCfgIntoContext(context.Background(), []string{"-v"}, false) + settings = FromContext(ctx) + assert.Equal(t, settings.VerboseLogging, true) +} diff --git a/lib/config/flags.go b/lib/config/flags.go deleted file mode 100644 index fe1d6827c..000000000 --- a/lib/config/flags.go +++ /dev/null @@ -1,54 +0,0 @@ -package config - -import ( - "log" - - "github.com/jessevdk/go-flags" -) - -type Settings struct { - Config *Config - VerboseLogging bool -} - -var settings *Settings - -func GetSettings() *Settings { - if settings == nil { - log.Fatal("Settings is empty, we need to initialize.") - } - - return settings -} - -// ParseArgs will take the flags and then parse -// loadConfig is optional for testing purposes. -func ParseArgs(args []string, loadConfig bool) { - var opts struct { - ConfigFilePath string `short:"c" long:"config" description:"path to the config file"` - Verbose bool `short:"v" long:"verbose" description:"debug logging" optional:"true"` - } - - _, err := flags.ParseArgs(&opts, args) - if err != nil { - log.Fatalf("Failed to parse args, err: %v", err) - } - - var config *Config - if loadConfig { - config, err = readFileToConfig(opts.ConfigFilePath) - if err != nil { - log.Fatalf("Failed to parse config file. Please check your config, err: %v", err) - } - - err = config.Validate() - if err != nil { - log.Fatalf("Failed to validate config, err: %v", err) - } - } - - settings = &Settings{ - Config: config, - VerboseLogging: opts.Verbose, - } -} diff --git a/lib/config/flags_test.go b/lib/config/flags_test.go deleted file mode 100644 index d65c29fe7..000000000 --- a/lib/config/flags_test.go +++ /dev/null @@ -1,17 +0,0 @@ -package config - -import ( - "github.com/stretchr/testify/assert" - "testing" -) - -func TestParseArgs(t *testing.T) { - ParseArgs([]string{}, false) - settings := GetSettings() - - assert.Equal(t, settings.VerboseLogging, false) - assert.Nil(t, settings.Config) - - ParseArgs([]string{"-v"}, false) - assert.Equal(t, GetSettings().VerboseLogging, true) -} diff --git a/lib/dwh/utils/load.go b/lib/dwh/utils/load.go index c0a1bf0ef..926e2e5d1 100644 --- a/lib/dwh/utils/load.go +++ b/lib/dwh/utils/load.go @@ -12,8 +12,10 @@ import ( "github.com/artie-labs/transfer/lib/mocks" ) -func DataWarehouse(ctx context.Context) dwh.DataWarehouse { - switch config.GetSettings().Config.Output { +func DataWarehouse(ctx context.Context, store *db.Store) dwh.DataWarehouse { + settings := config.FromContext(ctx) + + switch settings.Config.Output { case "test": // TODO - In the future, we can create a fake store that follows the MERGE syntax for SQL standard. // Also, the fake library not only needs to support MERGE, but needs to be able to make it easy for us to return @@ -25,13 +27,13 @@ func DataWarehouse(ctx context.Context) dwh.DataWarehouse { }) return snowflake.LoadSnowflake(ctx, &store) case "snowflake": - return snowflake.LoadSnowflake(ctx, nil) + return snowflake.LoadSnowflake(ctx, store) case "bigquery": - return bigquery.LoadBigQuery(ctx, nil) + return bigquery.LoadBigQuery(ctx, store) } logger.FromContext(ctx).WithFields(map[string]interface{}{ - "source": config.GetSettings().Config.Output, + "source": settings.Config.Output, }).Fatal("No valid output sources specified.") return nil diff --git a/lib/logger/context.go b/lib/logger/context.go index bf5be4c67..9a1d59e00 100644 --- a/lib/logger/context.go +++ b/lib/logger/context.go @@ -2,25 +2,26 @@ package logger import ( "context" + "github.com/artie-labs/transfer/lib/config" "github.com/sirupsen/logrus" ) const loggerKey = "_log" -func InjectLoggerIntoCtx(logger *logrus.Logger, ctx context.Context) context.Context { - return context.WithValue(ctx, loggerKey, logger) +func InjectLoggerIntoCtx(ctx context.Context) context.Context { + return context.WithValue(ctx, loggerKey, new(config.FromContext(ctx))) } func FromContext(ctx context.Context) *logrus.Logger { logVal := ctx.Value(loggerKey) if logVal == nil { // Inject this back into context, so we don't need to initialize this again - return FromContext(InjectLoggerIntoCtx(NewLogger(nil), ctx)) + return FromContext(InjectLoggerIntoCtx(ctx)) } log, isOk := logVal.(*logrus.Logger) if !isOk { - return FromContext(InjectLoggerIntoCtx(NewLogger(nil), ctx)) + return FromContext(InjectLoggerIntoCtx(ctx)) } return log diff --git a/lib/logger/context_test.go b/lib/logger/context_test.go index bc828593b..e1e524663 100644 --- a/lib/logger/context_test.go +++ b/lib/logger/context_test.go @@ -2,39 +2,20 @@ package logger import ( "context" + "github.com/artie-labs/transfer/lib/config" "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "testing" ) func TestLogger(t *testing.T) { - ctx := context.Background() + ctx := config.InjectSettingsIntoContext(context.Background(), &config.Settings{ + VerboseLogging: true, + }) + log := &logrus.Logger{ Level: logrus.DebugLevel, } - assert.Equal(t, log.Level, FromContext(InjectLoggerIntoCtx(log, ctx)).Level) -} - -func TestLoggerNil(t *testing.T) { - assert.NotNil(t, FromContext(context.Background())) -} - -func TestLoggerWrongType(t *testing.T) { - ctx := context.WithValue(context.Background(), loggerKey, "foo") - assert.NotNil(t, FromContext(ctx)) -} - -func TestLoggerSubsequent(t *testing.T) { - // Start with nil - // Get the logger, then update the level - // Fetch the logger again and ensure level is the same. - ctx := context.Background() - log := FromContext(ctx) - assert.NotNil(t, log) - - log.Level = logrus.DebugLevel - ctx = InjectLoggerIntoCtx(log, ctx) - - assert.Equal(t, FromContext(ctx).Level, log.Level) + assert.Equal(t, log.Level, FromContext(InjectLoggerIntoCtx(ctx)).Level) } diff --git a/lib/logger/log.go b/lib/logger/log.go index c032e544e..8273b566c 100644 --- a/lib/logger/log.go +++ b/lib/logger/log.go @@ -9,7 +9,7 @@ import ( "github.com/artie-labs/transfer/lib/config" ) -func NewLogger(settings *config.Settings) *logrus.Logger { +func new(settings *config.Settings) *logrus.Logger { log := logrus.New() log.SetOutput(os.Stdout) diff --git a/lib/telemetry/metrics/datadog_test.go b/lib/telemetry/metrics/datadog_test.go index 382ca44c1..6137b62fa 100644 --- a/lib/telemetry/metrics/datadog_test.go +++ b/lib/telemetry/metrics/datadog_test.go @@ -1,31 +1,28 @@ package metrics import ( - "context" "github.com/stretchr/testify/assert" - "testing" ) -func TestGetSampleRate(t *testing.T) { - assert.Equal(t, getSampleRate("foo"), float64(DefaultSampleRate)) - assert.Equal(t, getSampleRate(1.25), float64(DefaultSampleRate)) - assert.Equal(t, getSampleRate(1), float64(1)) - assert.Equal(t, getSampleRate(0.33), 0.33) - assert.Equal(t, getSampleRate(0), float64(DefaultSampleRate)) - assert.Equal(t, getSampleRate(-0.55), float64(DefaultSampleRate)) +func (m *MetricsTestSuite) TestGetSampleRate() { + assert.Equal(m.T(), getSampleRate("foo"), float64(DefaultSampleRate)) + assert.Equal(m.T(), getSampleRate(1.25), float64(DefaultSampleRate)) + assert.Equal(m.T(), getSampleRate(1), float64(1)) + assert.Equal(m.T(), getSampleRate(0.33), 0.33) + assert.Equal(m.T(), getSampleRate(0), float64(DefaultSampleRate)) + assert.Equal(m.T(), getSampleRate(-0.55), float64(DefaultSampleRate)) } -func TestGetTags(t *testing.T) { - assert.Equal(t, getTags(nil), []string{}) - assert.Equal(t, getTags([]string{}), []string{}) - assert.Equal(t, getTags([]interface{}{"env:bar", "a:b"}), []string{"env:bar", "a:b"}) +func (m *MetricsTestSuite) TestGetTags() { + assert.Equal(m.T(), getTags(nil), []string{}) + assert.Equal(m.T(), getTags([]string{}), []string{}) + assert.Equal(m.T(), getTags([]interface{}{"env:bar", "a:b"}), []string{"env:bar", "a:b"}) } -func TestNewDatadogClient(t *testing.T) { - ctx := context.Background() +func (m *MetricsTestSuite) TestNewDatadogClient() { var err error - ctx, err = NewDatadogClient(ctx, map[string]interface{}{ + m.ctx, err = NewDatadogClient(m.ctx, map[string]interface{}{ Tags: []string{ "env:production", }, @@ -34,10 +31,10 @@ func TestNewDatadogClient(t *testing.T) { // Cannot test datadogAddr (addr is private) }) - assert.NoError(t, err, err) - mtr := FromContext(ctx).(*statsClient) + assert.NoError(m.T(), err, err) + mtr := FromContext(m.ctx).(*statsClient) - assert.Equal(t, mtr.rate, 0.255, mtr.rate) - assert.Equal(t, mtr.client.Namespace, "dusty.", mtr.client.Namespace) - assert.Equal(t, mtr.client.Tags, []string{"env:production"}, mtr.client.Tags) + assert.Equal(m.T(), mtr.rate, 0.255, mtr.rate) + assert.Equal(m.T(), mtr.client.Namespace, "dusty.", mtr.client.Namespace) + assert.Equal(m.T(), mtr.client.Tags, []string{"env:production"}, mtr.client.Tags) } diff --git a/lib/telemetry/metrics/metrics_suite_test.go b/lib/telemetry/metrics/metrics_suite_test.go new file mode 100644 index 000000000..43e0ceded --- /dev/null +++ b/lib/telemetry/metrics/metrics_suite_test.go @@ -0,0 +1,24 @@ +package metrics + +import ( + "context" + "github.com/artie-labs/transfer/lib/config" + "github.com/stretchr/testify/suite" + "testing" +) + +type MetricsTestSuite struct { + suite.Suite + ctx context.Context +} + +func (m *MetricsTestSuite) SetupTest() { + m.ctx = config.InjectSettingsIntoContext(context.Background(), &config.Settings{ + Config: &config.Config{}, + VerboseLogging: false, + }) +} + +func TestMetricsTestSuite(t *testing.T) { + suite.Run(t, new(MetricsTestSuite)) +} diff --git a/lib/telemetry/metrics/stats.go b/lib/telemetry/metrics/stats.go index c8e6f9674..c59bf9702 100644 --- a/lib/telemetry/metrics/stats.go +++ b/lib/telemetry/metrics/stats.go @@ -2,6 +2,7 @@ package metrics import ( "context" + "github.com/artie-labs/transfer/lib/config" "github.com/artie-labs/transfer/lib/config/constants" "github.com/artie-labs/transfer/lib/logger" @@ -21,7 +22,11 @@ func exporterKindValid(kind constants.ExporterKind) bool { return valid } -func LoadExporter(ctx context.Context, kind constants.ExporterKind, settings map[string]interface{}) context.Context { +func LoadExporter(ctx context.Context) context.Context { + settings := config.FromContext(ctx) + kind := settings.Config.Telemetry.Metrics.Provider + ddSettings := settings.Config.Telemetry.Metrics.Settings + if !exporterKindValid(kind) { logger.FromContext(ctx).WithFields(map[string]interface{}{ "exporterKind": kind, @@ -31,7 +36,7 @@ func LoadExporter(ctx context.Context, kind constants.ExporterKind, settings map switch kind { case constants.Datadog: var exportErr error - ctx, exportErr = NewDatadogClient(ctx, settings) + ctx, exportErr = NewDatadogClient(ctx, ddSettings) if exportErr != nil { logger.FromContext(ctx).WithField("provider", kind).Error(exportErr) } else { diff --git a/lib/telemetry/metrics/stats_test.go b/lib/telemetry/metrics/stats_test.go index fa9013381..8506c98cd 100644 --- a/lib/telemetry/metrics/stats_test.go +++ b/lib/telemetry/metrics/stats_test.go @@ -3,13 +3,12 @@ package metrics import ( "context" "fmt" + "github.com/artie-labs/transfer/lib/config" "github.com/artie-labs/transfer/lib/config/constants" - "testing" - "github.com/stretchr/testify/assert" ) -func TestExporterKindValid(t *testing.T) { +func (m *MetricsTestSuite) TestExporterKindValid() { exporterKindToResultsMap := map[constants.ExporterKind]bool{ constants.Datadog: true, constants.ExporterKind("daaaa"): false, @@ -18,12 +17,12 @@ func TestExporterKindValid(t *testing.T) { } for exporterKind, expectedResults := range exporterKindToResultsMap { - assert.Equal(t, expectedResults, exporterKindValid(exporterKind), + assert.Equal(m.T(), expectedResults, exporterKindValid(exporterKind), fmt.Sprintf("kind: %v should have been %v", exporterKind, expectedResults)) } } -func TestLoadExporter(t *testing.T) { +func (m *MetricsTestSuite) TestLoadExporter() { // Datadog should not be a NullMetricsProvider exporterKindToResultMap := map[constants.ExporterKind]bool{ constants.Datadog: false, @@ -32,12 +31,30 @@ func TestLoadExporter(t *testing.T) { for kind, result := range exporterKindToResultMap { // Wipe and create a new ctx per run - ctx := context.Background() - ctx = LoadExporter(ctx, kind, map[string]interface{}{ - "url": "localhost:8125", + m.ctx = context.Background() + m.ctx = config.InjectSettingsIntoContext(m.ctx, &config.Settings{ + Config: &config.Config{ + Telemetry: struct { + Metrics struct { + Provider constants.ExporterKind `yaml:"provider"` + Settings map[string]interface{} `yaml:"settings,omitempty"` + } + }{ + Metrics: struct { + Provider constants.ExporterKind `yaml:"provider"` + Settings map[string]interface{} `yaml:"settings,omitempty"` + }{ + Provider: kind, + Settings: map[string]interface{}{ + "url": "localhost:8125", + }, + }, + }, + }, }) - _, isOk := FromContext(ctx).(NullMetricsProvider) - assert.Equal(t, result, isOk) + m.ctx = LoadExporter(m.ctx) + _, isOk := FromContext(m.ctx).(NullMetricsProvider) + assert.Equal(m.T(), result, isOk) } } diff --git a/main.go b/main.go index 2849d5d1c..1b4d216d3 100644 --- a/main.go +++ b/main.go @@ -17,13 +17,12 @@ import ( func main() { // Parse args into settings. - config.ParseArgs(os.Args, true) - ctx := logger.InjectLoggerIntoCtx(logger.NewLogger(config.GetSettings()), context.Background()) + ctx := config.InitializeCfgIntoContext(context.Background(), os.Args, true) + ctx = logger.InjectLoggerIntoCtx(ctx) // Loading Telemetry - ctx = metrics.LoadExporter(ctx, config.GetSettings().Config.Telemetry.Metrics.Provider, - config.GetSettings().Config.Telemetry.Metrics.Settings) - ctx = utils.InjectDwhIntoCtx(utils.DataWarehouse(ctx), ctx) + ctx = metrics.LoadExporter(ctx) + ctx = utils.InjectDwhIntoCtx(utils.DataWarehouse(ctx, nil), ctx) models.LoadMemoryDB() @@ -35,11 +34,12 @@ func main() { pool.StartPool(ctx, constants.FlushTimeInterval, flushChan) }() + settings := config.FromContext(ctx) wg.Add(1) go func(ctx context.Context) { defer wg.Done() - switch config.GetSettings().Config.Queue { + switch settings.Config.Queue { case constants.Kafka: consumer.StartConsumer(ctx, flushChan) break @@ -47,7 +47,7 @@ func main() { consumer.StartSubscriber(ctx, flushChan) break default: - logger.FromContext(ctx).Fatalf("message queue: %s not supported", config.GetSettings().Config.Queue) + logger.FromContext(ctx).Fatalf("message queue: %s not supported", settings.Config.Queue) } }(ctx) diff --git a/models/flush/flush_suite_test.go b/models/flush/flush_suite_test.go index adbf31bf4..fbd861604 100644 --- a/models/flush/flush_suite_test.go +++ b/models/flush/flush_suite_test.go @@ -2,7 +2,7 @@ package flush import ( "context" - "github.com/artie-labs/transfer/clients/snowflake" + "github.com/artie-labs/transfer/lib/config" "github.com/artie-labs/transfer/lib/db" "github.com/artie-labs/transfer/lib/dwh/utils" "github.com/artie-labs/transfer/lib/kafkalib" @@ -25,11 +25,16 @@ func (f *FlushTestSuite) SetupTest() { f.fakeStore = &mocks.FakeStore{} store := db.Store(f.fakeStore) - ctx := context.Background() + f.ctx = context.Background() - // Not using LoadDataWarehouse here because config.GetSettings() is not initialized in this test - // TODO: Address ^ - f.ctx = utils.InjectDwhIntoCtx(snowflake.LoadSnowflake(ctx, &store), ctx) + f.ctx = config.InjectSettingsIntoContext(f.ctx, &config.Settings{ + Config: &config.Config{ + Output: "snowflake", + }, + VerboseLogging: false, + }) + + f.ctx = utils.InjectDwhIntoCtx(utils.DataWarehouse(f.ctx, &store), f.ctx) models.LoadMemoryDB() diff --git a/processes/consumer/kafka.go b/processes/consumer/kafka.go index c8cdf1cdb..d943a7946 100644 --- a/processes/consumer/kafka.go +++ b/processes/consumer/kafka.go @@ -5,7 +5,6 @@ import ( "crypto/tls" "github.com/artie-labs/transfer/lib/artie" awsCfg "github.com/aws/aws-sdk-go-v2/config" - "github.com/segmentio/kafka-go/sasl" "github.com/segmentio/kafka-go/sasl/aws_msk_iam_v2" "github.com/segmentio/kafka-go/sasl/plain" "sync" @@ -27,43 +26,41 @@ func SetKafkaConsumer(_topicToConsumer map[string]kafkalib.Consumer) { func StartConsumer(ctx context.Context, flushChan chan bool) { log := logger.FromContext(ctx) - log.Info("Starting Kafka consumer...", config.GetSettings().Config.Kafka) + settings := config.FromContext(ctx) + log.Info("Starting Kafka consumer...", settings.Config.Kafka) dialer := &kafka.Dialer{ Timeout: 10 * time.Second, DualStack: true, } - var mech sasl.Mechanism - // If using AWS MSK IAM, we expect this to be set in the ENV VAR // (AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, or the AWS Profile should be called default.) - if config.GetSettings().Config.Kafka.EnableAWSMSKIAM { + if settings.Config.Kafka.EnableAWSMSKIAM { cfg, err := awsCfg.LoadDefaultConfig(ctx) if err != nil { log.WithError(err).Fatal("failed to load aws configuration") } - mech = aws_msk_iam_v2.NewMechanism(cfg) - + dialer.SASLMechanism = aws_msk_iam_v2.NewMechanism(cfg) + dialer.TLS = &tls.Config{} } // If username or password is set, then let's enable PLAIN. // By default, we will support no auth (local testing) and PLAIN SASL. - if config.GetSettings().Config.Kafka.Username != "" { - mech = plain.Mechanism{ - Username: config.GetSettings().Config.Kafka.Username, - Password: config.GetSettings().Config.Kafka.Password, + if settings.Config.Kafka.Username != "" { + dialer.SASLMechanism = plain.Mechanism{ + Username: settings.Config.Kafka.Username, + Password: settings.Config.Kafka.Password, } + dialer.TLS = &tls.Config{} } - dialer.SASLMechanism = mech - dialer.TLS = &tls.Config{} topicToConfigFmtMap := make(map[string]TopicConfigFormatter) topicToConsumer = make(map[string]kafkalib.Consumer) var topics []string - for _, topicConfig := range config.GetSettings().Config.Kafka.TopicConfigs { + for _, topicConfig := range settings.Config.Kafka.TopicConfigs { topicToConfigFmtMap[topicConfig.Topic] = TopicConfigFormatter{ tc: topicConfig, Format: format.GetFormatParser(ctx, topicConfig.CDCFormat), @@ -77,8 +74,8 @@ func StartConsumer(ctx context.Context, flushChan chan bool) { go func(topic string) { defer wg.Done() kafkaConsumer := kafka.NewReader(kafka.ReaderConfig{ - Brokers: []string{config.GetSettings().Config.Kafka.BootstrapServer}, - GroupID: config.GetSettings().Config.Kafka.GroupID, + Brokers: []string{settings.Config.Kafka.BootstrapServer}, + GroupID: settings.Config.Kafka.GroupID, Dialer: dialer, Topic: topic, }) diff --git a/processes/consumer/pubsub.go b/processes/consumer/pubsub.go index 7d4013317..c2b01bffc 100644 --- a/processes/consumer/pubsub.go +++ b/processes/consumer/pubsub.go @@ -49,15 +49,16 @@ func findOrCreateSubscription(ctx context.Context, client *gcp_pubsub.Client, to func StartSubscriber(ctx context.Context, flushChan chan bool) { log := logger.FromContext(ctx) - client, clientErr := gcp_pubsub.NewClient(ctx, config.GetSettings().Config.Pubsub.ProjectID, - option.WithCredentialsFile(config.GetSettings().Config.Pubsub.PathToCredentials)) + settings := config.FromContext(ctx) + client, clientErr := gcp_pubsub.NewClient(ctx, settings.Config.Pubsub.ProjectID, + option.WithCredentialsFile(settings.Config.Pubsub.PathToCredentials)) if clientErr != nil { log.Fatalf("failed to create a pubsub client, err: %v", clientErr) } topicToConfigFmtMap := make(map[string]TopicConfigFormatter) var topics []string - for _, topicConfig := range config.GetSettings().Config.Pubsub.TopicConfigs { + for _, topicConfig := range settings.Config.Pubsub.TopicConfigs { topicToConfigFmtMap[topicConfig.Topic] = TopicConfigFormatter{ tc: topicConfig, Format: format.GetFormatParser(ctx, topicConfig.CDCFormat), @@ -66,7 +67,7 @@ func StartSubscriber(ctx context.Context, flushChan chan bool) { } var wg sync.WaitGroup - for _, topicConfig := range config.GetSettings().Config.Pubsub.TopicConfigs { + for _, topicConfig := range settings.Config.Pubsub.TopicConfigs { wg.Add(1) go func(ctx context.Context, client *gcp_pubsub.Client, topic string) { defer wg.Done()