From d43e3e557399da470874411264201a5d014c4aa3 Mon Sep 17 00:00:00 2001 From: dimkouv Date: Fri, 30 Aug 2024 10:46:19 +0300 Subject: [PATCH] test chainSupport dependency and apply small fixes --- commit/chain_support.go | 20 ++++-- commit/chain_support_test.go | 114 +++++++++++++++++++++++++++++++++++ 2 files changed, 130 insertions(+), 4 deletions(-) create mode 100644 commit/chain_support_test.go diff --git a/commit/chain_support.go b/commit/chain_support.go index cf9ced9d9..f56145f4e 100644 --- a/commit/chain_support.go +++ b/commit/chain_support.go @@ -2,6 +2,7 @@ package commit import ( "fmt" + "sort" mapset "github.com/deckarep/golang-set/v2" "github.com/smartcontractkit/libocr/commontypes" @@ -36,13 +37,18 @@ type CCIPChainSupport struct { } func (c CCIPChainSupport) KnownSourceChainsSlice() ([]cciptypes.ChainSelector, error) { - knownSourceChains, err := c.homeChain.GetKnownCCIPChains() + allChainsSet, err := c.homeChain.GetKnownCCIPChains() if err != nil { c.lggr.Errorw("error getting known chains", "err", err) return nil, fmt.Errorf("error getting known chains: %w", err) } - knownSourceChainsSlice := knownSourceChains.ToSlice() - return slicelib.Filter(knownSourceChainsSlice, func(ch cciptypes.ChainSelector) bool { return ch != c.destChain }), nil + + allChains := allChainsSet.ToSlice() + sort.Slice(allChains, func(i, j int) bool { return allChains[i] < allChains[j] }) + + sourceChains := slicelib.Filter(allChains, func(ch cciptypes.ChainSelector) bool { return ch != c.destChain }) + + return sourceChains, nil } // SupportedChains returns the set of chains that the given Oracle is configured to access @@ -66,7 +72,13 @@ func (c CCIPChainSupport) SupportsDestChain(oracle commontypes.OracleID) (bool, if err != nil { return false, fmt.Errorf("get chain config: %w", err) } - return destChainConfig.SupportedNodes.Contains(c.oracleIDToP2pID[oracle]), nil + + peerID, ok := c.oracleIDToP2pID[oracle] + if !ok { + return false, fmt.Errorf("oracle ID %d not found in oracleIDToP2pID", oracle) + } + + return destChainConfig.SupportedNodes.Contains(peerID), nil } // Interface compliance check diff --git a/commit/chain_support_test.go b/commit/chain_support_test.go new file mode 100644 index 000000000..982ade924 --- /dev/null +++ b/commit/chain_support_test.go @@ -0,0 +1,114 @@ +package commit + +import ( + "fmt" + "testing" + + mapset "github.com/deckarep/golang-set/v2" + "github.com/smartcontractkit/chainlink-common/pkg/logger" + cciptypes "github.com/smartcontractkit/chainlink-common/pkg/types/ccipocr3" + "github.com/smartcontractkit/libocr/commontypes" + "github.com/smartcontractkit/libocr/ragep2p/types" + "github.com/stretchr/testify/require" + + reader2 "github.com/smartcontractkit/chainlink-ccip/internal/reader" + "github.com/smartcontractkit/chainlink-ccip/mocks/internal_/reader" +) + +const ( + srcChainA = cciptypes.ChainSelector(0xa) + srcChainB = cciptypes.ChainSelector(0xb) + srcChainC = cciptypes.ChainSelector(0xc) + dstChain = cciptypes.ChainSelector(0xde) +) + +func TestCCIPChainSupport_KnownSourceChainsSlice(t *testing.T) { + lggr := logger.Test(t) + homeChainReader := reader.NewMockHomeChain(t) + cs := &CCIPChainSupport{ + lggr: lggr, + homeChain: homeChainReader, + destChain: dstChain, + } + + t.Run("happy path", func(t *testing.T) { + homeChainReader.EXPECT().GetKnownCCIPChains(). + Return(mapset.NewSet(srcChainA, srcChainB, srcChainC), nil).Once() + knownSourceChains, err := cs.KnownSourceChainsSlice() + require.NoError(t, err) + require.Equal(t, []cciptypes.ChainSelector{srcChainA, srcChainB, srcChainC}, knownSourceChains) + }) + + t.Run("error path", func(t *testing.T) { + homeChainReader.EXPECT().GetKnownCCIPChains().Return(nil, fmt.Errorf("some err")).Once() + _, err := cs.KnownSourceChainsSlice() + require.Error(t, err) + }) +} + +func TestCCIPChainSupport_SupportedChains(t *testing.T) { + lggr := logger.Test(t) + homeChainReader := reader.NewMockHomeChain(t) + cs := &CCIPChainSupport{ + lggr: lggr, + homeChain: homeChainReader, + oracleIDToP2pID: map[commontypes.OracleID]types.PeerID{1: [32]byte{1}}, + } + + t.Run("happy path", func(t *testing.T) { + exp := mapset.NewSet(srcChainA, srcChainB, srcChainC) + homeChainReader.EXPECT().GetSupportedChainsForPeer(types.PeerID{1}).Return(exp, nil).Once() + supportedChains, err := cs.SupportedChains(1) + require.NoError(t, err) + require.True(t, exp.Equal(supportedChains)) + }) + + t.Run("oracle not found", func(t *testing.T) { + _, err := cs.SupportedChains(2) + require.Error(t, err) + }) + + t.Run("home chain reader error", func(t *testing.T) { + homeChainReader.EXPECT().GetSupportedChainsForPeer(types.PeerID{1}). + Return(nil, fmt.Errorf("some err")).Once() + _, err := cs.SupportedChains(1) + require.Error(t, err) + }) +} + +func TestCCIPChainSupport_SupportsDestChain(t *testing.T) { + lggr := logger.Test(t) + homeChainReader := reader.NewMockHomeChain(t) + cs := &CCIPChainSupport{ + lggr: lggr, + homeChain: homeChainReader, + destChain: dstChain, + oracleIDToP2pID: map[commontypes.OracleID]types.PeerID{1: [32]byte{1}}, + } + + t.Run("happy path", func(t *testing.T) { + supportedNodes := mapset.NewSet(types.PeerID{1}) + homeChainReader.EXPECT().GetChainConfig(dstChain). + Return(reader2.ChainConfig{SupportedNodes: supportedNodes}, nil).Once() + supports, err := cs.SupportsDestChain(1) + require.NoError(t, err) + require.True(t, supports) + }) + + t.Run("oracle not found error", func(t *testing.T) { + supportedNodes := mapset.NewSet(types.PeerID{1}) + homeChainReader.EXPECT().GetChainConfig(dstChain). + Return(reader2.ChainConfig{SupportedNodes: supportedNodes}, nil).Once() + _, err := cs.SupportsDestChain(2) + require.Error(t, err) + }) + + t.Run("not supported", func(t *testing.T) { + supportedNodes := mapset.NewSet(types.PeerID{2}) + homeChainReader.EXPECT().GetChainConfig(dstChain). + Return(reader2.ChainConfig{SupportedNodes: supportedNodes}, nil).Once() + supports, err := cs.SupportsDestChain(1) + require.NoError(t, err) + require.False(t, supports) + }) +}