diff --git a/cmd/riverui/auth_middleware.go b/cmd/riverui/auth_middleware.go new file mode 100644 index 0000000..49ad668 --- /dev/null +++ b/cmd/riverui/auth_middleware.go @@ -0,0 +1,34 @@ +package main + +import ( + "crypto/subtle" + "net/http" + "os" +) + +func installAuthMiddleware(next http.Handler) http.Handler { + username := os.Getenv("RIVER_BASIC_AUTH_USER") + password := os.Getenv("RIVER_BASIC_AUTH_PASS") + + if username == "" || password == "" { + return next + } + + return http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) { + if isReqAuthorized(req, username, password) { + next.ServeHTTP(res, req) + return + } + + res.Header().Set("WWW-Authenticate", "Basic realm=\"riverui\"") + http.Error(res, "Unauthorized", http.StatusUnauthorized) + }) +} + +func isReqAuthorized(req *http.Request, username, password string) bool { + reqUsername, reqPassword, ok := req.BasicAuth() + + return ok && + subtle.ConstantTimeCompare([]byte(reqUsername), []byte(username)) == 1 && + subtle.ConstantTimeCompare([]byte(reqPassword), []byte(password)) == 1 +} diff --git a/cmd/riverui/cors_middleware.go b/cmd/riverui/cors_middleware.go new file mode 100644 index 0000000..3c65c69 --- /dev/null +++ b/cmd/riverui/cors_middleware.go @@ -0,0 +1,20 @@ +package main + +import ( + "net/http" + "os" + "strings" + + "github.com/rs/cors" +) + +func installCorsMiddleware(next http.Handler) http.Handler { + origins := strings.Split(os.Getenv("CORS_ORIGINS"), ",") + + handler := cors.New(cors.Options{ + AllowedMethods: []string{"GET", "HEAD", "POST", "PUT"}, + AllowedOrigins: origins, + }) + + return handler.Handler(next) +} diff --git a/cmd/riverui/logger_middleware.go b/cmd/riverui/logger_middleware.go new file mode 100644 index 0000000..0df2438 --- /dev/null +++ b/cmd/riverui/logger_middleware.go @@ -0,0 +1,16 @@ +package main + +import ( + sloghttp "github.com/samber/slog-http" + "net/http" + "os" +) + +func installLoggerMiddleware(next http.Handler) http.Handler { + otelEnabled := os.Getenv("OTEL_ENABLED") == "true" + + return sloghttp.NewWithConfig(logger, sloghttp.Config{ + WithSpanID: otelEnabled, + WithTraceID: otelEnabled, + })(next) +} diff --git a/cmd/riverui/main.go b/cmd/riverui/main.go index 9d90e03..4b1d2c1 100644 --- a/cmd/riverui/main.go +++ b/cmd/riverui/main.go @@ -14,7 +14,6 @@ import ( "time" "github.com/jackc/pgx/v5/pgxpool" - "github.com/rs/cors" sloghttp "github.com/samber/slog-http" "github.com/riverqueue/river" @@ -48,11 +47,9 @@ func initAndServe(ctx context.Context) int { pathPrefix = riverui.NormalizePathPrefix(pathPrefix) var ( - corsOrigins = strings.Split(os.Getenv("CORS_ORIGINS"), ",") - dbURL = mustEnv("DATABASE_URL") - host = os.Getenv("RIVER_HOST") // may be left empty to bind to all local interfaces - otelEnabled = os.Getenv("OTEL_ENABLED") == "true" - port = cmp.Or(os.Getenv("PORT"), "8080") + dbURL = mustEnv("DATABASE_URL") + host = os.Getenv("RIVER_HOST") // may be left empty to bind to all local interfaces + port = cmp.Or(os.Getenv("PORT"), "8080") ) dbPool, err := getDBPool(ctx, dbURL) @@ -62,11 +59,6 @@ func initAndServe(ctx context.Context) int { } defer dbPool.Close() - corsHandler := cors.New(cors.Options{ - AllowedMethods: []string{"GET", "HEAD", "POST", "PUT"}, - AllowedOrigins: corsOrigins, - }) - client, err := river.NewClient(riverpgxv5.New(dbPool), &river.Config{}) if err != nil { logger.ErrorContext(ctx, "error creating river client", slog.String("error", err.Error())) @@ -88,21 +80,19 @@ func initAndServe(ctx context.Context) int { return 1 } - if err := server.Start(ctx); err != nil { + if err = server.Start(ctx); err != nil { logger.ErrorContext(ctx, "error starting UI server", slog.String("error", err.Error())) return 1 } - logHandler := sloghttp.Recovery(server) - config := sloghttp.Config{ - WithSpanID: otelEnabled, - WithTraceID: otelEnabled, - } - wrappedHandler := sloghttp.NewWithConfig(logger, config)(corsHandler.Handler(logHandler)) + srvHandler := sloghttp.Recovery(server) + srvHandler = installAuthMiddleware(srvHandler) + srvHandler = installCorsMiddleware(srvHandler) + srvHandler = installLoggerMiddleware(srvHandler) srv := &http.Server{ Addr: host + ":" + port, - Handler: wrappedHandler, + Handler: srvHandler, ReadHeaderTimeout: 5 * time.Second, } diff --git a/docs/README.md b/docs/README.md index 190804d..ea7e10b 100644 --- a/docs/README.md +++ b/docs/README.md @@ -55,6 +55,11 @@ The `riverui` command utilizes the `RIVER_LOG_LEVEL` environment variable to con * `warn` * `error` +### Basic HTTP Authentication + +The `riverui` supports basic HTTP authentication to protect access to the UI. +To enable it, set the `RIVER_BASIC_AUTH_USER` and `RIVER_BASIC_AUTH_PASS` environment variables. + ## Development See [developing River UI](./development.md).