From 33f2b33eb087036e17327ceed81389e210788f20 Mon Sep 17 00:00:00 2001 From: Brandur Date: Tue, 6 Aug 2024 17:41:58 -0700 Subject: [PATCH] API-side query caching for job counts This one's related to trying to find a solution for #106. After messing around with query plans a lot on the "count by state" query, I came to the conclusion in the end that Postgres might actually be doing the right thing by falling back to a sequential scan, or at least only the minimally wrong thing. Even forcing the count query to run against a well-used index is a fairly slow operation when there are many jobs in the database. It's hard to provide specifics because caching affects the result so much (so running the same query twice in a row can produce vastly different timings), but I've seen the index version take _longer_ than the seq scan in some cases. So here, I'm proposing a radically different solution in which we add some infrastructure to the River UI API server that lets it run slow queries periodically in the background, then have API endpoints take advantage of those cached results instead of having to run each operation themselves, thereby making their responses ~instant. I've written it such that this caching only kicks in when we know we're working with a very large data set where it actually matters (currently defined as > 1M rows), with the idea being that for smaller databases we'll continue to run queries in-band so that results look as fresh and real-time as possible. To support this, I've had to make some changes to the River UI API server/handler so that it has a `Start` function that can be invoked to start background utilities like the query cache. It's a considerable change, but I think it leaves us in a more sustainable place API-wise because we may want to add other background utilities later on, and returning an `http.Handler` isn't enough because even if you were to start goroutines from `NewHandler`, it's very, very not ideal that there's no way to stop those goroutines again (problematic for anything that wants to check for leaks with goleak). I'm also going to propose that we increase the default API endpoint timeout from 5 seconds to 10 seconds. When I load in 3 to 5 million job rows, I see count queries taking right around that 3 to 5 seconds range. Since the original number of 5 seconds was a little arbitrary anyway, it can't hurt to give those queries a little more leeway. A problem that could still occur even with my proposal here is that if a user starts River UI and then immediately hits the UI, there won't be a cached results yet, and therefore the count query will go to the database directly, and that may still cause a timeout at 5 seconds. I've only applied caching to the count timeout so far, but I've written the `QueryCacher` code such that it can cleanly support other queries if we care to add them. --- cmd/riverui/main.go | 19 ++- go.mod | 17 ++- go.sum | 32 ++-- handler.go | 101 ++++++++++-- api_handler.go => handler_api_endpoint.go | 108 +++++++++++-- ...er_test.go => handler_api_endpoint_test.go | 124 ++++++++++----- handler_test.go | 4 +- internal/apiendpoint/api_endpoint.go | 28 ++-- internal/dbsqlc/query.sql | 12 +- internal/dbsqlc/query.sql.go | 12 +- internal/querycacher/query_cacher.go | 144 ++++++++++++++++++ internal/querycacher/query_cacher_test.go | 125 +++++++++++++++ 12 files changed, 595 insertions(+), 131 deletions(-) rename api_handler.go => handler_api_endpoint.go (88%) rename api_handler_test.go => handler_api_endpoint_test.go (83%) create mode 100644 internal/querycacher/query_cacher.go create mode 100644 internal/querycacher/query_cacher_test.go diff --git a/cmd/riverui/main.go b/cmd/riverui/main.go index fadd8bb..2c2f16e 100644 --- a/cmd/riverui/main.go +++ b/cmd/riverui/main.go @@ -29,15 +29,17 @@ func main() { if err := godotenv.Load(); err != nil { fmt.Printf("No .env file detected, using environment variables\n") } - logger = slog.New(slog.NewTextHandler(os.Stdout, nil)) + + if os.Getenv("RIVER_DEBUG") == "1" || os.Getenv("RIVER_DEBUG") == "true" { + logger = slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug})) + } else { + logger = slog.New(slog.NewTextHandler(os.Stdout, nil)) + } os.Exit(initAndServe(ctx)) } func initAndServe(ctx context.Context) int { - ctx, cancel := context.WithCancel(ctx) - defer cancel() - var pathPrefix string flag.StringVar(&pathPrefix, "prefix", "/", "path prefix to use for the API and UI HTTP requests") flag.Parse() @@ -82,13 +84,18 @@ func initAndServe(ctx context.Context) int { Prefix: pathPrefix, } - handler, err := riverui.NewHandler(handlerOpts) + server, err := riverui.NewServer(handlerOpts) if err != nil { logger.ErrorContext(ctx, "error creating handler", slog.String("error", err.Error())) return 1 } - logHandler := sloghttp.Recovery(handler) + if err := server.Start(ctx); err != nil { + logger.ErrorContext(ctx, "error starting UI server", slog.String("error", err.Error())) + return 1 + } + + logHandler := sloghttp.Recovery(server.Handler()) config := sloghttp.Config{ WithSpanID: otelEnabled, WithTraceID: otelEnabled, diff --git a/go.mod b/go.mod index daef98a..c882c9a 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,8 @@ module github.com/riverqueue/riverui -go 1.22.5 +go 1.21 + +toolchain go1.22.5 require ( github.com/go-playground/validator/v10 v10.22.0 @@ -8,10 +10,11 @@ require ( github.com/jackc/pgerrcode v0.0.0-20240316143900-6e2875d9b438 github.com/jackc/pgx/v5 v5.6.0 github.com/joho/godotenv v1.5.1 - github.com/riverqueue/river v0.8.0 - github.com/riverqueue/river/riverdriver v0.8.0 - github.com/riverqueue/river/riverdriver/riverpgxv5 v0.8.0 - github.com/riverqueue/river/rivertype v0.8.0 + github.com/riverqueue/river v0.11.2 + github.com/riverqueue/river/riverdriver v0.11.2 + github.com/riverqueue/river/riverdriver/riverpgxv5 v0.11.2 + github.com/riverqueue/river/rivershared v0.11.2 + github.com/riverqueue/river/rivertype v0.11.2 github.com/rs/cors v1.11.0 github.com/samber/slog-http v1.0.0 github.com/stretchr/testify v1.9.0 @@ -27,12 +30,12 @@ require ( github.com/jackc/puddle/v2 v2.2.1 // indirect github.com/leodido/go-urn v1.4.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/riverqueue/river/rivershared v0.11.1 // indirect go.opentelemetry.io/otel v1.19.0 // indirect go.opentelemetry.io/otel/trace v1.19.0 // indirect + go.uber.org/goleak v1.3.0 // indirect golang.org/x/crypto v0.22.0 // indirect golang.org/x/net v0.23.0 // indirect - golang.org/x/sync v0.7.0 // indirect + golang.org/x/sync v0.8.0 // indirect golang.org/x/sys v0.19.0 // indirect golang.org/x/text v0.16.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/go.sum b/go.sum index 0a16865..455ce18 100644 --- a/go.sum +++ b/go.sum @@ -37,22 +37,22 @@ github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/riverqueue/river v0.8.0 h1:IBUIP9eZX/dkLQ3T+XNNk0Zi7iyUksZd4aHxQIFChOQ= -github.com/riverqueue/river v0.8.0/go.mod h1:EHRbhqVXDpXQizFh4lndwswu53N0txITrLM2y3vOIF4= -github.com/riverqueue/river/riverdriver v0.8.0 h1:vSeIvf2Z+/hHH4QF1NK/rvzuZJeZZ+voHz55ZPf9efA= -github.com/riverqueue/river/riverdriver v0.8.0/go.mod h1:YZUVae96RsQJaAem0o0EpgD7fDNPdl/qJiuUFh/vkVE= -github.com/riverqueue/river/riverdriver/riverdatabasesql v0.8.0 h1:eH6kkU8qstq1Rj7d0PBYmptaZy6vPsea0WzhBf7/SL4= -github.com/riverqueue/river/riverdriver/riverdatabasesql v0.8.0/go.mod h1:4jXPB30TNOWSeOvNvk1Mdov4XIMTBCnIzysrdAXizzs= -github.com/riverqueue/river/riverdriver/riverpgxv5 v0.8.0 h1:9lF2GQIU0Z5gynaY6kevJwW5ycy/VbH9S/iYu0+Lf7U= -github.com/riverqueue/river/riverdriver/riverpgxv5 v0.8.0/go.mod h1:rPTUHOdsrQIEyeEesEaBzNyj0Hs4VtXGUHHPC4JwgZ0= -github.com/riverqueue/river/rivershared v0.11.1 h1:5HDZ5fPrHf68lrE2CTTTUfRfdCmfW1G6P/v0zCvor7I= -github.com/riverqueue/river/rivershared v0.11.1/go.mod h1:2egnQ7czNcW8IXKXMRjko0aEMrQzF4V3k3jddmYiihE= -github.com/riverqueue/river/rivertype v0.8.0 h1:Ys49e1AECeIOTxRquXC446uIEPXiXLMNVKD4KwexJPM= -github.com/riverqueue/river/rivertype v0.8.0/go.mod h1:nDd50b/mIdxR/ezQzGS/JiAhBPERA7tUIne21GdfspQ= +github.com/riverqueue/river v0.11.2 h1:U1f0xZ+B3qdOJSHJ8A2c93CEsFQGGkbG4ZN8blUas5g= +github.com/riverqueue/river v0.11.2/go.mod h1:0MCkMUIjwAjkKAmcWEbHP1IKWiXq+Z3iNVK5dsYVQYY= +github.com/riverqueue/river/riverdriver v0.11.2 h1:2xC+R0Y+CFEOSDWKyeFef0wqQLuvhk3PsLkos7MLa1w= +github.com/riverqueue/river/riverdriver v0.11.2/go.mod h1:RhMuAjEtNGexwOFnz445G1iFNZVOnYQ90HDYxHMI+jM= +github.com/riverqueue/river/riverdriver/riverdatabasesql v0.11.2 h1:I4ye1YEa35kqB6Jd3xVPNxbGDL6S1gpSTkZu25qffhc= +github.com/riverqueue/river/riverdriver/riverdatabasesql v0.11.2/go.mod h1:+cOcD4U+8ugUeRZVTGqVhtScy0FS7LPyp+ZsoPIeoMI= +github.com/riverqueue/river/riverdriver/riverpgxv5 v0.11.2 h1:yxFi09ECN02iAr2uO0n7QhFKAyyGZ+Rn9fzKTt2TGhk= +github.com/riverqueue/river/riverdriver/riverpgxv5 v0.11.2/go.mod h1:ajPqIw7OgYBfR24MqH3VGI/SiYVgq0DkvdM7wrs+uDA= +github.com/riverqueue/river/rivershared v0.11.2 h1:VbuLE6zm68R24xBi1elfnerhLBBn6X7DUxR9j4mcTR4= +github.com/riverqueue/river/rivershared v0.11.2/go.mod h1:J4U3qm8MbjHY1o5OlRNiWaminYagec1o8sHYX4ZQ4S4= +github.com/riverqueue/river/rivertype v0.11.2 h1:YREWOGxDMDe1DTdvttwr2DVq/ql65u6e4jkw3VxuNyU= +github.com/riverqueue/river/rivertype v0.11.2/go.mod h1:bm5EMOGAEWhtXKqo27POWnViqSD5nHMZDP/jsrJc530= github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs= github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro= -github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= -github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA= +github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= +github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= github.com/rs/cors v1.11.0 h1:0B9GE/r9Bc2UxRMMtymBkHTenPkHDv0CW4Y98GBY+po= github.com/rs/cors v1.11.0/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU= github.com/samber/slog-http v1.0.0 h1:KjxyJm2lOsuWBt904A04qvrp+0ZvOfwDnk6jI8h7/5c= @@ -72,8 +72,8 @@ golang.org/x/crypto v0.22.0 h1:g1v0xeRhjcugydODzvb3mEM9SQ0HGp9s/nh3COQ/C30= golang.org/x/crypto v0.22.0/go.mod h1:vr6Su+7cTlO45qkww3VDJlzDn0ctJvRgYbC2NvXHt+M= golang.org/x/net v0.23.0 h1:7EYJ93RZ9vYSZAIb2x3lnuvqO5zneoD6IvWjuhfxjTs= golang.org/x/net v0.23.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= -golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= -golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= +golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.19.0 h1:q5f1RH2jigJ1MoAWp2KTp3gm5zAGFUTarQZ5U386+4o= golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= diff --git a/handler.go b/handler.go index 71763d6..6a46642 100644 --- a/handler.go +++ b/handler.go @@ -14,6 +14,9 @@ import ( "github.com/jackc/pgx/v5" "github.com/riverqueue/river" + "github.com/riverqueue/river/rivershared/baseservice" + "github.com/riverqueue/river/rivershared/startstop" + "github.com/riverqueue/river/rivershared/util/randutil" "github.com/riverqueue/riverui/internal/apiendpoint" "github.com/riverqueue/riverui/internal/apimiddleware" "github.com/riverqueue/riverui/internal/dbsqlc" @@ -62,8 +65,14 @@ func normalizePathPrefix(prefix string) string { return prefix } -// NewHandler creates a new http.Handler that serves the River UI and API. -func NewHandler(opts *HandlerOpts) (http.Handler, error) { +type Server struct { + baseStartStop startstop.BaseStartStop + handler http.Handler + services []startstop.Service +} + +// NewServer creates a new http.Handler that serves the River UI and API. +func NewServer(opts *HandlerOpts) (*Server, error) { if opts == nil { return nil, errors.New("opts is required") } @@ -80,6 +89,12 @@ func NewHandler(opts *HandlerOpts) (http.Handler, error) { serveIndex := serveFileContents("index.html", httpFS) apiBundle := apiBundle{ + // TODO: Switch to baseservice.NewArchetype when available. + archetype: &baseservice.Archetype{ + Logger: opts.Logger, + Rand: randutil.NewCryptoSeededConcurrentSafeRand(), + Time: &baseservice.UnStubbableTimeGenerator{}, + }, client: opts.Client, dbPool: opts.DBPool, logger: opts.Logger, @@ -88,19 +103,35 @@ func NewHandler(opts *HandlerOpts) (http.Handler, error) { prefix := opts.Prefix mux := http.NewServeMux() - apiendpoint.Mount(mux, opts.Logger, &healthCheckGetEndpoint{apiBundle: apiBundle}) - apiendpoint.Mount(mux, opts.Logger, &jobCancelEndpoint{apiBundle: apiBundle}) - apiendpoint.Mount(mux, opts.Logger, &jobDeleteEndpoint{apiBundle: apiBundle}) - apiendpoint.Mount(mux, opts.Logger, &jobListEndpoint{apiBundle: apiBundle}) - apiendpoint.Mount(mux, opts.Logger, &jobRetryEndpoint{apiBundle: apiBundle}) - apiendpoint.Mount(mux, opts.Logger, &jobGetEndpoint{apiBundle: apiBundle}) - apiendpoint.Mount(mux, opts.Logger, &queueGetEndpoint{apiBundle: apiBundle}) - apiendpoint.Mount(mux, opts.Logger, &queueListEndpoint{apiBundle: apiBundle}) - apiendpoint.Mount(mux, opts.Logger, &queuePauseEndpoint{apiBundle: apiBundle}) - apiendpoint.Mount(mux, opts.Logger, &queueResumeEndpoint{apiBundle: apiBundle}) - apiendpoint.Mount(mux, opts.Logger, &stateAndCountGetEndpoint{apiBundle: apiBundle}) - apiendpoint.Mount(mux, opts.Logger, &workflowGetEndpoint{apiBundle: apiBundle}) - apiendpoint.Mount(mux, opts.Logger, &workflowListEndpoint{apiBundle: apiBundle}) + + endpoints := []apiendpoint.EndpointInterface{ + apiendpoint.Mount(mux, opts.Logger, newHealthCheckGetEndpoint(apiBundle)), + apiendpoint.Mount(mux, opts.Logger, newJobCancelEndpoint(apiBundle)), + apiendpoint.Mount(mux, opts.Logger, newJobDeleteEndpoint(apiBundle)), + apiendpoint.Mount(mux, opts.Logger, newJobListEndpoint(apiBundle)), + apiendpoint.Mount(mux, opts.Logger, newJobRetryEndpoint(apiBundle)), + apiendpoint.Mount(mux, opts.Logger, newJobGetEndpoint(apiBundle)), + apiendpoint.Mount(mux, opts.Logger, newQueueGetEndpoint(apiBundle)), + apiendpoint.Mount(mux, opts.Logger, newQueueListEndpoint(apiBundle)), + apiendpoint.Mount(mux, opts.Logger, newQueuePauseEndpoint(apiBundle)), + apiendpoint.Mount(mux, opts.Logger, newQueueResumeEndpoint(apiBundle)), + apiendpoint.Mount(mux, opts.Logger, newStateAndCountGetEndpoint(apiBundle)), + apiendpoint.Mount(mux, opts.Logger, newWorkflowGetEndpoint(apiBundle)), + apiendpoint.Mount(mux, opts.Logger, newWorkflowListEndpoint(apiBundle)), + } + + var services []startstop.Service + + type WithSubServices interface { + SubServices() []startstop.Service + } + + // If any endpoints are start/stop services, start them up. + for _, endpoint := range endpoints { + if withSubServices, ok := endpoint.(WithSubServices); ok { + services = append(services, withSubServices.SubServices()...) + } + } if err := mountStaticFiles(opts.Logger, mux); err != nil { return nil, err @@ -115,7 +146,45 @@ func NewHandler(opts *HandlerOpts) (http.Handler, error) { middlewareStack.Use(&stripPrefixMiddleware{prefix}) } - return middlewareStack.Mount(mux), nil + server := &Server{ + handler: middlewareStack.Mount(mux), + services: services, + } + + return server, nil +} + +// Handler returns an http.Handler that can be mounted to serve HTTP requests. +func (s *Server) Handler() http.Handler { return s.handler } + +// Start starts the server's background services. Notably, this does _not_ cause +// the server to start listening for HTTP in any way. To do that, call Handler +// and mount or run it using Go's built in `net/http`. +func (s *Server) Start(ctx context.Context) error { + ctx, shouldStart, started, stopped := s.baseStartStop.StartInit(ctx) + if !shouldStart { + return nil + } + + for _, service := range s.services { + if err := service.Start(ctx); err != nil { + return err + } + } + + go func() { + // Wait for all subservices to start up before signaling our own start. + startstop.WaitAllStarted(s.services...) + + started() + defer stopped() // this defer should come first so it's last out + + <-ctx.Done() + + startstop.StopAllParallel(s.services...) + }() + + return nil } //go:embed public diff --git a/api_handler.go b/handler_api_endpoint.go similarity index 88% rename from api_handler.go rename to handler_api_endpoint.go index 087c7c2..d7bc55b 100644 --- a/api_handler.go +++ b/handler_api_endpoint.go @@ -14,20 +14,24 @@ import ( "github.com/jackc/pgx/v5/pgtype" "github.com/riverqueue/river" + "github.com/riverqueue/river/rivershared/baseservice" + "github.com/riverqueue/river/rivershared/startstop" "github.com/riverqueue/river/rivershared/util/ptrutil" "github.com/riverqueue/river/rivershared/util/sliceutil" "github.com/riverqueue/river/rivertype" "github.com/riverqueue/riverui/internal/apiendpoint" "github.com/riverqueue/riverui/internal/apierror" "github.com/riverqueue/riverui/internal/dbsqlc" + "github.com/riverqueue/riverui/internal/querycacher" "github.com/riverqueue/riverui/internal/util/pgxutil" ) // A bundle of common utilities needed for many API endpoints. type apiBundle struct { - client *river.Client[pgx.Tx] - dbPool DBTXWithBegin - logger *slog.Logger + archetype *baseservice.Archetype + client *river.Client[pgx.Tx] + dbPool DBTXWithBegin + logger *slog.Logger } // SetBundle sets all values to the same as the given bundle. @@ -35,14 +39,6 @@ func (a *apiBundle) SetBundle(bundle *apiBundle) { *a = *bundle } -// withSetBundle is an interface that's automatically implemented by types that -// embed apiBundle. It lets places like tests generically set bundle values on -// any general endpoint type. -type withSetBundle interface { - // SetBundle sets all values to the same as the given bundle. - SetBundle(bundle *apiBundle) -} - type listResponse[T any] struct { Data []*T `json:"data"` } @@ -66,6 +62,10 @@ type healthCheckGetEndpoint struct { apiendpoint.Endpoint[healthCheckGetRequest, statusResponse] } +func newHealthCheckGetEndpoint(apiBundle apiBundle) *healthCheckGetEndpoint { + return &healthCheckGetEndpoint{apiBundle: apiBundle} +} + func (*healthCheckGetEndpoint) Meta() *apiendpoint.EndpointMeta { return &apiendpoint.EndpointMeta{ Pattern: "GET /api/health-checks/{name}", @@ -118,6 +118,10 @@ type jobCancelEndpoint struct { apiendpoint.Endpoint[jobCancelRequest, statusResponse] } +func newJobCancelEndpoint(apiBundle apiBundle) *jobCancelEndpoint { + return &jobCancelEndpoint{apiBundle: apiBundle} +} + func (*jobCancelEndpoint) Meta() *apiendpoint.EndpointMeta { return &apiendpoint.EndpointMeta{ Pattern: "POST /api/jobs/cancel", @@ -158,6 +162,10 @@ type jobDeleteEndpoint struct { apiendpoint.Endpoint[jobDeleteRequest, statusResponse] } +func newJobDeleteEndpoint(apiBundle apiBundle) *jobDeleteEndpoint { + return &jobDeleteEndpoint{apiBundle: apiBundle} +} + func (*jobDeleteEndpoint) Meta() *apiendpoint.EndpointMeta { return &apiendpoint.EndpointMeta{ Pattern: "POST /api/jobs/delete", @@ -198,6 +206,10 @@ type jobGetEndpoint struct { apiendpoint.Endpoint[jobGetRequest, RiverJob] } +func newJobGetEndpoint(apiBundle apiBundle) *jobGetEndpoint { + return &jobGetEndpoint{apiBundle: apiBundle} +} + func (*jobGetEndpoint) Meta() *apiendpoint.EndpointMeta { return &apiendpoint.EndpointMeta{ Pattern: "GET /api/jobs/{job_id}", @@ -243,6 +255,10 @@ type jobListEndpoint struct { apiendpoint.Endpoint[jobCancelRequest, listResponse[RiverJob]] } +func newJobListEndpoint(apiBundle apiBundle) *jobListEndpoint { + return &jobListEndpoint{apiBundle: apiBundle} +} + func (*jobListEndpoint) Meta() *apiendpoint.EndpointMeta { return &apiendpoint.EndpointMeta{ Pattern: "GET /api/jobs", @@ -305,6 +321,10 @@ type jobRetryEndpoint struct { apiendpoint.Endpoint[jobRetryRequest, statusResponse] } +func newJobRetryEndpoint(apiBundle apiBundle) *jobRetryEndpoint { + return &jobRetryEndpoint{apiBundle: apiBundle} +} + func (*jobRetryEndpoint) Meta() *apiendpoint.EndpointMeta { return &apiendpoint.EndpointMeta{ Pattern: "POST /api/jobs/retry", @@ -342,6 +362,10 @@ type queueGetEndpoint struct { apiendpoint.Endpoint[jobCancelRequest, RiverQueue] } +func newQueueGetEndpoint(apiBundle apiBundle) *queueGetEndpoint { + return &queueGetEndpoint{apiBundle: apiBundle} +} + func (*queueGetEndpoint) Meta() *apiendpoint.EndpointMeta { return &apiendpoint.EndpointMeta{ Pattern: "GET /api/queues/{name}", @@ -386,6 +410,10 @@ type queueListEndpoint struct { apiendpoint.Endpoint[jobCancelRequest, listResponse[RiverQueue]] } +func newQueueListEndpoint(apiBundle apiBundle) *queueListEndpoint { + return &queueListEndpoint{apiBundle: apiBundle} +} + func (*queueListEndpoint) Meta() *apiendpoint.EndpointMeta { return &apiendpoint.EndpointMeta{ Pattern: "GET /api/queues", @@ -437,6 +465,10 @@ type queuePauseEndpoint struct { apiendpoint.Endpoint[jobCancelRequest, statusResponse] } +func newQueuePauseEndpoint(apiBundle apiBundle) *queuePauseEndpoint { + return &queuePauseEndpoint{apiBundle: apiBundle} +} + func (*queuePauseEndpoint) Meta() *apiendpoint.EndpointMeta { return &apiendpoint.EndpointMeta{ Pattern: "PUT /api/queues/{name}/pause", @@ -475,6 +507,10 @@ type queueResumeEndpoint struct { apiendpoint.Endpoint[jobCancelRequest, statusResponse] } +func newQueueResumeEndpoint(apiBundle apiBundle) *queueResumeEndpoint { + return &queueResumeEndpoint{apiBundle: apiBundle} +} + func (*queueResumeEndpoint) Meta() *apiendpoint.EndpointMeta { return &apiendpoint.EndpointMeta{ Pattern: "PUT /api/queues/{name}/resume", @@ -511,6 +547,17 @@ func (a *queueResumeEndpoint) Execute(ctx context.Context, req *queueResumeReque type stateAndCountGetEndpoint struct { apiBundle apiendpoint.Endpoint[jobCancelRequest, stateAndCountGetResponse] + + queryCacheSkipThreshold int // constant normally, but settable for testing + queryCacher *querycacher.QueryCacher[[]*dbsqlc.JobCountByStateRow] +} + +func newStateAndCountGetEndpoint(apiBundle apiBundle) *stateAndCountGetEndpoint { + return &stateAndCountGetEndpoint{ + apiBundle: apiBundle, + queryCacheSkipThreshold: 1_000_000, + queryCacher: querycacher.NewQueryCacher(apiBundle.archetype, apiBundle.dbPool, dbsqlc.New().JobCountByState), + } } func (*stateAndCountGetEndpoint) Meta() *apiendpoint.EndpointMeta { @@ -520,6 +567,10 @@ func (*stateAndCountGetEndpoint) Meta() *apiendpoint.EndpointMeta { } } +func (a *stateAndCountGetEndpoint) SubServices() []startstop.Service { + return []startstop.Service{a.queryCacher} +} + type stateAndCountGetRequest struct{} type stateAndCountGetResponse struct { @@ -534,12 +585,31 @@ type stateAndCountGetResponse struct { } func (a *stateAndCountGetEndpoint) Execute(ctx context.Context, _ *stateAndCountGetRequest) (*stateAndCountGetResponse, error) { - stateAndCount, err := dbsqlc.New().JobCountByState(ctx, a.dbPool) - if err != nil { - return nil, fmt.Errorf("error getting states and counts: %w", err) + // Counts the total number of jobs in a state and count result. + totalJobs := func(stateAndCountRes []*dbsqlc.JobCountByStateRow) int { + var totalJobs int + for _, stateAndCount := range stateAndCountRes { + totalJobs += int(stateAndCount.Count) + } + return totalJobs + } + + // Counting jobs can be an expensive operation given a large table, so in + // the presence of such, prefer to use a result that's cached periodically + // instead of querying inline with the API request. In case we don't have a + // cached result yet or there's a relatively small number of job rows, run + // the query directly (in the case of the latter so we present the freshest + // possible information). + stateAndCountRes, ok := a.queryCacher.CachedRes() + if !ok || totalJobs(stateAndCountRes) < a.queryCacheSkipThreshold { + var err error + stateAndCountRes, err = dbsqlc.New().JobCountByState(ctx, a.dbPool) + if err != nil { + return nil, fmt.Errorf("error getting states and counts: %w", err) + } } - stateAndCountMap := sliceutil.KeyBy(stateAndCount, func(r *dbsqlc.JobCountByStateRow) (rivertype.JobState, int) { + stateAndCountMap := sliceutil.KeyBy(stateAndCountRes, func(r *dbsqlc.JobCountByStateRow) (rivertype.JobState, int) { return rivertype.JobState(r.State), int(r.Count) }) @@ -564,6 +634,10 @@ type workflowGetEndpoint struct { apiendpoint.Endpoint[jobCancelRequest, workflowGetResponse] } +func newWorkflowGetEndpoint(apiBundle apiBundle) *workflowGetEndpoint { + return &workflowGetEndpoint{apiBundle: apiBundle} +} + func (*workflowGetEndpoint) Meta() *apiendpoint.EndpointMeta { return &apiendpoint.EndpointMeta{ Pattern: "GET /api/workflows/{id}", @@ -612,6 +686,10 @@ type workflowListEndpoint struct { apiendpoint.Endpoint[jobCancelRequest, listResponse[workflowListItem]] } +func newWorkflowListEndpoint(apiBundle apiBundle) *workflowListEndpoint { + return &workflowListEndpoint{apiBundle: apiBundle} +} + func (*workflowListEndpoint) Meta() *apiendpoint.EndpointMeta { return &apiendpoint.EndpointMeta{ Pattern: "GET /api/workflows", diff --git a/api_handler_test.go b/handler_api_endpoint_test.go similarity index 83% rename from api_handler_test.go rename to handler_api_endpoint_test.go index d5b6e6d..ebce95b 100644 --- a/api_handler_test.go +++ b/handler_api_endpoint_test.go @@ -2,6 +2,7 @@ package riverui import ( "context" + "log/slog" "testing" "time" @@ -11,6 +12,8 @@ import ( "github.com/riverqueue/river" "github.com/riverqueue/river/riverdriver" + "github.com/riverqueue/river/rivershared/riversharedtest" + "github.com/riverqueue/river/rivershared/startstop" "github.com/riverqueue/river/rivershared/util/ptrutil" "github.com/riverqueue/river/rivertype" "github.com/riverqueue/riverui/internal/apierror" @@ -21,30 +24,35 @@ import ( type setupEndpointTestBundle struct { client *river.Client[pgx.Tx] exec riverdriver.ExecutorTx + logger *slog.Logger tx pgx.Tx } -func setupEndpoint[TEndpoint any](ctx context.Context, t *testing.T) (*TEndpoint, *setupEndpointTestBundle) { +func setupEndpoint[TEndpoint any](ctx context.Context, t *testing.T, initFunc func(apiBundle apiBundle) *TEndpoint) (*TEndpoint, *setupEndpointTestBundle) { t.Helper() var ( - endpoint TEndpoint logger = riverinternaltest.Logger(t) client, driver = insertOnlyClient(t, logger) tx = riverinternaltest.TestTx(ctx, t) ) - if withSetBundle, ok := any(&endpoint).(withSetBundle); ok { - withSetBundle.SetBundle(&apiBundle{ - client: client, - dbPool: tx, - logger: logger, - }) + endpoint := initFunc(apiBundle{ + archetype: riversharedtest.BaseServiceArchetype(t), + client: client, + dbPool: tx, + logger: logger, + }) + + if service, ok := any(endpoint).(startstop.Service); ok { + require.NoError(t, service.Start(ctx)) + t.Cleanup(service.Stop) } - return &endpoint, &setupEndpointTestBundle{ + return endpoint, &setupEndpointTestBundle{ client: client, exec: driver.UnwrapExecutor(tx), + logger: logger, tx: tx, } } @@ -57,7 +65,7 @@ func TestHandlerHealthCheckGetEndpoint(t *testing.T) { t.Run("CompleteSuccess", func(t *testing.T) { t.Parallel() - endpoint, _ := setupEndpoint[healthCheckGetEndpoint](ctx, t) + endpoint, _ := setupEndpoint(ctx, t, newHealthCheckGetEndpoint) resp, err := endpoint.Execute(ctx, &healthCheckGetRequest{Name: healthCheckNameComplete}) require.NoError(t, err) @@ -67,7 +75,7 @@ func TestHandlerHealthCheckGetEndpoint(t *testing.T) { t.Run("CompleteDatabaseError", func(t *testing.T) { t.Parallel() - endpoint, bundle := setupEndpoint[healthCheckGetEndpoint](ctx, t) + endpoint, bundle := setupEndpoint(ctx, t, newHealthCheckGetEndpoint) // Roll back prematurely so we get a database error. require.NoError(t, bundle.tx.Rollback(ctx)) @@ -82,7 +90,7 @@ func TestHandlerHealthCheckGetEndpoint(t *testing.T) { t.Run("Minimal", func(t *testing.T) { t.Parallel() - endpoint, _ := setupEndpoint[healthCheckGetEndpoint](ctx, t) + endpoint, _ := setupEndpoint(ctx, t, newHealthCheckGetEndpoint) resp, err := endpoint.Execute(ctx, &healthCheckGetRequest{Name: healthCheckNameMinimal}) require.NoError(t, err) @@ -92,7 +100,7 @@ func TestHandlerHealthCheckGetEndpoint(t *testing.T) { t.Run("NotFound", func(t *testing.T) { t.Parallel() - endpoint, _ := setupEndpoint[healthCheckGetEndpoint](ctx, t) + endpoint, _ := setupEndpoint(ctx, t, newHealthCheckGetEndpoint) _, err := endpoint.Execute(ctx, &healthCheckGetRequest{Name: "other"}) requireAPIError(t, apierror.NewNotFound("Health check %q not found. Use either `complete` or `minimal`.", "other"), err) @@ -107,7 +115,7 @@ func TestJobCancelEndpoint(t *testing.T) { t.Run("Success", func(t *testing.T) { t.Parallel() - endpoint, bundle := setupEndpoint[jobCancelEndpoint](ctx, t) + endpoint, bundle := setupEndpoint(ctx, t, newJobCancelEndpoint) job1 := testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{}) job2 := testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{}) @@ -128,7 +136,7 @@ func TestJobCancelEndpoint(t *testing.T) { t.Run("NotFound", func(t *testing.T) { t.Parallel() - endpoint, _ := setupEndpoint[jobCancelEndpoint](ctx, t) + endpoint, _ := setupEndpoint(ctx, t, newJobCancelEndpoint) _, err := endpoint.Execute(ctx, &jobCancelRequest{JobIDs: []int64String{123}}) requireAPIError(t, apierror.NewNotFoundJob(123), err) @@ -143,7 +151,7 @@ func TestJobDeleteEndpoint(t *testing.T) { t.Run("Success", func(t *testing.T) { t.Parallel() - endpoint, bundle := setupEndpoint[jobDeleteEndpoint](ctx, t) + endpoint, bundle := setupEndpoint(ctx, t, newJobDeleteEndpoint) job1 := testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{}) job2 := testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{}) @@ -162,7 +170,7 @@ func TestJobDeleteEndpoint(t *testing.T) { t.Run("NotFound", func(t *testing.T) { t.Parallel() - endpoint, _ := setupEndpoint[jobDeleteEndpoint](ctx, t) + endpoint, _ := setupEndpoint(ctx, t, newJobDeleteEndpoint) _, err := endpoint.Execute(ctx, &jobDeleteRequest{JobIDs: []int64String{123}}) requireAPIError(t, apierror.NewNotFoundJob(123), err) @@ -177,7 +185,7 @@ func TestJobGetEndpoint(t *testing.T) { t.Run("Success", func(t *testing.T) { t.Parallel() - endpoint, bundle := setupEndpoint[jobGetEndpoint](ctx, t) + endpoint, bundle := setupEndpoint(ctx, t, newJobGetEndpoint) job := testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{}) @@ -189,7 +197,7 @@ func TestJobGetEndpoint(t *testing.T) { t.Run("NotFound", func(t *testing.T) { t.Parallel() - endpoint, _ := setupEndpoint[jobGetEndpoint](ctx, t) + endpoint, _ := setupEndpoint(ctx, t, newJobGetEndpoint) _, err := endpoint.Execute(ctx, &jobGetRequest{JobID: 123}) requireAPIError(t, apierror.NewNotFoundJob(123), err) @@ -204,7 +212,7 @@ func TestAPIHandlerJobList(t *testing.T) { t.Run("Success", func(t *testing.T) { t.Parallel() - endpoint, bundle := setupEndpoint[jobListEndpoint](ctx, t) + endpoint, bundle := setupEndpoint(ctx, t, newJobListEndpoint) job1 := testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{State: ptrutil.Ptr(rivertype.JobStateRunning)}) job2 := testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{State: ptrutil.Ptr(rivertype.JobStateRunning)}) @@ -227,7 +235,7 @@ func TestAPIHandlerJobList(t *testing.T) { t.Run("Limit", func(t *testing.T) { t.Parallel() - endpoint, bundle := setupEndpoint[jobListEndpoint](ctx, t) + endpoint, bundle := setupEndpoint(ctx, t, newJobListEndpoint) job1 := testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{State: ptrutil.Ptr(rivertype.JobStateRunning)}) _ = testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{}) @@ -241,7 +249,7 @@ func TestAPIHandlerJobList(t *testing.T) { t.Run("FiltersFinalizedStatesAndOrdersDescending", func(t *testing.T) { t.Parallel() - endpoint, bundle := setupEndpoint[jobListEndpoint](ctx, t) + endpoint, bundle := setupEndpoint(ctx, t, newJobListEndpoint) job1 := testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{State: ptrutil.Ptr(rivertype.JobStateCompleted), FinalizedAt: ptrutil.Ptr(time.Now())}) job2 := testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{State: ptrutil.Ptr(rivertype.JobStateCompleted), FinalizedAt: ptrutil.Ptr(time.Now())}) @@ -259,7 +267,7 @@ func TestAPIHandlerJobList(t *testing.T) { t.Run("FiltersNonFinalizedStates", func(t *testing.T) { t.Parallel() - endpoint, bundle := setupEndpoint[jobListEndpoint](ctx, t) + endpoint, bundle := setupEndpoint(ctx, t, newJobListEndpoint) job1 := testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{}) job2 := testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{}) @@ -283,7 +291,7 @@ func TestJobRetryEndpoint(t *testing.T) { t.Run("Success", func(t *testing.T) { t.Parallel() - endpoint, bundle := setupEndpoint[jobRetryEndpoint](ctx, t) + endpoint, bundle := setupEndpoint(ctx, t, newJobRetryEndpoint) job1 := testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{ FinalizedAt: ptrutil.Ptr(time.Now()), @@ -310,7 +318,7 @@ func TestJobRetryEndpoint(t *testing.T) { t.Run("NotFound", func(t *testing.T) { t.Parallel() - endpoint, _ := setupEndpoint[jobRetryEndpoint](ctx, t) + endpoint, _ := setupEndpoint(ctx, t, newJobRetryEndpoint) _, err := endpoint.Execute(ctx, &jobRetryRequest{JobIDs: []int64String{123}}) requireAPIError(t, apierror.NewNotFoundJob(123), err) @@ -325,7 +333,7 @@ func TestAPIHandlerQueueGet(t *testing.T) { t.Run("Success", func(t *testing.T) { t.Parallel() - endpoint, bundle := setupEndpoint[queueGetEndpoint](ctx, t) + endpoint, bundle := setupEndpoint(ctx, t, newQueueGetEndpoint) queue := testfactory.Queue(ctx, t, bundle.exec, nil) @@ -341,7 +349,7 @@ func TestAPIHandlerQueueGet(t *testing.T) { t.Run("NotFound", func(t *testing.T) { t.Parallel() - endpoint, _ := setupEndpoint[queueGetEndpoint](ctx, t) + endpoint, _ := setupEndpoint(ctx, t, newQueueGetEndpoint) _, err := endpoint.Execute(ctx, &queueGetRequest{Name: "does_not_exist"}) requireAPIError(t, apierror.NewNotFoundQueue("does_not_exist"), err) @@ -356,7 +364,7 @@ func TestAPIHandlerQueueList(t *testing.T) { t.Run("Success", func(t *testing.T) { t.Parallel() - endpoint, bundle := setupEndpoint[queueListEndpoint](ctx, t) + endpoint, bundle := setupEndpoint(ctx, t, newQueueListEndpoint) queue1 := testfactory.Queue(ctx, t, bundle.exec, nil) queue2 := testfactory.Queue(ctx, t, bundle.exec, nil) @@ -376,7 +384,7 @@ func TestAPIHandlerQueueList(t *testing.T) { t.Run("Limit", func(t *testing.T) { t.Parallel() - endpoint, bundle := setupEndpoint[queueListEndpoint](ctx, t) + endpoint, bundle := setupEndpoint(ctx, t, newQueueListEndpoint) queue1 := testfactory.Queue(ctx, t, bundle.exec, nil) _ = testfactory.Queue(ctx, t, bundle.exec, nil) @@ -396,7 +404,7 @@ func TestAPIHandlerQueuePause(t *testing.T) { t.Run("Success", func(t *testing.T) { t.Parallel() - endpoint, bundle := setupEndpoint[queuePauseEndpoint](ctx, t) + endpoint, bundle := setupEndpoint(ctx, t, newQueuePauseEndpoint) queue := testfactory.Queue(ctx, t, bundle.exec, nil) @@ -408,7 +416,7 @@ func TestAPIHandlerQueuePause(t *testing.T) { t.Run("NotFound", func(t *testing.T) { t.Parallel() - endpoint, _ := setupEndpoint[queuePauseEndpoint](ctx, t) + endpoint, _ := setupEndpoint(ctx, t, newQueuePauseEndpoint) _, err := endpoint.Execute(ctx, &queuePauseRequest{Name: "does_not_exist"}) requireAPIError(t, apierror.NewNotFoundQueue("does_not_exist"), err) @@ -423,7 +431,7 @@ func TestAPIHandlerQueueResume(t *testing.T) { t.Run("Success", func(t *testing.T) { t.Parallel() - endpoint, bundle := setupEndpoint[queueResumeEndpoint](ctx, t) + endpoint, bundle := setupEndpoint(ctx, t, newQueueResumeEndpoint) queue := testfactory.Queue(ctx, t, bundle.exec, &testfactory.QueueOpts{ PausedAt: ptrutil.Ptr(time.Now()), @@ -437,7 +445,7 @@ func TestAPIHandlerQueueResume(t *testing.T) { t.Run("NotFound", func(t *testing.T) { t.Parallel() - endpoint, _ := setupEndpoint[queueResumeEndpoint](ctx, t) + endpoint, _ := setupEndpoint(ctx, t, newQueueResumeEndpoint) _, err := endpoint.Execute(ctx, &queueResumeRequest{Name: "does_not_exist"}) requireAPIError(t, apierror.NewNotFoundQueue("does_not_exist"), err) @@ -452,7 +460,7 @@ func TestStateAndCountGetEndpoint(t *testing.T) { t.Run("Success", func(t *testing.T) { t.Parallel() - endpoint, bundle := setupEndpoint[stateAndCountGetEndpoint](ctx, t) + endpoint, bundle := setupEndpoint(ctx, t, newStateAndCountGetEndpoint) _ = testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{State: ptrutil.Ptr(rivertype.JobStateAvailable)}) @@ -497,6 +505,46 @@ func TestStateAndCountGetEndpoint(t *testing.T) { Scheduled: 8, }, resp) }) + + t.Run("WithCachedQueryAboveSkipThreshold", func(t *testing.T) { + t.Parallel() + + endpoint, bundle := setupEndpoint(ctx, t, newStateAndCountGetEndpoint) + + const queryCacheSkipThreshold = 3 + for range queryCacheSkipThreshold + 1 { + _ = testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{State: ptrutil.Ptr(rivertype.JobStateAvailable)}) + } + + _, err := endpoint.queryCacher.RunQuery(ctx) + require.NoError(t, err) + + resp, err := endpoint.Execute(ctx, &stateAndCountGetRequest{}) + require.NoError(t, err) + require.Equal(t, &stateAndCountGetResponse{ + Available: queryCacheSkipThreshold + 1, + }, resp) + }) + + t.Run("WithCachedQueryBelowSkipThreshold", func(t *testing.T) { + t.Parallel() + + endpoint, bundle := setupEndpoint(ctx, t, newStateAndCountGetEndpoint) + + const queryCacheSkipThreshold = 3 + for range queryCacheSkipThreshold - 1 { + _ = testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{State: ptrutil.Ptr(rivertype.JobStateAvailable)}) + } + + _, err := endpoint.queryCacher.RunQuery(ctx) + require.NoError(t, err) + + resp, err := endpoint.Execute(ctx, &stateAndCountGetRequest{}) + require.NoError(t, err) + require.Equal(t, &stateAndCountGetResponse{ + Available: queryCacheSkipThreshold - 1, + }, resp) + }) } func TestAPIHandlerWorkflowGet(t *testing.T) { @@ -507,7 +555,7 @@ func TestAPIHandlerWorkflowGet(t *testing.T) { t.Run("Success", func(t *testing.T) { t.Parallel() - endpoint, bundle := setupEndpoint[workflowGetEndpoint](ctx, t) + endpoint, bundle := setupEndpoint(ctx, t, newWorkflowGetEndpoint) workflowID := uuid.New() job1 := testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{Metadata: mustMarshalJSON(t, map[string]uuid.UUID{"workflow_id": workflowID})}) @@ -523,7 +571,7 @@ func TestAPIHandlerWorkflowGet(t *testing.T) { t.Run("NotFound", func(t *testing.T) { t.Parallel() - endpoint, _ := setupEndpoint[workflowGetEndpoint](ctx, t) + endpoint, _ := setupEndpoint(ctx, t, newWorkflowGetEndpoint) workflowID := uuid.New() @@ -540,7 +588,7 @@ func TestAPIHandlerWorkflowList(t *testing.T) { t.Run("Success", func(t *testing.T) { t.Parallel() - endpoint, bundle := setupEndpoint[workflowListEndpoint](ctx, t) + endpoint, bundle := setupEndpoint(ctx, t, newWorkflowListEndpoint) job1 := testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{ Metadata: []byte(`{"workflow_id":"1", "workflow_name":"first_wf", "task":"a"}`), @@ -610,7 +658,7 @@ func TestAPIHandlerWorkflowList(t *testing.T) { t.Run("Limit", func(t *testing.T) { t.Parallel() - endpoint, bundle := setupEndpoint[workflowListEndpoint](ctx, t) + endpoint, bundle := setupEndpoint(ctx, t, newWorkflowListEndpoint) _ = testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{ Metadata: []byte(`{"workflow_id":"1", "workflow_name":"first_wf", "task":"a"}`), diff --git a/handler_test.go b/handler_test.go index 303e877..0f7cb74 100644 --- a/handler_test.go +++ b/handler_test.go @@ -45,7 +45,7 @@ func TestNewHandlerIntegration(t *testing.T) { require.NoError(t, err) t.Cleanup(func() { tx.Rollback(ctx) }) - handler, err := NewHandler(&HandlerOpts{ + server, err := NewServer(&HandlerOpts{ Client: client, DBPool: tx, Logger: logger, @@ -62,7 +62,7 @@ func TestNewHandlerIntegration(t *testing.T) { t.Logf("--> %s %s", method, path) - handler.ServeHTTP(recorder, req) + server.Handler().ServeHTTP(recorder, req) status := recorder.Result().StatusCode //nolint:bodyclose t.Logf("Response status: %d", status) diff --git a/internal/apiendpoint/api_endpoint.go b/internal/apiendpoint/api_endpoint.go index 8f7c753..f440d6d 100644 --- a/internal/apiendpoint/api_endpoint.go +++ b/internal/apiendpoint/api_endpoint.go @@ -33,15 +33,7 @@ type Endpoint[TReq any, TResp any] struct { func (e *Endpoint[TReq, TResp]) SetLogger(logger *slog.Logger) { e.logger = logger } func (e *Endpoint[TReq, TResp]) SetMeta(meta *EndpointMeta) { e.meta = meta } -// EndpointInterface is an interface to an API endpoint. Some of it is -// implemented by an embedded Endpoint struct, and some of it should be -// implemented by the endpoint itself. -type EndpointInterface[TReq any, TResp any] interface { - // Execute executes the API endpoint. - // - // This should be implemented by each specific API endpoint. - Execute(ctx context.Context, req *TReq) (*TResp, error) - +type EndpointInterface interface { // Meta returns metadata about an API endpoint, like the path it should be // mounted at, and the status code it returns on success. // @@ -60,6 +52,18 @@ type EndpointInterface[TReq any, TResp any] interface { SetMeta(meta *EndpointMeta) } +// EndpointExecuteInterface is an interface to an API endpoint. Some of it is +// implemented by an embedded Endpoint struct, and some of it should be +// implemented by the endpoint itself. +type EndpointExecuteInterface[TReq any, TResp any] interface { + EndpointInterface + + // Execute executes the API endpoint. + // + // This should be implemented by each specific API endpoint. + Execute(ctx context.Context, req *TReq) (*TResp, error) +} + // EndpointMeta is metadata about an API endpoint. type EndpointMeta struct { // Pattern is the API endpoint's HTTP method and path where it should be @@ -84,7 +88,7 @@ func (m *EndpointMeta) validate() { // Mount mounts an endpoint to a Go http.ServeMux. The logger is used to log // information about endpoint execution. -func Mount[TReq any, TResp any](mux *http.ServeMux, logger *slog.Logger, apiEndpoint EndpointInterface[TReq, TResp]) { +func Mount[TReq any, TResp any](mux *http.ServeMux, logger *slog.Logger, apiEndpoint EndpointExecuteInterface[TReq, TResp]) EndpointInterface { apiEndpoint.SetLogger(logger) meta := apiEndpoint.Meta() @@ -94,10 +98,12 @@ func Mount[TReq any, TResp any](mux *http.ServeMux, logger *slog.Logger, apiEndp mux.HandleFunc(meta.Pattern, func(w http.ResponseWriter, r *http.Request) { executeAPIEndpoint(w, r, logger, meta, apiEndpoint.Execute) }) + + return apiEndpoint } func executeAPIEndpoint[TReq any, TResp any](w http.ResponseWriter, r *http.Request, logger *slog.Logger, meta *EndpointMeta, execute func(ctx context.Context, req *TReq) (*TResp, error)) { - ctx, cancel := context.WithTimeout(r.Context(), 5*time.Second) + ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second) defer cancel() // Run as much code as we can in a sub-function that can return an error. diff --git a/internal/dbsqlc/query.sql b/internal/dbsqlc/query.sql index bba9af3..9392812 100644 --- a/internal/dbsqlc/query.sql +++ b/internal/dbsqlc/query.sql @@ -1,14 +1,6 @@ -- name: JobCountByState :many -SELECT - state, - count(*) -FROM - river_job -WHERE - queue IS NOT NULL AND - priority > 0 AND - scheduled_at IS NOT NULL AND - id IS NOT NULL +SELECT state, count(*) +FROM river_job GROUP BY state; -- name: JobListWorkflow :many diff --git a/internal/dbsqlc/query.sql.go b/internal/dbsqlc/query.sql.go index 7349a23..a63a459 100644 --- a/internal/dbsqlc/query.sql.go +++ b/internal/dbsqlc/query.sql.go @@ -77,16 +77,8 @@ func (q *Queries) JobCountByQueueAndState(ctx context.Context, db DBTX, queueNam } const jobCountByState = `-- name: JobCountByState :many -SELECT - state, - count(*) -FROM - river_job -WHERE - queue IS NOT NULL AND - priority > 0 AND - scheduled_at IS NOT NULL AND - id IS NOT NULL +SELECT state, count(*) +FROM river_job GROUP BY state ` diff --git a/internal/querycacher/query_cacher.go b/internal/querycacher/query_cacher.go new file mode 100644 index 0000000..812aaba --- /dev/null +++ b/internal/querycacher/query_cacher.go @@ -0,0 +1,144 @@ +package querycacher + +import ( + "context" + "regexp" + "sync" + "time" + + "github.com/riverqueue/river/rivershared/baseservice" + "github.com/riverqueue/river/rivershared/startstop" + "github.com/riverqueue/riverui/internal/dbsqlc" +) + +// QueryCacher executes a database query periodically and caches the result. The +// basic premise is that given large River databases even simple queries like +// counting job rows can get quite slow. This construct operates as a background +// service that runs a query, stores the result, making it available to API +// endpoints so that they don't need to execute the query in band. +type QueryCacher[TRes any] struct { + baseservice.BaseService + startstop.BaseStartStop + + cachedRes TRes + cachedResSet bool + db dbsqlc.DBTX + mu sync.RWMutex + runQuery func(ctx context.Context, dbtx dbsqlc.DBTX) (TRes, error) + runQueryTestChan chan struct{} // closed when query is run; for testing + tickPeriod time.Duration // constant normally, but settable for testing +} + +func NewQueryCacher[TRes any](archetype *baseservice.Archetype, db dbsqlc.DBTX, runQuery func(ctx context.Context, db dbsqlc.DBTX) (TRes, error)) *QueryCacher[TRes] { + // +/- 1s random variance to ticker interval. Makes sure that given multiple + // query caches running simultaneously, they all start and are scheduled a + // little differently to make a thundering herd problem less likely. + randomTickVariance := time.Duration(archetype.Rand.Float64()*float64(2*time.Second)) - 1*time.Second + + queryCacher := baseservice.Init(archetype, &QueryCacher[TRes]{ + db: db, + runQuery: runQuery, + tickPeriod: 10*time.Second + randomTickVariance, + }) + + // TODO(brandur): Push this up into baseservice. + queryCacher.Name = simplifyArchetypeLogName(queryCacher.Name) + + return queryCacher +} + +// CachedRes returns cached results, if there are any, and an "ok" boolean +// indicating whether cached results were available (true if so, and false +// otherwise). +func (s *QueryCacher[TRes]) CachedRes() (TRes, bool) { + s.mu.RLock() + defer s.mu.RUnlock() + + if !s.cachedResSet { + var emptyRes TRes + return emptyRes, false + } + + return s.cachedRes, true +} + +// RunQuery runs the internal query function and caches the result. It's not +// usually necessary to call this function explicitly since Start will do it +// periodically, but is made available for use in places like helping with +// testing. +func (s *QueryCacher[TRes]) RunQuery(ctx context.Context) (TRes, error) { + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + + start := time.Now() + + res, err := s.runQuery(ctx, s.db) + if err != nil { + var emptyRes TRes + return emptyRes, err + } + + s.Logger.DebugContext(ctx, s.Name+": Ran query and cached result", "duration", time.Since(start), "tick_period", s.tickPeriod) + + s.mu.Lock() + s.cachedRes = res + s.cachedResSet = true + s.mu.Unlock() + + // Tells a test that it can wake up and handle a result. + if s.runQueryTestChan != nil { + close(s.runQueryTestChan) + s.runQueryTestChan = nil + } + + return res, nil +} + +// Start starts the service, causing it to periodically run its query and cache +// the result. It stops when Stop is called or if its context is cancelled. +func (s *QueryCacher[TRes]) Start(ctx context.Context) error { + ctx, shouldStart, started, stopped := s.StartInit(ctx) + if !shouldStart { + return nil + } + + go func() { + started() + defer stopped() + + // In case a query runs long and exceeds tickPeriod, time.Ticker will + // drop ticks to compensate. + ticker := time.NewTicker(s.tickPeriod) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + + case <-ticker.C: + if _, err := s.RunQuery(ctx); err != nil { + s.Logger.ErrorContext(ctx, s.Name+": Error running query", "err", err) + } + } + } + }() + + return nil +} + +// Simplifies the name of a Go type that uses generics for cleaner logging output. +// +// So this: +// +// QueryCacher[[]*github.com/riverqueue/riverui/internal/dbsqlc.JobCountByStateRow] +// +// Becomes this: +// +// QueryCacher[[]*dbsqlc.JobCountByStateRow] +// +// TODO(brandur): Push this up into baseservice. +func simplifyArchetypeLogName(name string) string { + re := regexp.MustCompile(`\[([\[\]\*]*).*/([^/]+)\]`) + return re.ReplaceAllString(name, `[$1$2]`) +} diff --git a/internal/querycacher/query_cacher_test.go b/internal/querycacher/query_cacher_test.go new file mode 100644 index 0000000..d7459df --- /dev/null +++ b/internal/querycacher/query_cacher_test.go @@ -0,0 +1,125 @@ +package querycacher + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/riverqueue/river/riverdriver" + "github.com/riverqueue/river/riverdriver/riverpgxv5" + "github.com/riverqueue/river/rivershared/riversharedtest" + "github.com/riverqueue/river/rivershared/startstoptest" + "github.com/riverqueue/riverui/internal/dbsqlc" + "github.com/riverqueue/riverui/internal/riverinternaltest" + "github.com/riverqueue/riverui/internal/riverinternaltest/testfactory" +) + +func TestQueryCacher(t *testing.T) { + t.Parallel() + + ctx := context.Background() + + type testBundle struct { + exec riverdriver.ExecutorTx + } + + setup := func(ctx context.Context, t *testing.T) (*QueryCacher[[]*dbsqlc.JobCountByStateRow], *testBundle) { + t.Helper() + + var ( + archetype = riversharedtest.BaseServiceArchetype(t) + driver = riverpgxv5.New(nil) + tx = riverinternaltest.TestTx(ctx, t) + queryCacher = NewQueryCacher(archetype, tx, dbsqlc.New().JobCountByState) + ) + + return queryCacher, &testBundle{ + exec: driver.UnwrapExecutor(tx), + } + } + + start := func(ctx context.Context, t *testing.T, queryCacher *QueryCacher[[]*dbsqlc.JobCountByStateRow]) { + t.Helper() + + require.NoError(t, queryCacher.Start(ctx)) + t.Cleanup(queryCacher.Stop) + } + + t.Run("NoCachedResult", func(t *testing.T) { + t.Parallel() + + queryCacher, _ := setup(ctx, t) + + _, ok := queryCacher.CachedRes() + require.False(t, ok) + }) + + t.Run("WithCachedResult", func(t *testing.T) { + t.Parallel() + + queryCacher, bundle := setup(ctx, t) + + _ = testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{}) + + _, err := queryCacher.RunQuery(ctx) + require.NoError(t, err) + + res, ok := queryCacher.CachedRes() + require.True(t, ok) + require.Equal(t, []*dbsqlc.JobCountByStateRow{ + {State: dbsqlc.RiverJobStateAvailable, Count: 1}, + }, res) + }) + + t.Run("RunsPeriodically", func(t *testing.T) { + t.Parallel() + + queryCacher, bundle := setup(ctx, t) + + _ = testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{}) + + runQueryTestChan := make(chan struct{}) + queryCacher.runQueryTestChan = runQueryTestChan + + // Dramatically reduce tick period so we don't have to wait the full time. + queryCacher.tickPeriod = 1 * time.Millisecond + + start(ctx, t, queryCacher) + + riversharedtest.WaitOrTimeout(t, runQueryTestChan) + + res, ok := queryCacher.CachedRes() + require.True(t, ok) + require.Equal(t, []*dbsqlc.JobCountByStateRow{ + {State: dbsqlc.RiverJobStateAvailable, Count: 1}, + }, res) + }) + + t.Run("StartStopStress", func(t *testing.T) { + t.Parallel() + + queryCacher, _ := setup(ctx, t) + startstoptest.Stress(ctx, t, queryCacher) + }) +} + +func TestSimplifyArchetypeLogName(t *testing.T) { + t.Parallel() + + require.Equal(t, "NotGeneric", simplifyArchetypeLogName("NotGeneric")) + + // Simplified for use during debugging. Real generics will tend to have + // fully qualified paths and not look like this. + require.Equal(t, "Simple[int]", simplifyArchetypeLogName("Simple[int]")) + require.Equal(t, "Simple[*int]", simplifyArchetypeLogName("Simple[*int]")) + require.Equal(t, "Simple[[]int]", simplifyArchetypeLogName("Simple[[]int]")) + require.Equal(t, "Simple[[]*int]", simplifyArchetypeLogName("Simple[[]*int]")) + + // More realistic examples. + require.Equal(t, "QueryCacher[dbsqlc.JobCountByStateRow]", simplifyArchetypeLogName("QueryCacher[github.com/riverqueue/riverui/internal/dbsqlc.JobCountByStateRow]")) + require.Equal(t, "QueryCacher[*dbsqlc.JobCountByStateRow]", simplifyArchetypeLogName("QueryCacher[*github.com/riverqueue/riverui/internal/dbsqlc.JobCountByStateRow]")) + require.Equal(t, "QueryCacher[[]dbsqlc.JobCountByStateRow]", simplifyArchetypeLogName("QueryCacher[[]github.com/riverqueue/riverui/internal/dbsqlc.JobCountByStateRow]")) + require.Equal(t, "QueryCacher[[]*dbsqlc.JobCountByStateRow]", simplifyArchetypeLogName("QueryCacher[[]*github.com/riverqueue/riverui/internal/dbsqlc.JobCountByStateRow]")) +}