diff --git a/client.go b/client.go index b78ea617..6e4e3a0c 100644 --- a/client.go +++ b/client.go @@ -139,9 +139,9 @@ type Config struct { // deployments. JobCleanerTimeout time.Duration - // JobMiddleware are optional functions that can be called around different - // parts of each job's lifecycle. - JobMiddleware []rivertype.JobMiddleware + // JobInsertMiddleware are optional functions that can be called around job + // insertion. + JobInsertMiddleware []rivertype.JobInsertMiddleware // JobTimeout is the maximum amount of time a job is allowed to run before its // context is cancelled. A timeout of zero means JobTimeoutDefault will be @@ -239,6 +239,10 @@ type Config struct { // (i.e. That it wasn't forgotten by accident.) Workers *Workers + // WorkerMiddleware are optional functions that can be called around + // all job executions. + WorkerMiddleware []rivertype.WorkerMiddleware + // Scheduler run interval. Shared between the scheduler and producer/job // executors, but not currently exposed for configuration. schedulerInterval time.Duration @@ -471,7 +475,7 @@ func NewClient[TTx any](driver riverdriver.Driver[TTx], config *Config) (*Client FetchCooldown: valutil.ValOrDefault(config.FetchCooldown, FetchCooldownDefault), FetchPollInterval: valutil.ValOrDefault(config.FetchPollInterval, FetchPollIntervalDefault), ID: valutil.ValOrDefaultFunc(config.ID, func() string { return defaultClientID(time.Now().UTC()) }), - JobMiddleware: config.JobMiddleware, + JobInsertMiddleware: config.JobInsertMiddleware, JobTimeout: valutil.ValOrDefault(config.JobTimeout, JobTimeoutDefault), Logger: logger, MaxAttempts: valutil.ValOrDefault(config.MaxAttempts, MaxAttemptsDefault), @@ -483,6 +487,7 @@ func NewClient[TTx any](driver riverdriver.Driver[TTx], config *Config) (*Client RetryPolicy: retryPolicy, TestOnly: config.TestOnly, Workers: config.Workers, + WorkerMiddleware: config.WorkerMiddleware, schedulerInterval: valutil.ValOrDefault(config.schedulerInterval, maintenance.JobSchedulerIntervalDefault), time: config.time, } @@ -1470,12 +1475,12 @@ func (c *Client[TTx]) insertManyShared( return results, nil } - if len(c.config.JobMiddleware) > 0 { + if len(c.config.JobInsertMiddleware) > 0 { // Wrap middlewares in reverse order so the one defined first is wrapped // as the outermost function and is first to receive the operation. - for i := len(c.config.JobMiddleware) - 1; i >= 0; i-- { - middlewareItem := c.config.JobMiddleware[i] // capture the current middleware item - previousDoInner := doInner // Capture the current doInner function + for i := len(c.config.JobInsertMiddleware) - 1; i >= 0; i-- { + middlewareItem := c.config.JobInsertMiddleware[i] // capture the current middleware item + previousDoInner := doInner // Capture the current doInner function doInner = func(ctx context.Context) ([]*rivertype.JobInsertResult, error) { return middlewareItem.InsertMany(ctx, insertParams, previousDoInner) } @@ -1689,7 +1694,7 @@ func (c *Client[TTx]) addProducer(queueName string, queueConfig QueueConfig) *pr ErrorHandler: c.config.ErrorHandler, FetchCooldown: c.config.FetchCooldown, FetchPollInterval: c.config.FetchPollInterval, - JobMiddleware: c.config.JobMiddleware, + GlobalMiddleware: c.config.WorkerMiddleware, JobTimeout: c.config.JobTimeout, MaxWorkers: queueConfig.MaxWorkers, Notifier: c.notifier, diff --git a/client_test.go b/client_test.go index dee5fae0..84207bed 100644 --- a/client_test.go +++ b/client_test.go @@ -590,24 +590,25 @@ func Test_Client(t *testing.T) { require.Equal(t, `relation "river_job" does not exist`, pgErr.Message) }) - t.Run("WithJobMiddleware", func(t *testing.T) { + t.Run("WithWorkerMiddleware", func(t *testing.T) { t.Parallel() _, bundle := setup(t) middlewareCalled := false + type privateKey string + middleware := &overridableJobMiddleware{ workFunc: func(ctx context.Context, job *rivertype.JobRow, doInner func(ctx context.Context) error) error { - //lint:ignore SA1029 plain string context key is fine for a test - ctx = context.WithValue(ctx, "middleware", "called") + ctx = context.WithValue(ctx, privateKey("middleware"), "called") middlewareCalled = true return doInner(ctx) }, } - bundle.config.JobMiddleware = []rivertype.JobMiddleware{middleware} + bundle.config.WorkerMiddleware = []rivertype.WorkerMiddleware{middleware} AddWorker(bundle.config.Workers, WorkFunc(func(ctx context.Context, job *Job[callbackArgs]) error { - require.Equal(t, "called", (ctx.Value("middleware").(string))) + require.Equal(t, "called", ctx.Value(privateKey("middleware"))) return nil })) @@ -627,6 +628,52 @@ func Test_Client(t *testing.T) { require.True(t, middlewareCalled) }) + t.Run("WithWorkerMiddlewareOnWorker", func(t *testing.T) { + t.Parallel() + + _, bundle := setup(t) + middlewareCalled := false + + type privateKey string + + worker := &workerWithMiddleware[callbackArgs]{ + workFunc: func(ctx context.Context, job *Job[callbackArgs]) error { + require.Equal(t, "called", ctx.Value(privateKey("middleware"))) + return nil + }, + middlewareFunc: func(job *Job[callbackArgs]) []rivertype.WorkerMiddleware { + require.Equal(t, "middleware_test", job.Args.Name, "JSON should be decoded before middleware is called") + + return []rivertype.WorkerMiddleware{ + &overridableJobMiddleware{ + workFunc: func(ctx context.Context, job *rivertype.JobRow, doInner func(ctx context.Context) error) error { + ctx = context.WithValue(ctx, privateKey("middleware"), "called") + middlewareCalled = true + return doInner(ctx) + }, + }, + } + }, + } + + AddWorker(bundle.config.Workers, worker) + + driver := riverpgxv5.New(bundle.dbPool) + client, err := NewClient(driver, bundle.config) + require.NoError(t, err) + + subscribeChan := subscribe(t, client) + startClient(ctx, t, client) + + result, err := client.Insert(ctx, callbackArgs{Name: "middleware_test"}, nil) + require.NoError(t, err) + + event := riversharedtest.WaitOrTimeout(t, subscribeChan) + require.Equal(t, EventKindJobCompleted, event.Kind) + require.Equal(t, result.Job.ID, event.Job.ID) + require.True(t, middlewareCalled) + }) + t.Run("PauseAndResumeSingleQueue", func(t *testing.T) { t.Parallel() @@ -873,6 +920,20 @@ func Test_Client(t *testing.T) { }) } +type workerWithMiddleware[T JobArgs] struct { + WorkerDefaults[T] + workFunc func(context.Context, *Job[T]) error + middlewareFunc func(*Job[T]) []rivertype.WorkerMiddleware +} + +func (w *workerWithMiddleware[T]) Work(ctx context.Context, job *Job[T]) error { + return w.workFunc(ctx, job) +} + +func (w *workerWithMiddleware[T]) Middleware(job *Job[T]) []rivertype.WorkerMiddleware { + return w.middlewareFunc(job) +} + func Test_Client_Stop(t *testing.T) { t.Parallel() @@ -2458,7 +2519,7 @@ func Test_Client_InsertManyTx(t *testing.T) { require.Len(t, results, 1) }) - t.Run("WithJobMiddleware", func(t *testing.T) { + t.Run("WithJobInsertMiddleware", func(t *testing.T) { t.Parallel() _, bundle := setup(t) @@ -2484,7 +2545,7 @@ func Test_Client_InsertManyTx(t *testing.T) { }, } - config.JobMiddleware = []rivertype.JobMiddleware{middleware} + config.JobInsertMiddleware = []rivertype.JobInsertMiddleware{middleware} driver := riverpgxv5.New(nil) client, err := NewClient(driver, config) require.NoError(t, err) @@ -4850,8 +4911,12 @@ func Test_NewClient_Overrides(t *testing.T) { retryPolicy := &DefaultClientRetryPolicy{} - type noOpMiddleware struct { - JobMiddlewareDefaults + type noOpInsertMiddleware struct { + JobInsertMiddlewareDefaults + } + + type noOpWorkerMiddleware struct { + WorkerMiddlewareDefaults } client, err := NewClient(riverpgxv5.New(dbPool), &Config{ @@ -4862,7 +4927,7 @@ func Test_NewClient_Overrides(t *testing.T) { ErrorHandler: errorHandler, FetchCooldown: 123 * time.Millisecond, FetchPollInterval: 124 * time.Millisecond, - JobMiddleware: []rivertype.JobMiddleware{&noOpMiddleware{}}, + JobInsertMiddleware: []rivertype.JobInsertMiddleware{&noOpInsertMiddleware{}}, JobTimeout: 125 * time.Millisecond, Logger: logger, MaxAttempts: 5, @@ -4870,6 +4935,7 @@ func Test_NewClient_Overrides(t *testing.T) { RetryPolicy: retryPolicy, TestOnly: true, // disables staggered start in maintenance services Workers: workers, + WorkerMiddleware: []rivertype.WorkerMiddleware{&noOpWorkerMiddleware{}}, }) require.NoError(t, err) @@ -4888,11 +4954,12 @@ func Test_NewClient_Overrides(t *testing.T) { require.Equal(t, errorHandler, client.config.ErrorHandler) require.Equal(t, 123*time.Millisecond, client.config.FetchCooldown) require.Equal(t, 124*time.Millisecond, client.config.FetchPollInterval) - require.Len(t, client.config.JobMiddleware, 1) + require.Len(t, client.config.JobInsertMiddleware, 1) require.Equal(t, 125*time.Millisecond, client.config.JobTimeout) require.Equal(t, logger, client.baseService.Logger) require.Equal(t, 5, client.config.MaxAttempts) require.Equal(t, retryPolicy, client.config.RetryPolicy) + require.Len(t, client.config.WorkerMiddleware, 1) } func Test_NewClient_MissingParameters(t *testing.T) { diff --git a/internal/maintenance/job_rescuer_test.go b/internal/maintenance/job_rescuer_test.go index ec938875..f23941af 100644 --- a/internal/maintenance/job_rescuer_test.go +++ b/internal/maintenance/job_rescuer_test.go @@ -38,10 +38,11 @@ type callbackWorkUnit struct { timeout time.Duration // defaults to 0, which signals default timeout } -func (w *callbackWorkUnit) NextRetry() time.Time { return time.Now().Add(30 * time.Second) } -func (w *callbackWorkUnit) Timeout() time.Duration { return w.timeout } -func (w *callbackWorkUnit) Work(ctx context.Context) error { return w.callback(ctx, w.jobRow) } -func (w *callbackWorkUnit) UnmarshalJob() error { return nil } +func (w *callbackWorkUnit) Middleware() []rivertype.WorkerMiddleware { return nil } +func (w *callbackWorkUnit) NextRetry() time.Time { return time.Now().Add(30 * time.Second) } +func (w *callbackWorkUnit) Timeout() time.Duration { return w.timeout } +func (w *callbackWorkUnit) Work(ctx context.Context) error { return w.callback(ctx, w.jobRow) } +func (w *callbackWorkUnit) UnmarshalJob() error { return nil } type SimpleClientRetryPolicy struct{} diff --git a/internal/workunit/work_unit.go b/internal/workunit/work_unit.go index 746454f2..fc1f541d 100644 --- a/internal/workunit/work_unit.go +++ b/internal/workunit/work_unit.go @@ -15,6 +15,7 @@ import ( // // Implemented by river.wrapperWorkUnit. type WorkUnit interface { + Middleware() []rivertype.WorkerMiddleware NextRetry() time.Time Timeout() time.Duration UnmarshalJob() error diff --git a/job_executor.go b/job_executor.go index 0d13f973..ce730edc 100644 --- a/job_executor.go +++ b/job_executor.go @@ -131,7 +131,7 @@ type jobExecutor struct { ErrorHandler ErrorHandler InformProducerDoneFunc func(jobRow *rivertype.JobRow) JobRow *rivertype.JobRow - JobMiddleware []rivertype.JobMiddleware + GlobalMiddleware []rivertype.WorkerMiddleware SchedulerInterval time.Duration WorkUnit workunit.WorkUnit @@ -194,11 +194,13 @@ func (e *jobExecutor) execute(ctx context.Context) (res *jobExecutorResult) { return &jobExecutorResult{Err: &UnknownJobKindError{Kind: e.JobRow.Kind}} } - doInner := func(ctx context.Context) error { - if err := e.WorkUnit.UnmarshalJob(); err != nil { - return err - } + if err := e.WorkUnit.UnmarshalJob(); err != nil { + return &jobExecutorResult{Err: err} + } + workerMiddleware := e.WorkUnit.Middleware() + + doInner := func(ctx context.Context) error { jobTimeout := e.WorkUnit.Timeout() if jobTimeout == 0 { jobTimeout = e.ClientJobTimeout @@ -218,12 +220,16 @@ func (e *jobExecutor) execute(ctx context.Context) (res *jobExecutorResult) { return nil } - if len(e.JobMiddleware) > 0 { + allMiddleware := make([]rivertype.WorkerMiddleware, 0, len(e.GlobalMiddleware)+len(workerMiddleware)) + allMiddleware = append(allMiddleware, e.GlobalMiddleware...) + allMiddleware = append(allMiddleware, workerMiddleware...) + + if len(allMiddleware) > 0 { // Wrap middlewares in reverse order so the one defined first is wrapped // as the outermost function and is first to receive the operation. - for i := len(e.JobMiddleware) - 1; i >= 0; i-- { - middlewareItem := e.JobMiddleware[i] // capture the current middleware item - previousDoInner := doInner // Capture the current doInner function + for i := len(allMiddleware) - 1; i >= 0; i-- { + middlewareItem := allMiddleware[i] // capture the current middleware item + previousDoInner := doInner // Capture the current doInner function doInner = func(ctx context.Context) error { return middlewareItem.Work(ctx, e.JobRow, previousDoInner) } diff --git a/middleware_defaults.go b/middleware_defaults.go index c0a60a54..3a9195cb 100644 --- a/middleware_defaults.go +++ b/middleware_defaults.go @@ -6,16 +6,18 @@ import ( "github.com/riverqueue/river/rivertype" ) -// JobMiddlewareDefaults is an embeddable struct that provides default -// implementations for the rivertype.JobMiddleware. Use of this struct is -// recommended in case rivertype.JobMiddleware is expanded in the future so that +// JobInsertMiddlewareDefaults is an embeddable struct that provides default +// implementations for the rivertype.JobInsertMiddleware. Use of this struct is +// recommended in case rivertype.JobInsertMiddleware is expanded in the future so that // existing code isn't unexpectedly broken during an upgrade. -type JobMiddlewareDefaults struct{} +type JobInsertMiddlewareDefaults struct{} -func (l *JobMiddlewareDefaults) InsertMany(ctx context.Context, manyParams []*rivertype.JobInsertParams, doInner func(ctx context.Context) ([]*rivertype.JobInsertResult, error)) ([]*rivertype.JobInsertResult, error) { +func (d *JobInsertMiddlewareDefaults) InsertMany(ctx context.Context, manyParams []*rivertype.JobInsertParams, doInner func(ctx context.Context) ([]*rivertype.JobInsertResult, error)) ([]*rivertype.JobInsertResult, error) { return doInner(ctx) } -func (l *JobMiddlewareDefaults) Work(ctx context.Context, job *rivertype.JobRow, doInner func(ctx context.Context) error) error { +type WorkerMiddlewareDefaults struct{} + +func (d *WorkerMiddlewareDefaults) Work(ctx context.Context, job *rivertype.JobRow, doInner func(ctx context.Context) error) error { return doInner(ctx) } diff --git a/middleware_defaults_test.go b/middleware_defaults_test.go index 323e0415..f1520dac 100644 --- a/middleware_defaults_test.go +++ b/middleware_defaults_test.go @@ -2,4 +2,7 @@ package river import "github.com/riverqueue/river/rivertype" -var _ rivertype.JobMiddleware = &JobMiddlewareDefaults{} +var ( + _ rivertype.JobInsertMiddleware = &JobInsertMiddlewareDefaults{} + _ rivertype.WorkerMiddleware = &WorkerMiddlewareDefaults{} +) diff --git a/middleware_test.go b/middleware_test.go index 955c08d2..4e9b17dc 100644 --- a/middleware_test.go +++ b/middleware_test.go @@ -7,7 +7,8 @@ import ( ) type overridableJobMiddleware struct { - JobMiddlewareDefaults + JobInsertMiddlewareDefaults + WorkerMiddlewareDefaults insertManyFunc func(ctx context.Context, manyParams []*rivertype.JobInsertParams, doInner func(ctx context.Context) ([]*rivertype.JobInsertResult, error)) ([]*rivertype.JobInsertResult, error) workFunc func(ctx context.Context, job *rivertype.JobRow, doInner func(ctx context.Context) error) error @@ -17,12 +18,12 @@ func (m *overridableJobMiddleware) InsertMany(ctx context.Context, manyParams [] if m.insertManyFunc != nil { return m.insertManyFunc(ctx, manyParams, doInner) } - return m.JobMiddlewareDefaults.InsertMany(ctx, manyParams, doInner) + return m.JobInsertMiddlewareDefaults.InsertMany(ctx, manyParams, doInner) } func (m *overridableJobMiddleware) Work(ctx context.Context, job *rivertype.JobRow, doInner func(ctx context.Context) error) error { if m.workFunc != nil { return m.workFunc(ctx, job, doInner) } - return m.JobMiddlewareDefaults.Work(ctx, job, doInner) + return m.WorkerMiddlewareDefaults.Work(ctx, job, doInner) } diff --git a/producer.go b/producer.go index 8d0d0e6c..6d6392ad 100644 --- a/producer.go +++ b/producer.go @@ -63,9 +63,9 @@ type producerConfig struct { // LISTEN/NOTIFY, but this provides a fallback. FetchPollInterval time.Duration - JobMiddleware []rivertype.JobMiddleware - JobTimeout time.Duration - MaxWorkers int + GlobalMiddleware []rivertype.WorkerMiddleware + JobTimeout time.Duration + MaxWorkers int // Notifier is a notifier for subscribing to new job inserts and job // control. If nil, the producer will operate in poll-only mode. @@ -580,7 +580,7 @@ func (p *producer) startNewExecutors(workCtx context.Context, jobs []*rivertype. Completer: p.completer, ErrorHandler: p.errorHandler, InformProducerDoneFunc: p.handleWorkerDone, - JobMiddleware: p.config.JobMiddleware, + GlobalMiddleware: p.config.GlobalMiddleware, JobRow: job, SchedulerInterval: p.config.SchedulerInterval, WorkUnit: workUnit, diff --git a/rivertype/river_type.go b/rivertype/river_type.go index 8d545ebb..256531ab 100644 --- a/rivertype/river_type.go +++ b/rivertype/river_type.go @@ -246,29 +246,31 @@ type JobInsertParams struct { UniqueStates byte } -// JobMiddleware provides an interface for middleware that integrations can use -// to encapsulate common logic around various phases of a job's lifecycle. +// JobInsertMiddleware provides an interface for middleware that integrations can +// use to encapsulate common logic around job insertion. // // Implementations should embed river.JobMiddlewareDefaults to inherit default // implementations for phases where no custom code is needed, and for forward // compatibility in case new functions are added to this interface. -type JobMiddleware interface { +type JobInsertMiddleware interface { // InsertMany is invoked around a batch insert operation. Implementations // must always include a call to doInner to call down the middleware stack // and perfom the batch insertion, and may run custom code before and after. // // Returning an error from this function will fail the overarching insert // operation, even if the inner insertion originally succeeded. - InsertMany(ctx context.Context, manyParams []*JobInsertParams, doInner func(ctx context.Context) ([]*JobInsertResult, error)) ([]*JobInsertResult, error) + InsertMany(ctx context.Context, manyParams []*JobInsertParams, doInner func(context.Context) ([]*JobInsertResult, error)) ([]*JobInsertResult, error) +} - // Work is invoked around a job's JSON args being unmarshaled and the job - // worked. Implementations must always include a call to doInner to call - // down the middleware stack and perfom the batch insertion, and may run +type WorkerMiddleware interface { + // Work is invoked after a job's JSON args being unmarshaled and before the + // job is worked. Implementations must always include a call to doInner to + // call down the middleware stack and perfom the batch insertion, and may run // custom code before and after. // // Returning an error from this function will fail the overarching work // operation, even if the inner work originally succeeded. - Work(ctx context.Context, job *JobRow, doInner func(ctx context.Context) error) error + Work(ctx context.Context, job *JobRow, doInner func(context.Context) error) error } // PeriodicJobHandle is a reference to a dynamically added periodic job diff --git a/work_unit_wrapper.go b/work_unit_wrapper.go index 9e741ee7..939426a6 100644 --- a/work_unit_wrapper.go +++ b/work_unit_wrapper.go @@ -29,6 +29,13 @@ func (w *wrapperWorkUnit[T]) NextRetry() time.Time { return w.worker.N func (w *wrapperWorkUnit[T]) Timeout() time.Duration { return w.worker.Timeout(w.job) } func (w *wrapperWorkUnit[T]) Work(ctx context.Context) error { return w.worker.Work(ctx, w.job) } +func (w *wrapperWorkUnit[T]) Middleware() []rivertype.WorkerMiddleware { + if provider, ok := w.worker.(workerMiddlewareProvider[T]); ok { + return provider.Middleware(w.job) + } + return nil +} + func (w *wrapperWorkUnit[T]) UnmarshalJob() error { w.job = &Job[T]{ JobRow: w.jobRow, @@ -36,3 +43,8 @@ func (w *wrapperWorkUnit[T]) UnmarshalJob() error { return json.Unmarshal(w.jobRow.EncodedArgs, &w.job.Args) } + +type workerMiddlewareProvider[T JobArgs] interface { + // Middleware returns the type-specific middleware for this job. + Middleware(job *Job[T]) []rivertype.WorkerMiddleware +}