Skip to content

Commit

Permalink
Allow PG* env vars as an alternative to DATABASE_URL
Browse files Browse the repository at this point in the history
Accept the standard set of `PG*` env vars as an alternative to database
configuration (e.g. `PGHOST`, `PGDATABASE`, etc.). This is mostly driven
by having done something similar for the CLI in [1], but was also
requested in #249. This turns out to be quite easy to do because pgx
does all the heavy lifting.

As noted in [1], a bonus of this is that it adds some additional
configuration options that aren't very easily doable right now, for
example around the use of an SSL certificate to connect to Postgres. We
get automatic support for these vars:

* `PGSSLCERT`
* `PGSSLKEY`
* `PGSSLROOTCERT`
* `PGSSLPASSWORD`

As part of this I also ended up rearranging some things in `main.go`.
Not strongly married to this design, but the idea is to get it into a
place where we can write tests for it, which previously wasn't possible.

[1] riverqueue/river#702
  • Loading branch information
brandur committed Dec 24, 2024
1 parent ba06b39 commit 008ede2
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 96 deletions.
34 changes: 0 additions & 34 deletions cmd/riverui/logger.go

This file was deleted.

147 changes: 85 additions & 62 deletions cmd/riverui/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"errors"
"flag"
"fmt"
"log"
"log/slog"
"net/http"
"os"
Expand All @@ -26,69 +25,103 @@ import (

func main() {
ctx := context.Background()
initLogger()
os.Exit(initAndServe(ctx))
}

func initAndServe(ctx context.Context) int {
var (
devMode bool
liveFS bool
pathPrefix string
)
_, liveFS = os.LookupEnv("LIVE_FS")
_, devMode = os.LookupEnv("DEV")
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{
Level: getLogLevel(),
}))

var pathPrefix string
flag.StringVar(&pathPrefix, "prefix", "/", "path prefix to use for the API and UI HTTP requests")
flag.Parse()

initRes, err := initServer(ctx, logger, pathPrefix)
if err != nil {
logger.ErrorContext(ctx, "Error initializing server", slog.String("error", err.Error()))
os.Exit(1)
}

if err := startAndListen(ctx, logger, initRes); err != nil {
logger.ErrorContext(ctx, "Error starting server", slog.String("error", err.Error()))
os.Exit(1)
}
}

// Translates either a "1" or "true" from env to a Go boolean.
func envBooleanTrue(val string) bool {
return val == "1" || val == "true"
}

func getLogLevel() slog.Level {
if envBooleanTrue(os.Getenv("RIVER_DEBUG")) {
return slog.LevelDebug
}

switch strings.ToLower(os.Getenv("RIVER_LOG_LEVEL")) {
case "debug":
return slog.LevelDebug
case "warn":
return slog.LevelWarn
case "error":
return slog.LevelError
default:
return slog.LevelInfo
}
}

type initServerResult struct {
dbPool *pgxpool.Pool // database pool; close must be deferred by caller!
httpServer *http.Server // HTTP server wrapping the UI server
logger *slog.Logger // application logger (also internalized in UI server)
uiServer *riverui.Server // River UI server
}

