diff --git a/cmd/jujud-controller/agent/machine_test.go b/cmd/jujud-controller/agent/machine_test.go index b6c0a00eaea..cf68dc4f33e 100644 --- a/cmd/jujud-controller/agent/machine_test.go +++ b/cmd/jujud-controller/agent/machine_test.go @@ -375,7 +375,7 @@ func (s *MachineSuite) TestMachineAgentRunsDiskManagerWorker(c *gc.C) { started := newSignal() newWorker := func(diskmanager.ListBlockDevicesFunc, diskmanager.BlockDeviceSetter) worker.Worker { started.trigger() - return jworker.NewNoOpWorker() + return jworker.NoopWorker() } s.PatchValue(&diskmanager.NewWorker, newWorker) @@ -424,7 +424,7 @@ func (s *MachineSuite) TestMachineAgentRunsMachineStorageWorker(c *gc.C) { c.Check(config.Scope, gc.Equals, m.Tag()) c.Check(config.Validate(), jc.ErrorIsNil) started.trigger() - return jworker.NewNoOpWorker(), nil + return jworker.NoopWorker(), nil } s.PatchValue(&storageprovisioner.NewStorageProvisioner, newWorker) @@ -499,7 +499,7 @@ func (s *MachineSuite) setupIgnoreAddresses(c *gc.C, expectedIgnoreValue bool) c // The test just cares that NewMachiner is called with the correct // value, nothing else is done with the worker. - return newDummyWorker(), nil + return jworker.NoopWorker(), nil }) attrs := coretesting.Attrs{"ignore-machine-addresses": expectedIgnoreValue} diff --git a/cmd/jujud-controller/agent/util_test.go b/cmd/jujud-controller/agent/util_test.go index e3a7198f2ee..de1045adc3e 100644 --- a/cmd/jujud-controller/agent/util_test.go +++ b/cmd/jujud-controller/agent/util_test.go @@ -386,13 +386,6 @@ func runWithTimeout(c *gc.C, r runner) error { return fmt.Errorf("timed out waiting for agent to finish; stop error: %v", err) } -func newDummyWorker() worker.Worker { - return jworker.NewSimpleWorker(func(stop <-chan struct{}) error { - <-stop - return nil - }) -} - type FakeConfig struct { agent.ConfigSetter values map[string]string diff --git a/cmd/jujud/agent/agent_test.go b/cmd/jujud/agent/agent_test.go index a002c58f27e..dcc77a1ba8e 100644 --- a/cmd/jujud/agent/agent_test.go +++ b/cmd/jujud/agent/agent_test.go @@ -19,6 +19,7 @@ import ( "github.com/juju/juju/cmd/jujud/agent/agenttest" "github.com/juju/juju/core/network" imagetesting "github.com/juju/juju/environs/imagemetadata/testing" + jworker "github.com/juju/juju/internal/worker" "github.com/juju/juju/internal/worker/proxyupdater" ) @@ -66,7 +67,7 @@ func (s *AgentSuite) SetUpTest(c *gc.C) { err = st.SetAPIHostPorts(controllerConfig, hostPorts, hostPorts) c.Assert(err, jc.ErrorIsNil) s.PatchValue(&proxyupdater.NewWorker, func(proxyupdater.Config) (worker.Worker, error) { - return newDummyWorker(), nil + return jworker.NoopWorker(), nil }) // Tests should not try to use internet. Ensure base url is empty. diff --git a/cmd/jujud/agent/machine_test.go b/cmd/jujud/agent/machine_test.go index a74ba6deb54..7aba16b2916 100644 --- a/cmd/jujud/agent/machine_test.go +++ b/cmd/jujud/agent/machine_test.go @@ -356,7 +356,7 @@ func (s *MachineSuite) TestMachineAgentRunsDiskManagerWorker(c *gc.C) { started := newSignal() newWorker := func(diskmanager.ListBlockDevicesFunc, diskmanager.BlockDeviceSetter) worker.Worker { started.trigger() - return jworker.NewNoOpWorker() + return jworker.NoopWorker() } s.PatchValue(&diskmanager.NewWorker, newWorker) @@ -405,7 +405,7 @@ func (s *MachineSuite) TestMachineAgentRunsMachineStorageWorker(c *gc.C) { c.Check(config.Scope, gc.Equals, m.Tag()) c.Check(config.Validate(), jc.ErrorIsNil) started.trigger() - return jworker.NewNoOpWorker(), nil + return jworker.NoopWorker(), nil } s.PatchValue(&storageprovisioner.NewStorageProvisioner, newWorker) @@ -427,7 +427,7 @@ func (s *MachineSuite) setupIgnoreAddresses(c *gc.C, expectedIgnoreValue bool) c // The test just cares that NewMachiner is called with the correct // value, nothing else is done with the worker. - return newDummyWorker(), nil + return jworker.NoopWorker(), nil }) attrs := coretesting.Attrs{"ignore-machine-addresses": expectedIgnoreValue} diff --git a/cmd/jujud/agent/util_test.go b/cmd/jujud/agent/util_test.go index c65f44cb9a6..17e3282e277 100644 --- a/cmd/jujud/agent/util_test.go +++ b/cmd/jujud/agent/util_test.go @@ -339,13 +339,6 @@ func runWithTimeout(c *gc.C, r runner) error { return fmt.Errorf("timed out waiting for agent to finish; stop error: %v", err) } -func newDummyWorker() worker.Worker { - return jworker.NewSimpleWorker(func(stop <-chan struct{}) error { - <-stop - return nil - }) -} - type FakeConfig struct { agent.ConfigSetter values map[string]string diff --git a/internal/worker/caasenvironupgrader/worker.go b/internal/worker/caasenvironupgrader/worker.go index 566a2e2e1c7..986a1578ff3 100644 --- a/internal/worker/caasenvironupgrader/worker.go +++ b/internal/worker/caasenvironupgrader/worker.go @@ -4,6 +4,8 @@ package caasenvironupgrader import ( + "context" + "github.com/juju/errors" "github.com/juju/names/v5" "github.com/juju/worker/v4" @@ -61,7 +63,7 @@ func NewWorker(config Config) (worker.Worker, error) { } // There are no upgrade steps for a CAAS model. // We just set the status to available and unlock the gate. - return jujuworker.NewSimpleWorker(func(<-chan struct{}) error { + return jujuworker.NewSimpleWorker(func(context.Context) error { setStatus := func(s status.Status, info string) error { return config.Facade.SetModelStatus(config.ModelTag, s, info, nil) } diff --git a/internal/worker/environupgrader/worker.go b/internal/worker/environupgrader/worker.go index e9995222074..f6eebf37369 100644 --- a/internal/worker/environupgrader/worker.go +++ b/internal/worker/environupgrader/worker.go @@ -190,7 +190,7 @@ func newUpgradeWorker(config Config, targetVersion int) (worker.Worker, error) { return nil, errors.Trace(err) } - return jujuworker.NewSimpleWorker(func(<-chan struct{}) error { + return jujuworker.NewSimpleWorker(func(ctx stdcontext.Context) error { // NOTE(axw) the abort channel is ignored, because upgrade // steps are not interruptible. If we find they need to be // interruptible, we should consider passing through a diff --git a/internal/worker/identityfilewriter/manifold.go b/internal/worker/identityfilewriter/manifold.go index 66662348765..1cd61939e0f 100644 --- a/internal/worker/identityfilewriter/manifold.go +++ b/internal/worker/identityfilewriter/manifold.go @@ -49,7 +49,7 @@ func newWorker(ctx context.Context, a agent.Agent, apiCaller base.APICaller) (wo } var NewWorker = func(agentConfig agent.Config) (worker.Worker, error) { - inner := func(<-chan struct{}) error { + inner := func(ctx context.Context) error { return agent.WriteSystemIdentityFile(agentConfig) } return jworker.NewSimpleWorker(inner), nil diff --git a/internal/worker/logsender/worker.go b/internal/worker/logsender/worker.go index 246a76607a6..bc714800417 100644 --- a/internal/worker/logsender/worker.go +++ b/internal/worker/logsender/worker.go @@ -26,7 +26,7 @@ type LogSenderAPI interface { // New starts a logsender worker which reads log message structs from // a channel and sends them to the controller via the logsink API. func New(logs LogRecordCh, logSenderAPI LogSenderAPI) worker.Worker { - loop := func(stop <-chan struct{}) error { + loop := func(ctx context.Context) error { // It has been observed that sometimes the logsender.API gets wedged // attempting to get the LogWriter while the agent is being torn down, // and the call to logSenderAPI.LogWriter() doesn't return. This stops @@ -40,13 +40,13 @@ func New(logs LogRecordCh, logSenderAPI LogSenderAPI) worker.Worker { if err != nil { select { case errChan <- err: - case <-stop: + case <-ctx.Done(): } return } select { case sender <- logWriter: - case <-stop: + case <-ctx.Done(): logWriter.Close() } }() @@ -56,7 +56,7 @@ func New(logs LogRecordCh, logSenderAPI LogSenderAPI) worker.Worker { case logWriter = <-sender: case err = <-errChan: return errors.Annotate(err, "logsender dial failed") - case <-stop: + case <-ctx.Done(): return nil } defer logWriter.Close() @@ -101,7 +101,7 @@ func New(logs LogRecordCh, logSenderAPI LogSenderAPI) worker.Worker { } } - case <-stop: + case <-ctx.Done(): return nil } } diff --git a/internal/worker/noopworker.go b/internal/worker/noopworker.go index 855c19e7c1d..9d50222502f 100644 --- a/internal/worker/noopworker.go +++ b/internal/worker/noopworker.go @@ -5,16 +5,15 @@ package worker import ( + "context" + "github.com/juju/worker/v4" ) -func NewNoOpWorker() worker.Worker { - return NewSimpleWorker(doNothing) -} - -func doNothing(stop <-chan struct{}) error { - select { - case <-stop: +// NoopWorker returns a worker that waits for the context to be done. +func NoopWorker() worker.Worker { + return NewSimpleWorker(func(ctx context.Context) error { + <-ctx.Done() return nil - } + }) } diff --git a/internal/worker/periodicworker_test.go b/internal/worker/periodicworker_test.go index 0320e408a88..588d9882cdb 100644 --- a/internal/worker/periodicworker_test.go +++ b/internal/worker/periodicworker_test.go @@ -28,18 +28,18 @@ func (s *periodicWorkerSuite) TestWait(c *gc.C) { funcHasRun := make(chan struct{}) doWork := func(_ <-chan struct{}) error { funcHasRun <- struct{}{} - return testError + return errTest } w := NewPeriodicWorker(doWork, defaultPeriod, NewTimer) - defer func() { c.Assert(worker.Stop(w), gc.Equals, testError) }() + defer func() { c.Assert(worker.Stop(w), gc.Equals, errTest) }() select { case <-funcHasRun: case <-time.After(testing.ShortWait): c.Fatalf("The doWork function should have been called by now") } w.Kill() - c.Assert(w.Wait(), gc.Equals, testError) + c.Assert(w.Wait(), gc.Equals, errTest) select { case <-funcHasRun: c.Fatalf("After the kill we don't expect anymore calls to the function") @@ -146,7 +146,7 @@ func (s *periodicWorkerSuite) TestKill(c *gc.C) { ExpectedValue error }{ {nil, nil}, - {testError, testError}, + {errTest, errTest}, {ErrKilled, nil}, } diff --git a/internal/worker/simpleworker.go b/internal/worker/simpleworker.go index cb22190f841..a95da54206f 100644 --- a/internal/worker/simpleworker.go +++ b/internal/worker/simpleworker.go @@ -4,6 +4,8 @@ package worker import ( + "context" + "github.com/juju/worker/v4" "gopkg.in/tomb.v2" ) @@ -16,10 +18,13 @@ type simpleWorker struct { // NewSimpleWorker returns a worker that runs the given function. The // stopCh argument will be closed when the worker is killed. The error returned // by the doWork function will be returned by the worker's Wait function. -func NewSimpleWorker(doWork func(stopCh <-chan struct{}) error) worker.Worker { +func NewSimpleWorker(doWork func(context.Context) error) worker.Worker { w := &simpleWorker{} w.tomb.Go(func() error { - return doWork(w.tomb.Dying()) + ctx, cancel := context.WithCancel(w.tomb.Context(context.Background())) + defer cancel() + + return doWork(ctx) }) return w } diff --git a/internal/worker/simpleworker_test.go b/internal/worker/simpleworker_test.go index ecdc44846c3..fd84edbe6cc 100644 --- a/internal/worker/simpleworker_test.go +++ b/internal/worker/simpleworker_test.go @@ -4,6 +4,7 @@ package worker import ( + "context" "errors" gc "gopkg.in/check.v1" @@ -17,19 +18,19 @@ type simpleWorkerSuite struct { var _ = gc.Suite(&simpleWorkerSuite{}) -var testError = errors.New("test error") +var errTest = errors.New("test error") func (s *simpleWorkerSuite) TestWait(c *gc.C) { - doWork := func(_ <-chan struct{}) error { - return testError + doWork := func(context.Context) error { + return errTest } w := NewSimpleWorker(doWork) - c.Assert(w.Wait(), gc.Equals, testError) + c.Assert(w.Wait(), gc.Equals, errTest) } func (s *simpleWorkerSuite) TestWaitNil(c *gc.C) { - doWork := func(_ <-chan struct{}) error { + doWork := func(context.Context) error { return nil } @@ -38,14 +39,14 @@ func (s *simpleWorkerSuite) TestWaitNil(c *gc.C) { } func (s *simpleWorkerSuite) TestKill(c *gc.C) { - doWork := func(stopCh <-chan struct{}) error { - <-stopCh - return testError + doWork := func(ctx context.Context) error { + <-ctx.Done() + return errTest } w := NewSimpleWorker(doWork) w.Kill() - c.Assert(w.Wait(), gc.Equals, testError) + c.Assert(w.Wait(), gc.Equals, errTest) // test we can kill again without a panic w.Kill() diff --git a/internal/worker/undertaker/undertaker.go b/internal/worker/undertaker/undertaker.go index 47323fc6f33..affcabb863e 100644 --- a/internal/worker/undertaker/undertaker.go +++ b/internal/worker/undertaker/undertaker.go @@ -155,11 +155,12 @@ func (u *Undertaker) run() (errOut error) { // Watch for changes to model destroy values, if so, cancel the context // and restart the worker. - err = u.catacomb.Add(worker.NewSimpleWorker(func(stopCh <-chan struct{}) error { + err = u.catacomb.Add(worker.NewSimpleWorker(func(ctx context.Context) error { for { select { - case <-stopCh: + case <-ctx.Done(): return nil + case <-modelWatcher.Changes(): result, err := u.config.Facade.ModelInfo() if errors.Is(err, errors.NotFound) || err != nil || result.Error != nil {