diff --git a/api/controller/caasmodeloperator/client.go b/api/controller/caasmodeloperator/client.go index e8dfadc5841..315afc1bd43 100644 --- a/api/controller/caasmodeloperator/client.go +++ b/api/controller/caasmodeloperator/client.go @@ -100,9 +100,9 @@ func (c *Client) SetPassword(password string) error { // WatchModelOperatorProvisioningInfo provides a watcher for changes that affect the // information returned by ModelOperatorProvisioningInfo. -func (c *Client) WatchModelOperatorProvisioningInfo() (watcher.NotifyWatcher, error) { +func (c *Client) WatchModelOperatorProvisioningInfo(ctx context.Context) (watcher.NotifyWatcher, error) { var result params.NotifyWatchResult - if err := c.facade.FacadeCall(context.TODO(), "WatchModelOperatorProvisioningInfo", nil, &result); err != nil { + if err := c.facade.FacadeCall(ctx, "WatchModelOperatorProvisioningInfo", nil, &result); err != nil { return nil, err } if result.Error != nil { diff --git a/apiserver/common/cloudspec/cloudspec.go b/apiserver/common/cloudspec/cloudspec.go index 8341a844eb9..41aab120559 100644 --- a/apiserver/common/cloudspec/cloudspec.go +++ b/apiserver/common/cloudspec/cloudspec.go @@ -12,11 +12,12 @@ import ( "github.com/juju/juju/apiserver/common" apiservererrors "github.com/juju/juju/apiserver/errors" "github.com/juju/juju/apiserver/facade" + "github.com/juju/juju/apiserver/internal" corewatcher "github.com/juju/juju/core/watcher" + "github.com/juju/juju/core/watcher/eventsource" environscloudspec "github.com/juju/juju/environs/cloudspec" "github.com/juju/juju/rpc/params" "github.com/juju/juju/state" - "github.com/juju/juju/state/watcher" ) // CloudSpecer defines the CloudSpec api interface @@ -199,22 +200,32 @@ func (s CloudSpecAPI) watchCloudSpecChanges(ctx context.Context, tag names.Model if err != nil { return result, errors.Trace(err) } - var watch *common.MultiNotifyWatcher + var watcher eventsource.Watcher[struct{}] if credentialContentWatch != nil { - watch = common.NewMultiNotifyWatcher(&watcherAdaptor{cloudWatch}, credentialReferenceWatch, &watcherAdaptor{credentialContentWatch}) + watcher, err = eventsource.NewMultiNotifyWatcher(ctx, + &watcherAdaptor{NotifyWatcher: cloudWatch}, + credentialReferenceWatch, + &watcherAdaptor{NotifyWatcher: credentialContentWatch}, + ) } else { // It's rare but possible that a model does not have a credential. // In this case there is no point trying to 'watch' content changes. - watch = common.NewMultiNotifyWatcher(&watcherAdaptor{cloudWatch}, credentialReferenceWatch) - } - // Consume the initial event. Technically, API - // calls to Watch 'transmit' the initial event - // in the Watch response. But NotifyWatchers - // have no state to transmit. - if _, ok := <-watch.Changes(); ok { - result.NotifyWatcherId = s.resources.Register(watch) - } else { - return result, watcher.EnsureErr(watch) + watcher, err = eventsource.NewMultiNotifyWatcher(ctx, + &watcherAdaptor{NotifyWatcher: cloudWatch}, + credentialReferenceWatch, + ) } + if err != nil { + return result, errors.Trace(err) + } + // Consume the initial result for the API. + _, err = internal.FirstResult[struct{}](ctx, watcher) + if err != nil { + return result, errors.Trace(err) + } + + // Ensure we register the watcher, once we know it's working. + result.NotifyWatcherId = s.resources.Register(watcher) + return result, nil } diff --git a/apiserver/common/watch.go b/apiserver/common/watch.go index 59feb75de72..547a245a2fe 100644 --- a/apiserver/common/watch.go +++ b/apiserver/common/watch.go @@ -5,12 +5,9 @@ package common import ( "context" - "sync" - "time" "github.com/juju/errors" "github.com/juju/names/v5" - "gopkg.in/tomb.v2" apiservererrors "github.com/juju/juju/apiserver/errors" "github.com/juju/juju/apiserver/facade" @@ -86,109 +83,3 @@ func (a *AgentEntityWatcher) Watch(ctx context.Context, args params.Entities) (p } return result, nil } - -// MultiNotifyWatcher implements state.NotifyWatcher, combining -// multiple NotifyWatchers. -type MultiNotifyWatcher struct { - tomb tomb.Tomb - watchers []state.NotifyWatcher - changes chan struct{} -} - -// NewMultiNotifyWatcher creates a NotifyWatcher that combines -// each of the NotifyWatchers passed in. Each watcher's initial -// event is consumed, and a single initial event is sent. -// Subsequent events are not coalesced. -func NewMultiNotifyWatcher(w ...state.NotifyWatcher) *MultiNotifyWatcher { - m := &MultiNotifyWatcher{ - watchers: w, - changes: make(chan struct{}), - } - var wg sync.WaitGroup - wg.Add(len(w)) - staging := make(chan struct{}) - for _, w := range w { - // Consume the first event of each watcher. - <-w.Changes() - go func(wCopy state.NotifyWatcher) { - defer wg.Done() - _ = wCopy.Wait() - }(w) - // Copy events from the watcher to the staging channel. - go copyEvents(staging, w.Changes(), &m.tomb) - } - m.tomb.Go(func() error { - m.loop(staging) - wg.Wait() - return nil - }) - return m -} - -// loop copies events from the input channel to the output channel, -// coalescing events by waiting a short time between receiving and -// sending. -func (w *MultiNotifyWatcher) loop(in <-chan struct{}) { - defer close(w.changes) - // out is initialised to m.changes to send the initial event. - out := w.changes - var timer <-chan time.Time - for { - select { - case <-w.tomb.Dying(): - return - case <-in: - if timer == nil { - // TODO(fwereade): 2016-03-17 lp:1558657 - timer = time.After(10 * time.Millisecond) - } - case <-timer: - timer = nil - out = w.changes - case out <- struct{}{}: - out = nil - } - } -} - -// copyEvents copies channel events from "in" to "out", coalescing. -func copyEvents(out chan<- struct{}, in <-chan struct{}, tomb *tomb.Tomb) { - var outC chan<- struct{} - for { - select { - case <-tomb.Dying(): - return - case _, ok := <-in: - if !ok { - return - } - outC = out - case outC <- struct{}{}: - outC = nil - } - } -} - -func (w *MultiNotifyWatcher) Kill() { - w.tomb.Kill(nil) - for _, w := range w.watchers { - w.Kill() - } -} - -func (w *MultiNotifyWatcher) Wait() error { - return w.tomb.Wait() -} - -func (w *MultiNotifyWatcher) Stop() error { - w.Kill() - return w.Wait() -} - -func (w *MultiNotifyWatcher) Err() error { - return w.tomb.Err() -} - -func (w *MultiNotifyWatcher) Changes() <-chan struct{} { - return w.changes -} diff --git a/apiserver/common/watch_test.go b/apiserver/common/watch_test.go index 4fa94db8e72..2ba2133a89b 100644 --- a/apiserver/common/watch_test.go +++ b/apiserver/common/watch_test.go @@ -9,14 +9,12 @@ import ( "github.com/juju/names/v5" jc "github.com/juju/testing/checkers" - "github.com/juju/worker/v4/workertest" gc "gopkg.in/check.v1" "github.com/juju/juju/apiserver/common" apiservertesting "github.com/juju/juju/apiserver/testing" "github.com/juju/juju/rpc/params" "github.com/juju/juju/state" - statetesting "github.com/juju/juju/state/testing" ) type agentEntityWatcherSuite struct{} @@ -92,38 +90,3 @@ func (*agentEntityWatcherSuite) TestWatchNoArgsNoError(c *gc.C) { c.Assert(err, jc.ErrorIsNil) c.Assert(result.Results, gc.HasLen, 0) } - -type multiNotifyWatcherSuite struct{} - -var _ = gc.Suite(&multiNotifyWatcherSuite{}) - -func (*multiNotifyWatcherSuite) TestMultiNotifyWatcher(c *gc.C) { - w0 := apiservertesting.NewFakeNotifyWatcher() - w1 := apiservertesting.NewFakeNotifyWatcher() - - mw := common.NewMultiNotifyWatcher(w0, w1) - defer workertest.CleanKill(c, mw) - - wc := statetesting.NewNotifyWatcherC(c, mw) - wc.AssertOneChange() - - w0.C <- struct{}{} - wc.AssertOneChange() - w1.C <- struct{}{} - wc.AssertOneChange() - - w0.C <- struct{}{} - w1.C <- struct{}{} - wc.AssertOneChange() -} - -func (*multiNotifyWatcherSuite) TestMultiNotifyWatcherStop(c *gc.C) { - w0 := apiservertesting.NewFakeNotifyWatcher() - w1 := apiservertesting.NewFakeNotifyWatcher() - - mw := common.NewMultiNotifyWatcher(w0, w1) - wc := statetesting.NewNotifyWatcherC(c, mw) - wc.AssertOneChange() - statetesting.AssertCanStopWhenSending(c, mw) - wc.AssertClosed() -} diff --git a/apiserver/facades/agent/proxyupdater/proxyupdater.go b/apiserver/facades/agent/proxyupdater/proxyupdater.go index 39df351fbed..94be1fbfdc5 100644 --- a/apiserver/facades/agent/proxyupdater/proxyupdater.go +++ b/apiserver/facades/agent/proxyupdater/proxyupdater.go @@ -10,15 +10,15 @@ import ( "github.com/juju/names/v5" "github.com/juju/proxy" - "github.com/juju/juju/apiserver/common" apiservererrors "github.com/juju/juju/apiserver/errors" "github.com/juju/juju/apiserver/facade" + "github.com/juju/juju/apiserver/internal" "github.com/juju/juju/controller" "github.com/juju/juju/core/network" + "github.com/juju/juju/core/watcher/eventsource" "github.com/juju/juju/environs/config" "github.com/juju/juju/rpc/params" "github.com/juju/juju/state" - "github.com/juju/juju/state/watcher" ) // ProxyUpdaterV2 defines the public methods for the v2 facade. @@ -88,18 +88,24 @@ func NewAPIV2(controller ControllerBackend, backend Backend, resources facade.Re func (api *API) oneWatch(ctx context.Context) params.NotifyWatchResult { var result params.NotifyWatchResult - watch := common.NewMultiNotifyWatcher( + watch, err := eventsource.NewMultiNotifyWatcher(ctx, api.backend.WatchForModelConfigChanges(), - api.controller.WatchAPIHostPortsForAgents()) + api.controller.WatchAPIHostPortsForAgents(), + ) + if err != nil { + result.Error = apiservererrors.ServerError(err) + return result + } - if _, ok := <-watch.Changes(); ok { - result = params.NotifyWatchResult{ - NotifyWatcherId: api.resources.Register(watch), - } - } else { - result.Error = apiservererrors.ServerError(watcher.EnsureErr(watch)) + _, err = internal.FirstResult[struct{}](ctx, watch) + if err != nil { + result.Error = apiservererrors.ServerError(err) + return result + } + + return params.NotifyWatchResult{ + NotifyWatcherId: api.resources.Register(watch), } - return result } // WatchForProxyConfigAndAPIHostPortChanges watches for changes to the proxy and api host port settings. diff --git a/apiserver/facades/agent/uniter/storage.go b/apiserver/facades/agent/uniter/storage.go index a9327447e1b..7ecbc27218c 100644 --- a/apiserver/facades/agent/uniter/storage.go +++ b/apiserver/facades/agent/uniter/storage.go @@ -13,8 +13,10 @@ import ( "github.com/juju/juju/apiserver/common/storagecommon" apiservererrors "github.com/juju/juju/apiserver/errors" "github.com/juju/juju/apiserver/facade" + "github.com/juju/juju/apiserver/internal" "github.com/juju/juju/core/life" corewatcher "github.com/juju/juju/core/watcher" + "github.com/juju/juju/core/watcher/eventsource" "github.com/juju/juju/rpc/params" "github.com/juju/juju/state" "github.com/juju/juju/state/watcher" @@ -307,12 +309,14 @@ func (s *StorageAPI) watchOneStorageAttachment(ctx context.Context, id params.St if err != nil { return nothing, errors.Trace(err) } - if _, ok := <-watch.Changes(); ok { - return params.NotifyWatchResult{ - NotifyWatcherId: s.resources.Register(watch), - }, nil + + if _, err := internal.FirstResult[struct{}](ctx, watch); err != nil { + return nothing, errors.Trace(err) } - return nothing, watcher.EnsureErr(watch) + + return params.NotifyWatchResult{ + NotifyWatcherId: s.resources.Register(watch), + }, nil } // RemoveStorageAttachments removes the specified storage @@ -407,12 +411,12 @@ func watchStorageAttachment( storageTag names.StorageTag, hostTag names.Tag, unitTag names.UnitTag, -) (state.NotifyWatcher, error) { +) (eventsource.Watcher[struct{}], error) { storageInstance, err := st.StorageInstance(storageTag) if err != nil { return nil, errors.Annotate(err, "getting storage instance") } - var watchers []state.NotifyWatcher + var watchers []eventsource.Watcher[struct{}] switch storageInstance.Kind() { case state.StorageKindBlock: if stVolume == nil { @@ -425,7 +429,7 @@ func watchStorageAttachment( // We need to watch both the volume attachment, and the // machine's block devices. A volume attachment's block // device could change (most likely, become present). - watchers = []state.NotifyWatcher{ + watchers = []eventsource.Watcher[struct{}]{ stVolume.WatchVolumeAttachment(hostTag, volume.VolumeTag()), } @@ -452,14 +456,14 @@ func watchStorageAttachment( if err != nil { return nil, errors.Annotate(err, "getting storage filesystem") } - watchers = []state.NotifyWatcher{ + watchers = []eventsource.Watcher[struct{}]{ stFile.WatchFilesystemAttachment(hostTag, filesystem.FilesystemTag()), } default: return nil, errors.Errorf("invalid storage kind %v", storageInstance.Kind()) } watchers = append(watchers, st.WatchStorageAttachment(storageTag, unitTag)) - return common.NewMultiNotifyWatcher(watchers...), nil + return eventsource.NewMultiNotifyWatcher(ctx, watchers...) } // watcherAdaptor adapts a core watcher to a state watcher. diff --git a/apiserver/facades/agent/uniter/storage_test.go b/apiserver/facades/agent/uniter/storage_test.go index 085ad1244f4..9e0f5280d12 100644 --- a/apiserver/facades/agent/uniter/storage_test.go +++ b/apiserver/facades/agent/uniter/storage_test.go @@ -16,9 +16,9 @@ import ( "github.com/juju/juju/apiserver/facades/agent/uniter" apiservertesting "github.com/juju/juju/apiserver/testing" "github.com/juju/juju/core/watcher" + "github.com/juju/juju/core/watcher/watchertest" "github.com/juju/juju/rpc/params" "github.com/juju/juju/state" - statetesting "github.com/juju/juju/state/testing" "github.com/juju/juju/testing" ) @@ -578,7 +578,7 @@ func (s *watchStorageAttachmentSuite) testWatchStorageAttachment(c *gc.C, change s.unitTag, ) c.Assert(err, jc.ErrorIsNil) - wc := statetesting.NewNotifyWatcherC(c, w) + wc := watchertest.NewNotifyWatcherC(c, w) wc.AssertOneChange() change() wc.AssertOneChange() diff --git a/apiserver/facades/controller/caasapplicationprovisioner/provisioner.go b/apiserver/facades/controller/caasapplicationprovisioner/provisioner.go index de57364bcd1..dcaee5ef121 100644 --- a/apiserver/facades/controller/caasapplicationprovisioner/provisioner.go +++ b/apiserver/facades/controller/caasapplicationprovisioner/provisioner.go @@ -24,6 +24,7 @@ import ( "github.com/juju/juju/apiserver/common/storagecommon" apiservererrors "github.com/juju/juju/apiserver/errors" "github.com/juju/juju/apiserver/facade" + "github.com/juju/juju/apiserver/internal" "github.com/juju/juju/caas" k8sconstants "github.com/juju/juju/caas/kubernetes/provider/constants" "github.com/juju/juju/controller" @@ -32,6 +33,7 @@ import ( "github.com/juju/juju/core/objectstore" "github.com/juju/juju/core/resources" "github.com/juju/juju/core/status" + "github.com/juju/juju/core/watcher/eventsource" storageerrors "github.com/juju/juju/domain/storage/errors" "github.com/juju/juju/environs/bootstrap" "github.com/juju/juju/environs/config" @@ -224,7 +226,7 @@ func (a *API) WatchProvisioningInfo(ctx context.Context, args params.Entities) ( continue } - res, err := a.watchProvisioningInfo(appName) + res, err := a.watchProvisioningInfo(ctx, appName) if err != nil { result.Results[i].Error = apiservererrors.ServerError(err) continue @@ -235,7 +237,7 @@ func (a *API) WatchProvisioningInfo(ctx context.Context, args params.Entities) ( return result, nil } -func (a *API) watchProvisioningInfo(appName names.ApplicationTag) (params.NotifyWatchResult, error) { +func (a *API) watchProvisioningInfo(ctx context.Context, appName names.ApplicationTag) (params.NotifyWatchResult, error) { result := params.NotifyWatchResult{} app, err := a.state.Application(appName.Id()) if err != nil { @@ -252,14 +254,22 @@ func (a *API) watchProvisioningInfo(appName names.ApplicationTag) (params.Notify controllerAPIHostPortsWatcher := a.ctrlSt.WatchAPIHostPortsForAgents() modelConfigWatcher := model.WatchForModelConfigChanges() - multiWatcher := common.NewMultiNotifyWatcher(appWatcher, controllerConfigWatcher, controllerAPIHostPortsWatcher, modelConfigWatcher) + multiWatcher, err := eventsource.NewMultiNotifyWatcher(ctx, + appWatcher, + controllerConfigWatcher, + controllerAPIHostPortsWatcher, + modelConfigWatcher, + ) + if err != nil { + return result, errors.Trace(err) + } - if _, ok := <-multiWatcher.Changes(); ok { - result.NotifyWatcherId = a.resources.Register(multiWatcher) - } else { - return result, watcher.EnsureErr(multiWatcher) + // Consume the initial event and forward it to the result. + if _, err := internal.FirstResult[struct{}](ctx, multiWatcher); err != nil { + return result, errors.Trace(err) } + result.NotifyWatcherId = a.resources.Register(multiWatcher) return result, nil } diff --git a/apiserver/facades/controller/caasapplicationprovisioner/provisioner_test.go b/apiserver/facades/controller/caasapplicationprovisioner/provisioner_test.go index 8ad17c2a894..26190c0eb8b 100644 --- a/apiserver/facades/controller/caasapplicationprovisioner/provisioner_test.go +++ b/apiserver/facades/controller/caasapplicationprovisioner/provisioner_test.go @@ -23,6 +23,7 @@ import ( "github.com/juju/juju/core/config" jujuresource "github.com/juju/juju/core/resources" "github.com/juju/juju/core/status" + "github.com/juju/juju/core/watcher/eventsource" "github.com/juju/juju/internal/docker" "github.com/juju/juju/rpc/params" "github.com/juju/juju/state" @@ -132,7 +133,7 @@ func (s *CAASApplicationProvisionerSuite) TestProvisioningInfoPendingCharmError( url: "ch:gitlab", }, } - result, err := s.api.ProvisioningInfo(context.Background(), params.Entities{Entities: []params.Entity{{"application-gitlab"}}}) + result, err := s.api.ProvisioningInfo(context.Background(), params.Entities{Entities: []params.Entity{{Tag: "application-gitlab"}}}) c.Assert(err, jc.ErrorIsNil) c.Assert(result.Results[0].Error, gc.ErrorMatches, `charm "ch:gitlab" pending not provisioned`) } @@ -168,7 +169,7 @@ func (s *CAASApplicationProvisionerSuite) TestWatchProvisioningInfo(c *gc.C) { c.Assert(results.Results, gc.HasLen, 1) c.Assert(results.Results[0].Error, gc.IsNil) res := s.resources.Get("1") - c.Assert(res, gc.FitsTypeOf, (*common.MultiNotifyWatcher)(nil)) + c.Assert(res, gc.FitsTypeOf, (*eventsource.MultiWatcher[struct{}])(nil)) } func (s *CAASApplicationProvisionerSuite) TestSetOperatorStatus(c *gc.C) { diff --git a/apiserver/facades/controller/caasmodeloperator/operator.go b/apiserver/facades/controller/caasmodeloperator/operator.go index d39a9f11da9..897db19b05c 100644 --- a/apiserver/facades/controller/caasmodeloperator/operator.go +++ b/apiserver/facades/controller/caasmodeloperator/operator.go @@ -14,11 +14,12 @@ import ( "github.com/juju/juju/apiserver/common" apiservererrors "github.com/juju/juju/apiserver/errors" "github.com/juju/juju/apiserver/facade" + "github.com/juju/juju/apiserver/internal" "github.com/juju/juju/controller" + "github.com/juju/juju/core/watcher/eventsource" "github.com/juju/juju/internal/cloudconfig/podcfg" "github.com/juju/juju/internal/docker" "github.com/juju/juju/rpc/params" - "github.com/juju/juju/state/watcher" ) // TODO (manadart 2020-10-21): Remove the ModelUUID method @@ -63,7 +64,7 @@ func NewAPI( // WatchModelOperatorProvisioningInfo provides a watcher for changes that affect the // information returned by ModelOperatorProvisioningInfo. -func (a *API) WatchModelOperatorProvisioningInfo() (params.NotifyWatchResult, error) { +func (a *API) WatchModelOperatorProvisioningInfo(ctx context.Context) (params.NotifyWatchResult, error) { result := params.NotifyWatchResult{} model, err := a.state.Model() @@ -75,14 +76,21 @@ func (a *API) WatchModelOperatorProvisioningInfo() (params.NotifyWatchResult, er controllerAPIHostPortsWatcher := a.ctrlState.WatchAPIHostPortsForAgents() modelConfigWatcher := model.WatchForModelConfigChanges() - multiWatcher := common.NewMultiNotifyWatcher(controllerConfigWatcher, controllerAPIHostPortsWatcher, modelConfigWatcher) + multiWatcher, err := eventsource.NewMultiNotifyWatcher(ctx, + controllerConfigWatcher, + controllerAPIHostPortsWatcher, + modelConfigWatcher, + ) - if _, ok := <-multiWatcher.Changes(); ok { - result.NotifyWatcherId = a.resources.Register(multiWatcher) - } else { - return result, watcher.EnsureErr(multiWatcher) + if err != nil { + return result, errors.Trace(err) + } + + if _, err := internal.FirstResult[struct{}](ctx, multiWatcher); err != nil { + return result, errors.Trace(err) } + result.NotifyWatcherId = a.resources.Register(multiWatcher) return result, nil } diff --git a/apiserver/facades/controller/caasmodeloperator/operator_test.go b/apiserver/facades/controller/caasmodeloperator/operator_test.go index d2d8f540f9e..36a8dc78f46 100644 --- a/apiserver/facades/controller/caasmodeloperator/operator_test.go +++ b/apiserver/facades/controller/caasmodeloperator/operator_test.go @@ -14,6 +14,7 @@ import ( "github.com/juju/juju/apiserver/common" "github.com/juju/juju/apiserver/facades/controller/caasmodeloperator" apiservertesting "github.com/juju/juju/apiserver/testing" + "github.com/juju/juju/core/watcher/eventsource" "github.com/juju/juju/internal/cloudconfig/podcfg" statetesting "github.com/juju/juju/state/testing" coretesting "github.com/juju/juju/testing" @@ -94,9 +95,9 @@ func (m *ModelOperatorSuite) TestWatchProvisioningInfo(c *gc.C) { apiHostPortsForAgentsChanged <- struct{}{} modelConfigChanged <- struct{}{} - results, err := m.api.WatchModelOperatorProvisioningInfo() + results, err := m.api.WatchModelOperatorProvisioningInfo(context.Background()) c.Assert(err, jc.ErrorIsNil) c.Assert(results.Error, gc.IsNil) res := m.resources.Get("1") - c.Assert(res, gc.FitsTypeOf, (*common.MultiNotifyWatcher)(nil)) + c.Assert(res, gc.FitsTypeOf, (*eventsource.MultiWatcher[struct{}])(nil)) } diff --git a/apiserver/facades/controller/crossmodelrelations/crossmodelrelations.go b/apiserver/facades/controller/crossmodelrelations/crossmodelrelations.go index 01c23a7e6da..5bd23d74b4e 100644 --- a/apiserver/facades/controller/crossmodelrelations/crossmodelrelations.go +++ b/apiserver/facades/controller/crossmodelrelations/crossmodelrelations.go @@ -22,10 +22,12 @@ import ( "github.com/juju/juju/apiserver/common/firewall" apiservererrors "github.com/juju/juju/apiserver/errors" "github.com/juju/juju/apiserver/facade" + "github.com/juju/juju/apiserver/internal" "github.com/juju/juju/core/life" coremacaroon "github.com/juju/juju/core/macaroon" "github.com/juju/juju/core/secrets" corewatcher "github.com/juju/juju/core/watcher" + "github.com/juju/juju/core/watcher/eventsource" "github.com/juju/juju/rpc/params" "github.com/juju/juju/state" "github.com/juju/juju/state/watcher" @@ -33,7 +35,7 @@ import ( type egressAddressWatcherFunc func(facade.Resources, firewall.State, params.Entities) (params.StringsWatchResults, error) type relationStatusWatcherFunc func(CrossModelRelationsState, names.RelationTag) (state.StringsWatcher, error) -type offerStatusWatcherFunc func(CrossModelRelationsState, string) (OfferWatcher, error) +type offerStatusWatcherFunc func(context.Context, CrossModelRelationsState, string) (OfferWatcher, error) type consumedSecretsWatcherFunc func(CrossModelRelationsState, string) (state.StringsWatcher, error) // CrossModelRelationsAPIv3 provides access to the CrossModelRelations API facade. @@ -466,7 +468,7 @@ type OfferWatcher interface { } type offerWatcher struct { - *common.MultiNotifyWatcher + *eventsource.MultiWatcher[struct{}] offerUUID string offerName string } @@ -479,7 +481,7 @@ func (w *offerWatcher) OfferName() string { return w.offerName } -func watchOfferStatus(st CrossModelRelationsState, offerUUID string) (OfferWatcher, error) { +func watchOfferStatus(ctx context.Context, st CrossModelRelationsState, offerUUID string) (OfferWatcher, error) { w1, err := st.WatchOfferStatus(offerUUID) if err != nil { return nil, errors.Trace(err) @@ -489,8 +491,11 @@ func watchOfferStatus(st CrossModelRelationsState, offerUUID string) (OfferWatch return nil, errors.Trace(err) } w2 := st.WatchOffer(offer.OfferName) - mw := common.NewMultiNotifyWatcher(w1, w2) - return &offerWatcher{mw, offerUUID, offer.OfferName}, nil + mw, err := eventsource.NewMultiNotifyWatcher(ctx, w1, w2) + if err != nil { + return nil, errors.Trace(err) + } + return &offerWatcher{MultiWatcher: mw, offerUUID: offerUUID, offerName: offer.OfferName}, nil } func watchConsumedSecrets(st CrossModelRelationsState, appName string) (state.StringsWatcher, error) { @@ -516,16 +521,16 @@ func (api *CrossModelRelationsAPIv3) WatchOfferStatus( continue } - w, err := api.offerStatusWatcher(api.st, arg.OfferUUID) + w, err := api.offerStatusWatcher(ctx, api.st, arg.OfferUUID) if err != nil { results.Results[i].Error = apiservererrors.ServerError(err) continue } - _, ok := <-w.Changes() - if !ok { - results.Results[i].Error = apiservererrors.ServerError(watcher.EnsureErr(w)) + if _, err := internal.FirstResult[struct{}](ctx, w); err != nil { + results.Results[i].Error = apiservererrors.ServerError(err) continue } + change, err := commoncrossmodel.GetOfferStatusChange(api.st, arg.OfferUUID, w.OfferName()) if err != nil { results.Results[i].Error = apiservererrors.ServerError(err) diff --git a/apiserver/facades/controller/crossmodelrelations/crossmodelrelations_test.go b/apiserver/facades/controller/crossmodelrelations/crossmodelrelations_test.go index 3d9c5b27a81..fcf0f84b331 100644 --- a/apiserver/facades/controller/crossmodelrelations/crossmodelrelations_test.go +++ b/apiserver/facades/controller/crossmodelrelations/crossmodelrelations_test.go @@ -89,7 +89,7 @@ func (s *crossmodelRelationsSuite) SetUpTest(c *gc.C) { w.changes <- []string{"db2:db django:db"} return w, nil } - offerStatusWatcher := func(st crossmodelrelations.CrossModelRelationsState, offerUUID string) (crossmodelrelations.OfferWatcher, error) { + offerStatusWatcher := func(_ context.Context, st crossmodelrelations.CrossModelRelationsState, offerUUID string) (crossmodelrelations.OfferWatcher, error) { c.Assert(s.st, gc.Equals, st) s.watchedOffers = []string{offerUUID} w := &mockOfferStatusWatcher{ diff --git a/caas/application.go b/caas/application.go index 0b5ef69903c..4fc5428065a 100644 --- a/caas/application.go +++ b/caas/application.go @@ -16,14 +16,12 @@ import ( "github.com/juju/juju/internal/storage" ) -//go:generate go run go.uber.org/mock/mockgen -package mocks -destination mocks/application_mock.go github.com/juju/juju/caas Application - // Application is for interacting with the CAAS substrate. type Application interface { Ensure(config ApplicationConfig) error Exists() (DeploymentState, error) Delete() error - Watch() (watcher.NotifyWatcher, error) + Watch(context.Context) (watcher.NotifyWatcher, error) WatchReplicas() (watcher.NotifyWatcher, error) // ApplicationPodSpec returns the pod spec needed to run the application workload. diff --git a/caas/broker.go b/caas/broker.go index 78a11f92907..c5aae639d17 100644 --- a/caas/broker.go +++ b/caas/broker.go @@ -23,8 +23,6 @@ import ( "github.com/juju/juju/internal/storage" ) -//go:generate go run go.uber.org/mock/mockgen -package mocks -destination mocks/broker_mock.go github.com/juju/juju/caas Broker - // ContainerEnvironProvider represents a computing and storage provider // for a container runtime. type ContainerEnvironProvider interface { diff --git a/caas/kubernetes/provider/application/application.go b/caas/kubernetes/provider/application/application.go index 8b903be0f7a..d23b8266c31 100644 --- a/caas/kubernetes/provider/application/application.go +++ b/caas/kubernetes/provider/application/application.go @@ -46,6 +46,7 @@ import ( "github.com/juju/juju/core/paths" "github.com/juju/juju/core/status" "github.com/juju/juju/core/watcher" + "github.com/juju/juju/core/watcher/eventsource" "github.com/juju/juju/internal/cloudconfig/podcfg" jujustorage "github.com/juju/juju/internal/storage" "github.com/juju/juju/juju/osenv" @@ -1028,7 +1029,7 @@ func (a *app) Delete() error { // Watch returns a watcher which notifies when there // are changes to the application of the specified application. -func (a *app) Watch() (watcher.NotifyWatcher, error) { +func (a *app) Watch(ctx context.Context) (watcher.NotifyWatcher, error) { factory := informers.NewSharedInformerFactoryWithOptions(a.client, 0, informers.WithNamespace(a.namespace), informers.WithTweakListOptions(func(o *metav1.ListOptions) { @@ -1054,7 +1055,7 @@ func (a *app) Watch() (watcher.NotifyWatcher, error) { if err != nil { return nil, errors.Trace(err) } - return watcher.NewMultiNotifyWatcher(w1, w2), nil + return eventsource.NewMultiNotifyWatcher(ctx, w1, w2) } func (a *app) WatchReplicas() (watcher.NotifyWatcher, error) { diff --git a/caas/kubernetes/provider/application/application_test.go b/caas/kubernetes/provider/application/application_test.go index a18dbd5f00a..a052ee54dc5 100644 --- a/caas/kubernetes/provider/application/application_test.go +++ b/caas/kubernetes/provider/application/application_test.go @@ -1350,7 +1350,7 @@ func (s *applicationSuite) TestWatchNotsupported(c *gc.C) { return w, nil } - _, err := app.Watch() + _, err := app.Watch(context.Background()) c.Assert(err, gc.ErrorMatches, `unknown deployment type not supported`) } @@ -1363,7 +1363,7 @@ func (s *applicationSuite) TestWatch(c *gc.C) { return w, nil } - w, err := app.Watch() + w, err := app.Watch(context.Background()) c.Assert(err, jc.ErrorIsNil) select { diff --git a/caas/mocks/application_mock.go b/caas/mocks/application_mock.go index 71d3c144462..31e9b052afe 100644 --- a/caas/mocks/application_mock.go +++ b/caas/mocks/application_mock.go @@ -217,18 +217,18 @@ func (mr *MockApplicationMockRecorder) UpdateService(arg0 any) *gomock.Call { } // Watch mocks base method. -func (m *MockApplication) Watch() (watcher.Watcher[struct{}], error) { +func (m *MockApplication) Watch(arg0 context.Context) (watcher.Watcher[struct{}], error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Watch") + ret := m.ctrl.Call(m, "Watch", arg0) ret0, _ := ret[0].(watcher.Watcher[struct{}]) ret1, _ := ret[1].(error) return ret0, ret1 } // Watch indicates an expected call of Watch. -func (mr *MockApplicationMockRecorder) Watch() *gomock.Call { +func (mr *MockApplicationMockRecorder) Watch(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Watch", reflect.TypeOf((*MockApplication)(nil).Watch)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Watch", reflect.TypeOf((*MockApplication)(nil).Watch), arg0) } // WatchReplicas mocks base method. diff --git a/caas/package_test.go b/caas/package_test.go index 11bb4bf25e3..63c26d6e288 100644 --- a/caas/package_test.go +++ b/caas/package_test.go @@ -9,6 +9,9 @@ import ( gc "gopkg.in/check.v1" ) +//go:generate go run go.uber.org/mock/mockgen -package mocks -destination mocks/broker_mock.go github.com/juju/juju/caas Broker +//go:generate go run go.uber.org/mock/mockgen -package mocks -destination mocks/application_mock.go github.com/juju/juju/caas Application + func TestAll(t *testing.T) { gc.TestingT(t) } diff --git a/core/watcher/eventsource/multiwatcher.go b/core/watcher/eventsource/multiwatcher.go new file mode 100644 index 00000000000..f621f81fb38 --- /dev/null +++ b/core/watcher/eventsource/multiwatcher.go @@ -0,0 +1,132 @@ +// Copyright 2024 Canonical Ltd. +// Licensed under the AGPLv3, see LICENCE file for details. + +package eventsource + +import ( + "context" + + "github.com/juju/errors" + "github.com/juju/worker/v4" + "github.com/juju/worker/v4/catacomb" +) + +// Applier is a function that applies a change to a value. +type Applier[T any] func(T, T) T + +// MultiWatcher implements Watcher, combining multiple Watchers. +type MultiWatcher[T any] struct { + catacomb catacomb.Catacomb + staging, changes chan T + applier Applier[T] +} + +// NewMultiNotifyWatcher creates a NotifyWatcher that combines +// each of the NotifyWatchers passed in. Each watcher's initial +// event is consumed, and a single initial event is sent. +func NewMultiNotifyWatcher(ctx context.Context, watchers ...Watcher[struct{}]) (*MultiWatcher[struct{}], error) { + applier := func(_, _ struct{}) struct{} { + return struct{}{} + } + return NewMultiWatcher[struct{}](ctx, applier, watchers...) +} + +// NewMultiWatcher creates a NotifyWatcher that combines +// each of the NotifyWatchers passed in. Each watcher's initial +// event is consumed, and a single initial event is sent. +// Subsequent events are not coalesced. +func NewMultiWatcher[T any](ctx context.Context, applier Applier[T], watchers ...Watcher[T]) (*MultiWatcher[T], error) { + workers := make([]worker.Worker, len(watchers)) + for i, w := range watchers { + _, err := ConsumeInitialEvent[T](ctx, w) + if err != nil { + return nil, errors.Trace(err) + } + + workers[i] = w + } + + w := &MultiWatcher[T]{ + staging: make(chan T), + changes: make(chan T), + applier: applier, + } + + if err := catacomb.Invoke(catacomb.Plan{ + Site: &w.catacomb, + Work: w.loop, + Init: workers, + }); err != nil { + return nil, errors.Trace(err) + } + + for _, watcher := range watchers { + // Copy events from the watcher to the staging channel. + go w.copyEvents(watcher.Changes()) + } + + return w, nil +} + +// loop copies events from the input channel to the output channel, +// coalescing events by waiting a short time between receiving and +// sending. +func (w *MultiWatcher[T]) loop() error { + defer close(w.changes) + + out := w.changes + var payload T + for { + select { + case <-w.catacomb.Dying(): + return w.catacomb.ErrDying() + case payload = <-w.staging: + out = w.changes + case out <- payload: + out = nil + } + } +} + +// copyEvents copies channel events from "in" to "out", coalescing. +func (w *MultiWatcher[T]) copyEvents(in <-chan T) { + var ( + outC chan<- T + payload T + ) + for { + select { + case <-w.catacomb.Dying(): + return + case v, ok := <-in: + if !ok { + return + } + payload = w.applier(payload, v) + outC = w.staging + case outC <- payload: + outC = nil + } + } +} + +func (w *MultiWatcher[T]) Kill() { + w.catacomb.Kill(nil) +} + +func (w *MultiWatcher[T]) Wait() error { + return w.catacomb.Wait() +} + +func (w *MultiWatcher[T]) Stop() error { + w.Kill() + return w.Wait() +} + +func (w *MultiWatcher[T]) Err() error { + return w.catacomb.Err() +} + +func (w *MultiWatcher[T]) Changes() <-chan T { + return w.changes +} diff --git a/core/watcher/eventsource/mutliwatcher_test.go b/core/watcher/eventsource/mutliwatcher_test.go new file mode 100644 index 00000000000..8a474807c5a --- /dev/null +++ b/core/watcher/eventsource/mutliwatcher_test.go @@ -0,0 +1,81 @@ +// Copyright 2024 Canonical Ltd. +// Licensed under the AGPLv3, see LICENCE file for details. + +package eventsource + +import ( + "context" + + jc "github.com/juju/testing/checkers" + "github.com/juju/worker/v4/workertest" + gc "gopkg.in/check.v1" + + "github.com/juju/juju/core/watcher/watchertest" +) + +type multiNotifyWatcherSuite struct{} + +var _ = gc.Suite(&multiNotifyWatcherSuite{}) + +func (*multiNotifyWatcherSuite) TestMultiWatcher(c *gc.C) { + ch0 := make(chan struct{}, 1) + w0 := watchertest.NewMockNotifyWatcher(ch0) + defer workertest.DirtyKill(c, w0) + + ch1 := make(chan struct{}, 1) + w1 := watchertest.NewMockNotifyWatcher(ch1) + defer workertest.DirtyKill(c, w1) + + // Initial events are consumed by the multiwatcher. + ch0 <- struct{}{} + ch1 <- struct{}{} + + w, err := NewMultiNotifyWatcher(context.Background(), w0, w1) + c.Assert(err, jc.ErrorIsNil) + + wc := watchertest.NewNotifyWatcherC(c, w) + defer workertest.DirtyKill(c, w) + wc.AssertOneChange() + + ch0 <- struct{}{} + wc.AssertOneChange() + ch1 <- struct{}{} + wc.AssertOneChange() + + ch0 <- struct{}{} + ch1 <- struct{}{} + wc.AssertAtLeastOneChange() + + workertest.CleanKill(c, w) +} + +func (*multiNotifyWatcherSuite) TestMultiWatcherStop(c *gc.C) { + ch0 := make(chan struct{}, 1) + w0 := watchertest.NewMockNotifyWatcher(ch0) + defer workertest.DirtyKill(c, w0) + + ch1 := make(chan struct{}, 1) + w1 := watchertest.NewMockNotifyWatcher(ch1) + defer workertest.DirtyKill(c, w1) + + // Initial events are consumed by the multiwatcher. + ch0 <- struct{}{} + ch1 <- struct{}{} + + w, err := NewMultiNotifyWatcher(context.Background(), w0, w1) + c.Assert(err, jc.ErrorIsNil) + + wc := watchertest.NewNotifyWatcherC(c, w) + defer workertest.DirtyKill(c, w) + wc.AssertOneChange() + + workertest.CleanKill(c, w) + wc.AssertKilled() + + // Ensure that the underlying watchers are also stopped. + wc0 := watchertest.NewNotifyWatcherC(c, w0) + wc0.AssertKilled() + + wc1 := watchertest.NewNotifyWatcherC(c, w1) + wc1.AssertKilled() +} diff --git a/core/watcher/multinotify.go b/core/watcher/multinotify.go deleted file mode 100644 index afb4a0e2e77..00000000000 --- a/core/watcher/multinotify.go +++ /dev/null @@ -1,117 +0,0 @@ -// Copyright 2019 Canonical Ltd. -// Licensed under the AGPLv3, see LICENCE file for details. - -package watcher - -import ( - "sync" - "time" - - "gopkg.in/tomb.v2" -) - -// MultiNotifyWatcher implements NotifyWatcher, combining -// multiple NotifyWatchers. -type MultiNotifyWatcher struct { - tomb tomb.Tomb - watchers []NotifyWatcher - changes chan struct{} -} - -// NewMultiNotifyWatcher creates a NotifyWatcher that combines -// each of the NotifyWatchers passed in. Each watcher's initial -// event is consumed, and a single initial event is sent. -// Subsequent events are not coalesced. -func NewMultiNotifyWatcher(w ...NotifyWatcher) *MultiNotifyWatcher { - m := &MultiNotifyWatcher{ - watchers: w, - changes: make(chan struct{}), - } - var wg sync.WaitGroup - wg.Add(len(w)) - staging := make(chan struct{}) - for _, w := range w { - // Consume the first event of each watcher. - <-w.Changes() - go func(wCopy NotifyWatcher) { - defer wg.Done() - _ = wCopy.Wait() - }(w) - // Copy events from the watcher to the staging channel. - go copyEvents(staging, w.Changes(), &m.tomb) - } - m.tomb.Go(func() error { - m.loop(staging) - wg.Wait() - return nil - }) - return m -} - -// loop copies events from the input channel to the output channel, -// coalescing events by waiting a short time between receiving and -// sending. -func (w *MultiNotifyWatcher) loop(in <-chan struct{}) { - defer close(w.changes) - // out is initialised to m.changes to send the initial event. - out := w.changes - var timer <-chan time.Time - for { - select { - case <-w.tomb.Dying(): - return - case <-in: - if timer == nil { - // TODO(fwereade): 2016-03-17 lp:1558657 - timer = time.After(10 * time.Millisecond) - } - case <-timer: - timer = nil - out = w.changes - case out <- struct{}{}: - out = nil - } - } -} - -// copyEvents copies channel events from "in" to "out", coalescing. -func copyEvents(out chan<- struct{}, in <-chan struct{}, tomb *tomb.Tomb) { - var outC chan<- struct{} - for { - select { - case <-tomb.Dying(): - return - case _, ok := <-in: - if !ok { - return - } - outC = out - case outC <- struct{}{}: - outC = nil - } - } -} - -func (w *MultiNotifyWatcher) Kill() { - w.tomb.Kill(nil) - for _, w := range w.watchers { - w.Kill() - } -} - -func (w *MultiNotifyWatcher) Wait() error { - return w.tomb.Wait() -} - -func (w *MultiNotifyWatcher) Stop() error { - w.Kill() - return w.Wait() -} - -func (w *MultiNotifyWatcher) Err() error { - return w.tomb.Err() -} - -func (w *MultiNotifyWatcher) Changes() <-chan struct{} { - return w.changes -} diff --git a/core/watcher/multinotify_test.go b/core/watcher/multinotify_test.go deleted file mode 100644 index c8062525c4c..00000000000 --- a/core/watcher/multinotify_test.go +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright 2019 Canonical Ltd. -// Licensed under the AGPLv3, see LICENCE file for details. - -package watcher_test - -import ( - gc "gopkg.in/check.v1" - - "github.com/juju/juju/core/watcher" - "github.com/juju/juju/core/watcher/watchertest" -) - -type multiNotifyWatcherSuite struct{} - -var _ = gc.Suite(&multiNotifyWatcherSuite{}) - -func (*multiNotifyWatcherSuite) TestMultiNotifyWatcher(c *gc.C) { - ch0 := make(chan struct{}, 1) - w0 := watchertest.NewMockNotifyWatcher(ch0) - ch1 := make(chan struct{}, 1) - w1 := watchertest.NewMockNotifyWatcher(ch1) - - // Initial events are consumed by the multiwatcher. - ch0 <- struct{}{} - ch1 <- struct{}{} - - w := watcher.NewMultiNotifyWatcher(w0, w1) - wc := watchertest.NewNotifyWatcherC(c, w) - defer wc.AssertKilled() - wc.AssertOneChange() - - ch0 <- struct{}{} - wc.AssertOneChange() - ch1 <- struct{}{} - wc.AssertOneChange() - - ch0 <- struct{}{} - ch1 <- struct{}{} - wc.AssertOneChange() -} diff --git a/core/watcher/watchertest/notify.go b/core/watcher/watchertest/notify.go index 2dabb9f6c77..ece12979161 100644 --- a/core/watcher/watchertest/notify.go +++ b/core/watcher/watchertest/notify.go @@ -80,6 +80,17 @@ func (c NotifyWatcherC) AssertOneChange() { c.AssertNoChange() } +// AssertAtLeastOneChange fails if no change is sent before a long time has +// passed. +func (c NotifyWatcherC) AssertAtLeastOneChange() { + select { + case _, ok := <-c.Watcher.Changes(): + c.Assert(ok, jc.IsTrue) + case <-time.After(testing.LongWait): + c.Fatalf("watcher did not send change") + } +} + // AssertChanges asserts that there was a series of changes for a given // duration. If there are any more changes after that period, then it // will fail. diff --git a/internal/worker/caasapplicationprovisioner/application.go b/internal/worker/caasapplicationprovisioner/application.go index 44bd49d7d88..2d16a486aec 100644 --- a/internal/worker/caasapplicationprovisioner/application.go +++ b/internal/worker/caasapplicationprovisioner/application.go @@ -4,6 +4,7 @@ package caasapplicationprovisioner import ( + "context" "fmt" "time" @@ -104,6 +105,9 @@ func (a *appWorker) Wait() error { } func (a *appWorker) loop() error { + ctx, cancel := a.scopedContext() + defer cancel() + // TODO(sidecar): support more than statefulset app := a.broker.Application(a.name, caas.DeploymentStateful) @@ -179,7 +183,7 @@ func (a *appWorker) loop() error { return errors.Annotatef(err, "failed to watch for application %q units changes", a.name) } - done := false + var done bool var ( initial = true @@ -246,7 +250,7 @@ func (a *appWorker) loop() error { } } if appChanges == nil { - appWatcher, err := app.Watch() + appWatcher, err := app.Watch(ctx) if err != nil { return errors.Annotatef(err, "failed to watch for changes to application %q", a.name) } @@ -424,3 +428,11 @@ func (a *appWorker) loop() error { } } } + +// scopedContext returns a context that is in the scope of the watcher lifetime. +// It returns a cancellable context that is cancelled when the action has +// completed. +func (a *appWorker) scopedContext() (context.Context, context.CancelFunc) { + ctx, cancel := context.WithCancel(context.Background()) + return a.catacomb.Context(ctx), cancel +} diff --git a/internal/worker/caasapplicationprovisioner/application_test.go b/internal/worker/caasapplicationprovisioner/application_test.go index 80108c16fa6..7a4559cb833 100644 --- a/internal/worker/caasapplicationprovisioner/application_test.go +++ b/internal/worker/caasapplicationprovisioner/application_test.go @@ -171,7 +171,7 @@ func (s *ApplicationWorkerSuite) TestWorker(c *gc.C) { facade.EXPECT().ProvisioningState("test").Return(nil, nil), facade.EXPECT().WatchProvisioningInfo("test").Return(watchertest.NewMockNotifyWatcher(provisioningInfoChan), nil), ops.EXPECT().AppAlive("test", app, gomock.Any(), gomock.Any(), facade, clk, s.logger).Return(nil), - app.EXPECT().Watch().Return(watchertest.NewMockNotifyWatcher(appChan), nil), + app.EXPECT().Watch(gomock.Any()).Return(watchertest.NewMockNotifyWatcher(appChan), nil), app.EXPECT().WatchReplicas().DoAndReturn(func() (watcher.NotifyWatcher, error) { scaleChan <- struct{}{} return watchertest.NewMockNotifyWatcher(appReplicasChan), nil @@ -272,7 +272,7 @@ func (s *ApplicationWorkerSuite) TestWorkerStatusOnly(c *gc.C) { facade.EXPECT().ProvisioningState("test").Return(¶ms.CAASApplicationProvisioningState{Scaling: true, ScaleTarget: 1}, nil), facade.EXPECT().SetProvisioningState("test", params.CAASApplicationProvisioningState{}).Return(nil), facade.EXPECT().WatchProvisioningInfo("test").Return(watchertest.NewMockNotifyWatcher(provisioningInfoChan), nil), - app.EXPECT().Watch().Return(watchertest.NewMockNotifyWatcher(appChan), nil), + app.EXPECT().Watch(gomock.Any()).Return(watchertest.NewMockNotifyWatcher(appChan), nil), app.EXPECT().WatchReplicas().DoAndReturn(func() (watcher.NotifyWatcher, error) { appChan <- struct{}{} return watchertest.NewMockNotifyWatcher(appReplicasChan), nil diff --git a/internal/worker/caasmodeloperator/modeloperator.go b/internal/worker/caasmodeloperator/modeloperator.go index ef9491ce3b9..f9f622eaf36 100644 --- a/internal/worker/caasmodeloperator/modeloperator.go +++ b/internal/worker/caasmodeloperator/modeloperator.go @@ -21,7 +21,7 @@ import ( type ModelOperatorAPI interface { SetPassword(password string) error ModelOperatorProvisioningInfo() (caasmodeloperator.ModelOperatorProvisioningInfo, error) - WatchModelOperatorProvisioningInfo() (watcher.NotifyWatcher, error) + WatchModelOperatorProvisioningInfo(context.Context) (watcher.NotifyWatcher, error) } // ModelOperatorBroker describes the caas broker interface needed for installing @@ -60,7 +60,10 @@ func (m *ModelOperatorManager) Wait() error { } func (m *ModelOperatorManager) loop() error { - watcher, err := m.api.WatchModelOperatorProvisioningInfo() + ctx, cancel := m.scopedContext() + defer cancel() + + watcher, err := m.api.WatchModelOperatorProvisioningInfo(ctx) if err != nil { return errors.Annotate(err, "cannot watch model operator provisioning info") } @@ -82,6 +85,11 @@ func (m *ModelOperatorManager) loop() error { } } +func (m *ModelOperatorManager) scopedContext() (context.Context, context.CancelFunc) { + ctx, cancel := context.WithCancel(context.Background()) + return m.catacomb.Context(ctx), cancel +} + func (m *ModelOperatorManager) update(ctx context.Context) error { m.logger.Debugf("gathering model operator provisioning information for model %s", m.modelUUID) info, err := m.api.ModelOperatorProvisioningInfo() diff --git a/internal/worker/caasmodeloperator/modeloperator_test.go b/internal/worker/caasmodeloperator/modeloperator_test.go index 22ebab02fc9..2d2b60edb3e 100644 --- a/internal/worker/caasmodeloperator/modeloperator_test.go +++ b/internal/worker/caasmodeloperator/modeloperator_test.go @@ -17,6 +17,7 @@ import ( "github.com/juju/juju/caas" "github.com/juju/juju/core/resources" "github.com/juju/juju/core/watcher" + "github.com/juju/juju/core/watcher/eventsource" "github.com/juju/juju/core/watcher/watchertest" "github.com/juju/juju/internal/worker/caasmodeloperator" ) @@ -65,9 +66,9 @@ func (a *dummyAPI) ModelOperatorProvisioningInfo() (modeloperatorapi.ModelOperat return a.provInfo() } -func (a *dummyAPI) WatchModelOperatorProvisioningInfo() (watcher.NotifyWatcher, error) { +func (a *dummyAPI) WatchModelOperatorProvisioningInfo(ctx context.Context) (watcher.NotifyWatcher, error) { if a.watchProvInfo == nil { - return watcher.NewMultiNotifyWatcher(), nil + return eventsource.NewMultiNotifyWatcher(ctx) } return a.watchProvInfo() }