func initServer(ctx context.Context, logger *slog.Logger, pathPrefix string) (*initServerResult, error) {
if !strings.HasPrefix(pathPrefix, "/") || pathPrefix == "" {
logger.ErrorContext(ctx, "invalid path prefix", slog.String("prefix", pathPrefix))
return 1
return nil, fmt.Errorf("invalid path prefix: %s", pathPrefix)
}
pathPrefix = riverui.NormalizePathPrefix(pathPrefix)

var (
basicAuthUsername = os.Getenv("RIVER_BASIC_AUTH_USER")
basicAuthPassword = os.Getenv("RIVER_BASIC_AUTH_PASS")
corsOrigins = strings.Split(os.Getenv("CORS_ORIGINS"), ",")
dbURL = mustEnv("DATABASE_URL")
databaseURL = os.Getenv("DATABASE_URL")
devMode = envBooleanTrue(os.Getenv("DEV"))
host = os.Getenv("RIVER_HOST") // may be left empty to bind to all local interfaces
otelEnabled = os.Getenv("OTEL_ENABLED") == "true"
liveFS = envBooleanTrue(os.Getenv("LIVE_FS"))
otelEnabled = envBooleanTrue(os.Getenv("OTEL_ENABLED"))
port = cmp.Or(os.Getenv("PORT"), "8080")
)

dbPool, err := getDBPool(ctx, dbURL)
if databaseURL == "" && os.Getenv("PGDATABASE") == "" {
return nil, errors.New("expect to have DATABASE_URL or database configuration in standard PG* env vars like PGDATABASE/PGHOST/PGPORT/PGUSER/PGPASSWORD")
}

poolConfig, err := pgxpool.ParseConfig(databaseURL)
if err != nil {
return nil, fmt.Errorf("error parsing db config: %w", err)
}

dbPool, err := pgxpool.NewWithConfig(ctx, poolConfig)
if err != nil {
logger.ErrorContext(ctx, "error connecting to db", slog.String("error", err.Error()))
return 1
return nil, fmt.Errorf("error connecting to db: %w", err)
}
defer dbPool.Close()

client, err := river.NewClient(riverpgxv5.New(dbPool), &river.Config{})
if err != nil {
logger.ErrorContext(ctx, "error creating river client", slog.String("error", err.Error()))
return 1
return nil, err
}

handlerOpts := &riverui.ServerOpts{
uiServer, err := riverui.NewServer(&riverui.ServerOpts{
Client: client,
DB: dbPool,
DevMode: devMode,
LiveFS: liveFS,
Logger: logger,
Prefix: pathPrefix,
}

server, err := riverui.NewServer(handlerOpts)
})
if err != nil {
logger.ErrorContext(ctx, "error creating handler", slog.String("error", err.Error()))
return 1
}

if err = server.Start(ctx); err != nil {
logger.ErrorContext(ctx, "error starting UI server", slog.String("error", err.Error()))
return 1
return nil, err
}

corsHandler := cors.New(cors.Options{
Expand All @@ -109,40 +142,30 @@ func initAndServe(ctx context.Context) int {
middlewareStack.Use(&authMiddleware{username: basicAuthUsername, password: basicAuthPassword})
}

srv := &http.Server{
Addr: host + ":" + port,
Handler: middlewareStack.Mount(server),
ReadHeaderTimeout: 5 * time.Second,
}
return &initServerResult{
dbPool: dbPool,
httpServer: &http.Server{
Addr: host + ":" + port,
Handler: middlewareStack.Mount(uiServer),
ReadHeaderTimeout: 5 * time.Second,
},
logger: logger,
uiServer: uiServer,
}, nil
}

log.Printf("starting server on %s", srv.Addr)
func startAndListen(ctx context.Context, logger *slog.Logger, initRes *initServerResult) error {
defer initRes.dbPool.Close()

if err = srv.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
logger.ErrorContext(ctx, "error from ListenAndServe", slog.String("error", err.Error()))
return 1
if err := initRes.uiServer.Start(ctx); err != nil {
return err
}

return 0
}
logger.InfoContext(ctx, "Starting server", slog.String("addr", initRes.httpServer.Addr))

func getDBPool(ctx context.Context, dbURL string) (*pgxpool.Pool, error) {
poolConfig, err := pgxpool.ParseConfig(dbURL)
if err != nil {
return nil, fmt.Errorf("error parsing db config: %w", err)
if err := initRes.httpServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
return err
}

dbPool, err := pgxpool.NewWithConfig(ctx, poolConfig)
if err != nil {
return nil, fmt.Errorf("error connecting to db: %w", err)
}
return dbPool, nil
}

func mustEnv(name string) string {
val := os.Getenv(name)
if val == "" {
logger.Error("missing required env var", slog.String("name", name))
os.Exit(1)
}
return val
return nil
}
61 changes: 61 additions & 0 deletions cmd/riverui/main_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package main

import (
"cmp"
"context"
"net/url"
"os"
"testing"

"github.com/stretchr/testify/require"

"github.com/riverqueue/river/rivershared/riversharedtest"
)

func TestInitServer(t *testing.T) {
var (
ctx = context.Background()
databaseURL = cmp.Or(os.Getenv("TEST_DATABASE_URL"), "postgres://localhost/river_test")
)

t.Setenv("DEV", "true")

type testBundle struct{}

setup := func(t *testing.T) (*initServerResult, *testBundle) {
t.Helper()

initRes, err := initServer(ctx, riversharedtest.Logger(t), "/")
require.NoError(t, err)
t.Cleanup(initRes.dbPool.Close)

return initRes, &testBundle{}
}

t.Run("WithDatabaseURL", func(t *testing.T) {
t.Setenv("DATABASE_URL", databaseURL)

initRes, _ := setup(t)

_, err := initRes.dbPool.Exec(ctx, "SELECT 1")
require.NoError(t, err)
})

t.Run("WithPGEnvVars", func(t *testing.T) {
parsedURL, err := url.Parse(databaseURL)
require.NoError(t, err)

t.Setenv("PGDATABASE", parsedURL.Path[1:])
t.Setenv("PGHOST", parsedURL.Hostname())
pass, _ := parsedURL.User.Password()
t.Setenv("PGPASSWORD", pass)
t.Setenv("PGPORT", cmp.Or(parsedURL.Port(), "5432"))
t.Setenv("PGSSLMODE", parsedURL.Query().Get("sslmode"))
t.Setenv("PGUSER", parsedURL.User.Username())

initRes, _ := setup(t)

_, err = initRes.dbPool.Exec(ctx, "SELECT 1")
require.NoError(t, err)
})
}

0 comments on commit 008ede2

Please sign in to comment.