diff --git a/component/amqp/message_test.go b/component/amqp/message_test.go index e69a63c4b..cf2cedef9 100644 --- a/component/amqp/message_test.go +++ b/component/amqp/message_test.go @@ -27,7 +27,9 @@ var ( ) func TestMain(m *testing.M) { - os.Setenv("OTEL_BSP_SCHEDULE_DELAY", "100") + if err := os.Setenv("OTEL_BSP_SCHEDULE_DELAY", "100"); err != nil { + panic(err) + } tracePublisher = patrontrace.Setup("test", nil, traceExporter) diff --git a/component/http/observability.go b/component/http/observability.go index 1d30c05cc..16da648ee 100644 --- a/component/http/observability.go +++ b/component/http/observability.go @@ -5,6 +5,8 @@ import ( "fmt" "net/http" "net/http/pprof" + + "github.com/beatlabs/patron/observability/log" ) func ProfilingRoutes(enableExpVar bool) []*Route { @@ -47,3 +49,24 @@ func expVars(w http.ResponseWriter, _ *http.Request) { }) _, _ = fmt.Fprintf(w, "\n}\n") } + +// LoggingRoutes returns a routes relates to logs. +func LoggingRoutes() []*Route { + handler := func(w http.ResponseWriter, r *http.Request) { + lvl := r.PathValue("level") + if lvl == "" { + http.Error(w, "missing log level", http.StatusBadRequest) + return + } + + err := log.SetLevel(lvl) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + w.WriteHeader(http.StatusOK) + } + + route, _ := NewRoute("POST /debug/log/{level}", handler) + return []*Route{route} +} diff --git a/component/http/observability_test.go b/component/http/observability_test.go index 0fd297403..c4daa09f0 100644 --- a/component/http/observability_test.go +++ b/component/http/observability_test.go @@ -7,6 +7,7 @@ import ( "net/http/httptest" "testing" + "github.com/beatlabs/patron/observability/log" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -18,7 +19,7 @@ type profilingTestCase struct { func TestProfilingRoutes(t *testing.T) { t.Run("without vars", func(t *testing.T) { - server := createProfilingServer(false) + server := createServer(false) defer server.Close() for name, tt := range createProfilingTestCases(false) { @@ -34,7 +35,7 @@ func TestProfilingRoutes(t *testing.T) { }) t.Run("with vars", func(t *testing.T) { - server := createProfilingServer(true) + server := createServer(true) defer server.Close() for name, tt := range createProfilingTestCases(true) { @@ -50,11 +51,14 @@ func TestProfilingRoutes(t *testing.T) { }) } -func createProfilingServer(enableExpVar bool) *httptest.Server { +func createServer(enableExpVar bool) *httptest.Server { mux := http.NewServeMux() for _, route := range ProfilingRoutes(enableExpVar) { mux.HandleFunc(route.path, route.handler) } + for _, route := range LoggingRoutes() { + mux.HandleFunc(route.path, route.handler) + } return httptest.NewServer(mux) } @@ -80,3 +84,39 @@ func createProfilingTestCases(enableExpVar bool) map[string]profilingTestCase { "vars": {"/debug/vars/", expVarWant}, } } + +func TestLoggingRoutes(t *testing.T) { + require.NoError(t, log.Setup(&log.Config{ + IsJSON: true, + Level: "info", + })) + server := createServer(true) + defer server.Close() + + t.Run("change log level to debug", func(t *testing.T) { + req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, server.URL+"/debug/log/debug", nil) + require.NoError(t, err) + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + require.NoError(t, resp.Body.Close()) + }) + + t.Run("wrong log level", func(t *testing.T) { + req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, server.URL+"/debug/log/xxx", nil) + require.NoError(t, err) + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + require.NoError(t, resp.Body.Close()) + }) + + t.Run("empty log level", func(t *testing.T) { + req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, server.URL+"/debug/log/", nil) + require.NoError(t, err) + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + assert.Equal(t, http.StatusNotFound, resp.StatusCode) + require.NoError(t, resp.Body.Close()) + }) +} diff --git a/component/kafka/component_test.go b/component/kafka/component_test.go index 00af1330f..a5cf9cdca 100644 --- a/component/kafka/component_test.go +++ b/component/kafka/component_test.go @@ -26,7 +26,9 @@ var ( ) func TestMain(m *testing.M) { - os.Setenv("OTEL_BSP_SCHEDULE_DELAY", "100") + if err := os.Setenv("OTEL_BSP_SCHEDULE_DELAY", "100"); err != nil { + panic(err) + } tracePublisher = patrontrace.Setup("test", nil, traceExporter) os.Exit(m.Run()) @@ -39,7 +41,7 @@ func TestNew(t *testing.T) { // consumer will commit every batch in a blocking operation saramaCfg.Consumer.Offsets.AutoCommit.Enable = false saramaCfg.Consumer.Offsets.Initial = sarama.OffsetOldest - saramaCfg.Consumer.Group.Rebalance.Strategy = sarama.BalanceStrategySticky + saramaCfg.Consumer.Group.Rebalance.GroupStrategies = append(saramaCfg.Consumer.Group.Rebalance.GroupStrategies, sarama.NewBalanceStrategySticky()) saramaCfg.Net.DialTimeout = 15 * time.Second saramaCfg.Version = sarama.V2_6_0_0 diff --git a/component/sqs/component_test.go b/component/sqs/component_test.go index 7a7eedb2b..cc4642cd9 100644 --- a/component/sqs/component_test.go +++ b/component/sqs/component_test.go @@ -22,7 +22,9 @@ var ( ) func TestMain(m *testing.M) { - os.Setenv("OTEL_BSP_SCHEDULE_DELAY", "100") + if err := os.Setenv("OTEL_BSP_SCHEDULE_DELAY", "100"); err != nil { + panic(err) + } tracePublisher = patrontrace.Setup("test", nil, traceExporter) diff --git a/component/sqs/metric.go b/component/sqs/metric.go index 138a83e19..b284a8efa 100644 --- a/component/sqs/metric.go +++ b/component/sqs/metric.go @@ -30,8 +30,8 @@ func init() { messageQueueSizeGauge = patronmetric.Float64Gauge(packageName, "sqs.queue.size", "SQS message queue size.", "1") } -func observerMessageAge(ctx context.Context, queue string, attributes map[string]string) { - attribute, ok := attributes[sqsAttributeSentTimestamp] +func observerMessageAge(ctx context.Context, queue string, attrs map[string]string) { + attribute, ok := attrs[sqsAttributeSentTimestamp] if !ok || len(strings.TrimSpace(attribute)) == 0 { return } diff --git a/examples/service/http.go b/examples/service/http.go index 04db86fc3..97b82d937 100644 --- a/examples/service/http.go +++ b/examples/service/http.go @@ -32,10 +32,10 @@ func createHttpRouter() (patron.Component, error) { return nil, fmt.Errorf("failed to create routes: %w", err) } - router, err := router.New(router.WithRoutes(rr...)) + rt, err := router.New(router.WithRoutes(rr...)) if err != nil { return nil, fmt.Errorf("failed to create http router: %w", err) } - return patronhttp.New(router) + return patronhttp.New(rt) } diff --git a/observability/integration_test.go b/observability/integration_test.go index ad39de8c9..d36278718 100644 --- a/observability/integration_test.go +++ b/observability/integration_test.go @@ -6,6 +6,7 @@ import ( "context" "testing" + "github.com/beatlabs/patron/observability/log" "github.com/stretchr/testify/require" ) @@ -13,7 +14,11 @@ func TestSetup(t *testing.T) { t.Setenv("OTEL_EXPORTER_OTLP_INSECURE", "true") ctx := context.Background() - got, err := Setup(ctx, "test", "1.2.3") + got, err := Setup(ctx, Config{ + LogConfig: log.Config{ + Level: "debug", + }, + }) require.NoError(t, err) require.NoError(t, got.Shutdown(ctx)) diff --git a/observability/log/log.go b/observability/log/log.go index a388dd60d..76f0e5149 100644 --- a/observability/log/log.go +++ b/observability/log/log.go @@ -4,10 +4,64 @@ package log import ( "context" "log/slog" + "os" ) +// Config represents the configuration for setting up the logger. +type Config struct { + Attributes []slog.Attr + IsJSON bool + Level string +} + type ctxKey struct{} +var logCfg *Config + +// Setup sets up the logger with the given configuration. +func Setup(cfg *Config) error { + logCfg = cfg + return setDefaultLogger(cfg) +} + +// SetLevel sets the logger level. +func SetLevel(lvl string) error { + logCfg.Level = lvl + return setDefaultLogger(logCfg) +} + +func setDefaultLogger(cfg *Config) error { + lvl, err := level(cfg.Level) + if err != nil { + return err + } + + ho := &slog.HandlerOptions{ + AddSource: true, + Level: lvl, + } + + var hnd slog.Handler + + if cfg.IsJSON { + hnd = slog.NewJSONHandler(os.Stderr, ho) + } else { + hnd = slog.NewTextHandler(os.Stderr, ho) + } + + slog.SetDefault(slog.New(hnd.WithAttrs(cfg.Attributes))) + return nil +} + +func level(lvl string) (slog.Level, error) { + lv := slog.LevelVar{} + if err := lv.UnmarshalText([]byte(lvl)); err != nil { + return slog.LevelInfo, err + } + + return lv.Level(), nil +} + // FromContext returns the logger, if it exists in the context, or nil. func FromContext(ctx context.Context) *slog.Logger { if l, ok := ctx.Value(ctxKey{}).(*slog.Logger); ok { diff --git a/observability/log/log_test.go b/observability/log/log_test.go index 76d4a0d15..3b0ed5d53 100644 --- a/observability/log/log_test.go +++ b/observability/log/log_test.go @@ -7,8 +7,31 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) +func TestSetup(t *testing.T) { + t.Run("JSON", func(t *testing.T) { + cfg := &Config{ + Attributes: []slog.Attr{}, + IsJSON: true, + Level: "debug", + } + require.NoError(t, Setup(cfg)) + assert.NotNil(t, slog.Default()) + }) + + t.Run("Text", func(t *testing.T) { + cfg := &Config{ + Attributes: []slog.Attr{}, + IsJSON: false, + Level: "debug", + } + require.NoError(t, Setup(cfg)) + assert.NotNil(t, slog.Default()) + }) +} + func TestContext(t *testing.T) { l := slog.Default() @@ -23,22 +46,19 @@ func TestContext(t *testing.T) { }) } -func TestEnabled(t *testing.T) { - type args struct { - l slog.Level - } - tests := map[string]struct { - args args - want bool - }{ - "Disabled": {args{slog.LevelDebug}, false}, - "Enabled": {args{slog.LevelInfo}, true}, - } - for name, tt := range tests { - t.Run(name, func(t *testing.T) { - assert.Equal(t, tt.want, Enabled(tt.args.l)) - }) - } +func TestSetLevelAndCheckEnable(t *testing.T) { + require.NoError(t, Setup(&Config{ + Attributes: []slog.Attr{}, + IsJSON: true, + Level: "info", + })) + + assert.True(t, Enabled(slog.LevelInfo)) + assert.False(t, Enabled(slog.LevelDebug)) + + require.NoError(t, SetLevel("debug")) + + assert.True(t, Enabled(slog.LevelDebug)) } func TestErrorAttr(t *testing.T) { diff --git a/observability/observability.go b/observability/observability.go index dc8644850..436329988 100644 --- a/observability/observability.go +++ b/observability/observability.go @@ -49,11 +49,23 @@ func StatusAttribute(err error) attribute.KeyValue { return SucceededAttribute } +// Config represents the configuration for setting up traces, metrics and logs. +type Config struct { + Name string + Version string + LogConfig log.Config +} + // Setup initializes OpenTelemetry's traces and metrics. // It creates a resource with the given name and version, sets up the metric and trace providers, // and returns a Provider containing the initialized providers. -func Setup(ctx context.Context, name, version string) (*Provider, error) { - res, err := createResource(name, version) +func Setup(ctx context.Context, cfg Config) (*Provider, error) { + err := log.Setup(&cfg.LogConfig) + if err != nil { + return nil, err + } + + res, err := createResource(cfg.Name, cfg.Version) if err != nil { return nil, err } @@ -64,7 +76,7 @@ func Setup(ctx context.Context, name, version string) (*Provider, error) { if err != nil { return nil, err } - traceProvider, err := patrontrace.SetupGRPC(ctx, name, res) + traceProvider, err := patrontrace.SetupGRPC(ctx, cfg.Name, res) if err != nil { return nil, err } diff --git a/options.go b/options.go index 36f591988..ebf43f3e5 100644 --- a/options.go +++ b/options.go @@ -34,7 +34,7 @@ func WithLogFields(attrs ...slog.Attr) OptionFunc { continue } - svc.logConfig.attrs = append(svc.logConfig.attrs, attr) + svc.observabilityCfg.LogConfig.Attributes = append(svc.observabilityCfg.LogConfig.Attributes, attr) } return nil @@ -44,7 +44,7 @@ func WithLogFields(attrs ...slog.Attr) OptionFunc { // WithJSONLogger to use Go's slog package. func WithJSONLogger() OptionFunc { return func(svc *Service) error { - svc.logConfig.json = true + svc.observabilityCfg.LogConfig.IsJSON = true return nil } } diff --git a/options_test.go b/options_test.go index b1d250a24..6741cd79b 100644 --- a/options_test.go +++ b/options_test.go @@ -5,39 +5,46 @@ import ( "log/slog" "testing" + "github.com/beatlabs/patron/observability" + "github.com/beatlabs/patron/observability/log" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestLogFields(t *testing.T) { - defaultAttrs := defaultLogAttrs("test", "1.0") attrs := []slog.Attr{slog.String("key", "value")} - attrs1 := defaultLogAttrs("name1", "version1") + attrs1 := []slog.Attr{slog.String("name1", "version1")} + + expectedSuccess := observability.Config{LogConfig: log.Config{ + Attributes: attrs, + }} + expectedNoOverwrite := observability.Config{LogConfig: log.Config{ + Attributes: []slog.Attr{slog.String("name1", "version2")}, + }} + type args struct { fields []slog.Attr } tests := map[string]struct { args args - want logConfig + want observability.Config expectedErr string }{ "empty attributes": {args: args{fields: nil}, expectedErr: "attributes are empty"}, - "success": {args: args{fields: attrs}, want: logConfig{attrs: append(defaultAttrs, attrs...)}}, - "no overwrite": {args: args{fields: attrs1}, want: logConfig{attrs: defaultAttrs}}, + "success": {args: args{fields: attrs}, want: expectedSuccess}, + "no overwrite": {args: args{fields: attrs1}, want: expectedNoOverwrite}, } for name, tt := range tests { t.Run(name, func(t *testing.T) { svc := &Service{ - logConfig: logConfig{ - attrs: defaultAttrs, - }, + observabilityCfg: observability.Config{}, } err := WithLogFields(tt.args.fields...)(svc) if tt.expectedErr == "" { require.NoError(t, err) - assert.Equal(t, tt.want, svc.logConfig) + assert.Equal(t, tt.want, svc.observabilityCfg) } else { require.EqualError(t, err, tt.expectedErr) } diff --git a/service.go b/service.go index 48c64755a..57b1a76ec 100644 --- a/service.go +++ b/service.go @@ -32,10 +32,11 @@ type Service struct { version string termSig chan os.Signal sighupHandler func() - logConfig logConfig + observabilityCfg observability.Config observabilityProvider *observability.Provider } +// New creates a new Service instance. func New(name, version string, options ...OptionFunc) (*Service, error) { if name == "" { return nil, errors.New("name is required") @@ -46,7 +47,10 @@ func New(name, version string, options ...OptionFunc) (*Service, error) { var err error ctx := context.Background() - observabilityProvider, err := observability.Setup(ctx, name, version) + + cfg := observabilityConfig(name, version) + + observabilityProvider, err := observability.Setup(ctx, cfg) if err != nil { return nil, err } @@ -58,10 +62,7 @@ func New(name, version string, options ...OptionFunc) (*Service, error) { sighupHandler: func() { slog.Debug("sighup received: nothing setup") }, - logConfig: logConfig{ - attrs: defaultLogAttrs(name, version), - json: false, - }, + observabilityCfg: cfg, observabilityProvider: observabilityProvider, } @@ -78,12 +79,12 @@ func New(name, version string, options ...OptionFunc) (*Service, error) { return nil, errors.Join(optionErrors...) } - setupLogging(s.logConfig) s.setupOSSignal() return s, nil } +// Run starts the service with the provided components. func (s *Service) Run(ctx context.Context, components ...Component) error { if len(components) == 0 || components[0] == nil { return errors.New("components are empty or nil") @@ -149,51 +150,31 @@ func (s *Service) waitTermination(chErr <-chan error) error { } } -type logConfig struct { - attrs []slog.Attr - json bool -} - -func getLogLevel() slog.Level { +func observabilityConfig(name, version string) observability.Config { + var lvl string lvl, ok := os.LookupEnv("PATRON_LOG_LEVEL") if !ok { - return slog.LevelInfo - } - - lv := slog.LevelVar{} - if err := lv.UnmarshalText([]byte(lvl)); err != nil { - return slog.LevelInfo + lvl = "info" } - return lv.Level() -} - -func defaultLogAttrs(name, version string) []slog.Attr { hostname, err := os.Hostname() if err != nil { hostname = host } - return []slog.Attr{ + attrs := []slog.Attr{ slog.String(srv, name), slog.String(ver, version), slog.String(host, hostname), } -} - -func setupLogging(lc logConfig) { - ho := &slog.HandlerOptions{ - AddSource: true, - Level: getLogLevel(), - } - var hnd slog.Handler - - if lc.json { - hnd = slog.NewJSONHandler(os.Stderr, ho) - } else { - hnd = slog.NewTextHandler(os.Stderr, ho) + return observability.Config{ + Name: name, + Version: version, + LogConfig: log.Config{ + Attributes: attrs, + IsJSON: false, + Level: lvl, + }, } - - slog.New(hnd.WithAttrs(lc.attrs)) } diff --git a/service_test.go b/service_test.go index 5ade10161..d99510b4a 100644 --- a/service_test.go +++ b/service_test.go @@ -58,22 +58,3 @@ func TestNew(t *testing.T) { }) } } - -func Test_getLogLevel(t *testing.T) { - tests := map[string]struct { - lvl string - want slog.Level - }{ - "debug": {lvl: "debug", want: slog.LevelDebug}, - "info": {lvl: "info", want: slog.LevelInfo}, - "warn": {lvl: "warn", want: slog.LevelWarn}, - "error": {lvl: "error", want: slog.LevelError}, - "invalid level": {lvl: "invalid", want: slog.LevelInfo}, - } - for name, tt := range tests { - t.Run(name, func(t *testing.T) { - t.Setenv("PATRON_LOG_LEVEL", tt.lvl) - assert.Equal(t, tt.want, getLogLevel()) - }) - } -}