Skip to content

Commit

Permalink
Add middleware system for jobs (riverqueue#632)
Browse files Browse the repository at this point in the history
Here, implement a middleware system that adds middleware functions to
job lifecycles, which results in them being invoked during specific
phases of a job like as it's being inserted or worked.

The most obvious unlock for this is telemetry (e.g. logging, metrics),
but it also acts as a building block for features like encrypted jobs.

There are two middleware interfaces added: `JobInsertMiddleware` and
`WorkerMiddleware`. A user could implement these on the same struct type
for something like telemetry where you want to inject a trace ID at job
insertion, and then resurrect it into the context at execution time. In
other cases, such as execution-specific logic, it may only make sense to
implement one of these. Each of these interfaces has a corresponding
`*Defaults` struct that can be embedded to provide future-proofing if
additions are made to them for future functionality.

The `Worker[T]` type has been extended with a `Middleware()` method that
allows each individual worker type to define its own added lists of
middleware which will be run _after_ the global `WorkerMiddleware` at
the client config level. There are also global `JobInsertMiddleware` on
the client config.

Co-authored-by: Brandur <[email protected]>
  • Loading branch information
2 people authored and tigrato committed Dec 18, 2024
1 parent b90c49e commit 2629666
Show file tree
Hide file tree
Showing 20 changed files with 378 additions and 50 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

⚠️ Version 0.13.0 removes the original advisory lock based unique jobs implementation that was deprecated in v0.12.0. See details in the note below or the v0.12.0 release notes.

### Added

- A middleware system was added for job insertion and execution, providing the ability to extract shared functionality across workers. Both `JobInsertMiddleware` and `WorkerMiddleware` can be configured globally on the `Client`, and `WorkerMiddleware` can also be added on a per-worker basis using the new `Middleware` method on `Worker[T]`. Middleware can be useful for logging, telemetry, or for building higher level abstractions on top of base River functionality.

Despite the interface expansion, users should not encounter any breakage if they're embedding the `WorkerDefaults` type in their workers as recommended. [PR #632](https://github.com/riverqueue/river/pull/632).

### Changed

- **Breaking change:** The advisory lock unique jobs implementation which was deprecated in v0.12.0 has been removed. Users of that feature should first upgrade to v0.12.1 to ensure they don't see any warning logs about using the deprecated advisory lock uniqueness. The new, faster unique implementation will be used automatically as long as the `UniqueOpts.ByState` list hasn't been customized to remove [required states](https://riverqueue.com/docs/unique-jobs#unique-by-state) (`pending`, `scheduled`, `available`, and `running`). As of this release, customizing `ByState` without these required states returns an error. [PR #614](https://github.com/riverqueue/river/pull/614).
Expand Down
68 changes: 49 additions & 19 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,10 @@ type Config struct {
// deployments.
JobCleanerTimeout time.Duration

// JobInsertMiddleware are optional functions that can be called around job
// insertion.
JobInsertMiddleware []rivertype.JobInsertMiddleware

// JobTimeout is the maximum amount of time a job is allowed to run before its
// context is cancelled. A timeout of zero means JobTimeoutDefault will be
// used, whereas a value of -1 means the job's context will not be cancelled
Expand Down Expand Up @@ -235,6 +239,10 @@ type Config struct {
// (i.e. That it wasn't forgotten by accident.)
Workers *Workers

// WorkerMiddleware are optional functions that can be called around
// all job executions.
WorkerMiddleware []rivertype.WorkerMiddleware

// Scheduler run interval. Shared between the scheduler and producer/job
// executors, but not currently exposed for configuration.
schedulerInterval time.Duration
Expand Down Expand Up @@ -467,6 +475,7 @@ func NewClient[TTx any](driver riverdriver.Driver[TTx], config *Config) (*Client
FetchCooldown: valutil.ValOrDefault(config.FetchCooldown, FetchCooldownDefault),
FetchPollInterval: valutil.ValOrDefault(config.FetchPollInterval, FetchPollIntervalDefault),
ID: valutil.ValOrDefaultFunc(config.ID, func() string { return defaultClientID(time.Now().UTC()) }),
JobInsertMiddleware: config.JobInsertMiddleware,
JobTimeout: valutil.ValOrDefault(config.JobTimeout, JobTimeoutDefault),
Logger: logger,
MaxAttempts: valutil.ValOrDefault(config.MaxAttempts, MaxAttemptsDefault),
Expand All @@ -478,6 +487,7 @@ func NewClient[TTx any](driver riverdriver.Driver[TTx], config *Config) (*Client
RetryPolicy: retryPolicy,
TestOnly: config.TestOnly,
Workers: config.Workers,
WorkerMiddleware: config.WorkerMiddleware,
schedulerInterval: valutil.ValOrDefault(config.schedulerInterval, maintenance.JobSchedulerIntervalDefault),
time: config.time,
}
Expand Down Expand Up @@ -1168,7 +1178,7 @@ func (c *Client[TTx]) ID() string {
return c.config.ID
}

func insertParamsFromConfigArgsAndOptions(archetype *baseservice.Archetype, config *Config, args JobArgs, insertOpts *InsertOpts) (*riverdriver.JobInsertFastParams, error) {
func insertParamsFromConfigArgsAndOptions(archetype *baseservice.Archetype, config *Config, args JobArgs, insertOpts *InsertOpts) (*rivertype.JobInsertParams, error) {
encodedArgs, err := json.Marshal(args)
if err != nil {
return nil, fmt.Errorf("error marshaling args to JSON: %w", err)
Expand Down Expand Up @@ -1230,13 +1240,13 @@ func insertParamsFromConfigArgsAndOptions(archetype *baseservice.Archetype, conf
metadata = []byte("{}")
}

insertParams := &riverdriver.JobInsertFastParams{
insertParams := &rivertype.JobInsertParams{
Args: args,
CreatedAt: createdAt,
EncodedArgs: json.RawMessage(encodedArgs),
EncodedArgs: encodedArgs,
Kind: args.Kind(),
MaxAttempts: maxAttempts,
Metadata: json.RawMessage(metadata),
Metadata: metadata,
Priority: priority,
Queue: queue,
State: rivertype.JobStateAvailable,
Expand Down Expand Up @@ -1439,39 +1449,58 @@ func (c *Client[TTx]) insertMany(ctx context.Context, tx riverdriver.ExecutorTx,
func (c *Client[TTx]) insertManyShared(
ctx context.Context,
tx riverdriver.ExecutorTx,
params []InsertManyParams,
rawParams []InsertManyParams,
execute func(context.Context, []*riverdriver.JobInsertFastParams) ([]*rivertype.JobInsertResult, error),
) ([]*rivertype.JobInsertResult, error) {
insertParams, err := c.insertManyParams(params)
insertParams, err := c.insertManyParams(rawParams)
if err != nil {
return nil, err
}

inserted, err := execute(ctx, insertParams)
if err != nil {
return inserted, err
}
doInner := func(ctx context.Context) ([]*rivertype.JobInsertResult, error) {
finalInsertParams := sliceutil.Map(insertParams, func(params *rivertype.JobInsertParams) *riverdriver.JobInsertFastParams {
return (*riverdriver.JobInsertFastParams)(params)
})
results, err := execute(ctx, finalInsertParams)
if err != nil {
return results, err
}

queues := make([]string, 0, 10)
for _, params := range insertParams {
if params.State == rivertype.JobStateAvailable {
queues = append(queues, params.Queue)
queues := make([]string, 0, 10)
for _, params := range insertParams {
if params.State == rivertype.JobStateAvailable {
queues = append(queues, params.Queue)
}
}
if err := c.maybeNotifyInsertForQueues(ctx, tx, queues); err != nil {
return nil, err
}
return results, nil
}
if err := c.maybeNotifyInsertForQueues(ctx, tx, queues); err != nil {
return nil, err

if len(c.config.JobInsertMiddleware) > 0 {
// Wrap middlewares in reverse order so the one defined first is wrapped
// as the outermost function and is first to receive the operation.
for i := len(c.config.JobInsertMiddleware) - 1; i >= 0; i-- {
middlewareItem := c.config.JobInsertMiddleware[i] // capture the current middleware item
previousDoInner := doInner // Capture the current doInner function
doInner = func(ctx context.Context) ([]*rivertype.JobInsertResult, error) {
return middlewareItem.InsertMany(ctx, insertParams, previousDoInner)
}
}
}
return inserted, nil

return doInner(ctx)
}

// Validates input parameters for a batch insert operation and generates a set
// of batch insert parameters.
func (c *Client[TTx]) insertManyParams(params []InsertManyParams) ([]*riverdriver.JobInsertFastParams, error) {
func (c *Client[TTx]) insertManyParams(params []InsertManyParams) ([]*rivertype.JobInsertParams, error) {
if len(params) < 1 {
return nil, errors.New("no jobs to insert")
}

insertParams := make([]*riverdriver.JobInsertFastParams, len(params))
insertParams := make([]*rivertype.JobInsertParams, len(params))
for i, param := range params {
if err := c.validateJobArgs(param.Args); err != nil {
return nil, err
Expand Down Expand Up @@ -1668,6 +1697,7 @@ func (c *Client[TTx]) addProducer(queueName string, queueConfig QueueConfig) *pr
ErrorHandler: c.config.ErrorHandler,
FetchCooldown: c.config.FetchCooldown,
FetchPollInterval: c.config.FetchPollInterval,
GlobalMiddleware: c.config.WorkerMiddleware,
JobTimeout: c.config.JobTimeout,
MaxWorkers: queueConfig.MaxWorkers,
Notifier: c.notifier,
Expand Down
157 changes: 155 additions & 2 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"github.com/jackc/pgx/v5/stdlib"
"github.com/robfig/cron/v3"
"github.com/stretchr/testify/require"
"github.com/tidwall/sjson"

"github.com/riverqueue/river/internal/dbunique"
"github.com/riverqueue/river/internal/maintenance"
Expand Down Expand Up @@ -589,6 +590,90 @@ func Test_Client(t *testing.T) {
require.Equal(t, `relation "river_job" does not exist`, pgErr.Message)
})

t.Run("WithWorkerMiddleware", func(t *testing.T) {
t.Parallel()

_, bundle := setup(t)
middlewareCalled := false

type privateKey string

middleware := &overridableJobMiddleware{
workFunc: func(ctx context.Context, job *rivertype.JobRow, doInner func(ctx context.Context) error) error {
ctx = context.WithValue(ctx, privateKey("middleware"), "called")
middlewareCalled = true
return doInner(ctx)
},
}
bundle.config.WorkerMiddleware = []rivertype.WorkerMiddleware{middleware}

AddWorker(bundle.config.Workers, WorkFunc(func(ctx context.Context, job *Job[callbackArgs]) error {
require.Equal(t, "called", ctx.Value(privateKey("middleware")))
return nil
}))

driver := riverpgxv5.New(bundle.dbPool)
client, err := NewClient(driver, bundle.config)
require.NoError(t, err)

subscribeChan := subscribe(t, client)
startClient(ctx, t, client)

result, err := client.Insert(ctx, callbackArgs{}, nil)
require.NoError(t, err)

event := riversharedtest.WaitOrTimeout(t, subscribeChan)
require.Equal(t, EventKindJobCompleted, event.Kind)
require.Equal(t, result.Job.ID, event.Job.ID)
require.True(t, middlewareCalled)
})

t.Run("WithWorkerMiddlewareOnWorker", func(t *testing.T) {
t.Parallel()

_, bundle := setup(t)
middlewareCalled := false

type privateKey string

worker := &workerWithMiddleware[callbackArgs]{
workFunc: func(ctx context.Context, job *Job[callbackArgs]) error {
require.Equal(t, "called", ctx.Value(privateKey("middleware")))
return nil
},
middlewareFunc: func(job *Job[callbackArgs]) []rivertype.WorkerMiddleware {
require.Equal(t, "middleware_test", job.Args.Name, "JSON should be decoded before middleware is called")

return []rivertype.WorkerMiddleware{
&overridableJobMiddleware{
workFunc: func(ctx context.Context, job *rivertype.JobRow, doInner func(ctx context.Context) error) error {
ctx = context.WithValue(ctx, privateKey("middleware"), "called")
middlewareCalled = true
return doInner(ctx)
},
},
}
},
}

AddWorker(bundle.config.Workers, worker)

driver := riverpgxv5.New(bundle.dbPool)
client, err := NewClient(driver, bundle.config)
require.NoError(t, err)

subscribeChan := subscribe(t, client)
startClient(ctx, t, client)

result, err := client.Insert(ctx, callbackArgs{Name: "middleware_test"}, nil)
require.NoError(t, err)

event := riversharedtest.WaitOrTimeout(t, subscribeChan)
require.Equal(t, EventKindJobCompleted, event.Kind)
require.Equal(t, result.Job.ID, event.Job.ID)
require.True(t, middlewareCalled)
})

t.Run("PauseAndResumeSingleQueue", func(t *testing.T) {
t.Parallel()

Expand Down Expand Up @@ -835,6 +920,20 @@ func Test_Client(t *testing.T) {
})
}

type workerWithMiddleware[T JobArgs] struct {
WorkerDefaults[T]
workFunc func(context.Context, *Job[T]) error
middlewareFunc func(*Job[T]) []rivertype.WorkerMiddleware
}

func (w *workerWithMiddleware[T]) Work(ctx context.Context, job *Job[T]) error {
return w.workFunc(ctx, job)
}

func (w *workerWithMiddleware[T]) Middleware(job *Job[T]) []rivertype.WorkerMiddleware {
return w.middlewareFunc(job)
}

func Test_Client_Stop(t *testing.T) {
t.Parallel()

Expand Down Expand Up @@ -2420,6 +2519,48 @@ func Test_Client_InsertManyTx(t *testing.T) {
require.Len(t, results, 1)
})

t.Run("WithJobInsertMiddleware", func(t *testing.T) {
t.Parallel()

_, bundle := setup(t)
config := newTestConfig(t, nil)
config.Queues = nil

insertCalled := false
var innerResults []*rivertype.JobInsertResult

middleware := &overridableJobMiddleware{
insertManyFunc: func(ctx context.Context, manyParams []*rivertype.JobInsertParams, doInner func(ctx context.Context) ([]*rivertype.JobInsertResult, error)) ([]*rivertype.JobInsertResult, error) {
insertCalled = true
var err error
for _, params := range manyParams {
params.Metadata, err = sjson.SetBytes(params.Metadata, "middleware", "called")
require.NoError(t, err)
}

results, err := doInner(ctx)
require.NoError(t, err)
innerResults = results
return results, nil
},
}

config.JobInsertMiddleware = []rivertype.JobInsertMiddleware{middleware}
driver := riverpgxv5.New(nil)
client, err := NewClient(driver, config)
require.NoError(t, err)

results, err := client.InsertManyTx(ctx, bundle.tx, []InsertManyParams{{Args: noOpArgs{}}})
require.NoError(t, err)
require.Len(t, results, 1)

require.True(t, insertCalled)
require.Len(t, innerResults, 1)
require.Len(t, results, 1)
require.Equal(t, innerResults[0].Job.ID, results[0].Job.ID)
require.JSONEq(t, `{"middleware": "called"}`, string(results[0].Job.Metadata))
})

t.Run("WithUniqueOpts", func(t *testing.T) {
t.Parallel()

Expand Down Expand Up @@ -2998,7 +3139,7 @@ func Test_Client_ErrorHandler(t *testing.T) {
// unknown job.
insertParams, err := insertParamsFromConfigArgsAndOptions(&client.baseService.Archetype, config, unregisteredJobArgs{}, nil)
require.NoError(t, err)
_, err = client.driver.GetExecutor().JobInsertFastMany(ctx, []*riverdriver.JobInsertFastParams{insertParams})
_, err = client.driver.GetExecutor().JobInsertFastMany(ctx, []*riverdriver.JobInsertFastParams{(*riverdriver.JobInsertFastParams)(insertParams)})
require.NoError(t, err)

riversharedtest.WaitOrTimeout(t, bundle.SubscribeChan)
Expand Down Expand Up @@ -4600,7 +4741,7 @@ func Test_Client_UnknownJobKindErrorsTheJob(t *testing.T) {

insertParams, err := insertParamsFromConfigArgsAndOptions(&client.baseService.Archetype, config, unregisteredJobArgs{}, nil)
require.NoError(err)
insertedResults, err := client.driver.GetExecutor().JobInsertFastMany(ctx, []*riverdriver.JobInsertFastParams{insertParams})
insertedResults, err := client.driver.GetExecutor().JobInsertFastMany(ctx, []*riverdriver.JobInsertFastParams{(*riverdriver.JobInsertFastParams)(insertParams)})
require.NoError(err)

insertedResult := insertedResults[0]
Expand Down Expand Up @@ -4770,6 +4911,14 @@ func Test_NewClient_Overrides(t *testing.T) {

retryPolicy := &DefaultClientRetryPolicy{}

type noOpInsertMiddleware struct {
JobInsertMiddlewareDefaults
}

type noOpWorkerMiddleware struct {
WorkerMiddlewareDefaults
}

client, err := NewClient(riverpgxv5.New(dbPool), &Config{
AdvisoryLockPrefix: 123_456,
CancelledJobRetentionPeriod: 1 * time.Hour,
Expand All @@ -4778,13 +4927,15 @@ func Test_NewClient_Overrides(t *testing.T) {
ErrorHandler: errorHandler,
FetchCooldown: 123 * time.Millisecond,
FetchPollInterval: 124 * time.Millisecond,
JobInsertMiddleware: []rivertype.JobInsertMiddleware{&noOpInsertMiddleware{}},
JobTimeout: 125 * time.Millisecond,
Logger: logger,
MaxAttempts: 5,
Queues: map[string]QueueConfig{QueueDefault: {MaxWorkers: 1}},
RetryPolicy: retryPolicy,
TestOnly: true, // disables staggered start in maintenance services
Workers: workers,
WorkerMiddleware: []rivertype.WorkerMiddleware{&noOpWorkerMiddleware{}},
})
require.NoError(t, err)

Expand All @@ -4803,10 +4954,12 @@ func Test_NewClient_Overrides(t *testing.T) {
require.Equal(t, errorHandler, client.config.ErrorHandler)
require.Equal(t, 123*time.Millisecond, client.config.FetchCooldown)
require.Equal(t, 124*time.Millisecond, client.config.FetchPollInterval)
require.Len(t, client.config.JobInsertMiddleware, 1)
require.Equal(t, 125*time.Millisecond, client.config.JobTimeout)
require.Equal(t, logger, client.baseService.Logger)
require.Equal(t, 5, client.config.MaxAttempts)
require.Equal(t, retryPolicy, client.config.RetryPolicy)
require.Len(t, client.config.WorkerMiddleware, 1)
}

func Test_NewClient_MissingParameters(t *testing.T) {
Expand Down
Loading

0 comments on commit 2629666

Please sign in to comment.