Skip to content

Commit

Permalink
Add middleware system for jobs
Browse files Browse the repository at this point in the history
Here, experiment with a middleware-like system that adds middleware
functions to job lifecycles, which results in them being invoked during
specific phases of a job like as it's being inserted or worked.

The most obvious unlock for this is telemetry (e.g. logging, metrics),
but it also acts as a building block for features like encrypted jobs.

Co-authored-by: Blake Gentry <[email protected]>
  • Loading branch information
brandur and bgentry committed Oct 5, 2024
1 parent 4876b52 commit e495633
Show file tree
Hide file tree
Showing 14 changed files with 159 additions and 49 deletions.
63 changes: 44 additions & 19 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,10 @@ type Config struct {
// deployments.
JobCleanerTimeout time.Duration

// JobMiddleware are optional functions that can be called around different
// parts of each job's lifecycle.
JobMiddleware []rivertype.JobMiddleware

// JobTimeout is the maximum amount of time a job is allowed to run before its
// context is cancelled. A timeout of zero means JobTimeoutDefault will be
// used, whereas a value of -1 means the job's context will not be cancelled
Expand Down Expand Up @@ -467,6 +471,7 @@ func NewClient[TTx any](driver riverdriver.Driver[TTx], config *Config) (*Client
FetchCooldown: valutil.ValOrDefault(config.FetchCooldown, FetchCooldownDefault),
FetchPollInterval: valutil.ValOrDefault(config.FetchPollInterval, FetchPollIntervalDefault),
ID: valutil.ValOrDefaultFunc(config.ID, func() string { return defaultClientID(time.Now().UTC()) }),
JobMiddleware: config.JobMiddleware,
JobTimeout: valutil.ValOrDefault(config.JobTimeout, JobTimeoutDefault),
Logger: logger,
MaxAttempts: valutil.ValOrDefault(config.MaxAttempts, MaxAttemptsDefault),
Expand Down Expand Up @@ -1165,7 +1170,7 @@ func (c *Client[TTx]) ID() string {
return c.config.ID
}

func insertParamsFromConfigArgsAndOptions(archetype *baseservice.Archetype, config *Config, args JobArgs, insertOpts *InsertOpts) (*riverdriver.JobInsertFastParams, error) {
func insertParamsFromConfigArgsAndOptions(archetype *baseservice.Archetype, config *Config, args JobArgs, insertOpts *InsertOpts) (*rivertype.JobInsertParams, error) {
encodedArgs, err := json.Marshal(args)
if err != nil {
return nil, fmt.Errorf("error marshaling args to JSON: %w", err)
Expand Down Expand Up @@ -1227,13 +1232,13 @@ func insertParamsFromConfigArgsAndOptions(archetype *baseservice.Archetype, conf
metadata = []byte("{}")
}

insertParams := &riverdriver.JobInsertFastParams{
insertParams := &rivertype.JobInsertParams{
Args: args,
CreatedAt: createdAt,
EncodedArgs: json.RawMessage(encodedArgs),
EncodedArgs: encodedArgs,
Kind: args.Kind(),
MaxAttempts: maxAttempts,
Metadata: json.RawMessage(metadata),
Metadata: metadata,
Priority: priority,
Queue: queue,
State: rivertype.JobStateAvailable,
Expand Down Expand Up @@ -1436,39 +1441,58 @@ func (c *Client[TTx]) insertMany(ctx context.Context, tx riverdriver.ExecutorTx,
func (c *Client[TTx]) insertManyShared(
ctx context.Context,
tx riverdriver.ExecutorTx,
params []InsertManyParams,
rawParams []InsertManyParams,
execute func(context.Context, []*riverdriver.JobInsertFastParams) ([]*rivertype.JobInsertResult, error),
) ([]*rivertype.JobInsertResult, error) {
insertParams, err := c.insertManyParams(params)
insertParams, err := c.insertManyParams(rawParams)
if err != nil {
return nil, err
}

inserted, err := execute(ctx, insertParams)
if err != nil {
return inserted, err
}
doInner := func(ctx context.Context) ([]*rivertype.JobInsertResult, error) {
finalInsertParams := sliceutil.Map(insertParams, func(params *rivertype.JobInsertParams) *riverdriver.JobInsertFastParams {
return (*riverdriver.JobInsertFastParams)(params)
})
results, err := execute(ctx, finalInsertParams)
if err != nil {
return results, err
}

queues := make([]string, 0, 10)
for _, params := range insertParams {
if params.State == rivertype.JobStateAvailable {
queues = append(queues, params.Queue)
queues := make([]string, 0, 10)
for _, params := range insertParams {
if params.State == rivertype.JobStateAvailable {
queues = append(queues, params.Queue)
}
}
if err := c.maybeNotifyInsertForQueues(ctx, tx, queues); err != nil {
return nil, err
}
return results, nil
}
if err := c.maybeNotifyInsertForQueues(ctx, tx, queues); err != nil {
return nil, err

if len(c.config.JobMiddleware) > 0 {
// Wrap middlewares in reverse order so the one defined first is wrapped
// as the outermost function and is first to receive the operation.
for i := len(c.config.JobMiddleware) - 1; i >= 0; i-- {
middlewareItem := c.config.JobMiddleware[i] // capture the current middleware item
previousDoInner := doInner // Capture the current doInner function
doInner = func(ctx context.Context) ([]*rivertype.JobInsertResult, error) {
return middlewareItem.InsertMany(ctx, insertParams, previousDoInner)
}
}
}
return inserted, nil

return doInner(ctx)
}

// Validates input parameters for a batch insert operation and generates a set
// of batch insert parameters.
func (c *Client[TTx]) insertManyParams(params []InsertManyParams) ([]*riverdriver.JobInsertFastParams, error) {
func (c *Client[TTx]) insertManyParams(params []InsertManyParams) ([]*rivertype.JobInsertParams, error) {
if len(params) < 1 {
return nil, errors.New("no jobs to insert")
}

insertParams := make([]*riverdriver.JobInsertFastParams, len(params))
insertParams := make([]*rivertype.JobInsertParams, len(params))
for i, param := range params {
if err := c.validateJobArgs(param.Args); err != nil {
return nil, err
Expand Down Expand Up @@ -1665,6 +1689,7 @@ func (c *Client[TTx]) addProducer(queueName string, queueConfig QueueConfig) *pr
ErrorHandler: c.config.ErrorHandler,
FetchCooldown: c.config.FetchCooldown,
FetchPollInterval: c.config.FetchPollInterval,
JobMiddleware: c.config.JobMiddleware,
JobTimeout: c.config.JobTimeout,
MaxWorkers: queueConfig.MaxWorkers,
Notifier: c.notifier,
Expand Down
4 changes: 2 additions & 2 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2998,7 +2998,7 @@ func Test_Client_ErrorHandler(t *testing.T) {
// unknown job.
insertParams, err := insertParamsFromConfigArgsAndOptions(&client.baseService.Archetype, config, unregisteredJobArgs{}, nil)
require.NoError(t, err)
_, err = client.driver.GetExecutor().JobInsertFastMany(ctx, []*riverdriver.JobInsertFastParams{insertParams})
_, err = client.driver.GetExecutor().JobInsertFastMany(ctx, []*riverdriver.JobInsertFastParams{(*riverdriver.JobInsertFastParams)(insertParams)})
require.NoError(t, err)

riversharedtest.WaitOrTimeout(t, bundle.SubscribeChan)
Expand Down Expand Up @@ -4600,7 +4600,7 @@ func Test_Client_UnknownJobKindErrorsTheJob(t *testing.T) {

insertParams, err := insertParamsFromConfigArgsAndOptions(&client.baseService.Archetype, config, unregisteredJobArgs{}, nil)
require.NoError(err)
insertedResults, err := client.driver.GetExecutor().JobInsertFastMany(ctx, []*riverdriver.JobInsertFastParams{insertParams})
insertedResults, err := client.driver.GetExecutor().JobInsertFastMany(ctx, []*riverdriver.JobInsertFastParams{(*riverdriver.JobInsertFastParams)(insertParams)})
require.NoError(err)

insertedResult := insertedResults[0]
Expand Down
8 changes: 3 additions & 5 deletions internal/dbunique/db_unique.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"

"github.com/riverqueue/river/riverdriver"
"github.com/riverqueue/river/rivershared/baseservice"
"github.com/riverqueue/river/rivershared/util/sliceutil"
"github.com/riverqueue/river/rivertype"
Expand Down Expand Up @@ -64,7 +63,7 @@ func (o *UniqueOpts) StateBitmask() byte {
return UniqueStatesToBitmask(states)
}

func UniqueKey(timeGen baseservice.TimeGenerator, uniqueOpts *UniqueOpts, params *riverdriver.JobInsertFastParams) ([]byte, error) {
func UniqueKey(timeGen baseservice.TimeGenerator, uniqueOpts *UniqueOpts, params *rivertype.JobInsertParams) ([]byte, error) {
uniqueKeyString, err := buildUniqueKeyString(timeGen, uniqueOpts, params)
if err != nil {
return nil, err
Expand All @@ -74,9 +73,8 @@ func UniqueKey(timeGen baseservice.TimeGenerator, uniqueOpts *UniqueOpts, params
}

// Builds a unique key made up of the unique options in place. The key is hashed
// to become a value for `unique_key` in the fast insertion path, or hashed and
// used for an advisory lock on the slow insertion path.
func buildUniqueKeyString(timeGen baseservice.TimeGenerator, uniqueOpts *UniqueOpts, params *riverdriver.JobInsertFastParams) (string, error) {
// to become a value for `unique_key`.
func buildUniqueKeyString(timeGen baseservice.TimeGenerator, uniqueOpts *UniqueOpts, params *rivertype.JobInsertParams) (string, error) {
var sb strings.Builder

if !uniqueOpts.ExcludeKind {
Expand Down
3 changes: 1 addition & 2 deletions internal/dbunique/db_unique_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (

"github.com/stretchr/testify/require"

"github.com/riverqueue/river/riverdriver"
"github.com/riverqueue/river/rivershared/riversharedtest"
"github.com/riverqueue/river/rivertype"
)
Expand Down Expand Up @@ -229,7 +228,7 @@ func TestUniqueKey(t *testing.T) {
states = tt.uniqueOpts.ByState
}

jobParams := &riverdriver.JobInsertFastParams{
jobParams := &rivertype.JobInsertParams{
Args: args,
CreatedAt: &now,
EncodedArgs: encodedArgs,
Expand Down
6 changes: 3 additions & 3 deletions internal/maintenance/periodic_job_enqueuer.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func (ts *PeriodicJobEnqueuerTestSignals) Init() {
// river.PeriodicJobArgs, but needs a separate type because the enqueuer is in a
// subpackage.
type PeriodicJob struct {
ConstructorFunc func() (*riverdriver.JobInsertFastParams, error)
ConstructorFunc func() (*rivertype.JobInsertParams, error)
RunOnStart bool
ScheduleFunc func(time.Time) time.Time

Expand Down Expand Up @@ -373,7 +373,7 @@ func (s *PeriodicJobEnqueuer) insertBatch(ctx context.Context, insertParamsMany
s.TestSignals.InsertedJobs.Signal(struct{}{})
}

func (s *PeriodicJobEnqueuer) insertParamsFromConstructor(ctx context.Context, constructorFunc func() (*riverdriver.JobInsertFastParams, error), scheduledAt time.Time) (*riverdriver.JobInsertFastParams, bool) {
func (s *PeriodicJobEnqueuer) insertParamsFromConstructor(ctx context.Context, constructorFunc func() (*rivertype.JobInsertParams, error), scheduledAt time.Time) (*riverdriver.JobInsertFastParams, bool) {
insertParams, err := constructorFunc()
if err != nil {
if errors.Is(err, ErrNoJobToInsert) {
Expand All @@ -389,7 +389,7 @@ func (s *PeriodicJobEnqueuer) insertParamsFromConstructor(ctx context.Context, c
insertParams.ScheduledAt = &scheduledAt
}

return insertParams, true
return (*riverdriver.JobInsertFastParams)(insertParams), true
}

const periodicJobEnqueuerVeryLongDuration = 24 * time.Hour
Expand Down
10 changes: 5 additions & 5 deletions internal/maintenance/periodic_job_enqueuer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ func TestPeriodicJobEnqueuer(t *testing.T) {
stubSvc := &riversharedtest.TimeStub{}
stubSvc.StubNowUTC(time.Now().UTC())

jobConstructorWithQueueFunc := func(name string, unique bool, queue string) func() (*riverdriver.JobInsertFastParams, error) {
return func() (*riverdriver.JobInsertFastParams, error) {
params := &riverdriver.JobInsertFastParams{
jobConstructorWithQueueFunc := func(name string, unique bool, queue string) func() (*rivertype.JobInsertParams, error) {
return func() (*rivertype.JobInsertParams, error) {
params := &rivertype.JobInsertParams{
Args: noOpArgs{},
EncodedArgs: []byte("{}"),
Kind: name,
Expand All @@ -66,7 +66,7 @@ func TestPeriodicJobEnqueuer(t *testing.T) {
}
}

jobConstructorFunc := func(name string, unique bool) func() (*riverdriver.JobInsertFastParams, error) {
jobConstructorFunc := func(name string, unique bool) func() (*rivertype.JobInsertParams, error) {
return jobConstructorWithQueueFunc(name, unique, rivercommon.QueueDefault)
}

Expand Down Expand Up @@ -256,7 +256,7 @@ func TestPeriodicJobEnqueuer(t *testing.T) {

svc.AddMany([]*PeriodicJob{
// skip this insert when it returns nil:
{ScheduleFunc: periodicIntervalSchedule(time.Second), ConstructorFunc: func() (*riverdriver.JobInsertFastParams, error) {
{ScheduleFunc: periodicIntervalSchedule(time.Second), ConstructorFunc: func() (*rivertype.JobInsertParams, error) {
return nil, ErrNoJobToInsert
}, RunOnStart: true},
})
Expand Down
4 changes: 2 additions & 2 deletions internal/maintenance/queue_maintainer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@ import (

"github.com/riverqueue/river/internal/riverinternaltest"
"github.com/riverqueue/river/internal/riverinternaltest/sharedtx"
"github.com/riverqueue/river/riverdriver"
"github.com/riverqueue/river/riverdriver/riverpgxv5"
"github.com/riverqueue/river/rivershared/baseservice"
"github.com/riverqueue/river/rivershared/riversharedtest"
"github.com/riverqueue/river/rivershared/startstop"
"github.com/riverqueue/river/rivershared/startstoptest"
"github.com/riverqueue/river/rivershared/testsignal"
"github.com/riverqueue/river/rivertype"
)

type testService struct {
Expand Down Expand Up @@ -107,7 +107,7 @@ func TestQueueMaintainer(t *testing.T) {
NewPeriodicJobEnqueuer(archetype, &PeriodicJobEnqueuerConfig{
PeriodicJobs: []*PeriodicJob{
{
ConstructorFunc: func() (*riverdriver.JobInsertFastParams, error) {
ConstructorFunc: func() (*rivertype.JobInsertParams, error) {
return nil, ErrNoJobToInsert
},
ScheduleFunc: cron.Every(15 * time.Minute).Next,
Expand Down
29 changes: 24 additions & 5 deletions job_executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ type jobExecutor struct {
ErrorHandler ErrorHandler
InformProducerDoneFunc func(jobRow *rivertype.JobRow)
JobRow *rivertype.JobRow
JobMiddleware []rivertype.JobMiddleware
SchedulerInterval time.Duration
WorkUnit workunit.WorkUnit

Expand Down Expand Up @@ -193,11 +194,11 @@ func (e *jobExecutor) execute(ctx context.Context) (res *jobExecutorResult) {
return &jobExecutorResult{Err: &UnknownJobKindError{Kind: e.JobRow.Kind}}
}

if err := e.WorkUnit.UnmarshalJob(); err != nil {
return &jobExecutorResult{Err: err}
}
doInner := func(ctx context.Context) error {
if err := e.WorkUnit.UnmarshalJob(); err != nil {
return err
}

{
jobTimeout := e.WorkUnit.Timeout()
if jobTimeout == 0 {
jobTimeout = e.ClientJobTimeout
Expand All @@ -210,8 +211,26 @@ func (e *jobExecutor) execute(ctx context.Context) (res *jobExecutorResult) {
defer cancel()
}

return &jobExecutorResult{Err: e.WorkUnit.Work(ctx)}
if err := e.WorkUnit.Work(ctx); err != nil {
return err
}

return nil
}

if len(e.JobMiddleware) > 0 {
// Wrap middlewares in reverse order so the one defined first is wrapped
// as the outermost function and is first to receive the operation.
for i := len(e.JobMiddleware) - 1; i >= 0; i-- {
middlewareItem := e.JobMiddleware[i] // capture the current middleware item
previousDoInner := doInner // Capture the current doInner function
doInner = func(ctx context.Context) error {
return middlewareItem.Work(ctx, e.JobRow, previousDoInner)
}
}
}

return &jobExecutorResult{Err: doInner(ctx)}
}

func (e *jobExecutor) invokeErrorHandler(ctx context.Context, res *jobExecutorResult) bool {
Expand Down
21 changes: 21 additions & 0 deletions middleware_defaults.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package river

import (
"context"

"github.com/riverqueue/river/rivertype"
)

// JobMiddlewareDefaults is an embeddable struct that provides default
// implementations for the rivertype.JobMiddleware. Use of this struct is
// recommended in case rivertype.JobMiddleware is expanded in the future so that
// existing code isn't unexpectedly broken during an upgrade.
type JobMiddlewareDefaults struct{}

func (l *JobMiddlewareDefaults) InsertMany(ctx context.Context, manyParams []*rivertype.JobInsertParams, doInner func(ctx context.Context) ([]*rivertype.JobInsertResult, error)) ([]*rivertype.JobInsertResult, error) {
return doInner(ctx)
}

func (l *JobMiddlewareDefaults) Work(ctx context.Context, job *rivertype.JobRow, doInner func(ctx context.Context) error) error {
return doInner(ctx)
}
5 changes: 5 additions & 0 deletions middleware_defaults_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package river

import "github.com/riverqueue/river/rivertype"

var _ rivertype.JobMiddleware = &JobMiddlewareDefaults{}
3 changes: 1 addition & 2 deletions periodic_job.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"time"

"github.com/riverqueue/river/internal/maintenance"
"github.com/riverqueue/river/riverdriver"
"github.com/riverqueue/river/rivershared/util/sliceutil"
"github.com/riverqueue/river/rivertype"
)
Expand Down Expand Up @@ -180,7 +179,7 @@ func (b *PeriodicJobBundle) toInternal(periodicJob *PeriodicJob) *maintenance.Pe
opts = periodicJob.opts
}
return &maintenance.PeriodicJob{
ConstructorFunc: func() (*riverdriver.JobInsertFastParams, error) {
ConstructorFunc: func() (*rivertype.JobInsertParams, error) {
args, options := periodicJob.constructorFunc()
if args == nil {
return nil, maintenance.ErrNoJobToInsert
Expand Down
6 changes: 4 additions & 2 deletions producer.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,9 @@ type producerConfig struct {
// LISTEN/NOTIFY, but this provides a fallback.
FetchPollInterval time.Duration

JobTimeout time.Duration
MaxWorkers int
JobMiddleware []rivertype.JobMiddleware
JobTimeout time.Duration
MaxWorkers int

// Notifier is a notifier for subscribing to new job inserts and job
// control. If nil, the producer will operate in poll-only mode.
Expand Down Expand Up @@ -579,6 +580,7 @@ func (p *producer) startNewExecutors(workCtx context.Context, jobs []*rivertype.
Completer: p.completer,
ErrorHandler: p.errorHandler,
InformProducerDoneFunc: p.handleWorkerDone,
JobMiddleware: p.config.JobMiddleware,
JobRow: job,
SchedulerInterval: p.config.SchedulerInterval,
WorkUnit: workUnit,
Expand Down
Loading

0 comments on commit e495633

Please sign in to comment.