diff --git a/client.go b/client.go index b1a2d16c..fbe8f123 100644 --- a/client.go +++ b/client.go @@ -23,6 +23,7 @@ import ( "github.com/riverqueue/river/internal/maintenance/startstop" "github.com/riverqueue/river/internal/notifier" "github.com/riverqueue/river/internal/rivercommon" + "github.com/riverqueue/river/internal/util/maputil" "github.com/riverqueue/river/internal/util/randutil" "github.com/riverqueue/river/internal/util/sliceutil" "github.com/riverqueue/river/internal/util/valutil" @@ -262,20 +263,16 @@ type QueueConfig struct { // multiple instances operating on different databases or Postgres schemas // within a single database. type Client[TTx any] struct { - // BaseService can't be embedded like on other services because its - // properties would leak to the external API. - baseService baseservice.BaseService + // BaseService and BaseStartStop can't be embedded like on other services + // because their properties would leak to the external API. + baseService baseservice.BaseService + baseStartStop startstop.BaseStartStop completer jobcompleter.JobCompleter config *Config driver riverdriver.Driver[TTx] elector *leadership.Elector - // fetchWorkCancel cancels the context used for fetching new work. This - // will be used to stop fetching new work whenever stop is initiated, or - // when the context provided to Run is itself cancelled. - fetchWorkCancel context.CancelCauseFunc - monitor *clientMonitor notifier *notifier.Notifier producersByQueueName map[string]*producer @@ -284,10 +281,10 @@ type Client[TTx any] struct { subscriptions map[int]*eventSubscription subscriptionsMu sync.Mutex subscriptionsSeq int // used for generating simple IDs - stopComplete chan struct{} statsAggregate jobstats.JobStatistics statsMu sync.Mutex statsNumJobs int + stopped chan struct{} testSignals clientTestSignals uniqueInserter *dbunique.UniqueInserter @@ -433,7 +430,6 @@ func NewClient[TTx any](driver riverdriver.Driver[TTx], config *Config) (*Client driver: driver, monitor: newClientMonitor(), producersByQueueName: make(map[string]*producer), - stopComplete: make(chan struct{}), subscriptions: make(map[int]*eventSubscription), testSignals: clientTestSignals{}, uniqueInserter: baseservice.Init(archetype, &dbunique.UniqueInserter{ @@ -575,146 +571,154 @@ func NewClient[TTx any](driver riverdriver.Driver[TTx], config *Config) (*Client // jobs, but will also cancel the context for any currently-running jobs. If // using StopAndCancel, there's no need to also call Stop. func (c *Client[TTx]) Start(ctx context.Context) error { - if !c.config.willExecuteJobs() { - return errors.New("client Queues and Workers must be configured for a client to start working") - } - if c.config.Workers != nil && len(c.config.Workers.workersMap) < 1 { - return errors.New("at least one Worker must be added to the Workers bundle") + fetchCtx, shouldStart, stopped := c.baseStartStop.StartInit(ctx) + if !shouldStart { + return nil } - // Before doing anything else, make an initial connection to the database to - // verify that it appears healthy. Many of the subcomponents below start up - // in a goroutine and in case of initial failure, only produce a log line, - // so even in the case of a fundamental failure like the database not being - // available, the client appears to have started even though it's completely - // non-functional. Here we try to make an initial assessment of health and - // return quickly in case of an apparent problem. - _, err := c.driver.GetExecutor().Exec(ctx, "SELECT 1") - if err != nil { - return fmt.Errorf("error making initial connection to database: %w", err) - } + c.stopped = stopped - // In case of error, stop any services that might have started. This - // is safe because even services that were never started will still - // tolerate being stopped. - stopServicesOnError := func() { - startstop.StopAllParallel(c.services) - c.monitor.Stop() + stopProducers := func() { + startstop.StopAllParallel(sliceutil.Map( + maputil.Values(c.producersByQueueName), + func(p *producer) startstop.Service { return p }), + ) } - // Monitor should be the first subprocess to start, and the last to stop. - // It's not part of the waitgroup because we need to wait for everything else - // to shut down prior to closing the monitor. - // - // Unlike other services, it's given a background context so that it doesn't - // cancel on normal stops. - if err := c.monitor.Start(context.Background()); err != nil { //nolint:contextcheck - return err - } + var workCtx context.Context - if c.completer != nil { - // The completer is part of the services list below, but although it can - // stop gracefully along with all the other services, it needs to be - // started with a context that's _not_ fetchCtx. This ensures that even - // when fetch is cancelled on shutdown, the completer is still given a - // separate opportunity to start stopping only after the producers have - // finished up and returned. - if err := c.completer.Start(ctx); err != nil { - stopServicesOnError() - return err + // Startup code. Wrapped in a closure so it doesn't have to remember to + // close the stopped channel if returning with an error. + if err := func() error { + if !c.config.willExecuteJobs() { + return errors.New("client Queues and Workers must be configured for a client to start working") + } + if c.config.Workers != nil && len(c.config.Workers.workersMap) < 1 { + return errors.New("at least one Worker must be added to the Workers bundle") } - // Receives job complete notifications from the completer and - // distributes them to any subscriptions. - c.completer.Subscribe(c.distributeJobCompleterCallback) - } - - // We use separate contexts for fetching and working to allow for a graceful - // stop. However, both inherit from the provided context so if it is - // cancelled a more aggressive stop will be initiated. - fetchCtx, fetchWorkCancel := context.WithCancelCause(ctx) - c.fetchWorkCancel = fetchWorkCancel - workCtx, workCancel := context.WithCancelCause(withClient[TTx](ctx, c)) - c.workCancel = workCancel - - for _, service := range c.services { - // TODO(brandur): Reevaluate the use of fetchCtx here. It's currently - // necessary to speed up shutdown so that all services start shutting - // down before having to wait for the producers to finish, but as - // stopping becomes more normalized (hopefully by making the client - // itself a start/stop service), we can likely accomplish that in a - // cleaner way. - if err := service.Start(fetchCtx); err != nil { - stopServicesOnError() - if errors.Is(context.Cause(ctx), rivercommon.ErrShutdown) { - return nil - } - return err + // Before doing anything else, make an initial connection to the database to + // verify that it appears healthy. Many of the subcomponents below start up + // in a goroutine and in case of initial failure, only produce a log line, + // so even in the case of a fundamental failure like the database not being + // available, the client appears to have started even though it's completely + // non-functional. Here we try to make an initial assessment of health and + // return quickly in case of an apparent problem. + _, err := c.driver.GetExecutor().Exec(fetchCtx, "SELECT 1") + if err != nil { + return fmt.Errorf("error making initial connection to database: %w", err) } - } - for _, producer := range c.producersByQueueName { - producer := producer + // In case of error, stop any services that might have started. This + // is safe because even services that were never started will still + // tolerate being stopped. + stopServicesOnError := func() { + startstop.StopAllParallel(c.services) + c.monitor.Stop() + } - if err := producer.StartWorkContext(fetchCtx, workCtx); err != nil { + // Monitor should be the first subprocess to start, and the last to stop. + // It's not part of the waitgroup because we need to wait for everything else + // to shut down prior to closing the monitor. + // + // Unlike other services, it's given a background context so that it doesn't + // cancel on normal stops. + if err := c.monitor.Start(context.Background()); err != nil { //nolint:contextcheck return err } - } - go func() { - <-fetchCtx.Done() - c.signalStopComplete(ctx) - }() + if c.completer != nil { + // The completer is part of the services list below, but although it can + // stop gracefully along with all the other services, it needs to be + // started with a context that's _not_ fetchCtx. This ensures that even + // when fetch is cancelled on shutdown, the completer is still given a + // separate opportunity to start stopping only after the producers have + // finished up and returned. + if err := c.completer.Start(ctx); err != nil { + stopServicesOnError() + return err + } - c.baseService.Logger.InfoContext(workCtx, "River client successfully started", slog.String("client_id", c.ID())) + // Receives job complete notifications from the completer and + // distributes them to any subscriptions. + c.completer.Subscribe(c.distributeJobCompleterCallback) + } - return nil -} + // We use separate contexts for fetching and working to allow for a graceful + // stop. Both inherit from the provided context, so if it's cancelled, a + // more aggressive stop will be initiated. + workCtx, c.workCancel = context.WithCancelCause(withClient[TTx](ctx, c)) -// ctx is used only for logging, not for lifecycle. -func (c *Client[TTx]) signalStopComplete(ctx context.Context) { - for _, producer := range c.producersByQueueName { - producer.Stop() - } + for _, service := range c.services { + if err := service.Start(fetchCtx); err != nil { + stopServicesOnError() + return err + } + } - // Stop all mainline services where stop order isn't important. - startstop.StopAllParallel(append( - // This list of services contains the completer, which should always - // stop after the producers so that any remaining work that was enqueued - // will have a chance to have its state completed as it finishes. - // - // TODO: there's a risk here that the completer is stuck on a job that - // won't complete. We probably need a timeout or way to move on in those - // cases. - c.services, + for _, producer := range c.producersByQueueName { + producer := producer - // Will only be started if this client was leader, but can tolerate a stop - // without having been started. - c.queueMaintainer, - )) + if err := producer.StartWorkContext(fetchCtx, workCtx); err != nil { + stopProducers() + stopServicesOnError() + return err + } + } - c.baseService.Logger.InfoContext(ctx, c.baseService.Name+": All services stopped") + return nil + }(); err != nil { + defer close(stopped) + if errors.Is(context.Cause(fetchCtx), startstop.ErrStop) { + return rivercommon.ErrShutdown + } + return err + } - // As of now, the Adapter doesn't have any async behavior, so we don't need - // to wait for it to stop. Once all executors and completers are done, we - // know that nothing else is happening that's from us. + go func() { + defer close(stopped) - // Remove all subscriptions and close corresponding channels. - func() { - c.subscriptionsMu.Lock() - defer c.subscriptionsMu.Unlock() + c.baseService.Logger.InfoContext(ctx, "River client started", slog.String("client_id", c.ID())) + defer c.baseService.Logger.InfoContext(ctx, "River client stopped", slog.String("client_id", c.ID())) - for subID, sub := range c.subscriptions { - close(sub.Chan) - delete(c.subscriptions, subID) - } - }() + // The call to Stop cancels this context. Block here until shutdown. + <-fetchCtx.Done() + + // On stop, have the producers stop fetching first of all. + stopProducers() + + // Stop all mainline services where stop order isn't important. + startstop.StopAllParallel(append( + // This list of services contains the completer, which should always + // stop after the producers so that any remaining work that was enqueued + // will have a chance to have its state completed as it finishes. + // + // TODO: there's a risk here that the completer is stuck on a job that + // won't complete. We probably need a timeout or way to move on in those + // cases. + c.services, + + // Will only be started if this client was leader, but can tolerate a + // stop without having been started. + c.queueMaintainer, + )) + + // Remove all subscriptions and close corresponding channels. + func() { + c.subscriptionsMu.Lock() + defer c.subscriptionsMu.Unlock() + + for subID, sub := range c.subscriptions { + close(sub.Chan) + delete(c.subscriptions, subID) + } + }() - // Shut down the monitor last so it can broadcast final status updates: - c.monitor.Stop() + // Shut down the monitor last so it can broadcast final status updates: + c.monitor.Stop() + }() - c.baseService.Logger.InfoContext(ctx, c.baseService.Name+": Stop complete") - close(c.stopComplete) + return nil } // Stop performs a graceful shutdown of the Client. It signals all producers @@ -725,20 +729,17 @@ func (c *Client[TTx]) signalStopComplete(ctx context.Context) { // There's no need to call this method if a hard stop has already been initiated // by cancelling the context passed to Start or by calling StopAndCancel. func (c *Client[TTx]) Stop(ctx context.Context) error { - if c.fetchWorkCancel == nil { - return errors.New("client not started") + shouldStop, stopped, finalizeStop := c.baseStartStop.StopInit() + if !shouldStop { + return nil } - c.baseService.Logger.InfoContext(ctx, c.baseService.Name+": Stop started") - c.fetchWorkCancel(rivercommon.ErrShutdown) - return c.awaitStop(ctx) -} - -func (c *Client[TTx]) awaitStop(ctx context.Context) error { select { - case <-ctx.Done(): + case <-ctx.Done(): // stop context cancelled + finalizeStop(false) // not stopped; allow Stop to be called again return ctx.Err() - case <-c.stopComplete: + case <-stopped: + finalizeStop(true) return nil } } @@ -754,10 +755,22 @@ func (c *Client[TTx]) awaitStop(ctx context.Context) error { // no need to call this method if the context passed to Run is cancelled // instead. func (c *Client[TTx]) StopAndCancel(ctx context.Context) error { + shouldStop, stopped, finalizeStop := c.baseStartStop.StopInit() + if !shouldStop { + return nil + } + c.baseService.Logger.InfoContext(ctx, c.baseService.Name+": Hard stop started; cancelling all work") - c.fetchWorkCancel(rivercommon.ErrShutdown) c.workCancel(rivercommon.ErrShutdown) - return c.awaitStop(ctx) + + select { + case <-ctx.Done(): // stop context cancelled + finalizeStop(false) // not stopped; allow Stop to be called again + return ctx.Err() + case <-stopped: + finalizeStop(true) + return nil + } } // Stopped returns a channel that will be closed when the Client has stopped. @@ -765,7 +778,7 @@ func (c *Client[TTx]) StopAndCancel(ctx context.Context) error { // // It is not affected by any contexts passed to Stop or StopAndCancel. func (c *Client[TTx]) Stopped() <-chan struct{} { - return c.stopComplete + return c.stopped } // Subscribe subscribes to the provided kinds of events that occur within the diff --git a/client_test.go b/client_test.go index d2f12654..2a6808b1 100644 --- a/client_test.go +++ b/client_test.go @@ -24,6 +24,7 @@ import ( "github.com/riverqueue/river/internal/maintenance" "github.com/riverqueue/river/internal/rivercommon" "github.com/riverqueue/river/internal/riverinternaltest" + "github.com/riverqueue/river/internal/riverinternaltest/startstoptest" "github.com/riverqueue/river/internal/riverinternaltest/testfactory" "github.com/riverqueue/river/internal/util/dbutil" "github.com/riverqueue/river/internal/util/ptrutil" @@ -116,6 +117,16 @@ func (w *callbackWorker) Work(ctx context.Context, job *Job[callbackArgs]) error return w.fn(ctx, job) } +// A small wrapper around Client that gives us a struct that corrects the +// client's Stop function so that it can implement startstop.Service. +type clientWithSimpleStop[TTx any] struct { + *Client[TTx] +} + +func (c *clientWithSimpleStop[TTx]) Stop() { + _ = c.Client.Stop(context.Background()) +} + func newTestConfig(t *testing.T, callback callbackFunc) *Config { t.Helper() workers := NewWorkers() @@ -433,6 +444,16 @@ func Test_Client(t *testing.T) { require.Equal(t, `relation "river_job" does not exist`, pgErr.Message) }) + t.Run("StartStopStress", func(t *testing.T) { + t.Parallel() + + client, _ := setup(t) + + clientWithStop := &clientWithSimpleStop[pgx.Tx]{Client: client} + + startstoptest.StressErr(ctx, t, clientWithStop, rivercommon.ErrShutdown) + }) + t.Run("StopAndCancel", func(t *testing.T) { t.Parallel() @@ -515,17 +536,6 @@ func Test_Client_Stop(t *testing.T) { } } - t.Run("not started", func(t *testing.T) { - t.Parallel() - - dbPool := riverinternaltest.TestDB(ctx, t) - client := newTestClient(t, dbPool, newTestConfig(t, nil)) - - err := client.Stop(ctx) - require.Error(t, err) - require.Equal(t, "client not started", err.Error()) - }) - t.Run("no jobs in progress", func(t *testing.T) { t.Parallel() client := runNewTestClient(ctx, t, newTestConfig(t, nil)) @@ -2564,8 +2574,8 @@ func Test_Client_InsertTriggersImmediateWork(t *testing.T) { ctx := context.Background() require := require.New(t) - ctx, cancel := context.WithTimeout(ctx, 10*time.Second) - t.Cleanup(cancel) + // ctx, cancel := context.WithTimeout(ctx, 10*time.Second) + // t.Cleanup(cancel) doneCh := make(chan struct{}) close(doneCh) // don't need to block any jobs from completing diff --git a/internal/maintenance/startstop/start_stop.go b/internal/maintenance/startstop/start_stop.go index 9196ddf1..12365fbf 100644 --- a/internal/maintenance/startstop/start_stop.go +++ b/internal/maintenance/startstop/start_stop.go @@ -2,9 +2,15 @@ package startstop import ( "context" + "errors" "sync" ) +// ErrStop is an error injected into WithCancelCause when context is canceled +// because a service is stopping. Makes it possible to differentiate a +// controlled stop from a context cancellation. +var ErrStop = errors.New("service stopped") + // Service is a generalized interface for a service that starts and stops, // usually one backed by embedding BaseStartStop. type Service interface { @@ -48,7 +54,7 @@ type serviceWithStopped interface { // A Stop implementation is provided automatically and it's not necessary to // override it. type BaseStartStop struct { - cancelFunc context.CancelFunc + cancelFunc context.CancelCauseFunc mu sync.Mutex started bool stopped chan struct{} @@ -69,6 +75,21 @@ type BaseStartStop struct { // defer close(stopped) // // ... +// +// Be careful to also close it in the event of startup errors, otherwise a +// service that failed to start once will never be able to start up. +// +// ctx, shouldStart, stopped := s.StartInit(ctx) +// if !shouldStart { +// return nil +// } +// +// if err := possibleStartUpError(); err != nil { +// close(stopped) +// return err +// } +// +// ... func (s *BaseStartStop) StartInit(ctx context.Context) (context.Context, bool, chan struct{}) { s.mu.Lock() defer s.mu.Unlock() @@ -79,7 +100,7 @@ func (s *BaseStartStop) StartInit(ctx context.Context) (context.Context, bool, c s.started = true s.stopped = make(chan struct{}) - ctx, s.cancelFunc = context.WithCancel(ctx) + ctx, s.cancelFunc = context.WithCancelCause(ctx) return ctx, true, s.stopped } @@ -87,19 +108,55 @@ func (s *BaseStartStop) StartInit(ctx context.Context) (context.Context, bool, c // Stop is an automatically provided implementation for the maintenance Service // interface's Stop. func (s *BaseStartStop) Stop() { + shouldStop, stopped, finalizeStop := s.StopInit() + if !shouldStop { + return + } + + <-stopped + finalizeStop(true) +} + +// StopInit provides a way to build a more customized Stop implementation. It +// should be avoided unless there'a an exception reason not to because Stop +// should be fine in the vast majority of situations. It returns a boolean +// indicating whether the service should do any additional work to stop (false +// is returned if the service was never started), a stopped channel to wait on +// for full stop, and a finalizeStop function that should be deferred in the +// stop function to ensure that locks are cleaned up and the struct is reset +// after stopping. +// +// shouldStop, stopped, finalizeStop := s.StartInit(ctx) +// if !shouldStop { +// return +// } +// +// defer finalizeStop(true) +// +// ... +// +// finalizeStop takes a boolean which indicates where the service should indeed +// be considered stopped. This should usually be true, but callers can pass +// false to cancel the stop action, keeping the service from starting again, and +// potentially allowing the service to try another stop. +func (s *BaseStartStop) StopInit() (bool, <-chan struct{}, func(didStop bool)) { s.mu.Lock() - defer s.mu.Unlock() // Tolerate being told to stop without having been started. if s.stopped == nil { - return + s.mu.Unlock() + return false, nil, func(didStop bool) {} } - s.cancelFunc() + s.cancelFunc(ErrStop) - <-s.stopped - s.started = false - s.stopped = nil + return true, s.stopped, func(didStop bool) { + defer s.mu.Unlock() + if didStop { + s.started = false + s.stopped = nil + } + } } // Stopped returns a channel that can be waited on for the service to be diff --git a/internal/maintenance/startstop/start_stop_test.go b/internal/maintenance/startstop/start_stop_test.go index bd23ebc2..e13f0440 100644 --- a/internal/maintenance/startstop/start_stop_test.go +++ b/internal/maintenance/startstop/start_stop_test.go @@ -49,6 +49,15 @@ func testService(t *testing.T, newService func(t *testing.T) serviceWithStopped) return newService(t), &testBundle{} } + t.Run("StopAndStart", func(t *testing.T) { + t.Parallel() + + service, _ := setup(t) + + require.NoError(t, service.Start(ctx)) + service.Stop() + }) + t.Run("DoubleStop", func(t *testing.T) { t.Parallel() @@ -140,6 +149,119 @@ func TestBaseStartStopFunc(t *testing.T) { testService(t, makeFunc) } +func TestErrStop(t *testing.T) { + t.Parallel() + + var ( + workCtx context.Context + started = make(chan struct{}) + ) + + startStop := StartStopFunc(func(ctx context.Context, shouldStart bool, stopped chan struct{}) error { + if !shouldStart { + return nil + } + + workCtx = ctx + + go func() { + close(started) + defer close(stopped) + <-ctx.Done() + }() + + return nil + }) + + ctx := context.Background() + + require.NoError(t, startStop.Start(ctx)) + <-started + startStop.Stop() + require.ErrorIs(t, context.Cause(workCtx), ErrStop) +} + +// A service with the more unusual case. +type sampleServiceWithStopInit struct { + baseservice.BaseService + BaseStartStop + + didStop bool + + // Some simple state in the service which a started service taints. The + // purpose of this variable is to allow us to detect a data race allowed by + // BaseStartStop. + state bool +} + +func (s *sampleServiceWithStopInit) Start(ctx context.Context) error { + ctx, shouldStart, stopped := s.StartInit(ctx) + if !shouldStart { + return nil + } + + go func() { + defer close(stopped) + s.state = true + <-ctx.Done() + }() + + return nil +} + +func (s *sampleServiceWithStopInit) Stop() { + shouldStop, stopped, finalizeStop := s.StopInit() + if !shouldStop { + return + } + + <-stopped + finalizeStop(s.didStop) +} + +func TestWithStopInit(t *testing.T) { + t.Parallel() + + testService(t, func(t *testing.T) serviceWithStopped { t.Helper(); return &sampleServiceWithStopInit{didStop: true} }) + + ctx := context.Background() + + type testBundle struct{} + + setup := func() (*sampleServiceWithStopInit, *testBundle) { + return &sampleServiceWithStopInit{}, &testBundle{} + } + + t.Run("FinalizeDidStop", func(t *testing.T) { + t.Parallel() + + service, _ := setup() + service.didStop = true // will set stopped + + require.NoError(t, service.Start(ctx)) + + service.Stop() + + require.False(t, service.started) + require.Nil(t, service.stopped) + }) + + t.Run("FinalizeDidNotStop", func(t *testing.T) { + t.Parallel() + + service, _ := setup() + service.didStop = false // will NOT set stopped + + require.NoError(t, service.Start(ctx)) + + service.Stop() + + // service is still started because didStop was set to false + require.True(t, service.started) + require.NotNil(t, service.stopped) + }) +} + func TestStopAllParallel(t *testing.T) { t.Parallel() diff --git a/internal/riverinternaltest/startstoptest/startstoptest.go b/internal/riverinternaltest/startstoptest/startstoptest.go index 6725a89a..0948b9fb 100644 --- a/internal/riverinternaltest/startstoptest/startstoptest.go +++ b/internal/riverinternaltest/startstoptest/startstoptest.go @@ -3,7 +3,6 @@ package startstoptest import ( "context" "sync" - "testing" "time" "github.com/stretchr/testify/require" @@ -14,7 +13,15 @@ import ( // Stress is a test helper that puts stress on a service's start and stop // functions so that we can detect any data races that it might have due to // improper use of BaseStopStart. -func Stress(ctx context.Context, tb testing.TB, svc startstop.Service) { +func Stress(ctx context.Context, tb testingT, svc startstop.Service) { + StressErr(ctx, tb, svc, nil) +} + +// StressErr is the same as Stress except that the given allowedStartErr is +// tolerated on start (either no error or an error that is allowedStartErr is +// allowed). This is useful for services that may want to return an error if +// they're shut down as they're still starting up. +func StressErr(ctx context.Context, tb testingT, svc startstop.Service, allowedStartErr error) { //nolint:varnamelen tb.Helper() var wg sync.WaitGroup @@ -25,7 +32,12 @@ func Stress(ctx context.Context, tb testing.TB, svc startstop.Service) { defer wg.Done() for j := 0; j < 50; j++ { - require.NoError(tb, svc.Start(ctx)) + err := svc.Start(ctx) + if allowedStartErr == nil { + require.NoError(tb, err) + } else if err != nil { + require.ErrorIs(tb, err, allowedStartErr) + } stopped := make(chan struct{}) @@ -45,3 +57,11 @@ func Stress(ctx context.Context, tb testing.TB, svc startstop.Service) { wg.Wait() } + +// Minimal interface for *testing.B/*testing.T that lets us test a failure +// condition for our test helpers above. +type testingT interface { + Errorf(format string, args ...interface{}) + FailNow() + Helper() +} diff --git a/internal/riverinternaltest/startstoptest/startstoptest_test.go b/internal/riverinternaltest/startstoptest/startstoptest_test.go index dc43c6f1..f7efd156 100644 --- a/internal/riverinternaltest/startstoptest/startstoptest_test.go +++ b/internal/riverinternaltest/startstoptest/startstoptest_test.go @@ -2,16 +2,21 @@ package startstoptest import ( "context" + "errors" "log/slog" + "sync/atomic" "testing" + "github.com/stretchr/testify/require" + "github.com/riverqueue/river/internal/maintenance/startstop" "github.com/riverqueue/river/internal/riverinternaltest" ) type MyService struct { startstop.BaseStartStop - logger *slog.Logger + logger *slog.Logger + startErr error } func (s *MyService) Start(ctx context.Context) error { @@ -20,6 +25,11 @@ func (s *MyService) Start(ctx context.Context) error { return nil } + if s.startErr != nil { + close(stopped) + return s.startErr + } + go func() { defer close(stopped) @@ -39,3 +49,31 @@ func TestStress(t *testing.T) { Stress(ctx, t, &MyService{logger: riverinternaltest.Logger(t)}) } + +func TestStressErr(t *testing.T) { + t.Parallel() + + ctx := context.Background() + + startErr := errors.New("error returned on start") + + StressErr(ctx, t, &MyService{logger: riverinternaltest.Logger(t), startErr: startErr}, startErr) + + mockT := newMockTestingT(t) + StressErr(ctx, mockT, &MyService{logger: riverinternaltest.Logger(t), startErr: errors.New("different error")}, startErr) + require.True(t, mockT.failed.Load()) +} + +type mockTestingT struct { + failed atomic.Bool + tb testing.TB +} + +func newMockTestingT(tb testing.TB) *mockTestingT { + tb.Helper() + return &mockTestingT{tb: tb} +} + +func (t *mockTestingT) Errorf(format string, args ...interface{}) {} +func (t *mockTestingT) FailNow() { t.failed.Store(true) } +func (t *mockTestingT) Helper() { t.tb.Helper() }