diff --git a/.mockery.yaml b/.mockery.yaml index 3c495fbb8..2eb60d839 100644 --- a/.mockery.yaml +++ b/.mockery.yaml @@ -17,3 +17,8 @@ packages: HomeChain: CCIP: TokenPrices: + github.com/smartcontractkit/chainlink-common/pkg/types: + interfaces: + ChainReader: + config: + dir: mocks/cl-common/chainreader diff --git a/Makefile b/Makefile index 003ed101f..5511796ee 100644 --- a/Makefile +++ b/Makefile @@ -9,7 +9,7 @@ generate: ensure_go_version mockery test: ensure_go_version - go test -race -fullpath -shuffle on -count $(TEST_COUNT) -coverprofile=$(COVERAGE_FILE) ./... + go test -race -fullpath -shuffle on -count $(TEST_COUNT) -coverprofile=$(COVERAGE_FILE) `go list ./... | grep -Ev 'chainlink-ccip/internal/mocks|chainlink-ccip/mocks'` lint: ensure_go_version golangci-lint run -c .golangci.yml diff --git a/execute/plugin_e2e_test.go b/execute/plugin_e2e_test.go index 138973bce..c2382fff6 100644 --- a/execute/plugin_e2e_test.go +++ b/execute/plugin_e2e_test.go @@ -26,6 +26,7 @@ import ( "github.com/smartcontractkit/chainlink-ccip/internal/mocks" "github.com/smartcontractkit/chainlink-ccip/internal/mocks/inmem" "github.com/smartcontractkit/chainlink-ccip/internal/reader" + chainreadermocks "github.com/smartcontractkit/chainlink-ccip/mocks/cl-common/chainreader" mock_types "github.com/smartcontractkit/chainlink-ccip/mocks/execute/exectypes" "github.com/smartcontractkit/chainlink-ccip/pkg/consts" "github.com/smartcontractkit/chainlink-ccip/pluginconfig" @@ -90,8 +91,11 @@ type nodeSetup struct { TokenDataReader *mock_types.MockTokenDataReader } -func setupHomeChainPoller(lggr logger.Logger, chainConfigInfos []reader.ChainConfigInfo) reader.HomeChain { - homeChainReader := mocks.NewContractReaderMock() +func setupHomeChainPoller( + t *testing.T, + lggr logger.Logger, + chainConfigInfos []reader.ChainConfigInfo) reader.HomeChain { + homeChainReader := chainreadermocks.NewMockChainReader(t) var firstCall = true homeChainReader.On( "GetLatestValue", @@ -231,7 +235,7 @@ func setupSimpleTest( }, } - homeChain := setupHomeChainPoller(lggr, chainConfigInfos) + homeChain := setupHomeChainPoller(t, lggr, chainConfigInfos) err = homeChain.Start(ctx) require.NoError(t, err, "failed to start home chain poller") diff --git a/execute/plugin_test.go b/execute/plugin_test.go index c92d2acbd..ad87611d1 100644 --- a/execute/plugin_test.go +++ b/execute/plugin_test.go @@ -24,7 +24,6 @@ import ( "github.com/smartcontractkit/chainlink-ccip/execute/exectypes" "github.com/smartcontractkit/chainlink-ccip/internal/libs/slicelib" "github.com/smartcontractkit/chainlink-ccip/internal/plugincommon" - "github.com/smartcontractkit/chainlink-ccip/internal/reader" codec_mocks "github.com/smartcontractkit/chainlink-ccip/mocks/execute/internal_/gen" reader_mock "github.com/smartcontractkit/chainlink-ccip/mocks/internal_/reader" "github.com/smartcontractkit/chainlink-ccip/pluginconfig" @@ -209,16 +208,16 @@ func TestPlugin_ValidateObservation_SupportedChainsError(t *testing.T) { func TestPlugin_ValidateObservation_IneligibleObserver(t *testing.T) { lggr := logger.Test(t) + mockHomeChain := reader_mock.NewMockHomeChain(t) + mockHomeChain.EXPECT().GetSupportedChainsForPeer(mock.Anything).Return(mapset.NewSet[cciptypes.ChainSelector](), nil) + defer mockHomeChain.AssertExpectations(t) + p := &Plugin{ - homeChain: setupHomeChainPoller(lggr, []reader.ChainConfigInfo{ - { - ChainSelector: 0, - ChainConfig: reader.HomeChainConfigMapper{}, - }, - }), + homeChain: mockHomeChain, oracleIDToP2pID: map[commontypes.OracleID]libocrtypes.PeerID{ 0: {}, }, + lggr: lggr, } observation := exectypes.NewObservation(nil, exectypes.MessageObservations{ @@ -242,13 +241,12 @@ func TestPlugin_ValidateObservation_IneligibleObserver(t *testing.T) { func TestPlugin_ValidateObservation_ValidateObservedSeqNum_Error(t *testing.T) { lggr := logger.Test(t) + mockHomeChain := reader_mock.NewMockHomeChain(t) + mockHomeChain.EXPECT().GetSupportedChainsForPeer(mock.Anything).Return(mapset.NewSet(cciptypes.ChainSelector(0)), nil) + p := &Plugin{ - homeChain: setupHomeChainPoller(lggr, []reader.ChainConfigInfo{ - { - ChainSelector: 1, - ChainConfig: reader.HomeChainConfigMapper{}, - }, - }), + lggr: lggr, + homeChain: mockHomeChain, oracleIDToP2pID: map[commontypes.OracleID]libocrtypes.PeerID{ 0: {}, }, @@ -283,8 +281,11 @@ func TestPlugin_Observation_BadPreviousOutcome(t *testing.T) { func TestPlugin_Observation_EligibilityCheckFailure(t *testing.T) { lggr := logger.Test(t) + + mockHomeChain := reader_mock.NewMockHomeChain(t) + p := &Plugin{ - homeChain: setupHomeChainPoller(lggr, []reader.ChainConfigInfo{}), + homeChain: mockHomeChain, oracleIDToP2pID: map[commontypes.OracleID]libocrtypes.PeerID{}, lggr: lggr, } @@ -464,8 +465,10 @@ func TestPlugin_ShouldAcceptAttestedReport_ShouldAccept(t *testing.T) { func TestPlugin_ShouldTransmitAcceptReport_ElegibilityCheckFailure(t *testing.T) { lggr := logger.Test(t) + p := &Plugin{ - homeChain: setupHomeChainPoller(lggr, []reader.ChainConfigInfo{}), + lggr: lggr, + homeChain: reader_mock.NewMockHomeChain(t), oracleIDToP2pID: map[commontypes.OracleID]libocrtypes.PeerID{}, } @@ -477,11 +480,16 @@ func TestPlugin_ShouldTransmitAcceptReport_ElegibilityCheckFailure(t *testing.T) func TestPlugin_ShouldTransmitAcceptReport_Ineligible(t *testing.T) { lggr, logs := logger.TestObserved(t, zapcore.DebugLevel) + + mockHomeChain := reader_mock.NewMockHomeChain(t) + mockHomeChain.EXPECT().GetSupportedChainsForPeer(mock.Anything).Return(mapset.NewSet[cciptypes.ChainSelector](), nil) + defer mockHomeChain.AssertExpectations(t) + p := &Plugin{ lggr: lggr, cfg: pluginconfig.ExecutePluginConfig{DestChain: 1}, reportingCfg: ocr3types.ReportingPluginConfig{OracleID: 2}, - homeChain: setupHomeChainPoller(lggr, []reader.ChainConfigInfo{}), + homeChain: mockHomeChain, oracleIDToP2pID: map[commontypes.OracleID]libocrtypes.PeerID{ 2: {}, }, diff --git a/internal/mocks/contract_reader.go b/internal/mocks/contract_reader.go deleted file mode 100644 index 4148b2017..000000000 --- a/internal/mocks/contract_reader.go +++ /dev/null @@ -1,73 +0,0 @@ -package mocks - -import ( - "context" - - "github.com/stretchr/testify/mock" - - "github.com/smartcontractkit/chainlink-common/pkg/types" - "github.com/smartcontractkit/chainlink-common/pkg/types/query" - "github.com/smartcontractkit/chainlink-common/pkg/types/query/primitives" -) - -type ContractReaderMock struct { - *mock.Mock -} - -func NewContractReaderMock() *ContractReaderMock { - return &ContractReaderMock{ - Mock: &mock.Mock{}, - } -} - -func (cr *ContractReaderMock) GetLatestValue(ctx context.Context, contractName, method string, - confidenceLevel primitives.ConfidenceLevel, params, returnVal any) error { - args := cr.Called(ctx, contractName, method, confidenceLevel, params, returnVal) - return args.Error(0) -} - -func (cr *ContractReaderMock) BatchGetLatestValues(ctx context.Context, - request types.BatchGetLatestValuesRequest) (types.BatchGetLatestValuesResult, error) { - args := cr.Called(ctx, request) - return args.Get(0).(types.BatchGetLatestValuesResult), args.Error(1) -} - -func (cr *ContractReaderMock) Bind(ctx context.Context, bindings []types.BoundContract) error { - args := cr.Called(ctx, bindings) - return args.Error(0) -} - -func (cr *ContractReaderMock) QueryKey( - ctx context.Context, - contractName string, - filter query.KeyFilter, - limitAndSort query.LimitAndSort, - sequenceDataType any, -) ([]types.Sequence, error) { - args := cr.Called(ctx, contractName, filter, limitAndSort, sequenceDataType) - return args.Get(0).([]types.Sequence), args.Error(1) -} - -func (cr *ContractReaderMock) Start(ctx context.Context) error { - args := cr.Called(ctx) - return args.Error(0) -} - -func (cr *ContractReaderMock) Close() error { - args := cr.Called() - return args.Error(0) -} - -func (cr *ContractReaderMock) Ready() error { - args := cr.Called() - return args.Error(0) -} - -func (cr *ContractReaderMock) HealthReport() map[string]error { - args := cr.Called() - return args.Get(0).(map[string]error) -} - -func (cr *ContractReaderMock) Name() string { - return "ContractReaderMock" -} diff --git a/internal/mocks/inmem/ccipreader_inmem.go b/internal/mocks/inmem/ccipreader_inmem.go index 3f62e0d29..8a1e32fa3 100644 --- a/internal/mocks/inmem/ccipreader_inmem.go +++ b/internal/mocks/inmem/ccipreader_inmem.go @@ -28,6 +28,13 @@ type InMemoryCCIPReader struct { Dest cciptypes.ChainSelector } +// GetExpectedNextSequenceNumber implements reader.CCIP. +func (r InMemoryCCIPReader) GetExpectedNextSequenceNumber( + ctx context.Context, + sourceChainSelector, destChainSelector cciptypes.ChainSelector) (cciptypes.SeqNum, error) { + panic("unimplemented") +} + func (r InMemoryCCIPReader) CommitReportsGTETimestamp( _ context.Context, _ cciptypes.ChainSelector, ts time.Time, limit int, ) ([]plugintypes.CommitPluginReportWithMeta, error) { diff --git a/internal/reader/ccip.go b/internal/reader/ccip.go index b9cf242f2..ade05ccd0 100644 --- a/internal/reader/ccip.go +++ b/internal/reader/ccip.go @@ -53,6 +53,13 @@ type CCIP interface { seqNumRange cciptypes.SeqNumRange, ) ([]cciptypes.Message, error) + // GetExpectedNextSequenceNumber returns the next sequence number to be used + // in the onramp. + GetExpectedNextSequenceNumber( + ctx context.Context, + sourceChainSelector, destChainSelector cciptypes.ChainSelector, + ) (cciptypes.SeqNum, error) + // NextSeqNum reads the destination chain. // Returns the next expected sequence number for each one of the provided chains. // TODO: if destination was a parameter, this could be a capability reused across plugin instances. @@ -387,6 +394,41 @@ func (r *CCIPChainReader) MsgsBetweenSeqNums( return msgs, nil } +// GetExpectedNextSequenceNumber implements CCIP. +func (r *CCIPChainReader) GetExpectedNextSequenceNumber( + ctx context.Context, + sourceChainSelector, destChainSelector cciptypes.ChainSelector) (cciptypes.SeqNum, error) { + if destChainSelector != r.destChain { + return 0, fmt.Errorf("expected destination chain %d, got %d", r.destChain, destChainSelector) + } + + if err := r.validateReaderExistence(sourceChainSelector); err != nil { + return 0, err + } + + bindings := r.contractReaders[sourceChainSelector].GetBindings(consts.ContractNameOnRamp) + if len(bindings) != 1 { + return 0, fmt.Errorf("expected one binding for onRamp contract, got %d", len(bindings)) + } + + var expectedNextSequenceNumber uint64 + err := r.contractReaders[sourceChainSelector].GetLatestValue( + ctx, + consts.ContractNameOnRamp, + consts.MethodNameGetExpectedNextSequenceNumber, + primitives.Unconfirmed, + map[string]any{ + "destChainSelector": destChainSelector, + }, + &expectedNextSequenceNumber, + ) + if err != nil { + return 0, fmt.Errorf("failed to get expected next sequence number from onramp: %w", err) + } + + return cciptypes.SeqNum(expectedNextSequenceNumber), nil +} + func (r *CCIPChainReader) NextSeqNum( ctx context.Context, chains []cciptypes.ChainSelector, ) ([]cciptypes.SeqNum, error) { diff --git a/internal/reader/ccip_test.go b/internal/reader/ccip_test.go index b8098ccc5..bdd171684 100644 --- a/internal/reader/ccip_test.go +++ b/internal/reader/ccip_test.go @@ -12,17 +12,17 @@ import ( "github.com/smartcontractkit/chainlink-common/pkg/types" cciptypes "github.com/smartcontractkit/chainlink-common/pkg/types/ccipocr3" - "github.com/smartcontractkit/chainlink-ccip/internal/mocks" + chainreadermocks "github.com/smartcontractkit/chainlink-ccip/mocks/cl-common/chainreader" "github.com/smartcontractkit/chainlink-ccip/pkg/consts" ) func TestCCIPChainReader_getSourceChainsConfig(t *testing.T) { - sourceCRs := make(map[cciptypes.ChainSelector]*mocks.ContractReaderMock) + sourceCRs := make(map[cciptypes.ChainSelector]*chainreadermocks.MockChainReader) for _, chain := range []cciptypes.ChainSelector{chainA, chainB} { - sourceCRs[chain] = mocks.NewContractReaderMock() + sourceCRs[chain] = chainreadermocks.NewMockChainReader(t) } - destCR := mocks.NewContractReaderMock() + destCR := chainreadermocks.NewMockChainReader(t) destCR.On( "GetLatestValue", diff --git a/internal/reader/home_chain_test.go b/internal/reader/home_chain_test.go index 6c19b13b0..75660f621 100644 --- a/internal/reader/home_chain_test.go +++ b/internal/reader/home_chain_test.go @@ -10,7 +10,6 @@ import ( libocrtypes "github.com/smartcontractkit/libocr/ragep2p/types" "github.com/smartcontractkit/chainlink-ccip/chainconfig" - "github.com/smartcontractkit/chainlink-ccip/internal/mocks" "github.com/smartcontractkit/chainlink-ccip/pkg/consts" "github.com/smartcontractkit/libocr/commontypes" @@ -21,6 +20,8 @@ import ( "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" + + chainreadermocks "github.com/smartcontractkit/chainlink-ccip/mocks/cl-common/chainreader" ) var ( @@ -36,7 +37,7 @@ var ( ) func TestHomeChainConfigPoller_HealthReport(t *testing.T) { - homeChainReader := mocks.NewContractReaderMock() + homeChainReader := chainreadermocks.NewMockChainReader(t) homeChainReader.On( "GetLatestValue", mock.Anything, @@ -130,7 +131,7 @@ func Test_PollingWorking(t *testing.T) { }, } - homeChainReader := mocks.NewContractReaderMock() + homeChainReader := chainreadermocks.NewMockChainReader(t) homeChainReader.On( "GetLatestValue", mock.Anything, @@ -149,7 +150,7 @@ func Test_PollingWorking(t *testing.T) { var ( tickTime = 2 * time.Millisecond - totalSleepTime = tickTime * 4 + totalSleepTime = tickTime * 20 ) configPoller := NewHomeChainConfigPoller( @@ -185,7 +186,7 @@ func Test_PollingWorking(t *testing.T) { func Test_HomeChainPoller_GetOCRConfig(t *testing.T) { donID := uint32(1) pluginType := uint8(1) // execution - homeChainReader := mocks.NewContractReaderMock() + homeChainReader := chainreadermocks.NewMockChainReader(t) homeChainReader.On( "GetLatestValue", mock.Anything, diff --git a/internal/reader/onchain_prices_reader.go b/internal/reader/onchain_prices_reader.go index 6545d0c91..42accbd69 100644 --- a/internal/reader/onchain_prices_reader.go +++ b/internal/reader/onchain_prices_reader.go @@ -8,10 +8,9 @@ import ( "github.com/smartcontractkit/chainlink-ccip/pkg/consts" "github.com/smartcontractkit/chainlink-ccip/pluginconfig" - "github.com/smartcontractkit/libocr/offchainreporting2plus/types" ocr2types "github.com/smartcontractkit/libocr/offchainreporting2plus/types" - commontyps "github.com/smartcontractkit/chainlink-common/pkg/types" + commontypes "github.com/smartcontractkit/chainlink-common/pkg/types" "github.com/smartcontractkit/chainlink-common/pkg/types/query/primitives" "golang.org/x/sync/errgroup" @@ -25,15 +24,15 @@ type TokenPrices interface { type OnchainTokenPricesReader struct { // Reader for the chain that will have the token prices on-chain - ContractReader commontyps.ContractReader - PriceSources map[types.Account]pluginconfig.ArbitrumPriceSource - TokenDecimals map[types.Account]uint8 + ContractReader commontypes.ContractReader + PriceSources map[ocr2types.Account]pluginconfig.ArbitrumPriceSource + TokenDecimals map[ocr2types.Account]uint8 } func NewOnchainTokenPricesReader( - contractReader commontyps.ContractReader, - priceSources map[types.Account]pluginconfig.ArbitrumPriceSource, - tokenDecimals map[types.Account]uint8, + contractReader commontypes.ContractReader, + priceSources map[ocr2types.Account]pluginconfig.ArbitrumPriceSource, + tokenDecimals map[ocr2types.Account]uint8, ) *OnchainTokenPricesReader { return &OnchainTokenPricesReader{ ContractReader: contractReader, @@ -96,7 +95,7 @@ func (pr *OnchainTokenPricesReader) GetTokenPricesUSD( return prices, nil } -func (pr *OnchainTokenPricesReader) getFeedDecimals(ctx context.Context, token types.Account) (uint8, error) { +func (pr *OnchainTokenPricesReader) getFeedDecimals(ctx context.Context, token ocr2types.Account) (uint8, error) { var decimals uint8 if err := pr.ContractReader.GetLatestValue( @@ -104,7 +103,7 @@ func (pr *OnchainTokenPricesReader) getFeedDecimals(ctx context.Context, token t consts.ContractNamePriceAggregator, consts.MethodNameGetDecimals, primitives.Unconfirmed, - nil, + map[string]any{}, &decimals, //boundContract, ); err != nil { @@ -116,7 +115,7 @@ func (pr *OnchainTokenPricesReader) getFeedDecimals(ctx context.Context, token t func (pr *OnchainTokenPricesReader) getRawTokenPriceE18Normalized( ctx context.Context, - token types.Account, + token ocr2types.Account, ) (*big.Int, error) { var latestRoundData LatestRoundData if err := @@ -125,7 +124,7 @@ func (pr *OnchainTokenPricesReader) getRawTokenPriceE18Normalized( consts.ContractNamePriceAggregator, consts.MethodNameGetLatestRoundData, primitives.Unconfirmed, - nil, + map[string]any{}, &latestRoundData, //boundContract, ); err != nil { diff --git a/internal/reader/onchain_prices_reader_test.go b/internal/reader/onchain_prices_reader_test.go index ab0d799a4..7ca9402d1 100644 --- a/internal/reader/onchain_prices_reader_test.go +++ b/internal/reader/onchain_prices_reader_test.go @@ -6,7 +6,10 @@ import ( "math/big" "testing" - "github.com/smartcontractkit/chainlink-ccip/internal/mocks" + "github.com/smartcontractkit/chainlink-common/pkg/types/query/primitives" + "github.com/smartcontractkit/chainlink-common/pkg/utils/tests" + + chainreadermocks "github.com/smartcontractkit/chainlink-ccip/mocks/cl-common/chainreader" "github.com/smartcontractkit/chainlink-ccip/pkg/consts" "github.com/smartcontractkit/chainlink-ccip/pluginconfig" @@ -33,14 +36,13 @@ var ( func TestOnchainTokenPricesReader_GetTokenPricesUSD(t *testing.T) { testCases := []struct { - name string - inputTokens []ocr2types.Account - priceSources map[ocr2types.Account]pluginconfig.ArbitrumPriceSource - tokenDecimals map[ocr2types.Account]uint8 - mockPrices []*big.Int - want []*big.Int - errorAccounts []ocr2types.Account - wantErr bool + name string + inputTokens []ocr2types.Account + priceSources map[ocr2types.Account]pluginconfig.ArbitrumPriceSource + tokenDecimals map[ocr2types.Account]uint8 + want []*big.Int + getChainReader func(t *testing.T) *chainreadermocks.MockChainReader + wantErr bool }{ { name: "On-chain one price", @@ -52,30 +54,93 @@ func TestOnchainTokenPricesReader_GetTokenPricesUSD(t *testing.T) { EthAddr: Decimals18, }, inputTokens: []ocr2types.Account{ArbAddr}, - //TODO: change once we have control to return different prices in mock depending on the token - mockPrices: []*big.Int{ArbPrice}, - want: []*big.Int{ArbPrice}, + want: []*big.Int{ArbPrice}, + getChainReader: func(t *testing.T) *chainreadermocks.MockChainReader { + chainReader := chainreadermocks.NewMockChainReader(t) + // expect a single decimals() call. + chainReader. + EXPECT(). + GetLatestValue( + mock.Anything, + consts.ContractNamePriceAggregator, + consts.MethodNameGetDecimals, + primitives.Unconfirmed, + mock.Anything, + mock.Anything). + Run(func( + ctx context.Context, + contractName, + method string, + confidenceLevel primitives.ConfidenceLevel, + params, + returnVal interface{}) { + returnValUint8, ok := returnVal.(*uint8) + if !ok { + panic("returnVal is not a *uint8") + } + *returnValUint8 = Decimals18 + }). + Return(nil) + // expect a single getLatestRoundData() call. + chainReader. + EXPECT(). + GetLatestValue( + mock.Anything, + consts.ContractNamePriceAggregator, + consts.MethodNameGetLatestRoundData, + primitives.Unconfirmed, + mock.Anything, + mock.Anything). + Run(func( + ctx context.Context, + contractName, + method string, + confidenceLevel primitives.ConfidenceLevel, + params, + returnVal interface{}) { + returnValLatestRoundData := returnVal.(*LatestRoundData) + if returnValLatestRoundData == nil { + panic("returnVal is nil") + } + returnValLatestRoundData.Answer = big.NewInt(ArbPrice.Int64()) + }).Return(nil).Once() + return chainReader + }, }, { - name: "Missing price should error", - priceSources: map[ocr2types.Account]pluginconfig.ArbitrumPriceSource{}, - inputTokens: []ocr2types.Account{ArbAddr}, - mockPrices: []*big.Int{}, - errorAccounts: []ocr2types.Account{EthAddr}, - want: nil, - wantErr: true, + name: "Missing price should error", + priceSources: map[ocr2types.Account]pluginconfig.ArbitrumPriceSource{}, + inputTokens: []ocr2types.Account{ArbAddr}, + getChainReader: func(t *testing.T) *chainreadermocks.MockChainReader { + chainReader := chainreadermocks.NewMockChainReader(t) + // expect a single getLatestRoundData() call that will error + chainReader. + EXPECT(). + GetLatestValue( + mock.Anything, + consts.ContractNamePriceAggregator, + consts.MethodNameGetLatestRoundData, + primitives.Unconfirmed, + mock.Anything, + mock.Anything). + Return(fmt.Errorf("some error")).Once() + // no decimals() call since the above call errors. + return chainReader + }, + wantErr: true, }, } for _, tc := range testCases { - contractReader := createMockReader(tc.mockPrices, tc.errorAccounts) - tokenPricesReader := OnchainTokenPricesReader{ - ContractReader: contractReader, - PriceSources: tc.priceSources, - TokenDecimals: tc.tokenDecimals, - } t.Run(tc.name, func(t *testing.T) { - ctx := context.Background() + contractReader := tc.getChainReader(t) + + tokenPricesReader := OnchainTokenPricesReader{ + ContractReader: contractReader, + PriceSources: tc.priceSources, + TokenDecimals: tc.tokenDecimals, + } + ctx := tests.Context(t) result, err := tokenPricesReader.GetTokenPricesUSD(ctx, tc.inputTokens) if tc.wantErr { @@ -128,49 +193,3 @@ func TestPriceService_calculateUsdPer1e18TokenAmount(t *testing.T) { }) } } - -// nolint unparam -func createMockReader( - mockPrices []*big.Int, - errorAccounts []ocr2types.Account, -) *mocks.ContractReaderMock { - reader := mocks.NewContractReaderMock() - // TODO: Create a list of bound contracts from priceSources and return the price given in mockPrices - reader.On("GetLatestValue", - mock.Anything, - consts.ContractNamePriceAggregator, - consts.MethodNameGetDecimals, - mock.Anything, - nil, - mock.Anything).Run( - func(args mock.Arguments) { - arg := args.Get(5).(*uint8) - *arg = Decimals18 - }).Return(nil) - - for _, price := range mockPrices { - price := price - reader.On("GetLatestValue", - mock.Anything, - consts.ContractNamePriceAggregator, - consts.MethodNameGetLatestRoundData, - mock.Anything, - nil, - mock.Anything).Run( - func(args mock.Arguments) { - arg := args.Get(5).(*LatestRoundData) - arg.Answer = big.NewInt(price.Int64()) - }).Return(nil).Once() - } - - for i := 0; i < len(errorAccounts); i++ { - reader.On("GetLatestValue", - mock.Anything, - consts.ContractNamePriceAggregator, - consts.MethodNameGetLatestRoundData, - mock.Anything, - nil, - mock.Anything).Return(fmt.Errorf("error")).Once() - } - return reader -} diff --git a/mocks/cl-common/chainreader/chain_reader.go b/mocks/cl-common/chainreader/chain_reader.go new file mode 100644 index 000000000..30b1d1143 --- /dev/null +++ b/mocks/cl-common/chainreader/chain_reader.go @@ -0,0 +1,487 @@ +// Code generated by mockery v2.43.2. DO NOT EDIT. + +package types + +import ( + context "context" + + query "github.com/smartcontractkit/chainlink-common/pkg/types/query" + primitives "github.com/smartcontractkit/chainlink-common/pkg/types/query/primitives" + mock "github.com/stretchr/testify/mock" + + types "github.com/smartcontractkit/chainlink-common/pkg/types" +) + +// MockChainReader is an autogenerated mock type for the ChainReader type +type MockChainReader struct { + mock.Mock +} + +type MockChainReader_Expecter struct { + mock *mock.Mock +} + +func (_m *MockChainReader) EXPECT() *MockChainReader_Expecter { + return &MockChainReader_Expecter{mock: &_m.Mock} +} + +// BatchGetLatestValues provides a mock function with given fields: ctx, request +func (_m *MockChainReader) BatchGetLatestValues(ctx context.Context, request types.BatchGetLatestValuesRequest) (types.BatchGetLatestValuesResult, error) { + ret := _m.Called(ctx, request) + + if len(ret) == 0 { + panic("no return value specified for BatchGetLatestValues") + } + + var r0 types.BatchGetLatestValuesResult + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, types.BatchGetLatestValuesRequest) (types.BatchGetLatestValuesResult, error)); ok { + return rf(ctx, request) + } + if rf, ok := ret.Get(0).(func(context.Context, types.BatchGetLatestValuesRequest) types.BatchGetLatestValuesResult); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(types.BatchGetLatestValuesResult) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, types.BatchGetLatestValuesRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockChainReader_BatchGetLatestValues_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'BatchGetLatestValues' +type MockChainReader_BatchGetLatestValues_Call struct { + *mock.Call +} + +// BatchGetLatestValues is a helper method to define mock.On call +// - ctx context.Context +// - request types.BatchGetLatestValuesRequest +func (_e *MockChainReader_Expecter) BatchGetLatestValues(ctx interface{}, request interface{}) *MockChainReader_BatchGetLatestValues_Call { + return &MockChainReader_BatchGetLatestValues_Call{Call: _e.mock.On("BatchGetLatestValues", ctx, request)} +} + +func (_c *MockChainReader_BatchGetLatestValues_Call) Run(run func(ctx context.Context, request types.BatchGetLatestValuesRequest)) *MockChainReader_BatchGetLatestValues_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(types.BatchGetLatestValuesRequest)) + }) + return _c +} + +func (_c *MockChainReader_BatchGetLatestValues_Call) Return(_a0 types.BatchGetLatestValuesResult, _a1 error) *MockChainReader_BatchGetLatestValues_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockChainReader_BatchGetLatestValues_Call) RunAndReturn(run func(context.Context, types.BatchGetLatestValuesRequest) (types.BatchGetLatestValuesResult, error)) *MockChainReader_BatchGetLatestValues_Call { + _c.Call.Return(run) + return _c +} + +// Bind provides a mock function with given fields: ctx, bindings +func (_m *MockChainReader) Bind(ctx context.Context, bindings []types.BoundContract) error { + ret := _m.Called(ctx, bindings) + + if len(ret) == 0 { + panic("no return value specified for Bind") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, []types.BoundContract) error); ok { + r0 = rf(ctx, bindings) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockChainReader_Bind_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Bind' +type MockChainReader_Bind_Call struct { + *mock.Call +} + +// Bind is a helper method to define mock.On call +// - ctx context.Context +// - bindings []types.BoundContract +func (_e *MockChainReader_Expecter) Bind(ctx interface{}, bindings interface{}) *MockChainReader_Bind_Call { + return &MockChainReader_Bind_Call{Call: _e.mock.On("Bind", ctx, bindings)} +} + +func (_c *MockChainReader_Bind_Call) Run(run func(ctx context.Context, bindings []types.BoundContract)) *MockChainReader_Bind_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].([]types.BoundContract)) + }) + return _c +} + +func (_c *MockChainReader_Bind_Call) Return(_a0 error) *MockChainReader_Bind_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockChainReader_Bind_Call) RunAndReturn(run func(context.Context, []types.BoundContract) error) *MockChainReader_Bind_Call { + _c.Call.Return(run) + return _c +} + +// Close provides a mock function with given fields: +func (_m *MockChainReader) Close() error { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Close") + } + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockChainReader_Close_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Close' +type MockChainReader_Close_Call struct { + *mock.Call +} + +// Close is a helper method to define mock.On call +func (_e *MockChainReader_Expecter) Close() *MockChainReader_Close_Call { + return &MockChainReader_Close_Call{Call: _e.mock.On("Close")} +} + +func (_c *MockChainReader_Close_Call) Run(run func()) *MockChainReader_Close_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockChainReader_Close_Call) Return(_a0 error) *MockChainReader_Close_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockChainReader_Close_Call) RunAndReturn(run func() error) *MockChainReader_Close_Call { + _c.Call.Return(run) + return _c +} + +// GetLatestValue provides a mock function with given fields: ctx, contractName, method, confidenceLevel, params, returnVal +func (_m *MockChainReader) GetLatestValue(ctx context.Context, contractName string, method string, confidenceLevel primitives.ConfidenceLevel, params interface{}, returnVal interface{}) error { + ret := _m.Called(ctx, contractName, method, confidenceLevel, params, returnVal) + + if len(ret) == 0 { + panic("no return value specified for GetLatestValue") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, primitives.ConfidenceLevel, interface{}, interface{}) error); ok { + r0 = rf(ctx, contractName, method, confidenceLevel, params, returnVal) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockChainReader_GetLatestValue_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetLatestValue' +type MockChainReader_GetLatestValue_Call struct { + *mock.Call +} + +// GetLatestValue is a helper method to define mock.On call +// - ctx context.Context +// - contractName string +// - method string +// - confidenceLevel primitives.ConfidenceLevel +// - params interface{} +// - returnVal interface{} +func (_e *MockChainReader_Expecter) GetLatestValue(ctx interface{}, contractName interface{}, method interface{}, confidenceLevel interface{}, params interface{}, returnVal interface{}) *MockChainReader_GetLatestValue_Call { + return &MockChainReader_GetLatestValue_Call{Call: _e.mock.On("GetLatestValue", ctx, contractName, method, confidenceLevel, params, returnVal)} +} + +func (_c *MockChainReader_GetLatestValue_Call) Run(run func(ctx context.Context, contractName string, method string, confidenceLevel primitives.ConfidenceLevel, params interface{}, returnVal interface{})) *MockChainReader_GetLatestValue_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(primitives.ConfidenceLevel), args[4].(interface{}), args[5].(interface{})) + }) + return _c +} + +func (_c *MockChainReader_GetLatestValue_Call) Return(_a0 error) *MockChainReader_GetLatestValue_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockChainReader_GetLatestValue_Call) RunAndReturn(run func(context.Context, string, string, primitives.ConfidenceLevel, interface{}, interface{}) error) *MockChainReader_GetLatestValue_Call { + _c.Call.Return(run) + return _c +} + +// HealthReport provides a mock function with given fields: +func (_m *MockChainReader) HealthReport() map[string]error { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for HealthReport") + } + + var r0 map[string]error + if rf, ok := ret.Get(0).(func() map[string]error); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[string]error) + } + } + + return r0 +} + +// MockChainReader_HealthReport_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'HealthReport' +type MockChainReader_HealthReport_Call struct { + *mock.Call +} + +// HealthReport is a helper method to define mock.On call +func (_e *MockChainReader_Expecter) HealthReport() *MockChainReader_HealthReport_Call { + return &MockChainReader_HealthReport_Call{Call: _e.mock.On("HealthReport")} +} + +func (_c *MockChainReader_HealthReport_Call) Run(run func()) *MockChainReader_HealthReport_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockChainReader_HealthReport_Call) Return(_a0 map[string]error) *MockChainReader_HealthReport_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockChainReader_HealthReport_Call) RunAndReturn(run func() map[string]error) *MockChainReader_HealthReport_Call { + _c.Call.Return(run) + return _c +} + +// Name provides a mock function with given fields: +func (_m *MockChainReader) Name() string { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Name") + } + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// MockChainReader_Name_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Name' +type MockChainReader_Name_Call struct { + *mock.Call +} + +// Name is a helper method to define mock.On call +func (_e *MockChainReader_Expecter) Name() *MockChainReader_Name_Call { + return &MockChainReader_Name_Call{Call: _e.mock.On("Name")} +} + +func (_c *MockChainReader_Name_Call) Run(run func()) *MockChainReader_Name_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockChainReader_Name_Call) Return(_a0 string) *MockChainReader_Name_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockChainReader_Name_Call) RunAndReturn(run func() string) *MockChainReader_Name_Call { + _c.Call.Return(run) + return _c +} + +// QueryKey provides a mock function with given fields: ctx, contractName, filter, limitAndSort, sequenceDataType +func (_m *MockChainReader) QueryKey(ctx context.Context, contractName string, filter query.KeyFilter, limitAndSort query.LimitAndSort, sequenceDataType interface{}) ([]types.Sequence, error) { + ret := _m.Called(ctx, contractName, filter, limitAndSort, sequenceDataType) + + if len(ret) == 0 { + panic("no return value specified for QueryKey") + } + + var r0 []types.Sequence + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, query.KeyFilter, query.LimitAndSort, interface{}) ([]types.Sequence, error)); ok { + return rf(ctx, contractName, filter, limitAndSort, sequenceDataType) + } + if rf, ok := ret.Get(0).(func(context.Context, string, query.KeyFilter, query.LimitAndSort, interface{}) []types.Sequence); ok { + r0 = rf(ctx, contractName, filter, limitAndSort, sequenceDataType) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]types.Sequence) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string, query.KeyFilter, query.LimitAndSort, interface{}) error); ok { + r1 = rf(ctx, contractName, filter, limitAndSort, sequenceDataType) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockChainReader_QueryKey_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'QueryKey' +type MockChainReader_QueryKey_Call struct { + *mock.Call +} + +// QueryKey is a helper method to define mock.On call +// - ctx context.Context +// - contractName string +// - filter query.KeyFilter +// - limitAndSort query.LimitAndSort +// - sequenceDataType interface{} +func (_e *MockChainReader_Expecter) QueryKey(ctx interface{}, contractName interface{}, filter interface{}, limitAndSort interface{}, sequenceDataType interface{}) *MockChainReader_QueryKey_Call { + return &MockChainReader_QueryKey_Call{Call: _e.mock.On("QueryKey", ctx, contractName, filter, limitAndSort, sequenceDataType)} +} + +func (_c *MockChainReader_QueryKey_Call) Run(run func(ctx context.Context, contractName string, filter query.KeyFilter, limitAndSort query.LimitAndSort, sequenceDataType interface{})) *MockChainReader_QueryKey_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(query.KeyFilter), args[3].(query.LimitAndSort), args[4].(interface{})) + }) + return _c +} + +func (_c *MockChainReader_QueryKey_Call) Return(_a0 []types.Sequence, _a1 error) *MockChainReader_QueryKey_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockChainReader_QueryKey_Call) RunAndReturn(run func(context.Context, string, query.KeyFilter, query.LimitAndSort, interface{}) ([]types.Sequence, error)) *MockChainReader_QueryKey_Call { + _c.Call.Return(run) + return _c +} + +// Ready provides a mock function with given fields: +func (_m *MockChainReader) Ready() error { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Ready") + } + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockChainReader_Ready_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Ready' +type MockChainReader_Ready_Call struct { + *mock.Call +} + +// Ready is a helper method to define mock.On call +func (_e *MockChainReader_Expecter) Ready() *MockChainReader_Ready_Call { + return &MockChainReader_Ready_Call{Call: _e.mock.On("Ready")} +} + +func (_c *MockChainReader_Ready_Call) Run(run func()) *MockChainReader_Ready_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockChainReader_Ready_Call) Return(_a0 error) *MockChainReader_Ready_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockChainReader_Ready_Call) RunAndReturn(run func() error) *MockChainReader_Ready_Call { + _c.Call.Return(run) + return _c +} + +// Start provides a mock function with given fields: _a0 +func (_m *MockChainReader) Start(_a0 context.Context) error { + ret := _m.Called(_a0) + + if len(ret) == 0 { + panic("no return value specified for Start") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context) error); ok { + r0 = rf(_a0) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockChainReader_Start_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Start' +type MockChainReader_Start_Call struct { + *mock.Call +} + +// Start is a helper method to define mock.On call +// - _a0 context.Context +func (_e *MockChainReader_Expecter) Start(_a0 interface{}) *MockChainReader_Start_Call { + return &MockChainReader_Start_Call{Call: _e.mock.On("Start", _a0)} +} + +func (_c *MockChainReader_Start_Call) Run(run func(_a0 context.Context)) *MockChainReader_Start_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context)) + }) + return _c +} + +func (_c *MockChainReader_Start_Call) Return(_a0 error) *MockChainReader_Start_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockChainReader_Start_Call) RunAndReturn(run func(context.Context) error) *MockChainReader_Start_Call { + _c.Call.Return(run) + return _c +} + +// NewMockChainReader creates a new instance of MockChainReader. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockChainReader(t interface { + mock.TestingT + Cleanup(func()) +}) *MockChainReader { + mock := &MockChainReader{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/mocks/internal_/reader/ccip.go b/mocks/internal_/reader/ccip.go index 8500e1dff..4a0644f71 100644 --- a/mocks/internal_/reader/ccip.go +++ b/mocks/internal_/reader/ccip.go @@ -254,6 +254,64 @@ func (_c *MockCCIP_GasPrices_Call) RunAndReturn(run func(context.Context, []ccip return _c } +// GetExpectedNextSequenceNumber provides a mock function with given fields: ctx, sourceChainSelector, destChainSelector +func (_m *MockCCIP) GetExpectedNextSequenceNumber(ctx context.Context, sourceChainSelector ccipocr3.ChainSelector, destChainSelector ccipocr3.ChainSelector) (ccipocr3.SeqNum, error) { + ret := _m.Called(ctx, sourceChainSelector, destChainSelector) + + if len(ret) == 0 { + panic("no return value specified for GetExpectedNextSequenceNumber") + } + + var r0 ccipocr3.SeqNum + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, ccipocr3.ChainSelector, ccipocr3.ChainSelector) (ccipocr3.SeqNum, error)); ok { + return rf(ctx, sourceChainSelector, destChainSelector) + } + if rf, ok := ret.Get(0).(func(context.Context, ccipocr3.ChainSelector, ccipocr3.ChainSelector) ccipocr3.SeqNum); ok { + r0 = rf(ctx, sourceChainSelector, destChainSelector) + } else { + r0 = ret.Get(0).(ccipocr3.SeqNum) + } + + if rf, ok := ret.Get(1).(func(context.Context, ccipocr3.ChainSelector, ccipocr3.ChainSelector) error); ok { + r1 = rf(ctx, sourceChainSelector, destChainSelector) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockCCIP_GetExpectedNextSequenceNumber_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetExpectedNextSequenceNumber' +type MockCCIP_GetExpectedNextSequenceNumber_Call struct { + *mock.Call +} + +// GetExpectedNextSequenceNumber is a helper method to define mock.On call +// - ctx context.Context +// - sourceChainSelector ccipocr3.ChainSelector +// - destChainSelector ccipocr3.ChainSelector +func (_e *MockCCIP_Expecter) GetExpectedNextSequenceNumber(ctx interface{}, sourceChainSelector interface{}, destChainSelector interface{}) *MockCCIP_GetExpectedNextSequenceNumber_Call { + return &MockCCIP_GetExpectedNextSequenceNumber_Call{Call: _e.mock.On("GetExpectedNextSequenceNumber", ctx, sourceChainSelector, destChainSelector)} +} + +func (_c *MockCCIP_GetExpectedNextSequenceNumber_Call) Run(run func(ctx context.Context, sourceChainSelector ccipocr3.ChainSelector, destChainSelector ccipocr3.ChainSelector)) *MockCCIP_GetExpectedNextSequenceNumber_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(ccipocr3.ChainSelector), args[2].(ccipocr3.ChainSelector)) + }) + return _c +} + +func (_c *MockCCIP_GetExpectedNextSequenceNumber_Call) Return(_a0 ccipocr3.SeqNum, _a1 error) *MockCCIP_GetExpectedNextSequenceNumber_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockCCIP_GetExpectedNextSequenceNumber_Call) RunAndReturn(run func(context.Context, ccipocr3.ChainSelector, ccipocr3.ChainSelector) (ccipocr3.SeqNum, error)) *MockCCIP_GetExpectedNextSequenceNumber_Call { + _c.Call.Return(run) + return _c +} + // MsgsBetweenSeqNums provides a mock function with given fields: ctx, chain, seqNumRange func (_m *MockCCIP) MsgsBetweenSeqNums(ctx context.Context, chain ccipocr3.ChainSelector, seqNumRange ccipocr3.SeqNumRange) ([]ccipocr3.Message, error) { ret := _m.Called(ctx, chain, seqNumRange) diff --git a/pkg/contractreader/extended_test.go b/pkg/contractreader/extended_test.go index 8fdba6f93..8ca6d5a2e 100644 --- a/pkg/contractreader/extended_test.go +++ b/pkg/contractreader/extended_test.go @@ -7,14 +7,14 @@ import ( "github.com/stretchr/testify/assert" - "github.com/smartcontractkit/chainlink-ccip/internal/mocks" - "github.com/smartcontractkit/chainlink-common/pkg/types" + + chainreadermocks "github.com/smartcontractkit/chainlink-ccip/mocks/cl-common/chainreader" ) func TestExtendedContractReader(t *testing.T) { const contractName = "testContract" - cr := mocks.NewContractReaderMock() + cr := chainreadermocks.NewMockChainReader(t) extCr := NewExtendedContractReader(cr) bindings := extCr.GetBindings(contractName)