diff --git a/routes/index.go b/routes/index.go index 58e085611..9f17d1a73 100644 --- a/routes/index.go +++ b/routes/index.go @@ -231,6 +231,7 @@ func initChi() *chi.Mux { r.Use(middleware.RequestID) r.Use(middleware.Logger) r.Use(middleware.Recoverer) + r.Use(utils.RouteBasedUUIDMiddleware) r.Use(internalServerErrorHandler) cors := cors.New(cors.Options{ AllowedOrigins: []string{"*"}, diff --git a/utils/logger.go b/utils/logger.go index 46616f016..a0ba76c0a 100644 --- a/utils/logger.go +++ b/utils/logger.go @@ -1,10 +1,12 @@ package utils import ( + "github.com/google/uuid" + "github.com/stakwork/sphinx-tribes/config" "log" + "net/http" "os" - - "github.com/stakwork/sphinx-tribes/config" + "sync" ) type Logger struct { @@ -13,6 +15,8 @@ type Logger struct { errorLogger *log.Logger debugLogger *log.Logger machineLogger *log.Logger + mu sync.Mutex + requestUUID string } var Log = Logger{ @@ -23,32 +27,68 @@ var Log = Logger{ machineLogger: log.New(os.Stdout, "MACHINE: ", log.Ldate|log.Ltime|log.Lshortfile), } +func (l *Logger) SetRequestUUID(uuidString string) { + l.mu.Lock() + defer l.mu.Unlock() + l.requestUUID = uuidString +} + +func (l *Logger) ClearRequestUUID() { + l.mu.Lock() + defer l.mu.Unlock() + l.requestUUID = "" +} + +func RouteBasedUUIDMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + uuid := uuid.NewString() + Log.SetRequestUUID(uuid) + + defer Log.ClearRequestUUID() + + next.ServeHTTP(w, r) + }) +} + +func (l *Logger) logWithPrefix(logger *log.Logger, format string, v ...interface{}) { + l.mu.Lock() + + requestUUID := l.requestUUID + l.mu.Unlock() + + if requestUUID == "" { + logger.Printf(format, v...) + } else { + logger.Printf("["+requestUUID+"] "+format, v...) + } +} + func (l *Logger) Machine(format string, v ...interface{}) { if config.LogLevel == "MACHINE" { - l.machineLogger.Printf(format, v...) + l.logWithPrefix(l.machineLogger, format, v...) } } func (l *Logger) Debug(format string, v ...interface{}) { if config.LogLevel == "MACHINE" || config.LogLevel == "DEBUG" { - l.debugLogger.Printf(format, v...) + l.logWithPrefix(l.debugLogger, format, v...) } } func (l *Logger) Info(format string, v ...interface{}) { if config.LogLevel == "MACHINE" || config.LogLevel == "DEBUG" || config.LogLevel == "INFO" { - l.infoLogger.Printf(format, v...) + l.logWithPrefix(l.infoLogger, format, v...) } } func (l *Logger) Warning(format string, v ...interface{}) { if config.LogLevel == "MACHINE" || config.LogLevel == "DEBUG" || config.LogLevel == "INFO" || config.LogLevel == "WARNING" { - l.warningLogger.Printf(format, v...) + l.logWithPrefix(l.warningLogger, format, v...) } } func (l *Logger) Error(format string, v ...interface{}) { if config.LogLevel == "MACHINE" || config.LogLevel == "DEBUG" || config.LogLevel == "INFO" || config.LogLevel == "WARNING" || config.LogLevel == "ERROR" { - l.errorLogger.Printf(format, v...) + l.logWithPrefix(l.errorLogger, format, v...) } }