From 109c3637a0ebd1face96e85c60755698c45f328d Mon Sep 17 00:00:00 2001 From: Brandur Date: Tue, 24 Dec 2024 10:48:28 -0700 Subject: [PATCH] Allow `PG*` env vars as an alternative to `DATABASE_URL` Accept the standard set of `PG*` env vars as an alternative to database configuration (e.g. `PGHOST`, `PGDATABASE`, etc.). This is mostly driven by having done something similar for the CLI in [1], but was also requested in #249. This turns out to be quite easy to do because pgx does all the heavy lifting. As noted in [1], a bonus of this is that it adds some additional configuration options that aren't very easily doable right now, for example around the use of an SSL certificate to connect to Postgres. We get automatic support for these vars: * `PGSSLCERT` * `PGSSLKEY` * `PGSSLROOTCERT` * `PGSSLPASSWORD` As part of this I also ended up rearranging some things in `main.go`. Not strongly married to this design, but the idea is to get it into a place where we can write tests for it, which previously wasn't possible. Fixes #249. [1] https://github.com/riverqueue/river/pull/702 --- .github/workflows/ci.yaml | 3 - cmd/riverui/logger.go | 34 --------- cmd/riverui/main.go | 147 ++++++++++++++++++++++---------------- cmd/riverui/main_test.go | 65 +++++++++++++++++ 4 files changed, 150 insertions(+), 99 deletions(-) delete mode 100644 cmd/riverui/logger.go create mode 100644 cmd/riverui/main_test.go diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index dac1a1a..38cae2e 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -1,9 +1,6 @@ name: CI env: - # A suitable URL for the test database. - DATABASE_URL: postgres://postgres:postgres@127.0.0.1:5432/river_dev?sslmode=disable - # Test database. TEST_DATABASE_URL: postgres://postgres:postgres@127.0.0.1:5432/river_test?sslmode=disable diff --git a/cmd/riverui/logger.go b/cmd/riverui/logger.go deleted file mode 100644 index 5e4caa8..0000000 --- a/cmd/riverui/logger.go +++ /dev/null @@ -1,34 +0,0 @@ -package main - -import ( - "log/slog" - "os" - "strings" -) - -var logger *slog.Logger //nolint:gochecknoglobals - -func initLogger() { - options := &slog.HandlerOptions{Level: getLogLevel()} - logger = slog.New(slog.NewTextHandler(os.Stdout, options)) -} - -func getLogLevel() slog.Level { - debugEnv := os.Getenv("RIVER_DEBUG") - if debugEnv == "1" || debugEnv == "true" { - return slog.LevelDebug - } - - env := strings.ToLower(os.Getenv("RIVER_LOG_LEVEL")) - - switch env { - case "debug": - return slog.LevelDebug - case "warn": - return slog.LevelWarn - case "error": - return slog.LevelError - default: - return slog.LevelInfo - } -} diff --git a/cmd/riverui/main.go b/cmd/riverui/main.go index 7c4d06a..bf0cc51 100644 --- a/cmd/riverui/main.go +++ b/cmd/riverui/main.go @@ -6,7 +6,6 @@ import ( "errors" "flag" "fmt" - "log" "log/slog" "net/http" "os" @@ -26,25 +25,59 @@ import ( func main() { ctx := context.Background() - initLogger() - os.Exit(initAndServe(ctx)) -} -func initAndServe(ctx context.Context) int { - var ( - devMode bool - liveFS bool - pathPrefix string - ) - _, liveFS = os.LookupEnv("LIVE_FS") - _, devMode = os.LookupEnv("DEV") + logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ + Level: getLogLevel(), + })) + var pathPrefix string flag.StringVar(&pathPrefix, "prefix", "/", "path prefix to use for the API and UI HTTP requests") flag.Parse() + initRes, err := initServer(ctx, logger, pathPrefix) + if err != nil { + logger.ErrorContext(ctx, "Error initializing server", slog.String("error", err.Error())) + os.Exit(1) + } + + if err := startAndListen(ctx, logger, initRes); err != nil { + logger.ErrorContext(ctx, "Error starting server", slog.String("error", err.Error())) + os.Exit(1) + } +} + +// Translates either a "1" or "true" from env to a Go boolean. +func envBooleanTrue(val string) bool { + return val == "1" || val == "true" +} + +func getLogLevel() slog.Level { + if envBooleanTrue(os.Getenv("RIVER_DEBUG")) { + return slog.LevelDebug + } + + switch strings.ToLower(os.Getenv("RIVER_LOG_LEVEL")) { + case "debug": + return slog.LevelDebug + case "warn": + return slog.LevelWarn + case "error": + return slog.LevelError + default: + return slog.LevelInfo + } +} + +type initServerResult struct { + dbPool *pgxpool.Pool // database pool; close must be deferred by caller! + httpServer *http.Server // HTTP server wrapping the UI server + logger *slog.Logger // application logger (also internalized in UI server) + uiServer *riverui.Server // River UI server +} + +func initServer(ctx context.Context, logger *slog.Logger, pathPrefix string) (*initServerResult, error) { if !strings.HasPrefix(pathPrefix, "/") || pathPrefix == "" { - logger.ErrorContext(ctx, "invalid path prefix", slog.String("prefix", pathPrefix)) - return 1 + return nil, fmt.Errorf("invalid path prefix: %s", pathPrefix) } pathPrefix = riverui.NormalizePathPrefix(pathPrefix) @@ -52,43 +85,43 @@ func initAndServe(ctx context.Context) int { basicAuthUsername = os.Getenv("RIVER_BASIC_AUTH_USER") basicAuthPassword = os.Getenv("RIVER_BASIC_AUTH_PASS") corsOrigins = strings.Split(os.Getenv("CORS_ORIGINS"), ",") - dbURL = mustEnv("DATABASE_URL") + databaseURL = os.Getenv("DATABASE_URL") + devMode = envBooleanTrue(os.Getenv("DEV")) host = os.Getenv("RIVER_HOST") // may be left empty to bind to all local interfaces - otelEnabled = os.Getenv("OTEL_ENABLED") == "true" + liveFS = envBooleanTrue(os.Getenv("LIVE_FS")) + otelEnabled = envBooleanTrue(os.Getenv("OTEL_ENABLED")) port = cmp.Or(os.Getenv("PORT"), "8080") ) - dbPool, err := getDBPool(ctx, dbURL) + if databaseURL == "" && os.Getenv("PGDATABASE") == "" { + return nil, errors.New("expect to have DATABASE_URL or database configuration in standard PG* env vars like PGDATABASE/PGHOST/PGPORT/PGUSER/PGPASSWORD") + } + + poolConfig, err := pgxpool.ParseConfig(databaseURL) + if err != nil { + return nil, fmt.Errorf("error parsing db config: %w", err) + } + + dbPool, err := pgxpool.NewWithConfig(ctx, poolConfig) if err != nil { - logger.ErrorContext(ctx, "error connecting to db", slog.String("error", err.Error())) - return 1 + return nil, fmt.Errorf("error connecting to db: %w", err) } - defer dbPool.Close() client, err := river.NewClient(riverpgxv5.New(dbPool), &river.Config{}) if err != nil { - logger.ErrorContext(ctx, "error creating river client", slog.String("error", err.Error())) - return 1 + return nil, err } - handlerOpts := &riverui.ServerOpts{ + uiServer, err := riverui.NewServer(&riverui.ServerOpts{ Client: client, DB: dbPool, DevMode: devMode, LiveFS: liveFS, Logger: logger, Prefix: pathPrefix, - } - - server, err := riverui.NewServer(handlerOpts) + }) if err != nil { - logger.ErrorContext(ctx, "error creating handler", slog.String("error", err.Error())) - return 1 - } - - if err = server.Start(ctx); err != nil { - logger.ErrorContext(ctx, "error starting UI server", slog.String("error", err.Error())) - return 1 + return nil, err } corsHandler := cors.New(cors.Options{ @@ -109,40 +142,30 @@ func initAndServe(ctx context.Context) int { middlewareStack.Use(&authMiddleware{username: basicAuthUsername, password: basicAuthPassword}) } - srv := &http.Server{ - Addr: host + ":" + port, - Handler: middlewareStack.Mount(server), - ReadHeaderTimeout: 5 * time.Second, - } + return &initServerResult{ + dbPool: dbPool, + httpServer: &http.Server{ + Addr: host + ":" + port, + Handler: middlewareStack.Mount(uiServer), + ReadHeaderTimeout: 5 * time.Second, + }, + logger: logger, + uiServer: uiServer, + }, nil +} - log.Printf("starting server on %s", srv.Addr) +func startAndListen(ctx context.Context, logger *slog.Logger, initRes *initServerResult) error { + defer initRes.dbPool.Close() - if err = srv.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { - logger.ErrorContext(ctx, "error from ListenAndServe", slog.String("error", err.Error())) - return 1 + if err := initRes.uiServer.Start(ctx); err != nil { + return err } - return 0 -} + logger.InfoContext(ctx, "Starting server", slog.String("addr", initRes.httpServer.Addr)) -func getDBPool(ctx context.Context, dbURL string) (*pgxpool.Pool, error) { - poolConfig, err := pgxpool.ParseConfig(dbURL) - if err != nil { - return nil, fmt.Errorf("error parsing db config: %w", err) + if err := initRes.httpServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { + return err } - dbPool, err := pgxpool.NewWithConfig(ctx, poolConfig) - if err != nil { - return nil, fmt.Errorf("error connecting to db: %w", err) - } - return dbPool, nil -} - -func mustEnv(name string) string { - val := os.Getenv(name) - if val == "" { - logger.Error("missing required env var", slog.String("name", name)) - os.Exit(1) - } - return val + return nil } diff --git a/cmd/riverui/main_test.go b/cmd/riverui/main_test.go new file mode 100644 index 0000000..979a722 --- /dev/null +++ b/cmd/riverui/main_test.go @@ -0,0 +1,65 @@ +package main + +import ( + "cmp" + "context" + "net/url" + "os" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/riverqueue/river/rivershared/riversharedtest" +) + +func TestInitServer(t *testing.T) { + var ( + ctx = context.Background() + databaseURL = cmp.Or(os.Getenv("TEST_DATABASE_URL"), "postgres://localhost/river_test") + ) + + t.Setenv("DEV", "true") + + type testBundle struct{} + + setup := func(t *testing.T) (*initServerResult, *testBundle) { + t.Helper() + + initRes, err := initServer(ctx, riversharedtest.Logger(t), "/") + require.NoError(t, err) + t.Cleanup(initRes.dbPool.Close) + + return initRes, &testBundle{} + } + + t.Run("WithDatabaseURL", func(t *testing.T) { + t.Setenv("DATABASE_URL", databaseURL) + + initRes, _ := setup(t) + + _, err := initRes.dbPool.Exec(ctx, "SELECT 1") + require.NoError(t, err) + }) + + t.Run("WithPGEnvVars", func(t *testing.T) { + // Verify that DATABASE_URL is indeed not set to be sure we're taking + // the configuration branch we expect to be taking. + require.Empty(t, os.Getenv("DATABASE_URL")) + + parsedURL, err := url.Parse(databaseURL) + require.NoError(t, err) + + t.Setenv("PGDATABASE", parsedURL.Path[1:]) + t.Setenv("PGHOST", parsedURL.Hostname()) + pass, _ := parsedURL.User.Password() + t.Setenv("PGPASSWORD", pass) + t.Setenv("PGPORT", cmp.Or(parsedURL.Port(), "5432")) + t.Setenv("PGSSLMODE", parsedURL.Query().Get("sslmode")) + t.Setenv("PGUSER", parsedURL.User.Username()) + + initRes, _ := setup(t) + + _, err = initRes.dbPool.Exec(ctx, "SELECT 1") + require.NoError(t, err) + }) +}