From 4091846e9a0e3a5125337fd51a9c4c207f16a42e Mon Sep 17 00:00:00 2001 From: Brandur Date: Fri, 15 Mar 2024 19:23:12 -0700 Subject: [PATCH] Make client a start/stop service Here, continue service refactoring to its ultimate conclusion, and make the client itself a start/stop service, having a couple advantages: * (IMO) considerably simplifies the start and stop code, putting it all in one place, and largely even one function. Fewer helpers to follow around to understand what's going on. * Makes the client behave gracefully on double starts/stops. * Allows the client to easily handle start/stop under duress. Any number of goroutines can be starting or stopping it simultaneously and it's guaranteed free from races. Because the client's stop functions are different from the signature of other start/stop services (taking a context and returning an error), I had to modify the base start/stop service somewhat and allow for an additional `StopInit` helper (like `StartInit`) that lets stop behavior be customized while still providing race-free operations. --- client.go | 289 +++++++++--------- client_test.go | 36 ++- internal/maintenance/startstop/start_stop.go | 73 ++++- .../maintenance/startstop/start_stop_test.go | 122 ++++++++ .../startstoptest/startstoptest.go | 26 +- .../startstoptest/startstoptest_test.go | 40 ++- 6 files changed, 423 insertions(+), 163 deletions(-) diff --git a/client.go b/client.go index b1a2d16c..fbe8f123 100644 --- a/client.go +++ b/client.go @@ -23,6 +23,7 @@ import ( "github.com/riverqueue/river/internal/maintenance/startstop" "github.com/riverqueue/river/internal/notifier" "github.com/riverqueue/river/internal/rivercommon" + "github.com/riverqueue/river/internal/util/maputil" "github.com/riverqueue/river/internal/util/randutil" "github.com/riverqueue/river/internal/util/sliceutil" "github.com/riverqueue/river/internal/util/valutil" @@ -262,20 +263,16 @@ type QueueConfig struct { // multiple instances operating on different databases or Postgres schemas // within a single database. type Client[TTx any] struct { - // BaseService can't be embedded like on other services because its - // properties would leak to the external API. - baseService baseservice.BaseService + // BaseService and BaseStartStop can't be embedded like on other services + // because their properties would leak to the external API. + baseService baseservice.BaseService + baseStartStop startstop.BaseStartStop completer jobcompleter.JobCompleter config *Config driver riverdriver.Driver[TTx] elector *leadership.Elector - // fetchWorkCancel cancels the context used for fetching new work. This - // will be used to stop fetching new work whenever stop is initiated, or - // when the context provided to Run is itself cancelled. - fetchWorkCancel context.CancelCauseFunc - monitor *clientMonitor notifier *notifier.Notifier producersByQueueName map[string]*producer @@ -284,10 +281,10 @@ type Client[TTx any] struct { subscriptions map[int]*eventSubscription subscriptionsMu sync.Mutex subscriptionsSeq int // used for generating simple IDs - stopComplete chan struct{} statsAggregate jobstats.JobStatistics statsMu sync.Mutex statsNumJobs int + stopped chan struct{} testSignals clientTestSignals uniqueInserter *dbunique.UniqueInserter @@ -433,7 +430,6 @@ func NewClient[TTx any](driver riverdriver.Driver[TTx], config *Config) (*Client driver: driver, monitor: newClientMonitor(), producersByQueueName: make(map[string]*producer), - stopComplete: make(chan struct{}), subscriptions: make(map[int]*eventSubscription), testSignals: clientTestSignals{}, uniqueInserter: baseservice.Init(archetype, &dbunique.UniqueInserter{ @@ -575,146 +571,154 @@ func NewClient[TTx any](driver riverdriver.Driver[TTx], config *Config) (*Client // jobs, but will also cancel the context for any currently-running jobs. If // using StopAndCancel, there's no need to also call Stop. func (c *Client[TTx]) Start(ctx context.Context) error { - if !c.config.willExecuteJobs() { - return errors.New("client Queues and Workers must be configured for a client to start working") - } - if c.config.Workers != nil && len(c.config.Workers.workersMap) < 1 { - return errors.New("at least one Worker must be added to the Workers bundle") + fetchCtx, shouldStart, stopped := c.baseStartStop.StartInit(ctx) + if !shouldStart { + return nil } - // Before doing anything else, make an initial connection to the database to - // verify that it appears healthy. Many of the subcomponents below start up - // in a goroutine and in case of initial failure, only produce a log line, - // so even in the case of a fundamental failure like the database not being - // available, the client appears to have started even though it's completely - // non-functional. Here we try to make an initial assessment of health and - // return quickly in case of an apparent problem. - _, err := c.driver.GetExecutor().Exec(ctx, "SELECT 1") - if err != nil { - return fmt.Errorf("error making initial connection to database: %w", err) - } + c.stopped = stopped - // In case of error, stop any services that might have started. This - // is safe because even services that were never started will still - // tolerate being stopped. - stopServicesOnError := func() { - startstop.StopAllParallel(c.services) - c.monitor.Stop() + stopProducers := func() { + startstop.StopAllParallel(sliceutil.Map( + maputil.Values(c.producersByQueueName), + func(p *producer) startstop.Service { return p }), + ) } - // Monitor should be the first subprocess to start, and the last to stop. - // It's not part of the waitgroup because we need to wait for everything else - // to shut down prior to closing the monitor. - // - // Unlike other services, it's given a background context so that it doesn't - // cancel on normal stops. - if err := c.monitor.Start(context.Background()); err != nil { //nolint:contextcheck - return err - } + var workCtx context.Context - if c.completer != nil { - // The completer is part of the services list below, but although it can - // stop gracefully along with all the other services, it needs to be - // started with a context that's _not_ fetchCtx. This ensures that even - // when fetch is cancelled on shutdown, the completer is still given a - // separate opportunity to start stopping only after the producers have - // finished up and returned. - if err := c.completer.Start(ctx); err != nil { - stopServicesOnError() - return err + // Startup code. Wrapped in a closure so it doesn't have to remember to + // close the stopped channel if returning with an error. + if err := func() error { + if !c.config.willExecuteJobs() { + return errors.New("client Queues and Workers must be configured for a client to start working") + } + if c.config.Workers != nil && len(c.config.Workers.workersMap) < 1 { + return errors.New("at least one Worker must be added to the Workers bundle") } - // Receives job complete notifications from the completer and - // distributes them to any subscriptions. - c.completer.Subscribe(c.distributeJobCompleterCallback) - } - - // We use separate contexts for fetching and working to allow for a graceful - // stop. However, both inherit from the provided context so if it is - // cancelled a more aggressive stop will be initiated. - fetchCtx, fetchWorkCancel := context.WithCancelCause(ctx) - c.fetchWorkCancel = fetchWorkCancel - workCtx, workCancel := context.WithCancelCause(withClient[TTx](ctx, c)) - c.workCancel = workCancel - - for _, service := range c.services { - // TODO(brandur): Reevaluate the use of fetchCtx here. It's currently - // necessary to speed up shutdown so that all services start shutting - // down before having to wait for the producers to finish, but as - // stopping becomes more normalized (hopefully by making the client - // itself a start/stop service), we can likely accomplish that in a - // cleaner way. - if err := service.Start(fetchCtx); err != nil { - stopServicesOnError() - if errors.Is(context.Cause(ctx), rivercommon.ErrShutdown) { - return nil - } - return err + // Before doing anything else, make an initial connection to the database to + // verify that it appears healthy. Many of the subcomponents below start up + // in a goroutine and in case of initial failure, only produce a log line, + // so even in the case of a fundamental failure like the database not being + // available, the client appears to have started even though it's completely + // non-functional. Here we try to make an initial assessment of health and + // return quickly in case of an apparent problem. + _, err := c.driver.GetExecutor().Exec(fetchCtx, "SELECT 1") + if err != nil { + return fmt.Errorf("error making initial connection to database: %w", err) } - } - for _, producer := range c.producersByQueueName { - producer := producer + // In case of error, stop any services that might have started. This + // is safe because even services that were never started will still + // tolerate being stopped. + stopServicesOnError := func() { + startstop.StopAllParallel(c.services) + c.monitor.Stop() + } - if err := producer.StartWorkContext(fetchCtx, workCtx); err != nil { + // Monitor should be the first subprocess to start, and the last to stop. + // It's not part of the waitgroup because we need to wait for everything else + // to shut down prior to closing the monitor. + // + // Unlike other services, it's given a background context so that it doesn't + // cancel on normal stops. + if err := c.monitor.Start(context.Background()); err != nil { //nolint:contextcheck return err } - } - go func() { - <-fetchCtx.Done() - c.signalStopComplete(ctx) - }() + if c.completer != nil { + // The completer is part of the services list below, but although it can + // stop gracefully along with all the other services, it needs to be + // started with a context that's _not_ fetchCtx. This ensures that even + // when fetch is cancelled on shutdown, the completer is still given a + // separate opportunity to start stopping only after the producers have + // finished up and returned. + if err := c.completer.Start(ctx); err != nil { + stopServicesOnError() + return err + } - c.baseService.Logger.InfoContext(workCtx, "River client successfully started", slog.String("client_id", c.ID())) + // Receives job complete notifications from the completer and + // distributes them to any subscriptions. + c.completer.Subscribe(c.distributeJobCompleterCallback) + } - return nil -} + // We use separate contexts for fetching and working to allow for a graceful + // stop. Both inherit from the provided context, so if it's cancelled, a + // more aggressive stop will be initiated. + workCtx, c.workCancel = context.WithCancelCause(withClient[TTx](ctx, c)) -// ctx is used only for logging, not for lifecycle. -func (c *Client[TTx]) signalStopComplete(ctx context.Context) { - for _, producer := range c.producersByQueueName { - producer.Stop() - } + for _, service := range c.services { + if err := service.Start(fetchCtx); err != nil { + stopServicesOnError() + return err + } + } - // Stop all mainline services where stop order isn't important. - startstop.StopAllParallel(append( - // This list of services contains the completer, which should always - // stop after the producers so that any remaining work that was enqueued - // will have a chance to have its state completed as it finishes. - // - // TODO: there's a risk here that the completer is stuck on a job that - // won't complete. We probably need a timeout or way to move on in those - // cases. - c.services, + for _, producer := range c.producersByQueueName { + producer := producer - // Will only be started if this client was leader, but can tolerate a stop - // without having been started. - c.queueMaintainer, - )) + if err := producer.StartWorkContext(fetchCtx, workCtx); err != nil { + stopProducers() + stopServicesOnError() + return err + } + } - c.baseService.Logger.InfoContext(ctx, c.baseService.Name+": All services stopped") + return nil + }(); err != nil { + defer close(stopped) + if errors.Is(context.Cause(fetchCtx), startstop.ErrStop) { + return rivercommon.ErrShutdown + } + return err + } - // As of now, the Adapter doesn't have any async behavior, so we don't need - // to wait for it to stop. Once all executors and completers are done, we - // know that nothing else is happening that's from us. + go func() { + defer close(stopped) - // Remove all subscriptions and close corresponding channels. - func() { - c.subscriptionsMu.Lock() - defer c.subscriptionsMu.Unlock() + c.baseService.Logger.InfoContext(ctx, "River client started", slog.String("client_id", c.ID())) + defer c.baseService.Logger.InfoContext(ctx, "River client stopped", slog.String("client_id", c.ID())) - for subID, sub := range c.subscriptions { - close(sub.Chan) - delete(c.subscriptions, subID) - } - }() + // The call to Stop cancels this context. Block here until shutdown. + <-fetchCtx.Done() + + // On stop, have the producers stop fetching first of all. + stopProducers() + + // Stop all mainline services where stop order isn't important. + startstop.StopAllParallel(append( + // This list of services contains the completer, which should always + // stop after the producers so that any remaining work that was enqueued + // will have a chance to have its state completed as it finishes. + // + // TODO: there's a risk here that the completer is stuck on a job that + // won't complete. We probably need a timeout or way to move on in those + // cases. + c.services, + + // Will only be started if this client was leader, but can tolerate a + // stop without having been started. + c.queueMaintainer, + )) + + // Remove all subscriptions and close corresponding channels. + func() { + c.subscriptionsMu.Lock() + defer c.subscriptionsMu.Unlock() + + for subID, sub := range c.subscriptions { + close(sub.Chan) + delete(c.subscriptions, subID) + } + }() - // Shut down the monitor last so it can broadcast final status updates: - c.monitor.Stop() + // Shut down the monitor last so it can broadcast final status updates: + c.monitor.Stop() + }() - c.baseService.Logger.InfoContext(ctx, c.baseService.Name+": Stop complete") - close(c.stopComplete) + return nil } // Stop performs a graceful shutdown of the Client. It signals all producers @@ -725,20 +729,17 @@ func (c *Client[TTx]) signalStopComplete(ctx context.Context) { // There's no need to call this method if a hard stop has already been initiated // by cancelling the context passed to Start or by calling StopAndCancel. func (c *Client[TTx]) Stop(ctx context.Context) error { - if c.fetchWorkCancel == nil { - return errors.New("client not started") + shouldStop, stopped, finalizeStop := c.baseStartStop.StopInit() + if !shouldStop { + return nil } - c.baseService.Logger.InfoContext(ctx, c.baseService.Name+": Stop started") - c.fetchWorkCancel(rivercommon.ErrShutdown) - return c.awaitStop(ctx) -} - -func (c *Client[TTx]) awaitStop(ctx context.Context) error { select { - case <-ctx.Done(): + case <-ctx.Done(): // stop context cancelled + finalizeStop(false) // not stopped; allow Stop to be called again return ctx.Err() - case <-c.stopComplete: + case <-stopped: + finalizeStop(true) return nil } } @@ -754,10 +755,22 @@ func (c *Client[TTx]) awaitStop(ctx context.Context) error { // no need to call this method if the context passed to Run is cancelled // instead. func (c *Client[TTx]) StopAndCancel(ctx context.Context) error { + shouldStop, stopped, finalizeStop := c.baseStartStop.StopInit() + if !shouldStop { + return nil + } + c.baseService.Logger.InfoContext(ctx, c.baseService.Name+": Hard stop started; cancelling all work") - c.fetchWorkCancel(rivercommon.ErrShutdown) c.workCancel(rivercommon.ErrShutdown) - return c.awaitStop(ctx) + + select { + case <-ctx.Done(): // stop context cancelled + finalizeStop(false) // not stopped; allow Stop to be called again + return ctx.Err() + case <-stopped: + finalizeStop(true) + return nil + } } // Stopped returns a channel that will be closed when the Client has stopped. @@ -765,7 +778,7 @@ func (c *Client[TTx]) StopAndCancel(ctx context.Context) error { // // It is not affected by any contexts passed to Stop or StopAndCancel. func (c *Client[TTx]) Stopped() <-chan struct{} { - return c.stopComplete + return c.stopped } // Subscribe subscribes to the provided kinds of events that occur within the diff --git a/client_test.go b/client_test.go index d2f12654..2a6808b1 100644 --- a/client_test.go +++ b/client_test.go @@ -24,6 +24,7 @@ import ( "github.com/riverqueue/river/internal/maintenance" "github.com/riverqueue/river/internal/rivercommon" "github.com/riverqueue/river/internal/riverinternaltest" + "github.com/riverqueue/river/internal/riverinternaltest/startstoptest" "github.com/riverqueue/river/internal/riverinternaltest/testfactory" "github.com/riverqueue/river/internal/util/dbutil" "github.com/riverqueue/river/internal/util/ptrutil" @@ -116,6 +117,16 @@ func (w *callbackWorker) Work(ctx context.Context, job *Job[callbackArgs]) error return w.fn(ctx, job) } +// A small wrapper around Client that gives us a struct that corrects the +// client's Stop function so that it can implement startstop.Service. +type clientWithSimpleStop[TTx any] struct { + *Client[TTx] +} + +func (c *clientWithSimpleStop[TTx]) Stop() { + _ = c.Client.Stop(context.Background()) +} + func newTestConfig(t *testing.T, callback callbackFunc) *Config { t.Helper() workers := NewWorkers() @@ -433,6 +444,16 @@ func Test_Client(t *testing.T) { require.Equal(t, `relation "river_job" does not exist`, pgErr.Message) }) + t.Run("StartStopStress", func(t *testing.T) { + t.Parallel() + + client, _ := setup(t) + + clientWithStop := &clientWithSimpleStop[pgx.Tx]{Client: client} + + startstoptest.StressErr(ctx, t, clientWithStop, rivercommon.ErrShutdown) + }) + t.Run("StopAndCancel", func(t *testing.T) { t.Parallel() @@ -515,17 +536,6 @@ func Test_Client_Stop(t *testing.T) { } } - t.Run("not started", func(t *testing.T) { - t.Parallel() - - dbPool := riverinternaltest.TestDB(ctx, t) - client := newTestClient(t, dbPool, newTestConfig(t, nil)) - - err := client.Stop(ctx) - require.Error(t, err) - require.Equal(t, "client not started", err.Error()) - }) - t.Run("no jobs in progress", func(t *testing.T) { t.Parallel() client := runNewTestClient(ctx, t, newTestConfig(t, nil)) @@ -2564,8 +2574,8 @@ func Test_Client_InsertTriggersImmediateWork(t *testing.T) { ctx := context.Background() require := require.New(t) - ctx, cancel := context.WithTimeout(ctx, 10*time.Second) - t.Cleanup(cancel) + // ctx, cancel := context.WithTimeout(ctx, 10*time.Second) + // t.Cleanup(cancel) doneCh := make(chan struct{}) close(doneCh) // don't need to block any jobs from completing diff --git a/internal/maintenance/startstop/start_stop.go b/internal/maintenance/startstop/start_stop.go index 9196ddf1..12365fbf 100644 --- a/internal/maintenance/startstop/start_stop.go +++ b/internal/maintenance/startstop/start_stop.go @@ -2,9 +2,15 @@ package startstop import ( "context" + "errors" "sync" ) +// ErrStop is an error injected into WithCancelCause when context is canceled +// because a service is stopping. Makes it possible to differentiate a +// controlled stop from a context cancellation. +var ErrStop = errors.New("service stopped") + // Service is a generalized interface for a service that starts and stops, // usually one backed by embedding BaseStartStop. type Service interface { @@ -48,7 +54,7 @@ type serviceWithStopped interface { // A Stop implementation is provided automatically and it's not necessary to // override it. type BaseStartStop struct { - cancelFunc context.CancelFunc + cancelFunc context.CancelCauseFunc mu sync.Mutex started bool stopped chan struct{} @@ -69,6 +75,21 @@ type BaseStartStop struct { // defer close(stopped) // // ... +// +// Be careful to also close it in the event of startup errors, otherwise a +// service that failed to start once will never be able to start up. +// +// ctx, shouldStart, stopped := s.StartInit(ctx) +// if !shouldStart { +// return nil +// } +// +// if err := possibleStartUpError(); err != nil { +// close(stopped) +// return err +// } +// +// ... func (s *BaseStartStop) StartInit(ctx context.Context) (context.Context, bool, chan struct{}) { s.mu.Lock() defer s.mu.Unlock() @@ -79,7 +100,7 @@ func (s *BaseStartStop) StartInit(ctx context.Context) (context.Context, bool, c s.started = true s.stopped = make(chan struct{}) - ctx, s.cancelFunc = context.WithCancel(ctx) + ctx, s.cancelFunc = context.WithCancelCause(ctx) return ctx, true, s.stopped } @@ -87,19 +108,55 @@ func (s *BaseStartStop) StartInit(ctx context.Context) (context.Context, bool, c // Stop is an automatically provided implementation for the maintenance Service // interface's Stop. func (s *BaseStartStop) Stop() { + shouldStop, stopped, finalizeStop := s.StopInit() + if !shouldStop { + return + } + + <-stopped + finalizeStop(true) +} + +// StopInit provides a way to build a more customized Stop implementation. It +// should be avoided unless there'a an exception reason not to because Stop +// should be fine in the vast majority of situations. It returns a boolean +// indicating whether the service should do any additional work to stop (false +// is returned if the service was never started), a stopped channel to wait on +// for full stop, and a finalizeStop function that should be deferred in the +// stop function to ensure that locks are cleaned up and the struct is reset +// after stopping. +// +// shouldStop, stopped, finalizeStop := s.StartInit(ctx) +// if !shouldStop { +// return +// } +// +// defer finalizeStop(true) +// +// ... +// +// finalizeStop takes a boolean which indicates where the service should indeed +// be considered stopped. This should usually be true, but callers can pass +// false to cancel the stop action, keeping the service from starting again, and +// potentially allowing the service to try another stop. +func (s *BaseStartStop) StopInit() (bool, <-chan struct{}, func(didStop bool)) { s.mu.Lock() - defer s.mu.Unlock() // Tolerate being told to stop without having been started. if s.stopped == nil { - return + s.mu.Unlock() + return false, nil, func(didStop bool) {} } - s.cancelFunc() + s.cancelFunc(ErrStop) - <-s.stopped - s.started = false - s.stopped = nil + return true, s.stopped, func(didStop bool) { + defer s.mu.Unlock() + if didStop { + s.started = false + s.stopped = nil + } + } } // Stopped returns a channel that can be waited on for the service to be diff --git a/internal/maintenance/startstop/start_stop_test.go b/internal/maintenance/startstop/start_stop_test.go index bd23ebc2..e13f0440 100644 --- a/internal/maintenance/startstop/start_stop_test.go +++ b/internal/maintenance/startstop/start_stop_test.go @@ -49,6 +49,15 @@ func testService(t *testing.T, newService func(t *testing.T) serviceWithStopped) return newService(t), &testBundle{} } + t.Run("StopAndStart", func(t *testing.T) { + t.Parallel() + + service, _ := setup(t) + + require.NoError(t, service.Start(ctx)) + service.Stop() + }) + t.Run("DoubleStop", func(t *testing.T) { t.Parallel() @@ -140,6 +149,119 @@ func TestBaseStartStopFunc(t *testing.T) { testService(t, makeFunc) } +func TestErrStop(t *testing.T) { + t.Parallel() + + var ( + workCtx context.Context + started = make(chan struct{}) + ) + + startStop := StartStopFunc(func(ctx context.Context, shouldStart bool, stopped chan struct{}) error { + if !shouldStart { + return nil + } + + workCtx = ctx + + go func() { + close(started) + defer close(stopped) + <-ctx.Done() + }() + + return nil + }) + + ctx := context.Background() + + require.NoError(t, startStop.Start(ctx)) + <-started + startStop.Stop() + require.ErrorIs(t, context.Cause(workCtx), ErrStop) +} + +// A service with the more unusual case. +type sampleServiceWithStopInit struct { + baseservice.BaseService + BaseStartStop + + didStop bool + + // Some simple state in the service which a started service taints. The + // purpose of this variable is to allow us to detect a data race allowed by + // BaseStartStop. + state bool +} + +func (s *sampleServiceWithStopInit) Start(ctx context.Context) error { + ctx, shouldStart, stopped := s.StartInit(ctx) + if !shouldStart { + return nil + } + + go func() { + defer close(stopped) + s.state = true + <-ctx.Done() + }() + + return nil +} + +func (s *sampleServiceWithStopInit) Stop() { + shouldStop, stopped, finalizeStop := s.StopInit() + if !shouldStop { + return + } + + <-stopped + finalizeStop(s.didStop) +} + +func TestWithStopInit(t *testing.T) { + t.Parallel() + + testService(t, func(t *testing.T) serviceWithStopped { t.Helper(); return &sampleServiceWithStopInit{didStop: true} }) + + ctx := context.Background() + + type testBundle struct{} + + setup := func() (*sampleServiceWithStopInit, *testBundle) { + return &sampleServiceWithStopInit{}, &testBundle{} + } + + t.Run("FinalizeDidStop", func(t *testing.T) { + t.Parallel() + + service, _ := setup() + service.didStop = true // will set stopped + + require.NoError(t, service.Start(ctx)) + + service.Stop() + + require.False(t, service.started) + require.Nil(t, service.stopped) + }) + + t.Run("FinalizeDidNotStop", func(t *testing.T) { + t.Parallel() + + service, _ := setup() + service.didStop = false // will NOT set stopped + + require.NoError(t, service.Start(ctx)) + + service.Stop() + + // service is still started because didStop was set to false + require.True(t, service.started) + require.NotNil(t, service.stopped) + }) +} + func TestStopAllParallel(t *testing.T) { t.Parallel() diff --git a/internal/riverinternaltest/startstoptest/startstoptest.go b/internal/riverinternaltest/startstoptest/startstoptest.go index 6725a89a..0948b9fb 100644 --- a/internal/riverinternaltest/startstoptest/startstoptest.go +++ b/internal/riverinternaltest/startstoptest/startstoptest.go @@ -3,7 +3,6 @@ package startstoptest import ( "context" "sync" - "testing" "time" "github.com/stretchr/testify/require" @@ -14,7 +13,15 @@ import ( // Stress is a test helper that puts stress on a service's start and stop // functions so that we can detect any data races that it might have due to // improper use of BaseStopStart. -func Stress(ctx context.Context, tb testing.TB, svc startstop.Service) { +func Stress(ctx context.Context, tb testingT, svc startstop.Service) { + StressErr(ctx, tb, svc, nil) +} + +// StressErr is the same as Stress except that the given allowedStartErr is +// tolerated on start (either no error or an error that is allowedStartErr is +// allowed). This is useful for services that may want to return an error if +// they're shut down as they're still starting up. +func StressErr(ctx context.Context, tb testingT, svc startstop.Service, allowedStartErr error) { //nolint:varnamelen tb.Helper() var wg sync.WaitGroup @@ -25,7 +32,12 @@ func Stress(ctx context.Context, tb testing.TB, svc startstop.Service) { defer wg.Done() for j := 0; j < 50; j++ { - require.NoError(tb, svc.Start(ctx)) + err := svc.Start(ctx) + if allowedStartErr == nil { + require.NoError(tb, err) + } else if err != nil { + require.ErrorIs(tb, err, allowedStartErr) + } stopped := make(chan struct{}) @@ -45,3 +57,11 @@ func Stress(ctx context.Context, tb testing.TB, svc startstop.Service) { wg.Wait() } + +// Minimal interface for *testing.B/*testing.T that lets us test a failure +// condition for our test helpers above. +type testingT interface { + Errorf(format string, args ...interface{}) + FailNow() + Helper() +} diff --git a/internal/riverinternaltest/startstoptest/startstoptest_test.go b/internal/riverinternaltest/startstoptest/startstoptest_test.go index dc43c6f1..f7efd156 100644 --- a/internal/riverinternaltest/startstoptest/startstoptest_test.go +++ b/internal/riverinternaltest/startstoptest/startstoptest_test.go @@ -2,16 +2,21 @@ package startstoptest import ( "context" + "errors" "log/slog" + "sync/atomic" "testing" + "github.com/stretchr/testify/require" + "github.com/riverqueue/river/internal/maintenance/startstop" "github.com/riverqueue/river/internal/riverinternaltest" ) type MyService struct { startstop.BaseStartStop - logger *slog.Logger + logger *slog.Logger + startErr error } func (s *MyService) Start(ctx context.Context) error { @@ -20,6 +25,11 @@ func (s *MyService) Start(ctx context.Context) error { return nil } + if s.startErr != nil { + close(stopped) + return s.startErr + } + go func() { defer close(stopped) @@ -39,3 +49,31 @@ func TestStress(t *testing.T) { Stress(ctx, t, &MyService{logger: riverinternaltest.Logger(t)}) } + +func TestStressErr(t *testing.T) { + t.Parallel() + + ctx := context.Background() + + startErr := errors.New("error returned on start") + + StressErr(ctx, t, &MyService{logger: riverinternaltest.Logger(t), startErr: startErr}, startErr) + + mockT := newMockTestingT(t) + StressErr(ctx, mockT, &MyService{logger: riverinternaltest.Logger(t), startErr: errors.New("different error")}, startErr) + require.True(t, mockT.failed.Load()) +} + +type mockTestingT struct { + failed atomic.Bool + tb testing.TB +} + +func newMockTestingT(tb testing.TB) *mockTestingT { + tb.Helper() + return &mockTestingT{tb: tb} +} + +func (t *mockTestingT) Errorf(format string, args ...interface{}) {} +func (t *mockTestingT) FailNow() { t.failed.Store(true) } +func (t *mockTestingT) Helper() { t.tb.Helper() }