diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index dac1a1a..38cae2e 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -1,9 +1,6 @@ name: CI env: - # A suitable URL for the test database. - DATABASE_URL: postgres://postgres:postgres@127.0.0.1:5432/river_dev?sslmode=disable - # Test database. TEST_DATABASE_URL: postgres://postgres:postgres@127.0.0.1:5432/river_test?sslmode=disable diff --git a/cmd/riverui/logger.go b/cmd/riverui/logger.go deleted file mode 100644 index 5e4caa8..0000000 --- a/cmd/riverui/logger.go +++ /dev/null @@ -1,34 +0,0 @@ -package main - -import ( - "log/slog" - "os" - "strings" -) - -var logger *slog.Logger //nolint:gochecknoglobals - -func initLogger() { - options := &slog.HandlerOptions{Level: getLogLevel()} - logger = slog.New(slog.NewTextHandler(os.Stdout, options)) -} - -func getLogLevel() slog.Level { - debugEnv := os.Getenv("RIVER_DEBUG") - if debugEnv == "1" || debugEnv == "true" { - return slog.LevelDebug - } - - env := strings.ToLower(os.Getenv("RIVER_LOG_LEVEL")) - - switch env { - case "debug": - return slog.LevelDebug - case "warn": - return slog.LevelWarn - case "error": - return slog.LevelError - default: - return slog.LevelInfo - } -} diff --git a/cmd/riverui/main.go b/cmd/riverui/main.go index 7c4d06a..bf0cc51 100644 --- a/cmd/riverui/main.go +++ b/cmd/riverui/main.go @@ -6,7 +6,6 @@ import ( "errors" "flag" "fmt" - "log" "log/slog" "net/http" "os" @@ -26,25 +25,59 @@ import ( func main() { ctx := context.Background() - initLogger() - os.Exit(initAndServe(ctx)) -} -func initAndServe(ctx context.Context) int { - var ( - devMode bool - liveFS bool - pathPrefix string - ) - _, liveFS = os.LookupEnv("LIVE_FS") - _, devMode = os.LookupEnv("DEV") + logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ + Level: getLogLevel(), + })) + var pathPrefix string flag.StringVar(&pathPrefix, "prefix", "/", "path prefix to use for the API and UI HTTP requests") flag.Parse() + initRes, err := initServer(ctx, logger, pathPrefix) + if err != nil { + logger.ErrorContext(ctx, "Error initializing server", slog.String("error", err.Error())) + os.Exit(1) + } + + if err := startAndListen(ctx, logger, initRes); err != nil { + logger.ErrorContext(ctx, "Error starting server", slog.String("error", err.Error())) + os.Exit(1) + } +} + +// Translates either a "1" or "true" from env to a Go boolean. +func envBooleanTrue(val string) bool { + return val == "1" || val == "true" +} + +func getLogLevel() slog.Level { + if envBooleanTrue(os.Getenv("RIVER_DEBUG")) { + return slog.LevelDebug + } + + switch strings.ToLower(os.Getenv("RIVER_LOG_LEVEL")) { + case "debug": + return slog.LevelDebug + case "warn": + return slog.LevelWarn + case "error": + return slog.LevelError + default: + return slog.LevelInfo + } +} + +type initServerResult struct { + dbPool *pgxpool.Pool // database pool; close must be deferred by caller! + httpServer *http.Server // HTTP server wrapping the UI server + logger *slog.Logger // application logger (also internalized in UI server) + uiServer *riverui.Server // River UI server +} + +func initServer(ctx context.Context, logger *slog.Logger, pathPrefix string) (*initServerResult, error) { if !strings.HasPrefix(pathPrefix, "/") || pathPrefix == "" { - logger.ErrorContext(ctx, "invalid path prefix", slog.String("prefix", pathPrefix)) - return 1 + return nil, fmt.Errorf("invalid path prefix: %s", pathPrefix) } pathPrefix = riverui.NormalizePathPrefix(pathPrefix) @@ -52,43 +85,43 @@ func initAndServe(ctx context.Context) int { basicAuthUsername = os.Getenv("RIVER_BASIC_AUTH_USER") basicAuthPassword = os.Getenv("RIVER_BASIC_AUTH_PASS") corsOrigins = strings.Split(os.Getenv("CORS_ORIGINS"), ",") - dbURL = mustEnv("DATABASE_URL") + databaseURL = os.Getenv("DATABASE_URL") + devMode = envBooleanTrue(os.Getenv("DEV")) host = os.Getenv("RIVER_HOST") // may be left empty to bind to all local interfaces - otelEnabled = os.Getenv("OTEL_ENABLED") == "true" + liveFS = envBooleanTrue(os.Getenv("LIVE_FS")) + otelEnabled = envBooleanTrue(os.Getenv("OTEL_ENABLED")) port = cmp.Or(os.Getenv("PORT"), "8080") ) - dbPool, err := getDBPool(ctx, dbURL) + if databaseURL == "" && os.Getenv("PGDATABASE") == "" { + return nil, errors.New("expect to have DATABASE_URL or database configuration in standard PG* env vars like PGDATABASE/PGHOST/PGPORT/PGUSER/PGPASSWORD") + } + + poolConfig, err := pgxpool.ParseConfig(databaseURL) + if err != nil { + return nil, fmt.Errorf("error parsing db config: %w", err) + } + + dbPool, err := pgxpool.NewWithConfig(ctx, poolConfig) if err != nil { - logger.ErrorContext(ctx, "error connecting to db", slog.String("error", err.Error())) - return 1 + return nil, fmt.Errorf("error connecting to db: %w", err) } - defer dbPool.Close() client, err := river.NewClient(riverpgxv5.New(dbPool), &river.Config{}) if err != nil { - logger.ErrorContext(ctx, "error creating river client", slog.String("error", err.Error())) - return 1 + return nil, err } - handlerOpts := &riverui.ServerOpts{ + uiServer, err := riverui.NewServer(&riverui.ServerOpts{ Client: client, DB: dbPool, DevMode: devMode, LiveFS: liveFS, Logger: logger, Prefix: pathPrefix, - } - - server, err := riverui.NewServer(handlerOpts) + }) if err != nil { - logger.ErrorContext(ctx, "error creating handler", slog.String("error", err.Error())) - return 1 - } - - if err = server.Start(ctx); err != nil { - logger.ErrorContext(ctx, "error starting UI server", slog.String("error", err.Error())) - return 1 + return nil, err } corsHandler := cors.New(cors.Options{ @@ -109,40 +142,30 @@ func initAndServe(ctx context.Context) int { middlewareStack.Use(&authMiddleware{username: basicAuthUsername, password: basicAuthPassword}) } - srv := &http.Server{ - Addr: host + ":" + port, - Handler: middlewareStack.Mount(server), - ReadHeaderTimeout: 5 * time.Second, - } + return &initServerResult{ + dbPool: dbPool, + httpServer: &http.Server{ + Addr: host + ":" + port, + Handler: middlewareStack.Mount(uiServer), + ReadHeaderTimeout: 5 * time.Second, + }, + logger: logger, + uiServer: uiServer, + }, nil +} - log.Printf("starting server on %s", srv.Addr) +func startAndListen(ctx context.Context, logger *slog.Logger, initRes *initServerResult) error { + defer initRes.dbPool.Close() - if err = srv.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { - logger.ErrorContext(ctx, "error from ListenAndServe", slog.String("error", err.Error())) - return 1 + if err := initRes.uiServer.Start(ctx); err != nil { + return err } - return 0 -} + logger.InfoContext(ctx, "Starting server", slog.String("addr", initRes.httpServer.Addr)) -func getDBPool(ctx context.Context, dbURL string) (*pgxpool.Pool, error) { - poolConfig, err := pgxpool.ParseConfig(dbURL) - if err != nil { - return nil, fmt.Errorf("error parsing db config: %w", err) + if err := initRes.httpServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { + return err } - dbPool, err := pgxpool.NewWithConfig(ctx, poolConfig) - if err != nil { - return nil, fmt.Errorf("error connecting to db: %w", err) - } - return dbPool, nil -} - -func mustEnv(name string) string { - val := os.Getenv(name) - if val == "" { - logger.Error("missing required env var", slog.String("name", name)) - os.Exit(1) - } - return val + return nil } diff --git a/cmd/riverui/main_test.go b/cmd/riverui/main_test.go new file mode 100644 index 0000000..4ced74a --- /dev/null +++ b/cmd/riverui/main_test.go @@ -0,0 +1,61 @@ +package main + +import ( + "cmp" + "context" + "net/url" + "os" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/riverqueue/river/rivershared/riversharedtest" +) + +func TestInitServer(t *testing.T) { + var ( + ctx = context.Background() + databaseURL = cmp.Or(os.Getenv("TEST_DATABASE_URL"), "postgres://localhost/river_test") + ) + + t.Setenv("DEV", "true") + + type testBundle struct{} + + setup := func(t *testing.T) (*initServerResult, *testBundle) { + t.Helper() + + initRes, err := initServer(ctx, riversharedtest.Logger(t), "/") + require.NoError(t, err) + t.Cleanup(initRes.dbPool.Close) + + return initRes, &testBundle{} + } + + t.Run("WithDatabaseURL", func(t *testing.T) { + t.Setenv("DATABASE_URL", databaseURL) + + initRes, _ := setup(t) + + _, err := initRes.dbPool.Exec(ctx, "SELECT 1") + require.NoError(t, err) + }) + + t.Run("WithPGEnvVars", func(t *testing.T) { + parsedURL, err := url.Parse(databaseURL) + require.NoError(t, err) + + t.Setenv("PGDATABASE", parsedURL.Path[1:]) + t.Setenv("PGHOST", parsedURL.Hostname()) + pass, _ := parsedURL.User.Password() + t.Setenv("PGPASSWORD", pass) + t.Setenv("PGPORT", cmp.Or(parsedURL.Port(), "5432")) + t.Setenv("PGSSLMODE", parsedURL.Query().Get("sslmode")) + t.Setenv("PGUSER", parsedURL.User.Username()) + + initRes, _ := setup(t) + + _, err = initRes.dbPool.Exec(ctx, "SELECT 1") + require.NoError(t, err) + }) +}