diff --git a/CHANGELOG.md b/CHANGELOG.md index d2e767ee..5336e06a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Fixed + +- Fix `StopAndCancel` to not hang if called in parallel to an ongoing `Stop` call. [PR #376](https://github.com/riverqueue/river/pull/376). + ## [0.6.1] - 2024-05-21 ### Fixed diff --git a/client.go b/client.go index e0749bb5..698dce53 100644 --- a/client.go +++ b/client.go @@ -798,14 +798,14 @@ func (c *Client[TTx]) Stop(ctx context.Context) error { // no need to call this method if the context passed to Run is cancelled // instead. func (c *Client[TTx]) StopAndCancel(ctx context.Context) error { + c.baseService.Logger.InfoContext(ctx, c.baseService.Name+": Hard stop started; cancelling all work") + c.workCancel(rivercommon.ErrShutdown) + shouldStop, stopped, finalizeStop := c.baseStartStop.StopInit() if !shouldStop { return nil } - c.baseService.Logger.InfoContext(ctx, c.baseService.Name+": Hard stop started; cancelling all work") - c.workCancel(rivercommon.ErrShutdown) - select { case <-ctx.Done(): // stop context cancelled finalizeStop(false) // not stopped; allow Stop to be called again diff --git a/client_test.go b/client_test.go index ce4e47c7..876e10cb 100644 --- a/client_test.go +++ b/client_test.go @@ -644,39 +644,83 @@ func Test_Client(t *testing.T) { t.Run("StopAndCancel", func(t *testing.T) { t.Parallel() - client, _ := setup(t) - jobStartedChan := make(chan int64) - jobDoneChan := make(chan struct{}) - - type JobArgs struct { - JobArgsReflectKind[JobArgs] + type testBundle struct { + jobDoneChan chan struct{} + jobStartedChan chan int64 } - AddWorker(client.config.Workers, WorkFunc(func(ctx context.Context, job *Job[JobArgs]) error { - jobStartedChan <- job.ID - <-ctx.Done() - require.ErrorIs(t, context.Cause(ctx), rivercommon.ErrShutdown) - close(jobDoneChan) - return nil - })) + setupStopAndCancel := func(t *testing.T) (*Client[pgx.Tx], *testBundle) { + t.Helper() - startClient(ctx, t, client) + client, _ := setup(t) + jobStartedChan := make(chan int64) + jobDoneChan := make(chan struct{}) - insertRes, err := client.Insert(ctx, &JobArgs{}, nil) - require.NoError(t, err) + type JobArgs struct { + JobArgsReflectKind[JobArgs] + } - startedJobID := riverinternaltest.WaitOrTimeout(t, jobStartedChan) - require.Equal(t, insertRes.Job.ID, startedJobID) + AddWorker(client.config.Workers, WorkFunc(func(ctx context.Context, job *Job[JobArgs]) error { + jobStartedChan <- job.ID + <-ctx.Done() + require.ErrorIs(t, context.Cause(ctx), rivercommon.ErrShutdown) + close(jobDoneChan) + return nil + })) - select { - case <-client.Stopped(): - t.Fatal("expected client to not be stopped yet") - default: + startClient(ctx, t, client) + + insertRes, err := client.Insert(ctx, &JobArgs{}, nil) + require.NoError(t, err) + + startedJobID := riverinternaltest.WaitOrTimeout(t, jobStartedChan) + require.Equal(t, insertRes.Job.ID, startedJobID) + + select { + case <-client.Stopped(): + t.Fatal("expected client to not be stopped yet") + default: + } + + return client, &testBundle{ + jobDoneChan: jobDoneChan, + jobStartedChan: jobStartedChan, + } } - require.NoError(t, client.StopAndCancel(ctx)) + t.Run("OnItsOwn", func(t *testing.T) { + t.Parallel() + + client, _ := setupStopAndCancel(t) + + require.NoError(t, client.StopAndCancel(ctx)) + riverinternaltest.WaitOrTimeout(t, client.Stopped()) + }) + + t.Run("AfterStop", func(t *testing.T) { + t.Parallel() + + client, bundle := setupStopAndCancel(t) + + go func() { + require.NoError(t, client.Stop(ctx)) + }() + + select { + case <-client.Stopped(): + t.Fatal("expected client to not be stopped yet") + case <-time.After(500 * time.Millisecond): + } + + require.NoError(t, client.StopAndCancel(ctx)) + riverinternaltest.WaitOrTimeout(t, client.Stopped()) - riverinternaltest.WaitOrTimeout(t, client.Stopped()) + select { + case <-bundle.jobDoneChan: + default: + t.Fatal("expected job to be have exited") + } + }) }) } diff --git a/job_executor.go b/job_executor.go index 740d8add..33747bf6 100644 --- a/job_executor.go +++ b/job_executor.go @@ -247,7 +247,7 @@ func (e *jobExecutor) reportResult(ctx context.Context, res *jobExecutorResult) if res.Err != nil && errors.As(res.Err, &snoozeErr) { e.Logger.InfoContext(ctx, e.Name+": Job snoozed", slog.Int64("job_id", e.JobRow.ID), - slog.String("job_kind", e.JobRow.Kind), + slog.String("job_kind", e.JobRow.Kind), slog.Duration("duration", snoozeErr.duration), ) nextAttemptScheduledAt := time.Now().Add(snoozeErr.duration)