Skip to content

Commit

Permalink
split worker and insert middleware, dynamic worker middleware
Browse files Browse the repository at this point in the history
  • Loading branch information
bgentry committed Oct 5, 2024
1 parent aff18c8 commit a3ca8c1
Show file tree
Hide file tree
Showing 11 changed files with 155 additions and 55 deletions.
23 changes: 14 additions & 9 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,9 @@ type Config struct {
// deployments.
JobCleanerTimeout time.Duration

// JobMiddleware are optional functions that can be called around different
// parts of each job's lifecycle.
JobMiddleware []rivertype.JobMiddleware
// JobInsertMiddleware are optional functions that can be called around job
// insertion.
JobInsertMiddleware []rivertype.JobInsertMiddleware

// JobTimeout is the maximum amount of time a job is allowed to run before its
// context is cancelled. A timeout of zero means JobTimeoutDefault will be
Expand Down Expand Up @@ -239,6 +239,10 @@ type Config struct {
// (i.e. That it wasn't forgotten by accident.)
Workers *Workers

// WorkerMiddleware are optional functions that can be called around
// all job executions.
WorkerMiddleware []rivertype.WorkerMiddleware

// Scheduler run interval. Shared between the scheduler and producer/job
// executors, but not currently exposed for configuration.
schedulerInterval time.Duration
Expand Down Expand Up @@ -471,7 +475,7 @@ func NewClient[TTx any](driver riverdriver.Driver[TTx], config *Config) (*Client
FetchCooldown: valutil.ValOrDefault(config.FetchCooldown, FetchCooldownDefault),
FetchPollInterval: valutil.ValOrDefault(config.FetchPollInterval, FetchPollIntervalDefault),
ID: valutil.ValOrDefaultFunc(config.ID, func() string { return defaultClientID(time.Now().UTC()) }),
JobMiddleware: config.JobMiddleware,
JobInsertMiddleware: config.JobInsertMiddleware,
JobTimeout: valutil.ValOrDefault(config.JobTimeout, JobTimeoutDefault),
Logger: logger,
MaxAttempts: valutil.ValOrDefault(config.MaxAttempts, MaxAttemptsDefault),
Expand All @@ -483,6 +487,7 @@ func NewClient[TTx any](driver riverdriver.Driver[TTx], config *Config) (*Client
RetryPolicy: retryPolicy,
TestOnly: config.TestOnly,
Workers: config.Workers,
WorkerMiddleware: config.WorkerMiddleware,
schedulerInterval: valutil.ValOrDefault(config.schedulerInterval, maintenance.JobSchedulerIntervalDefault),
time: config.time,
}
Expand Down Expand Up @@ -1470,12 +1475,12 @@ func (c *Client[TTx]) insertManyShared(
return results, nil
}

if len(c.config.JobMiddleware) > 0 {
if len(c.config.JobInsertMiddleware) > 0 {
// Wrap middlewares in reverse order so the one defined first is wrapped
// as the outermost function and is first to receive the operation.
for i := len(c.config.JobMiddleware) - 1; i >= 0; i-- {
middlewareItem := c.config.JobMiddleware[i] // capture the current middleware item
previousDoInner := doInner // Capture the current doInner function
for i := len(c.config.JobInsertMiddleware) - 1; i >= 0; i-- {
middlewareItem := c.config.JobInsertMiddleware[i] // capture the current middleware item
previousDoInner := doInner // Capture the current doInner function
doInner = func(ctx context.Context) ([]*rivertype.JobInsertResult, error) {
return middlewareItem.InsertMany(ctx, insertParams, previousDoInner)
}
Expand Down Expand Up @@ -1689,7 +1694,7 @@ func (c *Client[TTx]) addProducer(queueName string, queueConfig QueueConfig) *pr
ErrorHandler: c.config.ErrorHandler,
FetchCooldown: c.config.FetchCooldown,
FetchPollInterval: c.config.FetchPollInterval,
JobMiddleware: c.config.JobMiddleware,
GlobalMiddleware: c.config.WorkerMiddleware,
JobTimeout: c.config.JobTimeout,
MaxWorkers: queueConfig.MaxWorkers,
Notifier: c.notifier,
Expand Down
89 changes: 78 additions & 11 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -590,24 +590,25 @@ func Test_Client(t *testing.T) {
require.Equal(t, `relation "river_job" does not exist`, pgErr.Message)
})

t.Run("WithJobMiddleware", func(t *testing.T) {
t.Run("WithWorkerMiddleware", func(t *testing.T) {
t.Parallel()

_, bundle := setup(t)
middlewareCalled := false

type privateKey string

middleware := &overridableJobMiddleware{
workFunc: func(ctx context.Context, job *rivertype.JobRow, doInner func(ctx context.Context) error) error {
//lint:ignore SA1029 plain string context key is fine for a test
ctx = context.WithValue(ctx, "middleware", "called")
ctx = context.WithValue(ctx, privateKey("middleware"), "called")
middlewareCalled = true
return doInner(ctx)
},
}
bundle.config.JobMiddleware = []rivertype.JobMiddleware{middleware}
bundle.config.WorkerMiddleware = []rivertype.WorkerMiddleware{middleware}

AddWorker(bundle.config.Workers, WorkFunc(func(ctx context.Context, job *Job[callbackArgs]) error {
require.Equal(t, "called", (ctx.Value("middleware").(string)))
require.Equal(t, "called", ctx.Value(privateKey("middleware")))
return nil
}))

Expand All @@ -627,6 +628,52 @@ func Test_Client(t *testing.T) {
require.True(t, middlewareCalled)
})

t.Run("WithWorkerMiddlewareOnWorker", func(t *testing.T) {
t.Parallel()

_, bundle := setup(t)
middlewareCalled := false

type privateKey string

worker := &workerWithMiddleware[callbackArgs]{
workFunc: func(ctx context.Context, job *Job[callbackArgs]) error {
require.Equal(t, "called", ctx.Value(privateKey("middleware")))
return nil
},
middlewareFunc: func(job *Job[callbackArgs]) []rivertype.WorkerMiddleware {
require.Equal(t, "middleware_test", job.Args.Name, "JSON should be decoded before middleware is called")

return []rivertype.WorkerMiddleware{
&overridableJobMiddleware{
workFunc: func(ctx context.Context, job *rivertype.JobRow, doInner func(ctx context.Context) error) error {
ctx = context.WithValue(ctx, privateKey("middleware"), "called")
middlewareCalled = true
return doInner(ctx)
},
},
}
},
}

AddWorker(bundle.config.Workers, worker)

driver := riverpgxv5.New(bundle.dbPool)
client, err := NewClient(driver, bundle.config)
require.NoError(t, err)

subscribeChan := subscribe(t, client)
startClient(ctx, t, client)

result, err := client.Insert(ctx, callbackArgs{Name: "middleware_test"}, nil)
require.NoError(t, err)

event := riversharedtest.WaitOrTimeout(t, subscribeChan)
require.Equal(t, EventKindJobCompleted, event.Kind)
require.Equal(t, result.Job.ID, event.Job.ID)
require.True(t, middlewareCalled)
})

t.Run("PauseAndResumeSingleQueue", func(t *testing.T) {
t.Parallel()

Expand Down Expand Up @@ -873,6 +920,20 @@ func Test_Client(t *testing.T) {
})
}

type workerWithMiddleware[T JobArgs] struct {
WorkerDefaults[T]
workFunc func(context.Context, *Job[T]) error
middlewareFunc func(*Job[T]) []rivertype.WorkerMiddleware
}

func (w *workerWithMiddleware[T]) Work(ctx context.Context, job *Job[T]) error {
return w.workFunc(ctx, job)
}

func (w *workerWithMiddleware[T]) Middleware(job *Job[T]) []rivertype.WorkerMiddleware {
return w.middlewareFunc(job)
}

func Test_Client_Stop(t *testing.T) {
t.Parallel()

Expand Down Expand Up @@ -2458,7 +2519,7 @@ func Test_Client_InsertManyTx(t *testing.T) {
require.Len(t, results, 1)
})

t.Run("WithJobMiddleware", func(t *testing.T) {
t.Run("WithJobInsertMiddleware", func(t *testing.T) {
t.Parallel()

_, bundle := setup(t)
Expand All @@ -2484,7 +2545,7 @@ func Test_Client_InsertManyTx(t *testing.T) {
},
}

config.JobMiddleware = []rivertype.JobMiddleware{middleware}
config.JobInsertMiddleware = []rivertype.JobInsertMiddleware{middleware}
driver := riverpgxv5.New(nil)
client, err := NewClient(driver, config)
require.NoError(t, err)
Expand Down Expand Up @@ -4850,8 +4911,12 @@ func Test_NewClient_Overrides(t *testing.T) {

retryPolicy := &DefaultClientRetryPolicy{}

type noOpMiddleware struct {
JobMiddlewareDefaults
type noOpInsertMiddleware struct {
JobInsertMiddlewareDefaults
}

type noOpWorkerMiddleware struct {
WorkerMiddlewareDefaults
}

client, err := NewClient(riverpgxv5.New(dbPool), &Config{
Expand All @@ -4862,14 +4927,15 @@ func Test_NewClient_Overrides(t *testing.T) {
ErrorHandler: errorHandler,
FetchCooldown: 123 * time.Millisecond,
FetchPollInterval: 124 * time.Millisecond,
JobMiddleware: []rivertype.JobMiddleware{&noOpMiddleware{}},
JobInsertMiddleware: []rivertype.JobInsertMiddleware{&noOpInsertMiddleware{}},
JobTimeout: 125 * time.Millisecond,
Logger: logger,
MaxAttempts: 5,
Queues: map[string]QueueConfig{QueueDefault: {MaxWorkers: 1}},
RetryPolicy: retryPolicy,
TestOnly: true, // disables staggered start in maintenance services
Workers: workers,
WorkerMiddleware: []rivertype.WorkerMiddleware{&noOpWorkerMiddleware{}},
})
require.NoError(t, err)

Expand All @@ -4888,11 +4954,12 @@ func Test_NewClient_Overrides(t *testing.T) {
require.Equal(t, errorHandler, client.config.ErrorHandler)
require.Equal(t, 123*time.Millisecond, client.config.FetchCooldown)
require.Equal(t, 124*time.Millisecond, client.config.FetchPollInterval)
require.Len(t, client.config.JobMiddleware, 1)
require.Len(t, client.config.JobInsertMiddleware, 1)
require.Equal(t, 125*time.Millisecond, client.config.JobTimeout)
require.Equal(t, logger, client.baseService.Logger)
require.Equal(t, 5, client.config.MaxAttempts)
require.Equal(t, retryPolicy, client.config.RetryPolicy)
require.Len(t, client.config.WorkerMiddleware, 1)
}

func Test_NewClient_MissingParameters(t *testing.T) {
Expand Down
9 changes: 5 additions & 4 deletions internal/maintenance/job_rescuer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,11 @@ type callbackWorkUnit struct {
timeout time.Duration // defaults to 0, which signals default timeout
}

func (w *callbackWorkUnit) NextRetry() time.Time { return time.Now().Add(30 * time.Second) }
func (w *callbackWorkUnit) Timeout() time.Duration { return w.timeout }
func (w *callbackWorkUnit) Work(ctx context.Context) error { return w.callback(ctx, w.jobRow) }
func (w *callbackWorkUnit) UnmarshalJob() error { return nil }
func (w *callbackWorkUnit) Middleware() []rivertype.WorkerMiddleware { return nil }
func (w *callbackWorkUnit) NextRetry() time.Time { return time.Now().Add(30 * time.Second) }
func (w *callbackWorkUnit) Timeout() time.Duration { return w.timeout }
func (w *callbackWorkUnit) Work(ctx context.Context) error { return w.callback(ctx, w.jobRow) }
func (w *callbackWorkUnit) UnmarshalJob() error { return nil }

type SimpleClientRetryPolicy struct{}

Expand Down
1 change: 1 addition & 0 deletions internal/workunit/work_unit.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
//
// Implemented by river.wrapperWorkUnit.
type WorkUnit interface {
Middleware() []rivertype.WorkerMiddleware
NextRetry() time.Time
Timeout() time.Duration
UnmarshalJob() error
Expand Down
24 changes: 15 additions & 9 deletions job_executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ type jobExecutor struct {
ErrorHandler ErrorHandler
InformProducerDoneFunc func(jobRow *rivertype.JobRow)
JobRow *rivertype.JobRow
JobMiddleware []rivertype.JobMiddleware
GlobalMiddleware []rivertype.WorkerMiddleware
SchedulerInterval time.Duration
WorkUnit workunit.WorkUnit

Expand Down Expand Up @@ -194,11 +194,13 @@ func (e *jobExecutor) execute(ctx context.Context) (res *jobExecutorResult) {
return &jobExecutorResult{Err: &UnknownJobKindError{Kind: e.JobRow.Kind}}
}

doInner := func(ctx context.Context) error {
if err := e.WorkUnit.UnmarshalJob(); err != nil {
return err
}
if err := e.WorkUnit.UnmarshalJob(); err != nil {
return &jobExecutorResult{Err: err}
}

workerMiddleware := e.WorkUnit.Middleware()

doInner := func(ctx context.Context) error {
jobTimeout := e.WorkUnit.Timeout()
if jobTimeout == 0 {
jobTimeout = e.ClientJobTimeout
Expand All @@ -218,12 +220,16 @@ func (e *jobExecutor) execute(ctx context.Context) (res *jobExecutorResult) {
return nil
}

if len(e.JobMiddleware) > 0 {
allMiddleware := make([]rivertype.WorkerMiddleware, 0, len(e.GlobalMiddleware)+len(workerMiddleware))
allMiddleware = append(allMiddleware, e.GlobalMiddleware...)
allMiddleware = append(allMiddleware, workerMiddleware...)

if len(allMiddleware) > 0 {
// Wrap middlewares in reverse order so the one defined first is wrapped
// as the outermost function and is first to receive the operation.
for i := len(e.JobMiddleware) - 1; i >= 0; i-- {
middlewareItem := e.JobMiddleware[i] // capture the current middleware item
previousDoInner := doInner // Capture the current doInner function
for i := len(allMiddleware) - 1; i >= 0; i-- {
middlewareItem := allMiddleware[i] // capture the current middleware item
previousDoInner := doInner // Capture the current doInner function
doInner = func(ctx context.Context) error {
return middlewareItem.Work(ctx, e.JobRow, previousDoInner)
}
Expand Down
14 changes: 8 additions & 6 deletions middleware_defaults.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,18 @@ import (
"github.com/riverqueue/river/rivertype"
)

// JobMiddlewareDefaults is an embeddable struct that provides default
// implementations for the rivertype.JobMiddleware. Use of this struct is
// recommended in case rivertype.JobMiddleware is expanded in the future so that
// JobInsertMiddlewareDefaults is an embeddable struct that provides default
// implementations for the rivertype.JobInsertMiddleware. Use of this struct is
// recommended in case rivertype.JobInsertMiddleware is expanded in the future so that
// existing code isn't unexpectedly broken during an upgrade.
type JobMiddlewareDefaults struct{}
type JobInsertMiddlewareDefaults struct{}

func (l *JobMiddlewareDefaults) InsertMany(ctx context.Context, manyParams []*rivertype.JobInsertParams, doInner func(ctx context.Context) ([]*rivertype.JobInsertResult, error)) ([]*rivertype.JobInsertResult, error) {
func (d *JobInsertMiddlewareDefaults) InsertMany(ctx context.Context, manyParams []*rivertype.JobInsertParams, doInner func(ctx context.Context) ([]*rivertype.JobInsertResult, error)) ([]*rivertype.JobInsertResult, error) {
return doInner(ctx)
}

func (l *JobMiddlewareDefaults) Work(ctx context.Context, job *rivertype.JobRow, doInner func(ctx context.Context) error) error {
type WorkerMiddlewareDefaults struct{}

func (d *WorkerMiddlewareDefaults) Work(ctx context.Context, job *rivertype.JobRow, doInner func(ctx context.Context) error) error {
return doInner(ctx)
}
5 changes: 4 additions & 1 deletion middleware_defaults_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,7 @@ package river

import "github.com/riverqueue/river/rivertype"

var _ rivertype.JobMiddleware = &JobMiddlewareDefaults{}
var (
_ rivertype.JobInsertMiddleware = &JobInsertMiddlewareDefaults{}
_ rivertype.WorkerMiddleware = &WorkerMiddlewareDefaults{}
)
7 changes: 4 additions & 3 deletions middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ import (
)

type overridableJobMiddleware struct {
JobMiddlewareDefaults
JobInsertMiddlewareDefaults
WorkerMiddlewareDefaults

insertManyFunc func(ctx context.Context, manyParams []*rivertype.JobInsertParams, doInner func(ctx context.Context) ([]*rivertype.JobInsertResult, error)) ([]*rivertype.JobInsertResult, error)
workFunc func(ctx context.Context, job *rivertype.JobRow, doInner func(ctx context.Context) error) error
Expand All @@ -17,12 +18,12 @@ func (m *overridableJobMiddleware) InsertMany(ctx context.Context, manyParams []
if m.insertManyFunc != nil {
return m.insertManyFunc(ctx, manyParams, doInner)
}
return m.JobMiddlewareDefaults.InsertMany(ctx, manyParams, doInner)
return m.JobInsertMiddlewareDefaults.InsertMany(ctx, manyParams, doInner)
}

func (m *overridableJobMiddleware) Work(ctx context.Context, job *rivertype.JobRow, doInner func(ctx context.Context) error) error {
if m.workFunc != nil {
return m.workFunc(ctx, job, doInner)
}
return m.JobMiddlewareDefaults.Work(ctx, job, doInner)
return m.WorkerMiddlewareDefaults.Work(ctx, job, doInner)
}
8 changes: 4 additions & 4 deletions producer.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,9 @@ type producerConfig struct {
// LISTEN/NOTIFY, but this provides a fallback.
FetchPollInterval time.Duration

JobMiddleware []rivertype.JobMiddleware
JobTimeout time.Duration
MaxWorkers int
GlobalMiddleware []rivertype.WorkerMiddleware
JobTimeout time.Duration
MaxWorkers int

// Notifier is a notifier for subscribing to new job inserts and job
// control. If nil, the producer will operate in poll-only mode.
Expand Down Expand Up @@ -580,7 +580,7 @@ func (p *producer) startNewExecutors(workCtx context.Context, jobs []*rivertype.
Completer: p.completer,
ErrorHandler: p.errorHandler,
InformProducerDoneFunc: p.handleWorkerDone,
JobMiddleware: p.config.JobMiddleware,
GlobalMiddleware: p.config.GlobalMiddleware,
JobRow: job,
SchedulerInterval: p.config.SchedulerInterval,
WorkUnit: workUnit,
Expand Down
Loading

0 comments on commit a3ca8c1

Please sign in to comment.