diff --git a/cmd/riverui/main.go b/cmd/riverui/main.go index fadd8bb..2c2f16e 100644 --- a/cmd/riverui/main.go +++ b/cmd/riverui/main.go @@ -29,15 +29,17 @@ func main() { if err := godotenv.Load(); err != nil { fmt.Printf("No .env file detected, using environment variables\n") } - logger = slog.New(slog.NewTextHandler(os.Stdout, nil)) + + if os.Getenv("RIVER_DEBUG") == "1" || os.Getenv("RIVER_DEBUG") == "true" { + logger = slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug})) + } else { + logger = slog.New(slog.NewTextHandler(os.Stdout, nil)) + } os.Exit(initAndServe(ctx)) } func initAndServe(ctx context.Context) int { - ctx, cancel := context.WithCancel(ctx) - defer cancel() - var pathPrefix string flag.StringVar(&pathPrefix, "prefix", "/", "path prefix to use for the API and UI HTTP requests") flag.Parse() @@ -82,13 +84,18 @@ func initAndServe(ctx context.Context) int { Prefix: pathPrefix, } - handler, err := riverui.NewHandler(handlerOpts) + server, err := riverui.NewServer(handlerOpts) if err != nil { logger.ErrorContext(ctx, "error creating handler", slog.String("error", err.Error())) return 1 } - logHandler := sloghttp.Recovery(handler) + if err := server.Start(ctx); err != nil { + logger.ErrorContext(ctx, "error starting UI server", slog.String("error", err.Error())) + return 1 + } + + logHandler := sloghttp.Recovery(server.Handler()) config := sloghttp.Config{ WithSpanID: otelEnabled, WithTraceID: otelEnabled, diff --git a/go.mod b/go.mod index daef98a..c882c9a 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,8 @@ module github.com/riverqueue/riverui -go 1.22.5 +go 1.21 + +toolchain go1.22.5 require ( github.com/go-playground/validator/v10 v10.22.0 @@ -8,10 +10,11 @@ require ( github.com/jackc/pgerrcode v0.0.0-20240316143900-6e2875d9b438 github.com/jackc/pgx/v5 v5.6.0 github.com/joho/godotenv v1.5.1 - github.com/riverqueue/river v0.8.0 - github.com/riverqueue/river/riverdriver v0.8.0 - github.com/riverqueue/river/riverdriver/riverpgxv5 v0.8.0 - github.com/riverqueue/river/rivertype v0.8.0 + github.com/riverqueue/river v0.11.2 + github.com/riverqueue/river/riverdriver v0.11.2 + github.com/riverqueue/river/riverdriver/riverpgxv5 v0.11.2 + github.com/riverqueue/river/rivershared v0.11.2 + github.com/riverqueue/river/rivertype v0.11.2 github.com/rs/cors v1.11.0 github.com/samber/slog-http v1.0.0 github.com/stretchr/testify v1.9.0 @@ -27,12 +30,12 @@ require ( github.com/jackc/puddle/v2 v2.2.1 // indirect github.com/leodido/go-urn v1.4.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/riverqueue/river/rivershared v0.11.1 // indirect go.opentelemetry.io/otel v1.19.0 // indirect go.opentelemetry.io/otel/trace v1.19.0 // indirect + go.uber.org/goleak v1.3.0 // indirect golang.org/x/crypto v0.22.0 // indirect golang.org/x/net v0.23.0 // indirect - golang.org/x/sync v0.7.0 // indirect + golang.org/x/sync v0.8.0 // indirect golang.org/x/sys v0.19.0 // indirect golang.org/x/text v0.16.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/go.sum b/go.sum index 0a16865..455ce18 100644 --- a/go.sum +++ b/go.sum @@ -37,22 +37,22 @@ github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/riverqueue/river v0.8.0 h1:IBUIP9eZX/dkLQ3T+XNNk0Zi7iyUksZd4aHxQIFChOQ= -github.com/riverqueue/river v0.8.0/go.mod h1:EHRbhqVXDpXQizFh4lndwswu53N0txITrLM2y3vOIF4= -github.com/riverqueue/river/riverdriver v0.8.0 h1:vSeIvf2Z+/hHH4QF1NK/rvzuZJeZZ+voHz55ZPf9efA= -github.com/riverqueue/river/riverdriver v0.8.0/go.mod h1:YZUVae96RsQJaAem0o0EpgD7fDNPdl/qJiuUFh/vkVE= -github.com/riverqueue/river/riverdriver/riverdatabasesql v0.8.0 h1:eH6kkU8qstq1Rj7d0PBYmptaZy6vPsea0WzhBf7/SL4= -github.com/riverqueue/river/riverdriver/riverdatabasesql v0.8.0/go.mod h1:4jXPB30TNOWSeOvNvk1Mdov4XIMTBCnIzysrdAXizzs= -github.com/riverqueue/river/riverdriver/riverpgxv5 v0.8.0 h1:9lF2GQIU0Z5gynaY6kevJwW5ycy/VbH9S/iYu0+Lf7U= -github.com/riverqueue/river/riverdriver/riverpgxv5 v0.8.0/go.mod h1:rPTUHOdsrQIEyeEesEaBzNyj0Hs4VtXGUHHPC4JwgZ0= -github.com/riverqueue/river/rivershared v0.11.1 h1:5HDZ5fPrHf68lrE2CTTTUfRfdCmfW1G6P/v0zCvor7I= -github.com/riverqueue/river/rivershared v0.11.1/go.mod h1:2egnQ7czNcW8IXKXMRjko0aEMrQzF4V3k3jddmYiihE= -github.com/riverqueue/river/rivertype v0.8.0 h1:Ys49e1AECeIOTxRquXC446uIEPXiXLMNVKD4KwexJPM= -github.com/riverqueue/river/rivertype v0.8.0/go.mod h1:nDd50b/mIdxR/ezQzGS/JiAhBPERA7tUIne21GdfspQ= +github.com/riverqueue/river v0.11.2 h1:U1f0xZ+B3qdOJSHJ8A2c93CEsFQGGkbG4ZN8blUas5g= +github.com/riverqueue/river v0.11.2/go.mod h1:0MCkMUIjwAjkKAmcWEbHP1IKWiXq+Z3iNVK5dsYVQYY= +github.com/riverqueue/river/riverdriver v0.11.2 h1:2xC+R0Y+CFEOSDWKyeFef0wqQLuvhk3PsLkos7MLa1w= +github.com/riverqueue/river/riverdriver v0.11.2/go.mod h1:RhMuAjEtNGexwOFnz445G1iFNZVOnYQ90HDYxHMI+jM= +github.com/riverqueue/river/riverdriver/riverdatabasesql v0.11.2 h1:I4ye1YEa35kqB6Jd3xVPNxbGDL6S1gpSTkZu25qffhc= +github.com/riverqueue/river/riverdriver/riverdatabasesql v0.11.2/go.mod h1:+cOcD4U+8ugUeRZVTGqVhtScy0FS7LPyp+ZsoPIeoMI= +github.com/riverqueue/river/riverdriver/riverpgxv5 v0.11.2 h1:yxFi09ECN02iAr2uO0n7QhFKAyyGZ+Rn9fzKTt2TGhk= +github.com/riverqueue/river/riverdriver/riverpgxv5 v0.11.2/go.mod h1:ajPqIw7OgYBfR24MqH3VGI/SiYVgq0DkvdM7wrs+uDA= +github.com/riverqueue/river/rivershared v0.11.2 h1:VbuLE6zm68R24xBi1elfnerhLBBn6X7DUxR9j4mcTR4= +github.com/riverqueue/river/rivershared v0.11.2/go.mod h1:J4U3qm8MbjHY1o5OlRNiWaminYagec1o8sHYX4ZQ4S4= +github.com/riverqueue/river/rivertype v0.11.2 h1:YREWOGxDMDe1DTdvttwr2DVq/ql65u6e4jkw3VxuNyU= +github.com/riverqueue/river/rivertype v0.11.2/go.mod h1:bm5EMOGAEWhtXKqo27POWnViqSD5nHMZDP/jsrJc530= github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs= github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro= -github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= -github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA= +github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= +github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= github.com/rs/cors v1.11.0 h1:0B9GE/r9Bc2UxRMMtymBkHTenPkHDv0CW4Y98GBY+po= github.com/rs/cors v1.11.0/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU= github.com/samber/slog-http v1.0.0 h1:KjxyJm2lOsuWBt904A04qvrp+0ZvOfwDnk6jI8h7/5c= @@ -72,8 +72,8 @@ golang.org/x/crypto v0.22.0 h1:g1v0xeRhjcugydODzvb3mEM9SQ0HGp9s/nh3COQ/C30= golang.org/x/crypto v0.22.0/go.mod h1:vr6Su+7cTlO45qkww3VDJlzDn0ctJvRgYbC2NvXHt+M= golang.org/x/net v0.23.0 h1:7EYJ93RZ9vYSZAIb2x3lnuvqO5zneoD6IvWjuhfxjTs= golang.org/x/net v0.23.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= -golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= -golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= +golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.19.0 h1:q5f1RH2jigJ1MoAWp2KTp3gm5zAGFUTarQZ5U386+4o= golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= diff --git a/handler.go b/handler.go index 71763d6..6a46642 100644 --- a/handler.go +++ b/handler.go @@ -14,6 +14,9 @@ import ( "github.com/jackc/pgx/v5" "github.com/riverqueue/river" + "github.com/riverqueue/river/rivershared/baseservice" + "github.com/riverqueue/river/rivershared/startstop" + "github.com/riverqueue/river/rivershared/util/randutil" "github.com/riverqueue/riverui/internal/apiendpoint" "github.com/riverqueue/riverui/internal/apimiddleware" "github.com/riverqueue/riverui/internal/dbsqlc" @@ -62,8 +65,14 @@ func normalizePathPrefix(prefix string) string { return prefix } -// NewHandler creates a new http.Handler that serves the River UI and API. -func NewHandler(opts *HandlerOpts) (http.Handler, error) { +type Server struct { + baseStartStop startstop.BaseStartStop + handler http.Handler + services []startstop.Service +} + +// NewServer creates a new http.Handler that serves the River UI and API. +func NewServer(opts *HandlerOpts) (*Server, error) { if opts == nil { return nil, errors.New("opts is required") } @@ -80,6 +89,12 @@ func NewHandler(opts *HandlerOpts) (http.Handler, error) { serveIndex := serveFileContents("index.html", httpFS) apiBundle := apiBundle{ + // TODO: Switch to baseservice.NewArchetype when available. + archetype: &baseservice.Archetype{ + Logger: opts.Logger, + Rand: randutil.NewCryptoSeededConcurrentSafeRand(), + Time: &baseservice.UnStubbableTimeGenerator{}, + }, client: opts.Client, dbPool: opts.DBPool, logger: opts.Logger, @@ -88,19 +103,35 @@ func NewHandler(opts *HandlerOpts) (http.Handler, error) { prefix := opts.Prefix mux := http.NewServeMux() - apiendpoint.Mount(mux, opts.Logger, &healthCheckGetEndpoint{apiBundle: apiBundle}) - apiendpoint.Mount(mux, opts.Logger, &jobCancelEndpoint{apiBundle: apiBundle}) - apiendpoint.Mount(mux, opts.Logger, &jobDeleteEndpoint{apiBundle: apiBundle}) - apiendpoint.Mount(mux, opts.Logger, &jobListEndpoint{apiBundle: apiBundle}) - apiendpoint.Mount(mux, opts.Logger, &jobRetryEndpoint{apiBundle: apiBundle}) - apiendpoint.Mount(mux, opts.Logger, &jobGetEndpoint{apiBundle: apiBundle}) - apiendpoint.Mount(mux, opts.Logger, &queueGetEndpoint{apiBundle: apiBundle}) - apiendpoint.Mount(mux, opts.Logger, &queueListEndpoint{apiBundle: apiBundle}) - apiendpoint.Mount(mux, opts.Logger, &queuePauseEndpoint{apiBundle: apiBundle}) - apiendpoint.Mount(mux, opts.Logger, &queueResumeEndpoint{apiBundle: apiBundle}) - apiendpoint.Mount(mux, opts.Logger, &stateAndCountGetEndpoint{apiBundle: apiBundle}) - apiendpoint.Mount(mux, opts.Logger, &workflowGetEndpoint{apiBundle: apiBundle}) - apiendpoint.Mount(mux, opts.Logger, &workflowListEndpoint{apiBundle: apiBundle}) + + endpoints := []apiendpoint.EndpointInterface{ + apiendpoint.Mount(mux, opts.Logger, newHealthCheckGetEndpoint(apiBundle)), + apiendpoint.Mount(mux, opts.Logger, newJobCancelEndpoint(apiBundle)), + apiendpoint.Mount(mux, opts.Logger, newJobDeleteEndpoint(apiBundle)), + apiendpoint.Mount(mux, opts.Logger, newJobListEndpoint(apiBundle)), + apiendpoint.Mount(mux, opts.Logger, newJobRetryEndpoint(apiBundle)), + apiendpoint.Mount(mux, opts.Logger, newJobGetEndpoint(apiBundle)), + apiendpoint.Mount(mux, opts.Logger, newQueueGetEndpoint(apiBundle)), + apiendpoint.Mount(mux, opts.Logger, newQueueListEndpoint(apiBundle)), + apiendpoint.Mount(mux, opts.Logger, newQueuePauseEndpoint(apiBundle)), + apiendpoint.Mount(mux, opts.Logger, newQueueResumeEndpoint(apiBundle)), + apiendpoint.Mount(mux, opts.Logger, newStateAndCountGetEndpoint(apiBundle)), + apiendpoint.Mount(mux, opts.Logger, newWorkflowGetEndpoint(apiBundle)), + apiendpoint.Mount(mux, opts.Logger, newWorkflowListEndpoint(apiBundle)), + } + + var services []startstop.Service + + type WithSubServices interface { + SubServices() []startstop.Service + } + + // If any endpoints are start/stop services, start them up. + for _, endpoint := range endpoints { + if withSubServices, ok := endpoint.(WithSubServices); ok { + services = append(services, withSubServices.SubServices()...) + } + } if err := mountStaticFiles(opts.Logger, mux); err != nil { return nil, err @@ -115,7 +146,45 @@ func NewHandler(opts *HandlerOpts) (http.Handler, error) { middlewareStack.Use(&stripPrefixMiddleware{prefix}) } - return middlewareStack.Mount(mux), nil + server := &Server{ + handler: middlewareStack.Mount(mux), + services: services, + } + + return server, nil +} + +// Handler returns an http.Handler that can be mounted to serve HTTP requests. +func (s *Server) Handler() http.Handler { return s.handler } + +// Start starts the server's background services. Notably, this does _not_ cause +// the server to start listening for HTTP in any way. To do that, call Handler +// and mount or run it using Go's built in `net/http`. +func (s *Server) Start(ctx context.Context) error { + ctx, shouldStart, started, stopped := s.baseStartStop.StartInit(ctx) + if !shouldStart { + return nil + } + + for _, service := range s.services { + if err := service.Start(ctx); err != nil { + return err + } + } + + go func() { + // Wait for all subservices to start up before signaling our own start. + startstop.WaitAllStarted(s.services...) + + started() + defer stopped() // this defer should come first so it's last out + + <-ctx.Done() + + startstop.StopAllParallel(s.services...) + }() + + return nil } //go:embed public diff --git a/api_handler.go b/handler_api_endpoint.go similarity index 88% rename from api_handler.go rename to handler_api_endpoint.go index 087c7c2..d7bc55b 100644 --- a/api_handler.go +++ b/handler_api_endpoint.go @@ -14,20 +14,24 @@ import ( "github.com/jackc/pgx/v5/pgtype" "github.com/riverqueue/river" + "github.com/riverqueue/river/rivershared/baseservice" + "github.com/riverqueue/river/rivershared/startstop" "github.com/riverqueue/river/rivershared/util/ptrutil" "github.com/riverqueue/river/rivershared/util/sliceutil" "github.com/riverqueue/river/rivertype" "github.com/riverqueue/riverui/internal/apiendpoint" "github.com/riverqueue/riverui/internal/apierror" "github.com/riverqueue/riverui/internal/dbsqlc" + "github.com/riverqueue/riverui/internal/querycacher" "github.com/riverqueue/riverui/internal/util/pgxutil" ) // A bundle of common utilities needed for many API endpoints. type apiBundle struct { - client *river.Client[pgx.Tx] - dbPool DBTXWithBegin - logger *slog.Logger + archetype *baseservice.Archetype + client *river.Client[pgx.Tx] + dbPool DBTXWithBegin + logger *slog.Logger } // SetBundle sets all values to the same as the given bundle. @@ -35,14 +39,6 @@ func (a *apiBundle) SetBundle(bundle *apiBundle) { *a = *bundle } -// withSetBundle is an interface that's automatically implemented by types that -// embed apiBundle. It lets places like tests generically set bundle values on -// any general endpoint type. -type withSetBundle interface { - // SetBundle sets all values to the same as the given bundle. - SetBundle(bundle *apiBundle) -} - type listResponse[T any] struct { Data []*T `json:"data"` } @@ -66,6 +62,10 @@ type healthCheckGetEndpoint struct { apiendpoint.Endpoint[healthCheckGetRequest, statusResponse] } +func newHealthCheckGetEndpoint(apiBundle apiBundle) *healthCheckGetEndpoint { + return &healthCheckGetEndpoint{apiBundle: apiBundle} +} + func (*healthCheckGetEndpoint) Meta() *apiendpoint.EndpointMeta { return &apiendpoint.EndpointMeta{ Pattern: "GET /api/health-checks/{name}", @@ -118,6 +118,10 @@ type jobCancelEndpoint struct { apiendpoint.Endpoint[jobCancelRequest, statusResponse] } +func newJobCancelEndpoint(apiBundle apiBundle) *jobCancelEndpoint { + return &jobCancelEndpoint{apiBundle: apiBundle} +} + func (*jobCancelEndpoint) Meta() *apiendpoint.EndpointMeta { return &apiendpoint.EndpointMeta{ Pattern: "POST /api/jobs/cancel", @@ -158,6 +162,10 @@ type jobDeleteEndpoint struct { apiendpoint.Endpoint[jobDeleteRequest, statusResponse] } +func newJobDeleteEndpoint(apiBundle apiBundle) *jobDeleteEndpoint { + return &jobDeleteEndpoint{apiBundle: apiBundle} +} + func (*jobDeleteEndpoint) Meta() *apiendpoint.EndpointMeta { return &apiendpoint.EndpointMeta{ Pattern: "POST /api/jobs/delete", @@ -198,6 +206,10 @@ type jobGetEndpoint struct { apiendpoint.Endpoint[jobGetRequest, RiverJob] } +func newJobGetEndpoint(apiBundle apiBundle) *jobGetEndpoint { + return &jobGetEndpoint{apiBundle: apiBundle} +} + func (*jobGetEndpoint) Meta() *apiendpoint.EndpointMeta { return &apiendpoint.EndpointMeta{ Pattern: "GET /api/jobs/{job_id}", @@ -243,6 +255,10 @@ type jobListEndpoint struct { apiendpoint.Endpoint[jobCancelRequest, listResponse[RiverJob]] } +func newJobListEndpoint(apiBundle apiBundle) *jobListEndpoint { + return &jobListEndpoint{apiBundle: apiBundle} +} + func (*jobListEndpoint) Meta() *apiendpoint.EndpointMeta { return &apiendpoint.EndpointMeta{ Pattern: "GET /api/jobs", @@ -305,6 +321,10 @@ type jobRetryEndpoint struct { apiendpoint.Endpoint[jobRetryRequest, statusResponse] } +func newJobRetryEndpoint(apiBundle apiBundle) *jobRetryEndpoint { + return &jobRetryEndpoint{apiBundle: apiBundle} +} + func (*jobRetryEndpoint) Meta() *apiendpoint.EndpointMeta { return &apiendpoint.EndpointMeta{ Pattern: "POST /api/jobs/retry", @@ -342,6 +362,10 @@ type queueGetEndpoint struct { apiendpoint.Endpoint[jobCancelRequest, RiverQueue] } +func newQueueGetEndpoint(apiBundle apiBundle) *queueGetEndpoint { + return &queueGetEndpoint{apiBundle: apiBundle} +} + func (*queueGetEndpoint) Meta() *apiendpoint.EndpointMeta { return &apiendpoint.EndpointMeta{ Pattern: "GET /api/queues/{name}", @@ -386,6 +410,10 @@ type queueListEndpoint struct { apiendpoint.Endpoint[jobCancelRequest, listResponse[RiverQueue]] } +func newQueueListEndpoint(apiBundle apiBundle) *queueListEndpoint { + return &queueListEndpoint{apiBundle: apiBundle} +} + func (*queueListEndpoint) Meta() *apiendpoint.EndpointMeta { return &apiendpoint.EndpointMeta{ Pattern: "GET /api/queues", @@ -437,6 +465,10 @@ type queuePauseEndpoint struct { apiendpoint.Endpoint[jobCancelRequest, statusResponse] } +func newQueuePauseEndpoint(apiBundle apiBundle) *queuePauseEndpoint { + return &queuePauseEndpoint{apiBundle: apiBundle} +} + func (*queuePauseEndpoint) Meta() *apiendpoint.EndpointMeta { return &apiendpoint.EndpointMeta{ Pattern: "PUT /api/queues/{name}/pause", @@ -475,6 +507,10 @@ type queueResumeEndpoint struct { apiendpoint.Endpoint[jobCancelRequest, statusResponse] } +func newQueueResumeEndpoint(apiBundle apiBundle) *queueResumeEndpoint { + return &queueResumeEndpoint{apiBundle: apiBundle} +} + func (*queueResumeEndpoint) Meta() *apiendpoint.EndpointMeta { return &apiendpoint.EndpointMeta{ Pattern: "PUT /api/queues/{name}/resume", @@ -511,6 +547,17 @@ func (a *queueResumeEndpoint) Execute(ctx context.Context, req *queueResumeReque type stateAndCountGetEndpoint struct { apiBundle apiendpoint.Endpoint[jobCancelRequest, stateAndCountGetResponse] + + queryCacheSkipThreshold int // constant normally, but settable for testing + queryCacher *querycacher.QueryCacher[[]*dbsqlc.JobCountByStateRow] +} + +func newStateAndCountGetEndpoint(apiBundle apiBundle) *stateAndCountGetEndpoint { + return &stateAndCountGetEndpoint{ + apiBundle: apiBundle, + queryCacheSkipThreshold: 1_000_000, + queryCacher: querycacher.NewQueryCacher(apiBundle.archetype, apiBundle.dbPool, dbsqlc.New().JobCountByState), + } } func (*stateAndCountGetEndpoint) Meta() *apiendpoint.EndpointMeta { @@ -520,6 +567,10 @@ func (*stateAndCountGetEndpoint) Meta() *apiendpoint.EndpointMeta { } } +func (a *stateAndCountGetEndpoint) SubServices() []startstop.Service { + return []startstop.Service{a.queryCacher} +} + type stateAndCountGetRequest struct{} type stateAndCountGetResponse struct { @@ -534,12 +585,31 @@ type stateAndCountGetResponse struct { } func (a *stateAndCountGetEndpoint) Execute(ctx context.Context, _ *stateAndCountGetRequest) (*stateAndCountGetResponse, error) { - stateAndCount, err := dbsqlc.New().JobCountByState(ctx, a.dbPool) - if err != nil { - return nil, fmt.Errorf("error getting states and counts: %w", err) + // Counts the total number of jobs in a state and count result. + totalJobs := func(stateAndCountRes []*dbsqlc.JobCountByStateRow) int { + var totalJobs int + for _, stateAndCount := range stateAndCountRes { + totalJobs += int(stateAndCount.Count) + } + return totalJobs + } + + // Counting jobs can be an expensive operation given a large table, so in + // the presence of such, prefer to use a result that's cached periodically + // instead of querying inline with the API request. In case we don't have a + // cached result yet or there's a relatively small number of job rows, run + // the query directly (in the case of the latter so we present the freshest + // possible information). + stateAndCountRes, ok := a.queryCacher.CachedRes() + if !ok || totalJobs(stateAndCountRes) < a.queryCacheSkipThreshold { + var err error + stateAndCountRes, err = dbsqlc.New().JobCountByState(ctx, a.dbPool) + if err != nil { + return nil, fmt.Errorf("error getting states and counts: %w", err) + } } - stateAndCountMap := sliceutil.KeyBy(stateAndCount, func(r *dbsqlc.JobCountByStateRow) (rivertype.JobState, int) { + stateAndCountMap := sliceutil.KeyBy(stateAndCountRes, func(r *dbsqlc.JobCountByStateRow) (rivertype.JobState, int) { return rivertype.JobState(r.State), int(r.Count) }) @@ -564,6 +634,10 @@ type workflowGetEndpoint struct { apiendpoint.Endpoint[jobCancelRequest, workflowGetResponse] } +func newWorkflowGetEndpoint(apiBundle apiBundle) *workflowGetEndpoint { + return &workflowGetEndpoint{apiBundle: apiBundle} +} + func (*workflowGetEndpoint) Meta() *apiendpoint.EndpointMeta { return &apiendpoint.EndpointMeta{ Pattern: "GET /api/workflows/{id}", @@ -612,6 +686,10 @@ type workflowListEndpoint struct { apiendpoint.Endpoint[jobCancelRequest, listResponse[workflowListItem]] } +func newWorkflowListEndpoint(apiBundle apiBundle) *workflowListEndpoint { + return &workflowListEndpoint{apiBundle: apiBundle} +} + func (*workflowListEndpoint) Meta() *apiendpoint.EndpointMeta { return &apiendpoint.EndpointMeta{ Pattern: "GET /api/workflows", diff --git a/api_handler_test.go b/handler_api_endpoint_test.go similarity index 82% rename from api_handler_test.go rename to handler_api_endpoint_test.go index d5b6e6d..d633f16 100644 --- a/api_handler_test.go +++ b/handler_api_endpoint_test.go @@ -2,6 +2,7 @@ package riverui import ( "context" + "log/slog" "testing" "time" @@ -11,6 +12,8 @@ import ( "github.com/riverqueue/river" "github.com/riverqueue/river/riverdriver" + "github.com/riverqueue/river/rivershared/riversharedtest" + "github.com/riverqueue/river/rivershared/startstop" "github.com/riverqueue/river/rivershared/util/ptrutil" "github.com/riverqueue/river/rivertype" "github.com/riverqueue/riverui/internal/apierror" @@ -21,30 +24,35 @@ import ( type setupEndpointTestBundle struct { client *river.Client[pgx.Tx] exec riverdriver.ExecutorTx + logger *slog.Logger tx pgx.Tx } -func setupEndpoint[TEndpoint any](ctx context.Context, t *testing.T) (*TEndpoint, *setupEndpointTestBundle) { +func setupEndpoint[TEndpoint any](ctx context.Context, t *testing.T, initFunc func(apiBundle apiBundle) *TEndpoint) (*TEndpoint, *setupEndpointTestBundle) { t.Helper() var ( - endpoint TEndpoint logger = riverinternaltest.Logger(t) client, driver = insertOnlyClient(t, logger) tx = riverinternaltest.TestTx(ctx, t) ) - if withSetBundle, ok := any(&endpoint).(withSetBundle); ok { - withSetBundle.SetBundle(&apiBundle{ - client: client, - dbPool: tx, - logger: logger, - }) + endpoint := initFunc(apiBundle{ + archetype: riversharedtest.BaseServiceArchetype(t), + client: client, + dbPool: tx, + logger: logger, + }) + + if service, ok := any(endpoint).(startstop.Service); ok { + require.NoError(t, service.Start(ctx)) + t.Cleanup(service.Stop) } - return &endpoint, &setupEndpointTestBundle{ + return endpoint, &setupEndpointTestBundle{ client: client, exec: driver.UnwrapExecutor(tx), + logger: logger, tx: tx, } } @@ -57,7 +65,7 @@ func TestHandlerHealthCheckGetEndpoint(t *testing.T) { t.Run("CompleteSuccess", func(t *testing.T) { t.Parallel() - endpoint, _ := setupEndpoint[healthCheckGetEndpoint](ctx, t) + endpoint, _ := setupEndpoint(ctx, t, newHealthCheckGetEndpoint) resp, err := endpoint.Execute(ctx, &healthCheckGetRequest{Name: healthCheckNameComplete}) require.NoError(t, err) @@ -67,7 +75,7 @@ func TestHandlerHealthCheckGetEndpoint(t *testing.T) { t.Run("CompleteDatabaseError", func(t *testing.T) { t.Parallel() - endpoint, bundle := setupEndpoint[healthCheckGetEndpoint](ctx, t) + endpoint, bundle := setupEndpoint(ctx, t, newHealthCheckGetEndpoint) // Roll back prematurely so we get a database error. require.NoError(t, bundle.tx.Rollback(ctx)) @@ -82,7 +90,7 @@ func TestHandlerHealthCheckGetEndpoint(t *testing.T) { t.Run("Minimal", func(t *testing.T) { t.Parallel() - endpoint, _ := setupEndpoint[healthCheckGetEndpoint](ctx, t) + endpoint, _ := setupEndpoint(ctx, t, newHealthCheckGetEndpoint) resp, err := endpoint.Execute(ctx, &healthCheckGetRequest{Name: healthCheckNameMinimal}) require.NoError(t, err) @@ -92,7 +100,7 @@ func TestHandlerHealthCheckGetEndpoint(t *testing.T) { t.Run("NotFound", func(t *testing.T) { t.Parallel() - endpoint, _ := setupEndpoint[healthCheckGetEndpoint](ctx, t) + endpoint, _ := setupEndpoint(ctx, t, newHealthCheckGetEndpoint) _, err := endpoint.Execute(ctx, &healthCheckGetRequest{Name: "other"}) requireAPIError(t, apierror.NewNotFound("Health check %q not found. Use either `complete` or `minimal`.", "other"), err) @@ -107,7 +115,7 @@ func TestJobCancelEndpoint(t *testing.T) { t.Run("Success", func(t *testing.T) { t.Parallel() - endpoint, bundle := setupEndpoint[jobCancelEndpoint](ctx, t) + endpoint, bundle := setupEndpoint(ctx, t, newJobCancelEndpoint) job1 := testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{}) job2 := testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{}) @@ -128,7 +136,7 @@ func TestJobCancelEndpoint(t *testing.T) { t.Run("NotFound", func(t *testing.T) { t.Parallel() - endpoint, _ := setupEndpoint[jobCancelEndpoint](ctx, t) + endpoint, _ := setupEndpoint(ctx, t, newJobCancelEndpoint) _, err := endpoint.Execute(ctx, &jobCancelRequest{JobIDs: []int64String{123}}) requireAPIError(t, apierror.NewNotFoundJob(123), err) @@ -143,7 +151,7 @@ func TestJobDeleteEndpoint(t *testing.T) { t.Run("Success", func(t *testing.T) { t.Parallel() - endpoint, bundle := setupEndpoint[jobDeleteEndpoint](ctx, t) + endpoint, bundle := setupEndpoint(ctx, t, newJobDeleteEndpoint) job1 := testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{}) job2 := testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{}) @@ -162,7 +170,7 @@ func TestJobDeleteEndpoint(t *testing.T) { t.Run("NotFound", func(t *testing.T) { t.Parallel() - endpoint, _ := setupEndpoint[jobDeleteEndpoint](ctx, t) + endpoint, _ := setupEndpoint(ctx, t, newJobDeleteEndpoint) _, err := endpoint.Execute(ctx, &jobDeleteRequest{JobIDs: []int64String{123}}) requireAPIError(t, apierror.NewNotFoundJob(123), err) @@ -177,7 +185,7 @@ func TestJobGetEndpoint(t *testing.T) { t.Run("Success", func(t *testing.T) { t.Parallel() - endpoint, bundle := setupEndpoint[jobGetEndpoint](ctx, t) + endpoint, bundle := setupEndpoint(ctx, t, newJobGetEndpoint) job := testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{}) @@ -189,7 +197,7 @@ func TestJobGetEndpoint(t *testing.T) { t.Run("NotFound", func(t *testing.T) { t.Parallel() - endpoint, _ := setupEndpoint[jobGetEndpoint](ctx, t) + endpoint, _ := setupEndpoint(ctx, t, newJobGetEndpoint) _, err := endpoint.Execute(ctx, &jobGetRequest{JobID: 123}) requireAPIError(t, apierror.NewNotFoundJob(123), err) @@ -204,7 +212,7 @@ func TestAPIHandlerJobList(t *testing.T) { t.Run("Success", func(t *testing.T) { t.Parallel() - endpoint, bundle := setupEndpoint[jobListEndpoint](ctx, t) + endpoint, bundle := setupEndpoint(ctx, t, newJobListEndpoint) job1 := testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{State: ptrutil.Ptr(rivertype.JobStateRunning)}) job2 := testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{State: ptrutil.Ptr(rivertype.JobStateRunning)}) @@ -227,7 +235,7 @@ func TestAPIHandlerJobList(t *testing.T) { t.Run("Limit", func(t *testing.T) { t.Parallel() - endpoint, bundle := setupEndpoint[jobListEndpoint](ctx, t) + endpoint, bundle := setupEndpoint(ctx, t, newJobListEndpoint) job1 := testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{State: ptrutil.Ptr(rivertype.JobStateRunning)}) _ = testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{}) @@ -241,7 +249,7 @@ func TestAPIHandlerJobList(t *testing.T) { t.Run("FiltersFinalizedStatesAndOrdersDescending", func(t *testing.T) { t.Parallel() - endpoint, bundle := setupEndpoint[jobListEndpoint](ctx, t) + endpoint, bundle := setupEndpoint(ctx, t, newJobListEndpoint) job1 := testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{State: ptrutil.Ptr(rivertype.JobStateCompleted), FinalizedAt: ptrutil.Ptr(time.Now())}) job2 := testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{State: ptrutil.Ptr(rivertype.JobStateCompleted), FinalizedAt: ptrutil.Ptr(time.Now())}) @@ -259,7 +267,7 @@ func TestAPIHandlerJobList(t *testing.T) { t.Run("FiltersNonFinalizedStates", func(t *testing.T) { t.Parallel() - endpoint, bundle := setupEndpoint[jobListEndpoint](ctx, t) + endpoint, bundle := setupEndpoint(ctx, t, newJobListEndpoint) job1 := testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{}) job2 := testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{}) @@ -283,7 +291,7 @@ func TestJobRetryEndpoint(t *testing.T) { t.Run("Success", func(t *testing.T) { t.Parallel() - endpoint, bundle := setupEndpoint[jobRetryEndpoint](ctx, t) + endpoint, bundle := setupEndpoint(ctx, t, newJobRetryEndpoint) job1 := testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{ FinalizedAt: ptrutil.Ptr(time.Now()), @@ -310,7 +318,7 @@ func TestJobRetryEndpoint(t *testing.T) { t.Run("NotFound", func(t *testing.T) { t.Parallel() - endpoint, _ := setupEndpoint[jobRetryEndpoint](ctx, t) + endpoint, _ := setupEndpoint(ctx, t, newJobRetryEndpoint) _, err := endpoint.Execute(ctx, &jobRetryRequest{JobIDs: []int64String{123}}) requireAPIError(t, apierror.NewNotFoundJob(123), err) @@ -325,7 +333,7 @@ func TestAPIHandlerQueueGet(t *testing.T) { t.Run("Success", func(t *testing.T) { t.Parallel() - endpoint, bundle := setupEndpoint[queueGetEndpoint](ctx, t) + endpoint, bundle := setupEndpoint(ctx, t, newQueueGetEndpoint) queue := testfactory.Queue(ctx, t, bundle.exec, nil) @@ -341,7 +349,7 @@ func TestAPIHandlerQueueGet(t *testing.T) { t.Run("NotFound", func(t *testing.T) { t.Parallel() - endpoint, _ := setupEndpoint[queueGetEndpoint](ctx, t) + endpoint, _ := setupEndpoint(ctx, t, newQueueGetEndpoint) _, err := endpoint.Execute(ctx, &queueGetRequest{Name: "does_not_exist"}) requireAPIError(t, apierror.NewNotFoundQueue("does_not_exist"), err) @@ -356,7 +364,7 @@ func TestAPIHandlerQueueList(t *testing.T) { t.Run("Success", func(t *testing.T) { t.Parallel() - endpoint, bundle := setupEndpoint[queueListEndpoint](ctx, t) + endpoint, bundle := setupEndpoint(ctx, t, newQueueListEndpoint) queue1 := testfactory.Queue(ctx, t, bundle.exec, nil) queue2 := testfactory.Queue(ctx, t, bundle.exec, nil) @@ -376,7 +384,7 @@ func TestAPIHandlerQueueList(t *testing.T) { t.Run("Limit", func(t *testing.T) { t.Parallel() - endpoint, bundle := setupEndpoint[queueListEndpoint](ctx, t) + endpoint, bundle := setupEndpoint(ctx, t, newQueueListEndpoint) queue1 := testfactory.Queue(ctx, t, bundle.exec, nil) _ = testfactory.Queue(ctx, t, bundle.exec, nil) @@ -396,7 +404,7 @@ func TestAPIHandlerQueuePause(t *testing.T) { t.Run("Success", func(t *testing.T) { t.Parallel() - endpoint, bundle := setupEndpoint[queuePauseEndpoint](ctx, t) + endpoint, bundle := setupEndpoint(ctx, t, newQueuePauseEndpoint) queue := testfactory.Queue(ctx, t, bundle.exec, nil) @@ -408,7 +416,7 @@ func TestAPIHandlerQueuePause(t *testing.T) { t.Run("NotFound", func(t *testing.T) { t.Parallel() - endpoint, _ := setupEndpoint[queuePauseEndpoint](ctx, t) + endpoint, _ := setupEndpoint(ctx, t, newQueuePauseEndpoint) _, err := endpoint.Execute(ctx, &queuePauseRequest{Name: "does_not_exist"}) requireAPIError(t, apierror.NewNotFoundQueue("does_not_exist"), err) @@ -423,7 +431,7 @@ func TestAPIHandlerQueueResume(t *testing.T) { t.Run("Success", func(t *testing.T) { t.Parallel() - endpoint, bundle := setupEndpoint[queueResumeEndpoint](ctx, t) + endpoint, bundle := setupEndpoint(ctx, t, newQueueResumeEndpoint) queue := testfactory.Queue(ctx, t, bundle.exec, &testfactory.QueueOpts{ PausedAt: ptrutil.Ptr(time.Now()), @@ -437,7 +445,7 @@ func TestAPIHandlerQueueResume(t *testing.T) { t.Run("NotFound", func(t *testing.T) { t.Parallel() - endpoint, _ := setupEndpoint[queueResumeEndpoint](ctx, t) + endpoint, _ := setupEndpoint(ctx, t, newQueueResumeEndpoint) _, err := endpoint.Execute(ctx, &queueResumeRequest{Name: "does_not_exist"}) requireAPIError(t, apierror.NewNotFoundQueue("does_not_exist"), err) @@ -452,35 +460,35 @@ func TestStateAndCountGetEndpoint(t *testing.T) { t.Run("Success", func(t *testing.T) { t.Parallel() - endpoint, bundle := setupEndpoint[stateAndCountGetEndpoint](ctx, t) + endpoint, bundle := setupEndpoint(ctx, t, newStateAndCountGetEndpoint) _ = testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{State: ptrutil.Ptr(rivertype.JobStateAvailable)}) - for range 2 { + for i := 0; i < 2; i++ { _ = testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{State: ptrutil.Ptr(rivertype.JobStateCancelled), FinalizedAt: ptrutil.Ptr(time.Now())}) } - for range 3 { + for i := 0; i < 3; i++ { _ = testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{State: ptrutil.Ptr(rivertype.JobStateCompleted), FinalizedAt: ptrutil.Ptr(time.Now())}) } - for range 4 { + for i := 0; i < 4; i++ { _ = testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{State: ptrutil.Ptr(rivertype.JobStateDiscarded), FinalizedAt: ptrutil.Ptr(time.Now())}) } - for range 5 { + for i := 0; i < 5; i++ { _ = testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{State: ptrutil.Ptr(rivertype.JobStatePending)}) } - for range 6 { + for i := 0; i < 6; i++ { _ = testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{State: ptrutil.Ptr(rivertype.JobStateRetryable)}) } - for range 7 { + for i := 0; i < 7; i++ { _ = testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{State: ptrutil.Ptr(rivertype.JobStateRunning)}) } - for range 8 { + for i := 0; i < 8; i++ { _ = testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{State: ptrutil.Ptr(rivertype.JobStateScheduled)}) } @@ -497,6 +505,46 @@ func TestStateAndCountGetEndpoint(t *testing.T) { Scheduled: 8, }, resp) }) + + t.Run("WithCachedQueryAboveSkipThreshold", func(t *testing.T) { + t.Parallel() + + endpoint, bundle := setupEndpoint(ctx, t, newStateAndCountGetEndpoint) + + const queryCacheSkipThreshold = 3 + for i := 0; i < queryCacheSkipThreshold+1; i++ { + _ = testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{State: ptrutil.Ptr(rivertype.JobStateAvailable)}) + } + + _, err := endpoint.queryCacher.RunQuery(ctx) + require.NoError(t, err) + + resp, err := endpoint.Execute(ctx, &stateAndCountGetRequest{}) + require.NoError(t, err) + require.Equal(t, &stateAndCountGetResponse{ + Available: queryCacheSkipThreshold + 1, + }, resp) + }) + + t.Run("WithCachedQueryBelowSkipThreshold", func(t *testing.T) { + t.Parallel() + + endpoint, bundle := setupEndpoint(ctx, t, newStateAndCountGetEndpoint) + + const queryCacheSkipThreshold = 3 + for i := 0; i < queryCacheSkipThreshold-1; i++ { + _ = testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{State: ptrutil.Ptr(rivertype.JobStateAvailable)}) + } + + _, err := endpoint.queryCacher.RunQuery(ctx) + require.NoError(t, err) + + resp, err := endpoint.Execute(ctx, &stateAndCountGetRequest{}) + require.NoError(t, err) + require.Equal(t, &stateAndCountGetResponse{ + Available: queryCacheSkipThreshold - 1, + }, resp) + }) } func TestAPIHandlerWorkflowGet(t *testing.T) { @@ -507,7 +555,7 @@ func TestAPIHandlerWorkflowGet(t *testing.T) { t.Run("Success", func(t *testing.T) { t.Parallel() - endpoint, bundle := setupEndpoint[workflowGetEndpoint](ctx, t) + endpoint, bundle := setupEndpoint(ctx, t, newWorkflowGetEndpoint) workflowID := uuid.New() job1 := testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{Metadata: mustMarshalJSON(t, map[string]uuid.UUID{"workflow_id": workflowID})}) @@ -523,7 +571,7 @@ func TestAPIHandlerWorkflowGet(t *testing.T) { t.Run("NotFound", func(t *testing.T) { t.Parallel() - endpoint, _ := setupEndpoint[workflowGetEndpoint](ctx, t) + endpoint, _ := setupEndpoint(ctx, t, newWorkflowGetEndpoint) workflowID := uuid.New() @@ -540,7 +588,7 @@ func TestAPIHandlerWorkflowList(t *testing.T) { t.Run("Success", func(t *testing.T) { t.Parallel() - endpoint, bundle := setupEndpoint[workflowListEndpoint](ctx, t) + endpoint, bundle := setupEndpoint(ctx, t, newWorkflowListEndpoint) job1 := testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{ Metadata: []byte(`{"workflow_id":"1", "workflow_name":"first_wf", "task":"a"}`), @@ -610,7 +658,7 @@ func TestAPIHandlerWorkflowList(t *testing.T) { t.Run("Limit", func(t *testing.T) { t.Parallel() - endpoint, bundle := setupEndpoint[workflowListEndpoint](ctx, t) + endpoint, bundle := setupEndpoint(ctx, t, newWorkflowListEndpoint) _ = testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{ Metadata: []byte(`{"workflow_id":"1", "workflow_name":"first_wf", "task":"a"}`), diff --git a/handler_test.go b/handler_test.go index 303e877..0f7cb74 100644 --- a/handler_test.go +++ b/handler_test.go @@ -45,7 +45,7 @@ func TestNewHandlerIntegration(t *testing.T) { require.NoError(t, err) t.Cleanup(func() { tx.Rollback(ctx) }) - handler, err := NewHandler(&HandlerOpts{ + server, err := NewServer(&HandlerOpts{ Client: client, DBPool: tx, Logger: logger, @@ -62,7 +62,7 @@ func TestNewHandlerIntegration(t *testing.T) { t.Logf("--> %s %s", method, path) - handler.ServeHTTP(recorder, req) + server.Handler().ServeHTTP(recorder, req) status := recorder.Result().StatusCode //nolint:bodyclose t.Logf("Response status: %d", status) diff --git a/internal/apiendpoint/api_endpoint.go b/internal/apiendpoint/api_endpoint.go index 8f7c753..f440d6d 100644 --- a/internal/apiendpoint/api_endpoint.go +++ b/internal/apiendpoint/api_endpoint.go @@ -33,15 +33,7 @@ type Endpoint[TReq any, TResp any] struct { func (e *Endpoint[TReq, TResp]) SetLogger(logger *slog.Logger) { e.logger = logger } func (e *Endpoint[TReq, TResp]) SetMeta(meta *EndpointMeta) { e.meta = meta } -// EndpointInterface is an interface to an API endpoint. Some of it is -// implemented by an embedded Endpoint struct, and some of it should be -// implemented by the endpoint itself. -type EndpointInterface[TReq any, TResp any] interface { - // Execute executes the API endpoint. - // - // This should be implemented by each specific API endpoint. - Execute(ctx context.Context, req *TReq) (*TResp, error) - +type EndpointInterface interface { // Meta returns metadata about an API endpoint, like the path it should be // mounted at, and the status code it returns on success. // @@ -60,6 +52,18 @@ type EndpointInterface[TReq any, TResp any] interface { SetMeta(meta *EndpointMeta) } +// EndpointExecuteInterface is an interface to an API endpoint. Some of it is +// implemented by an embedded Endpoint struct, and some of it should be +// implemented by the endpoint itself. +type EndpointExecuteInterface[TReq any, TResp any] interface { + EndpointInterface + + // Execute executes the API endpoint. + // + // This should be implemented by each specific API endpoint. + Execute(ctx context.Context, req *TReq) (*TResp, error) +} + // EndpointMeta is metadata about an API endpoint. type EndpointMeta struct { // Pattern is the API endpoint's HTTP method and path where it should be @@ -84,7 +88,7 @@ func (m *EndpointMeta) validate() { // Mount mounts an endpoint to a Go http.ServeMux. The logger is used to log // information about endpoint execution. -func Mount[TReq any, TResp any](mux *http.ServeMux, logger *slog.Logger, apiEndpoint EndpointInterface[TReq, TResp]) { +func Mount[TReq any, TResp any](mux *http.ServeMux, logger *slog.Logger, apiEndpoint EndpointExecuteInterface[TReq, TResp]) EndpointInterface { apiEndpoint.SetLogger(logger) meta := apiEndpoint.Meta() @@ -94,10 +98,12 @@ func Mount[TReq any, TResp any](mux *http.ServeMux, logger *slog.Logger, apiEndp mux.HandleFunc(meta.Pattern, func(w http.ResponseWriter, r *http.Request) { executeAPIEndpoint(w, r, logger, meta, apiEndpoint.Execute) }) + + return apiEndpoint } func executeAPIEndpoint[TReq any, TResp any](w http.ResponseWriter, r *http.Request, logger *slog.Logger, meta *EndpointMeta, execute func(ctx context.Context, req *TReq) (*TResp, error)) { - ctx, cancel := context.WithTimeout(r.Context(), 5*time.Second) + ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second) defer cancel() // Run as much code as we can in a sub-function that can return an error. diff --git a/internal/dbsqlc/query.sql b/internal/dbsqlc/query.sql index bba9af3..9392812 100644 --- a/internal/dbsqlc/query.sql +++ b/internal/dbsqlc/query.sql @@ -1,14 +1,6 @@ -- name: JobCountByState :many -SELECT - state, - count(*) -FROM - river_job -WHERE - queue IS NOT NULL AND - priority > 0 AND - scheduled_at IS NOT NULL AND - id IS NOT NULL +SELECT state, count(*) +FROM river_job GROUP BY state; -- name: JobListWorkflow :many diff --git a/internal/dbsqlc/query.sql.go b/internal/dbsqlc/query.sql.go index 7349a23..a63a459 100644 --- a/internal/dbsqlc/query.sql.go +++ b/internal/dbsqlc/query.sql.go @@ -77,16 +77,8 @@ func (q *Queries) JobCountByQueueAndState(ctx context.Context, db DBTX, queueNam } const jobCountByState = `-- name: JobCountByState :many -SELECT - state, - count(*) -FROM - river_job -WHERE - queue IS NOT NULL AND - priority > 0 AND - scheduled_at IS NOT NULL AND - id IS NOT NULL +SELECT state, count(*) +FROM river_job GROUP BY state ` diff --git a/internal/querycacher/query_cacher.go b/internal/querycacher/query_cacher.go new file mode 100644 index 0000000..812aaba --- /dev/null +++ b/internal/querycacher/query_cacher.go @@ -0,0 +1,144 @@ +package querycacher + +import ( + "context" + "regexp" + "sync" + "time" + + "github.com/riverqueue/river/rivershared/baseservice" + "github.com/riverqueue/river/rivershared/startstop" + "github.com/riverqueue/riverui/internal/dbsqlc" +) + +// QueryCacher executes a database query periodically and caches the result. The +// basic premise is that given large River databases even simple queries like +// counting job rows can get quite slow. This construct operates as a background +// service that runs a query, stores the result, making it available to API +// endpoints so that they don't need to execute the query in band. +type QueryCacher[TRes any] struct { + baseservice.BaseService + startstop.BaseStartStop + + cachedRes TRes + cachedResSet bool + db dbsqlc.DBTX + mu sync.RWMutex + runQuery func(ctx context.Context, dbtx dbsqlc.DBTX) (TRes, error) + runQueryTestChan chan struct{} // closed when query is run; for testing + tickPeriod time.Duration // constant normally, but settable for testing +} + +func NewQueryCacher[TRes any](archetype *baseservice.Archetype, db dbsqlc.DBTX, runQuery func(ctx context.Context, db dbsqlc.DBTX) (TRes, error)) *QueryCacher[TRes] { + // +/- 1s random variance to ticker interval. Makes sure that given multiple + // query caches running simultaneously, they all start and are scheduled a + // little differently to make a thundering herd problem less likely. + randomTickVariance := time.Duration(archetype.Rand.Float64()*float64(2*time.Second)) - 1*time.Second + + queryCacher := baseservice.Init(archetype, &QueryCacher[TRes]{ + db: db, + runQuery: runQuery, + tickPeriod: 10*time.Second + randomTickVariance, + }) + + // TODO(brandur): Push this up into baseservice. + queryCacher.Name = simplifyArchetypeLogName(queryCacher.Name) + + return queryCacher +} + +// CachedRes returns cached results, if there are any, and an "ok" boolean +// indicating whether cached results were available (true if so, and false +// otherwise). +func (s *QueryCacher[TRes]) CachedRes() (TRes, bool) { + s.mu.RLock() + defer s.mu.RUnlock() + + if !s.cachedResSet { + var emptyRes TRes + return emptyRes, false + } + + return s.cachedRes, true +} + +// RunQuery runs the internal query function and caches the result. It's not +// usually necessary to call this function explicitly since Start will do it +// periodically, but is made available for use in places like helping with +// testing. +func (s *QueryCacher[TRes]) RunQuery(ctx context.Context) (TRes, error) { + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + + start := time.Now() + + res, err := s.runQuery(ctx, s.db) + if err != nil { + var emptyRes TRes + return emptyRes, err + } + + s.Logger.DebugContext(ctx, s.Name+": Ran query and cached result", "duration", time.Since(start), "tick_period", s.tickPeriod) + + s.mu.Lock() + s.cachedRes = res + s.cachedResSet = true + s.mu.Unlock() + + // Tells a test that it can wake up and handle a result. + if s.runQueryTestChan != nil { + close(s.runQueryTestChan) + s.runQueryTestChan = nil + } + + return res, nil +} + +// Start starts the service, causing it to periodically run its query and cache +// the result. It stops when Stop is called or if its context is cancelled. +func (s *QueryCacher[TRes]) Start(ctx context.Context) error { + ctx, shouldStart, started, stopped := s.StartInit(ctx) + if !shouldStart { + return nil + } + + go func() { + started() + defer stopped() + + // In case a query runs long and exceeds tickPeriod, time.Ticker will + // drop ticks to compensate. + ticker := time.NewTicker(s.tickPeriod) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + + case <-ticker.C: + if _, err := s.RunQuery(ctx); err != nil { + s.Logger.ErrorContext(ctx, s.Name+": Error running query", "err", err) + } + } + } + }() + + return nil +} + +// Simplifies the name of a Go type that uses generics for cleaner logging output. +// +// So this: +// +// QueryCacher[[]*github.com/riverqueue/riverui/internal/dbsqlc.JobCountByStateRow] +// +// Becomes this: +// +// QueryCacher[[]*dbsqlc.JobCountByStateRow] +// +// TODO(brandur): Push this up into baseservice. +func simplifyArchetypeLogName(name string) string { + re := regexp.MustCompile(`\[([\[\]\*]*).*/([^/]+)\]`) + return re.ReplaceAllString(name, `[$1$2]`) +} diff --git a/internal/querycacher/query_cacher_test.go b/internal/querycacher/query_cacher_test.go new file mode 100644 index 0000000..d7459df --- /dev/null +++ b/internal/querycacher/query_cacher_test.go @@ -0,0 +1,125 @@ +package querycacher + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/riverqueue/river/riverdriver" + "github.com/riverqueue/river/riverdriver/riverpgxv5" + "github.com/riverqueue/river/rivershared/riversharedtest" + "github.com/riverqueue/river/rivershared/startstoptest" + "github.com/riverqueue/riverui/internal/dbsqlc" + "github.com/riverqueue/riverui/internal/riverinternaltest" + "github.com/riverqueue/riverui/internal/riverinternaltest/testfactory" +) + +func TestQueryCacher(t *testing.T) { + t.Parallel() + + ctx := context.Background() + + type testBundle struct { + exec riverdriver.ExecutorTx + } + + setup := func(ctx context.Context, t *testing.T) (*QueryCacher[[]*dbsqlc.JobCountByStateRow], *testBundle) { + t.Helper() + + var ( + archetype = riversharedtest.BaseServiceArchetype(t) + driver = riverpgxv5.New(nil) + tx = riverinternaltest.TestTx(ctx, t) + queryCacher = NewQueryCacher(archetype, tx, dbsqlc.New().JobCountByState) + ) + + return queryCacher, &testBundle{ + exec: driver.UnwrapExecutor(tx), + } + } + + start := func(ctx context.Context, t *testing.T, queryCacher *QueryCacher[[]*dbsqlc.JobCountByStateRow]) { + t.Helper() + + require.NoError(t, queryCacher.Start(ctx)) + t.Cleanup(queryCacher.Stop) + } + + t.Run("NoCachedResult", func(t *testing.T) { + t.Parallel() + + queryCacher, _ := setup(ctx, t) + + _, ok := queryCacher.CachedRes() + require.False(t, ok) + }) + + t.Run("WithCachedResult", func(t *testing.T) { + t.Parallel() + + queryCacher, bundle := setup(ctx, t) + + _ = testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{}) + + _, err := queryCacher.RunQuery(ctx) + require.NoError(t, err) + + res, ok := queryCacher.CachedRes() + require.True(t, ok) + require.Equal(t, []*dbsqlc.JobCountByStateRow{ + {State: dbsqlc.RiverJobStateAvailable, Count: 1}, + }, res) + }) + + t.Run("RunsPeriodically", func(t *testing.T) { + t.Parallel() + + queryCacher, bundle := setup(ctx, t) + + _ = testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{}) + + runQueryTestChan := make(chan struct{}) + queryCacher.runQueryTestChan = runQueryTestChan + + // Dramatically reduce tick period so we don't have to wait the full time. + queryCacher.tickPeriod = 1 * time.Millisecond + + start(ctx, t, queryCacher) + + riversharedtest.WaitOrTimeout(t, runQueryTestChan) + + res, ok := queryCacher.CachedRes() + require.True(t, ok) + require.Equal(t, []*dbsqlc.JobCountByStateRow{ + {State: dbsqlc.RiverJobStateAvailable, Count: 1}, + }, res) + }) + + t.Run("StartStopStress", func(t *testing.T) { + t.Parallel() + + queryCacher, _ := setup(ctx, t) + startstoptest.Stress(ctx, t, queryCacher) + }) +} + +func TestSimplifyArchetypeLogName(t *testing.T) { + t.Parallel() + + require.Equal(t, "NotGeneric", simplifyArchetypeLogName("NotGeneric")) + + // Simplified for use during debugging. Real generics will tend to have + // fully qualified paths and not look like this. + require.Equal(t, "Simple[int]", simplifyArchetypeLogName("Simple[int]")) + require.Equal(t, "Simple[*int]", simplifyArchetypeLogName("Simple[*int]")) + require.Equal(t, "Simple[[]int]", simplifyArchetypeLogName("Simple[[]int]")) + require.Equal(t, "Simple[[]*int]", simplifyArchetypeLogName("Simple[[]*int]")) + + // More realistic examples. + require.Equal(t, "QueryCacher[dbsqlc.JobCountByStateRow]", simplifyArchetypeLogName("QueryCacher[github.com/riverqueue/riverui/internal/dbsqlc.JobCountByStateRow]")) + require.Equal(t, "QueryCacher[*dbsqlc.JobCountByStateRow]", simplifyArchetypeLogName("QueryCacher[*github.com/riverqueue/riverui/internal/dbsqlc.JobCountByStateRow]")) + require.Equal(t, "QueryCacher[[]dbsqlc.JobCountByStateRow]", simplifyArchetypeLogName("QueryCacher[[]github.com/riverqueue/riverui/internal/dbsqlc.JobCountByStateRow]")) + require.Equal(t, "QueryCacher[[]*dbsqlc.JobCountByStateRow]", simplifyArchetypeLogName("QueryCacher[[]*github.com/riverqueue/riverui/internal/dbsqlc.JobCountByStateRow]")) +}