diff --git a/internal/reader/home_chain.go b/internal/reader/home_chain.go index cae8a865e..86db229b3 100644 --- a/internal/reader/home_chain.go +++ b/internal/reader/home_chain.go @@ -2,6 +2,7 @@ package reader import ( "context" + "errors" "fmt" "sync" "time" @@ -247,11 +248,16 @@ func (r *homeChainPoller) GetOCRConfigs( } func (r *homeChainPoller) Close() error { - return r.sync.StopOnce(r.Name(), func() error { + err := r.sync.StopOnce(r.Name(), func() error { defer r.wg.Wait() close(r.stopCh) return nil }) + + if errors.Is(err, services.ErrAlreadyStopped) { + return nil + } + return err } func (r *homeChainPoller) Ready() error { diff --git a/pkg/reader/rmn_home.go b/pkg/reader/rmn_home.go index 7e697ed06..6fc52e249 100644 --- a/pkg/reader/rmn_home.go +++ b/pkg/reader/rmn_home.go @@ -3,6 +3,7 @@ package reader import ( "context" "crypto/ed25519" + "errors" "fmt" "math/big" "sync" @@ -223,11 +224,17 @@ func (r *rmnHomePoller) GetAllConfigDigests() ( } func (r *rmnHomePoller) Close() error { - return r.sync.StopOnce(r.Name(), func() error { + err := r.sync.StopOnce(r.Name(), func() error { defer r.wg.Wait() close(r.stopCh) return nil }) + + if errors.Is(err, services.ErrAlreadyStopped) { + return nil + } + + return err } func (r *rmnHomePoller) Ready() error { diff --git a/pkg/reader/rmn_home_test.go b/pkg/reader/rmn_home_test.go index 64093246f..cb3a32753 100644 --- a/pkg/reader/rmn_home_test.go +++ b/pkg/reader/rmn_home_test.go @@ -7,6 +7,7 @@ import ( "testing" "time" + tests "github.com/smartcontractkit/chainlink-common/pkg/utils/tests" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" @@ -271,6 +272,44 @@ func Test_RMNHomePollingWorking(t *testing.T) { } } +func Test_RMNHomePoller_Close(t *testing.T) { + homeChainReader := readermock.NewMockContractReaderFacade(t) + poller := NewRMNHomePoller( + homeChainReader, + rmnHomeBoundContract, + logger.Test(t), + 10*time.Millisecond, + ).(*rmnHomePoller) + + homeChainReader.On("GetLatestValue", + mock.Anything, + mock.Anything, + mock.Anything, + mock.Anything, + mock.Anything, + ).Run(func(args mock.Arguments) { + result := args.Get(4).(*GetAllConfigsResponse) + *result = GetAllConfigsResponse{ + ActiveConfig: VersionedConfig{ + ConfigDigest: [32]byte{1}, + Version: 1, + StaticConfig: StaticConfig{Nodes: []Node{}}, + DynamicConfig: DynamicConfig{SourceChains: []SourceChain{}}, + }, + } + }).Return(nil) + + ctx := tests.Context(t) + require.NoError(t, poller.Start(ctx)) + + err1 := poller.Close() + require.NoError(t, err1) + err2 := poller.Close() + require.NoError(t, err2) + err3 := poller.Close() + require.NoError(t, err3) +} + func TestIsNodeObserver(t *testing.T) { tests := []struct { name string