diff --git a/component/http/observability.go b/component/http/observability.go index 1d30c05cc..16da648ee 100644 --- a/component/http/observability.go +++ b/component/http/observability.go @@ -5,6 +5,8 @@ import ( "fmt" "net/http" "net/http/pprof" + + "github.com/beatlabs/patron/observability/log" ) func ProfilingRoutes(enableExpVar bool) []*Route { @@ -47,3 +49,24 @@ func expVars(w http.ResponseWriter, _ *http.Request) { }) _, _ = fmt.Fprintf(w, "\n}\n") } + +// LoggingRoutes returns a routes relates to logs. +func LoggingRoutes() []*Route { + handler := func(w http.ResponseWriter, r *http.Request) { + lvl := r.PathValue("level") + if lvl == "" { + http.Error(w, "missing log level", http.StatusBadRequest) + return + } + + err := log.SetLevel(lvl) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + w.WriteHeader(http.StatusOK) + } + + route, _ := NewRoute("POST /debug/log/{level}", handler) + return []*Route{route} +} diff --git a/component/http/observability_test.go b/component/http/observability_test.go index 0fd297403..5825ca110 100644 --- a/component/http/observability_test.go +++ b/component/http/observability_test.go @@ -7,6 +7,7 @@ import ( "net/http/httptest" "testing" + "github.com/beatlabs/patron/observability/log" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -18,7 +19,7 @@ type profilingTestCase struct { func TestProfilingRoutes(t *testing.T) { t.Run("without vars", func(t *testing.T) { - server := createProfilingServer(false) + server := createServer(false) defer server.Close() for name, tt := range createProfilingTestCases(false) { @@ -34,7 +35,7 @@ func TestProfilingRoutes(t *testing.T) { }) t.Run("with vars", func(t *testing.T) { - server := createProfilingServer(true) + server := createServer(true) defer server.Close() for name, tt := range createProfilingTestCases(true) { @@ -50,11 +51,14 @@ func TestProfilingRoutes(t *testing.T) { }) } -func createProfilingServer(enableExpVar bool) *httptest.Server { +func createServer(enableExpVar bool) *httptest.Server { mux := http.NewServeMux() for _, route := range ProfilingRoutes(enableExpVar) { mux.HandleFunc(route.path, route.handler) } + for _, route := range LoggingRoutes() { + mux.HandleFunc(route.path, route.handler) + } return httptest.NewServer(mux) } @@ -80,3 +84,39 @@ func createProfilingTestCases(enableExpVar bool) map[string]profilingTestCase { "vars": {"/debug/vars/", expVarWant}, } } + +func TestLoggingRoutes(t *testing.T) { + log.Setup(&log.Config{ + IsJSON: true, + Level: "info", + }) + server := createServer(true) + defer server.Close() + + t.Run("change log level to debug", func(t *testing.T) { + req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, fmt.Sprintf("%s/debug/log/debug", server.URL), nil) + require.NoError(t, err) + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + require.NoError(t, resp.Body.Close()) + }) + + t.Run("wrong log level", func(t *testing.T) { + req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, fmt.Sprintf("%s/debug/log/xxx", server.URL), nil) + require.NoError(t, err) + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + require.NoError(t, resp.Body.Close()) + }) + + t.Run("empty log level", func(t *testing.T) { + req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, fmt.Sprintf("%s/debug/log/", server.URL), nil) + require.NoError(t, err) + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + assert.Equal(t, http.StatusNotFound, resp.StatusCode) + require.NoError(t, resp.Body.Close()) + }) +} diff --git a/observability/log/log.go b/observability/log/log.go index 84daca233..76f0e5149 100644 --- a/observability/log/log.go +++ b/observability/log/log.go @@ -19,21 +19,26 @@ type ctxKey struct{} var logCfg *Config // Setup sets up the logger with the given configuration. -func Setup(cfg *Config) { +func Setup(cfg *Config) error { logCfg = cfg - setDefaultLogger(cfg) + return setDefaultLogger(cfg) } // SetLevel sets the logger level. -func SetLevel(lvl string) { +func SetLevel(lvl string) error { logCfg.Level = lvl - setDefaultLogger(logCfg) + return setDefaultLogger(logCfg) } -func setDefaultLogger(cfg *Config) { +func setDefaultLogger(cfg *Config) error { + lvl, err := level(cfg.Level) + if err != nil { + return err + } + ho := &slog.HandlerOptions{ AddSource: true, - Level: level(cfg.Level), + Level: lvl, } var hnd slog.Handler @@ -45,15 +50,16 @@ func setDefaultLogger(cfg *Config) { } slog.SetDefault(slog.New(hnd.WithAttrs(cfg.Attributes))) + return nil } -func level(lvl string) slog.Level { +func level(lvl string) (slog.Level, error) { lv := slog.LevelVar{} if err := lv.UnmarshalText([]byte(lvl)); err != nil { - return slog.LevelInfo + return slog.LevelInfo, err } - return lv.Level() + return lv.Level(), nil } // FromContext returns the logger, if it exists in the context, or nil. diff --git a/observability/log/log_test.go b/observability/log/log_test.go index 0fc1089b0..a42e32b62 100644 --- a/observability/log/log_test.go +++ b/observability/log/log_test.go @@ -16,7 +16,7 @@ func TestSetup(t *testing.T) { IsJSON: true, Level: "debug", } - Setup(cfg) + assert.NoError(t, Setup(cfg)) assert.NotNil(t, slog.Default()) }) @@ -26,7 +26,7 @@ func TestSetup(t *testing.T) { IsJSON: false, Level: "debug", } - Setup(cfg) + assert.NoError(t, Setup(cfg)) assert.NotNil(t, slog.Default()) }) } @@ -55,7 +55,7 @@ func TestSetLevelAndCheckEnable(t *testing.T) { assert.True(t, Enabled(slog.LevelInfo)) assert.False(t, Enabled(slog.LevelDebug)) - SetLevel("debug") + assert.NoError(t, SetLevel("debug")) assert.True(t, Enabled(slog.LevelDebug)) } diff --git a/observability/observability.go b/observability/observability.go index 4e316febf..436329988 100644 --- a/observability/observability.go +++ b/observability/observability.go @@ -60,7 +60,10 @@ type Config struct { // It creates a resource with the given name and version, sets up the metric and trace providers, // and returns a Provider containing the initialized providers. func Setup(ctx context.Context, cfg Config) (*Provider, error) { - log.Setup(&cfg.LogConfig) + err := log.Setup(&cfg.LogConfig) + if err != nil { + return nil, err + } res, err := createResource(cfg.Name, cfg.Version) if err != nil {