diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index a9a2f897..1f4c59b2 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -132,6 +132,19 @@ jobs: run: ./river validate --database-url $DATABASE_URL working-directory: ./cmd/river + - name: river bench + run: | + ( sleep 10 && killall -SIGTERM river ) & + ./river bench --database-url $DATABASE_URL + working-directory: ./cmd/river + + # Bench again in fixed number of jobs mode. + - name: river bench + run: | + ( sleep 10 && killall -SIGTERM river ) & + ./river bench --database-url $DATABASE_URL --num-total-jobs 1_234 + working-directory: ./cmd/river + - name: river migrate-down run: ./river migrate-down --database-url $DATABASE_URL --max-steps 100 working-directory: ./cmd/river diff --git a/CHANGELOG.md b/CHANGELOG.md index cee7678f..c01cd288 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Added + +- The River CLI now supports `river bench` to benchmark River's job throughput against a database. [PR #254](https://github.com/riverqueue/river/pull/254). + ### Changed - Changed default client IDs to be a combination of hostname and the time which the client started. This can still be changed by specifying `Config.ID`. [PR #255](https://github.com/riverqueue/river/pull/255). diff --git a/cmd/river/go.mod b/cmd/river/go.mod index efea3d74..061eece0 100644 --- a/cmd/river/go.mod +++ b/cmd/river/go.mod @@ -13,6 +13,7 @@ go 1.21.4 require ( github.com/jackc/pgx/v5 v5.5.2 github.com/riverqueue/river v0.0.17 + github.com/riverqueue/river/riverdriver v0.0.17 github.com/riverqueue/river/riverdriver/riverpgxv5 v0.0.17 github.com/spf13/cobra v1.8.0 ) @@ -22,9 +23,10 @@ require ( github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect github.com/jackc/puddle/v2 v2.2.1 // indirect - github.com/riverqueue/river/riverdriver v0.0.17 // indirect + github.com/oklog/ulid/v2 v2.1.0 // indirect github.com/spf13/pflag v1.0.5 // indirect golang.org/x/crypto v0.17.0 // indirect + golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 // indirect golang.org/x/sync v0.6.0 // indirect golang.org/x/text v0.14.0 // indirect ) diff --git a/cmd/river/go.sum b/cmd/river/go.sum index 9298993d..a570f9b1 100644 --- a/cmd/river/go.sum +++ b/cmd/river/go.sum @@ -16,6 +16,9 @@ github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/oklog/ulid/v2 v2.1.0 h1:+9lhoxAP56we25tyYETBBY1YLA2SaoLvUFgrP2miPJU= +github.com/oklog/ulid/v2 v2.1.0/go.mod h1:rcEKHmBBKfef9DhnvX7y1HZBYxjXb0cP5ExxNsTT1QQ= +github.com/pborman/getopt v0.0.0-20170112200414-7148bc3a4c30/go.mod h1:85jBQOZwpVEaDAr341tbn15RS4fCAsIst0qp7i8ex1o= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/riverqueue/river v0.0.17 h1:7beHZxo1WMzhN48y1Jt7CKkkmsw+TjuLd6qCEaznm7s= @@ -26,6 +29,8 @@ github.com/riverqueue/river/riverdriver/riverdatabasesql v0.0.17 h1:xPmTpQNBicTZ github.com/riverqueue/river/riverdriver/riverdatabasesql v0.0.17/go.mod h1:zlZKXZ6XHcbwYniSKWX2+GwFlXHTnG9pJtE/BkxK0Xc= github.com/riverqueue/river/riverdriver/riverpgxv5 v0.0.17 h1:iuruCNT7nkC7Z4Qzb79jcvAVniGyK+Kstsy7fKJagUU= github.com/riverqueue/river/riverdriver/riverpgxv5 v0.0.17/go.mod h1:kL59NW3LoPbQxPz9DQoUtDYq3Zkcpdt5CIowgeBZwFw= +github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs= +github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/spf13/cobra v1.8.0 h1:7aJaZx1B85qltLMc546zn58BxxfZdR/W22ej9CFoEf0= github.com/spf13/cobra v1.8.0/go.mod h1:WXLWApfZ71AjXPya3WOlMsY9yMs7YeiHhFVlvLyhcho= diff --git a/cmd/river/main.go b/cmd/river/main.go index 9b2a5308..2aa0ae27 100644 --- a/cmd/river/main.go +++ b/cmd/river/main.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "log/slog" "os" "strconv" "time" @@ -11,6 +12,7 @@ import ( "github.com/jackc/pgx/v5/pgxpool" "github.com/spf13/cobra" + "github.com/riverqueue/river/cmd/river/riverbench" "github.com/riverqueue/river/riverdriver/riverpgxv5" "github.com/riverqueue/river/rivermigrate" ) @@ -39,7 +41,7 @@ Provides command line facilities for the River job queue. } } - mustMarkFlagRequired := func(cmd *cobra.Command, name string) { + mustMarkFlagRequired := func(cmd *cobra.Command, name string) { //nolint:unparam // We just panic here because this will never happen outside of an error // in development. if err := cmd.MarkFlagRequired(name); err != nil { @@ -47,6 +49,38 @@ Provides command line facilities for the River job queue. } } + // bench + { + var opts benchOpts + + cmd := &cobra.Command{ + Use: "bench", + Short: "Run River benchmark", + Long: ` +Run a River benchmark which inserts and works jobs continually, giving a rough +idea of jobs per second and time to work a single job. + +By default, the benchmark will continuously insert and work jobs in perpetuity +until interrupted by SIGINT (Ctrl^C). It can alternatively take a maximum run +duration with --duration, which takes a Go-style duration string like 1m. +Lastly, it can take --num-total-jobs, which inserts the given number of jobs +before starting the client, and works until all jobs are finished. + +The database in --database-url will have its jobs table truncated, so make sure +to use a development database only. + `, + Run: func(cmd *cobra.Command, args []string) { + execHandlingError(func() (bool, error) { return bench(ctx, &opts) }) + }, + } + cmd.Flags().StringVar(&opts.DatabaseURL, "database-url", "", "URL of the database to benchmark (should look like `postgres://...`") + cmd.Flags().DurationVar(&opts.Duration, "duration", 0, "duration after which to stop benchmark, accepting Go-style durations like 1m, 5m30s") + cmd.Flags().IntVarP(&opts.NumTotalJobs, "num-total-jobs", "n", 0, "number of jobs to insert before starting and which are worked down until finish") + cmd.Flags().BoolVarP(&opts.Verbose, "verbose", "v", false, "output additional logging verbosity") + mustMarkFlagRequired(cmd, "database-url") + rootCmd.AddCommand(cmd) + } + // migrate-down { var opts migrateDownOpts @@ -65,8 +99,8 @@ Defaults to running a single down migration. This behavior can be changed with }, } cmd.Flags().StringVar(&opts.DatabaseURL, "database-url", "", "URL of the database to migrate (should look like `postgres://...`") - cmd.Flags().IntVar(&opts.MaxSteps, "max-steps", 1, "Maximum number of steps to migrate") - cmd.Flags().IntVar(&opts.TargetVersion, "target-version", 0, "Target version to migrate to (final state includes this version, but none after it)") + cmd.Flags().IntVar(&opts.MaxSteps, "max-steps", 1, "maximum number of steps to migrate") + cmd.Flags().IntVar(&opts.TargetVersion, "target-version", 0, "target version to migrate to (final state includes this version, but none after it)") mustMarkFlagRequired(cmd, "database-url") rootCmd.AddCommand(cmd) } @@ -89,8 +123,8 @@ restricted with --max-steps or --target-version. }, } cmd.Flags().StringVar(&opts.DatabaseURL, "database-url", "", "URL of the database to migrate (should look like `postgres://...`") - cmd.Flags().IntVar(&opts.MaxSteps, "max-steps", 0, "Maximum number of steps to migrate") - cmd.Flags().IntVar(&opts.TargetVersion, "target-version", 0, "Target version to migrate to (final state includes this version)") + cmd.Flags().IntVar(&opts.MaxSteps, "max-steps", 0, "maximum number of steps to migrate") + cmd.Flags().IntVar(&opts.TargetVersion, "target-version", 0, "target version to migrate to (final state includes this version)") mustMarkFlagRequired(cmd, "database-url") rootCmd.AddCommand(cmd) } @@ -151,6 +185,48 @@ func setParamIfUnset(runtimeParams map[string]string, name, val string) { runtimeParams[name] = val } +type benchOpts struct { + DatabaseURL string + Duration time.Duration + NumTotalJobs int + Verbose bool +} + +func (o *benchOpts) validate() error { + if o.DatabaseURL == "" { + return errors.New("database URL cannot be empty") + } + + return nil +} + +func bench(ctx context.Context, opts *benchOpts) (bool, error) { + if err := opts.validate(); err != nil { + return false, err + } + + dbPool, err := openDBPool(ctx, opts.DatabaseURL) + if err != nil { + return false, err + } + defer dbPool.Close() + + var logger *slog.Logger + if opts.Verbose { + logger = slog.New(slog.NewTextHandler(os.Stdout, nil)) + } else { + logger = slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelWarn})) + } + + benchmarker := riverbench.NewBenchmarker(riverpgxv5.New(dbPool), logger, opts.Duration, opts.NumTotalJobs) + + if err := benchmarker.Run(ctx); err != nil { + return false, err + } + + return true, nil +} + type migrateDownOpts struct { DatabaseURL string MaxSteps int diff --git a/cmd/river/riverbench/river_bench.go b/cmd/river/riverbench/river_bench.go new file mode 100644 index 00000000..71750185 --- /dev/null +++ b/cmd/river/riverbench/river_bench.go @@ -0,0 +1,424 @@ +package riverbench + +import ( + "context" + "errors" + "fmt" + "log/slog" + "os" + "os/signal" + "sync/atomic" + "syscall" + "time" + + "github.com/riverqueue/river" + "github.com/riverqueue/river/riverdriver" +) + +type Benchmarker[TTx any] struct { + driver riverdriver.Driver[TTx] + duration time.Duration + logger *slog.Logger + name string + numTotalJobs int +} + +func NewBenchmarker[TTx any](driver riverdriver.Driver[TTx], logger *slog.Logger, duration time.Duration, numTotalJobs int) *Benchmarker[TTx] { + return &Benchmarker[TTx]{ + driver: driver, + duration: duration, + logger: logger, + name: "Benchmarker", + numTotalJobs: numTotalJobs, + } +} + +// Run starts the benchmarking loop. Stops upon receiving SIGINT/SIGTERM, or +// when reaching maximum configured run duration. +func (b *Benchmarker[TTx]) Run(ctx context.Context) error { + var ( + numJobsInserted atomic.Int64 + numJobsLeft atomic.Int64 + numJobsWorked atomic.Int64 + shutdown = make(chan struct{}) + shutdownClosed bool + ) + + // Prevents double-close on shutdown channel. + closeShutdown := func() { + if !shutdownClosed { + b.logger.InfoContext(ctx, "Closing shutdown channel") + close(shutdown) + } + shutdownClosed = true + } + + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + // Installing signals allows us to try and stop the client cleanly, and also + // to produce a final summary log line for th whole bench run (by default, + // Go will terminate programs abruptly and not even defers will run). + go func() { + signalChan := make(chan os.Signal, 1) + signal.Notify(signalChan, syscall.SIGINT, syscall.SIGTERM) + + select { + case <-ctx.Done(): + case <-signalChan: + closeShutdown() + + // Wait again since the client may take an absurd amount of time to + // shut down. If we receive another signal in the intervening + // period, cancel context, thereby forcing a hard shut down. + select { + case <-ctx.Done(): + case <-signalChan: + fmt.Printf("second signal received; canceling context\n") + cancel() + } + } + }() + + if err := b.resetJobsTable(ctx); err != nil { + return err + } + + workers := river.NewWorkers() + river.AddWorker(workers, &BenchmarkWorker{}) + + client, err := river.NewClient(b.driver, &river.Config{ + Logger: slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelWarn})), + Queues: map[string]river.QueueConfig{ + river.QueueDefault: {MaxWorkers: river.QueueNumWorkersMax}, + }, + Workers: workers, + }) + if err != nil { + return err + } + + // Notably, we use a subscribe channel to track how many jobs have been + // worked instead of using telemetry from the worker itself because the + // subscribe channel accounts for the job moving through the completer while + // the worker does not. + subscribeChan, subscribeCancel := client.Subscribe( + river.EventKindJobCancelled, + river.EventKindJobCompleted, + river.EventKindJobFailed, + ) + defer subscribeCancel() + + go func() { + for { + select { + case <-ctx.Done(): + return + + case <-shutdown: + return + + case <-subscribeChan: + numJobsLeft.Add(-1) + numJobsWorked := numJobsWorked.Add(1) + + const logBatchSize = 5_000 + if numJobsWorked%logBatchSize == 0 { + b.logger.InfoContext(ctx, b.name+": Worked job batch", "num_worked", logBatchSize) + } + } + } + }() + + minJobsReady := make(chan struct{}) + + if b.numTotalJobs != 0 { + b.insertJobs(ctx, client, minJobsReady, &numJobsInserted, &numJobsLeft, shutdown) + } else { + insertJobsFinished := make(chan struct{}) + defer func() { <-insertJobsFinished }() + + go func() { + defer close(insertJobsFinished) + b.insertJobsContinuously(ctx, client, minJobsReady, &numJobsInserted, &numJobsLeft, shutdown) + }() + } + + // Must appear after we wait for insert jobs to finish before so that the + // defers run in the right order. + defer closeShutdown() + + // Don't start measuring until the first batch of jobs is confirmed ready. + select { + case <-ctx.Done(): + return ctx.Err() + case <-minJobsReady: + // okay + case <-shutdown: + return nil + case <-time.After(5 * time.Second): + return errors.New("timed out waiting for minimum starting jobs to be inserted") + } + + b.logger.InfoContext(ctx, b.name+": Minimum jobs inserted; starting iteration") + + b.logger.InfoContext(ctx, b.name+": Client starting") + if err := client.Start(ctx); err != nil { + return err + } + + defer func() { + b.logger.InfoContext(ctx, b.name+": Client stopping") + if err := client.Stop(ctx); err != nil { + b.logger.ErrorContext(ctx, b.name+": Error stopping client", "err", err) + } + b.logger.InfoContext(ctx, b.name+": Client stopped") + }() + + // Prints one last log line before exit summarizing all operations. + start := time.Now() + defer func() { + runPeriod := time.Since(start) + jobsPerSecond := float64(numJobsWorked.Load()) / runPeriod.Seconds() + + fmt.Printf("bench: total jobs worked [ %10d ], total jobs inserted [ %10d ], overall job/sec [ %10.1f ], running %s\n", + numJobsWorked.Load(), numJobsInserted.Load(), jobsPerSecond, runPeriod) + }() + + const iterationPeriod = 2 * time.Second + + var ( + firstRun = true + numJobsInsertedLast int64 + numJobsWorkedLast int64 + ticker = time.NewTicker(iterationPeriod) + ) + defer ticker.Stop() + + for numIterations := 0; ; numIterations++ { + // Use iterations multiplied by period time instead of actual elapsed + // time to allow a precise, predictable run duration to be specified. + if b.duration != 0 && time.Duration(numIterations)*iterationPeriod >= b.duration { + return nil + } + + var ( + numJobsInsertedSinceLast = numJobsInserted.Load() - numJobsInsertedLast + numJobsWorkedSinceLast = numJobsWorked.Load() - numJobsWorkedLast + ) + + jobsPerSecond := float64(numJobsWorkedSinceLast) / iterationPeriod.Seconds() + + // On first run, show iteration period as 0s because no time was given + // for jobs to be worked. + period := iterationPeriod + if firstRun { + period = 0 * time.Second + } + + fmt.Printf("bench: jobs worked [ %10d ], inserted [ %10d ], job/sec [ %10.1f ] [%s]\n", + numJobsWorkedSinceLast, numJobsInsertedSinceLast, jobsPerSecond, period) + + firstRun = false + numJobsInsertedLast = numJobsInserted.Load() + numJobsWorkedLast = numJobsWorked.Load() + + // If working in the mode where we're burning jobs down and there are no + // jobs left, end. + if b.numTotalJobs != 0 && numJobsLeft.Load() < 1 { + return nil + } + + select { + case <-ctx.Done(): + return nil + + case <-shutdown: + return nil + + case <-ticker.C: + } + } +} + +const ( + insertBatchSize = 2_000 + minJobs = 50_000 +) + +// Inserts `b.numTotalJobs` in batches. This variant inserts a bulk of initial +// jobs and ends, and is used in cases the `-n`/`--num-total-jobs` flag is +// specified. +func (b *Benchmarker[TTx]) insertJobs( + ctx context.Context, + client *river.Client[TTx], + minJobsReady chan struct{}, + numJobsInserted *atomic.Int64, + numJobsLeft *atomic.Int64, + shutdown chan struct{}, +) { + defer close(minJobsReady) + + var ( + // We'll be reusing the same batch for all inserts because (1) we can + // get away with it, and (2) to avoid needless allocations. + insertParamsBatch = make([]river.InsertManyParams, insertBatchSize) + jobArgsBatch = make([]BenchmarkArgs, insertBatchSize) + + jobNum int + ) + + var numInsertedThisRound int + + for { + for _, jobArgs := range jobArgsBatch { + jobNum++ + jobArgs.Num = jobNum + } + + for i := range insertParamsBatch { + insertParamsBatch[i].Args = jobArgsBatch[i] + } + + numLeft := b.numTotalJobs - numInsertedThisRound + if numLeft < insertBatchSize { + insertParamsBatch = insertParamsBatch[0:numLeft] + } + + if _, err := client.InsertMany(ctx, insertParamsBatch); err != nil { + b.logger.ErrorContext(ctx, b.name+": Error inserting jobs", "err", err) + } + + numJobsInserted.Add(int64(len(insertParamsBatch))) + numJobsLeft.Add(int64(len(insertParamsBatch))) + numInsertedThisRound += len(insertParamsBatch) + + if numJobsLeft.Load() >= int64(b.numTotalJobs) { + b.logger.InfoContext(ctx, b.name+": Finished inserting jobs", + "num_inserted", numInsertedThisRound) + return + } + + // Will be very unusual, but break early if done between batches. + select { + case <-ctx.Done(): + return + case <-shutdown: + return + default: + } + } +} + +// Inserts jobs continuously, but only if it notices that the number of jobs +// left is below a minimum threshold. This has the effect of keeping enough job +// slack in the pool to be worked, but keeping the total number of jobs being +// inserted roughly matched with the rate at which the benchmark can work them. +func (b *Benchmarker[TTx]) insertJobsContinuously( + ctx context.Context, + client *river.Client[TTx], + minJobsReady chan struct{}, + numJobsInserted *atomic.Int64, + numJobsLeft *atomic.Int64, + shutdown chan struct{}, +) { + var ( + // We'll be reusing the same batch for all inserts because (1) we can + // get away with it, and (2) to avoid needless allocations. + insertParamsBatch = make([]river.InsertManyParams, insertBatchSize) + jobArgsBatch = make([]BenchmarkArgs, insertBatchSize) + + jobNum int + ) + + for { + select { + case <-ctx.Done(): + return + + case <-shutdown: + return + + case <-time.After(50 * time.Millisecond): + } + + if numJobsLeft.Load() >= minJobs { + continue + } + + var numInsertedThisRound int + + for { + for _, jobArgs := range jobArgsBatch { + jobNum++ + jobArgs.Num = jobNum + } + + for i := range insertParamsBatch { + insertParamsBatch[i].Args = jobArgsBatch[i] + } + + if _, err := client.InsertMany(ctx, insertParamsBatch); err != nil { + b.logger.ErrorContext(ctx, b.name+": Error inserting jobs", "err", err) + } + + numJobsInserted.Add(int64(len(insertParamsBatch))) + numJobsLeft.Add(int64(len(insertParamsBatch))) + numInsertedThisRound += len(insertParamsBatch) + + if numJobsLeft.Load() >= minJobs { + b.logger.InfoContext(ctx, b.name+": Finished inserting batch of jobs", + "num_inserted", numInsertedThisRound) + break // break inner loop to go back to sleep + } + + // Will be very unusual, but break early if done between batches. + select { + case <-ctx.Done(): + return + case <-shutdown: + return + default: + } + } + + // Close the first time we insert a full batch to tell the main loop it + // can start benchmarking. + if minJobsReady != nil { + close(minJobsReady) + minJobsReady = nil + } + } +} + +// Truncates and `VACUUM FULL`s the jobs table to guarantee as little state +// related job variance as possible. +func (b *Benchmarker[TTx]) resetJobsTable(ctx context.Context) error { + b.logger.InfoContext(ctx, b.name+": Truncating and vacuuming jobs table") + + _, err := b.driver.GetExecutor().Exec(ctx, "TRUNCATE river_job") + if err != nil { + return err + } + _, err = b.driver.GetExecutor().Exec(ctx, "VACUUM FULL river_job") + if err != nil { + return err + } + + return nil +} + +type BenchmarkArgs struct { + Num int `json:"num"` +} + +func (BenchmarkArgs) Kind() string { return "benchmark" } + +// BenchmarkWorker is a job worker for counting the number of worked jobs. +type BenchmarkWorker struct { + river.WorkerDefaults[BenchmarkArgs] +} + +func (w *BenchmarkWorker) Work(ctx context.Context, j *river.Job[BenchmarkArgs]) error { + return nil +} diff --git a/internal/cmd/riverbench/main.go b/internal/cmd/riverbench/main.go deleted file mode 100644 index 3d22f658..00000000 --- a/internal/cmd/riverbench/main.go +++ /dev/null @@ -1,202 +0,0 @@ -package main - -import ( - "context" - "fmt" - "log/slog" - "os" - "sync/atomic" - "time" - - "github.com/jackc/pgx/v5" - "github.com/jackc/pgx/v5/pgxpool" - - "github.com/riverqueue/river" - "github.com/riverqueue/river/internal/riverinternaltest" //nolint:depguard - "github.com/riverqueue/river/riverdriver/riverpgxv5" -) - -func main() { - logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelWarn})) - - if err := prepareAndRunBenchmark(context.Background(), logger); err != nil { - logger.Error("failed", "error", err.Error()) - } -} - -func prepareAndRunBenchmark(ctx context.Context, logger *slog.Logger) error { - ctx, cancel := context.WithCancel(ctx) - defer cancel() - - const total = 1000000 - const workerCount = 500 - const insertBatchSize = 10000 - const poolLimit = 10 - const fetchInterval = 2 * time.Millisecond - - fmt.Printf( - "--- benchmarking throughput: jobs=%d workers=%d pool_limit=%d fetch_interval=%s\n", - total, - workerCount, - poolLimit, - fetchInterval, - ) - - dbPool := mustGetDBPool(ctx, poolLimit) - fmt.Println("--- truncating DB…") - if err := truncateDB(ctx, dbPool); err != nil { - return fmt.Errorf("failed to truncate DB: %w", err) - } - - counterWorker := &CounterWorker{} - workers := river.NewWorkers() - river.AddWorker(workers, counterWorker) - - client, err := river.NewClient(riverpgxv5.New(dbPool), &river.Config{ - FetchCooldown: fetchInterval, - FetchPollInterval: fetchInterval, - Logger: logger, - Queues: map[string]river.QueueConfig{ - river.QueueDefault: {MaxWorkers: workerCount}, - }, - Workers: workers, - }) - if err != nil { - return fmt.Errorf("failed to create river client: %w", err) - } - - // Insert jobs in batches until we hit the total: - fmt.Printf("--- inserting %d jobs in batches of %d…\n", total, insertBatchSize) - batch := make([]river.InsertManyParams, 0, insertBatchSize) - for i := 0; i < total; i += insertBatchSize { - batch = batch[:0] - for j := 0; j < insertBatchSize && i+j < total; j++ { - batch = append(batch, river.InsertManyParams{Args: CounterArgs{Number: i + j}}) - } - if err := insertBatch(ctx, client, batch); err != nil { - return fmt.Errorf("failed to insert batch: %w", err) - } - } - fmt.Printf("\n") - logger.Info("done inserting jobs, sleeping 5s") - time.Sleep(5 * time.Second) - logger.Info("starting client") - - startTime := time.Now() - - go func() { - for { - select { - case <-ctx.Done(): - return - case <-time.After(5 * time.Second): - jobsWorked := counterWorker.Counter.Load() - now := time.Now() - fmt.Printf("stats: %d jobs worked in %s (%.2f jobs/sec)\n", - counterWorker.Counter.Load(), - now.Sub(startTime), - float64(jobsWorked)/now.Sub(startTime).Seconds(), - ) - } - } - }() - - go func() { - for { - select { - case <-ctx.Done(): - return - case <-time.After(100 * time.Millisecond): - jobsWorked := counterWorker.Counter.Load() - if jobsWorked == total { - logger.Info("--- all jobs worked") - cancel() - return - } - } - } - }() - - if err := client.Start(ctx); err != nil { - return fmt.Errorf("failed to run river client: %w", err) - } - - logger.Info("client started") - - <-ctx.Done() - logger.Info("initiating shutdown") - shutdownCtx, shutdownCancel := context.WithTimeout(ctx, 10*time.Second) - defer shutdownCancel() - - if err = client.Stop(shutdownCtx); err != nil { - return fmt.Errorf("error shutting down client: %w", err) - } - - logger.Info("client shutdown complete", "jobs_worked_count", counterWorker.Counter.Load()) - - endTime := time.Now() - fmt.Printf("final stats: %d jobs worked in %s (%.2f jobs/sec)\n", - counterWorker.Counter.Load(), - endTime.Sub(startTime), - float64(counterWorker.Counter.Load())/endTime.Sub(startTime).Seconds(), - ) - return nil -} - -func mustGetDBPool(ctx context.Context, connCount int32) *pgxpool.Pool { - ctx, cancel := context.WithTimeout(ctx, 10*time.Second) - defer cancel() - - config := riverinternaltest.DatabaseConfig("river_testdb_example") - config.MaxConns = connCount - config.MinConns = connCount - dbPool, err := pgxpool.NewWithConfig(ctx, config) - if err != nil { - panic(err) - } - return dbPool -} - -func truncateDB(ctx context.Context, pool *pgxpool.Pool) error { - ctx, cancel := context.WithTimeout(ctx, 10*time.Second) - defer cancel() - - _, err := pool.Exec(ctx, "TRUNCATE river_job") - return err -} - -func insertBatch(ctx context.Context, client *river.Client[pgx.Tx], params []river.InsertManyParams) error { - ctx, cancel := context.WithTimeout(ctx, 20*time.Second) - defer cancel() - - insertedCount, err := client.InsertMany(ctx, params) - if err != nil { - return err - } - - if insertedCount != int64(len(params)) { - return fmt.Errorf("inserted %d jobs, expected %d", insertedCount, len(params)) - } - - fmt.Printf(".") - return nil -} - -// CounterArgs are arguments for CounterWorker. -type CounterArgs struct { - // Number is the number of this job. - Number int `json:"number"` -} - -func (CounterArgs) Kind() string { return "counter_worker" } - -// CounterWorker is a job worker for counting the number of worked jobs. -type CounterWorker struct { - river.WorkerDefaults[CounterArgs] - Counter atomic.Uint64 -} - -func (w *CounterWorker) Work(ctx context.Context, j *river.Job[CounterArgs]) error { - w.Counter.Add(1) - return nil -}