From 119df08eec3609a41880a5b471c466e90fff36f8 Mon Sep 17 00:00:00 2001 From: ilija42 <57732589+ilija42@users.noreply.github.com> Date: Fri, 17 May 2024 14:42:05 +0200 Subject: [PATCH] BCF-3225 - Implement forwarder fallback if forwarder not present as a transmitter on OCR2 aggregator (#13221) * Implement forwarder OCR2 fallback if fwd not present as a transmitter * Add changeset --- .changeset/hungry-carpets-flow.md | 5 + common/txmgr/mocks/tx_manager.go | 28 +++++ common/txmgr/txmgr.go | 14 +++ common/txmgr/types/forwarder_manager.go | 1 + common/txmgr/types/mocks/forwarder_manager.go | 28 +++++ .../evm/forwarders/forwarder_manager.go | 38 ++++++ .../evm/forwarders/forwarder_manager_test.go | 110 +++++++++++++++++- core/services/ocr2/delegate.go | 12 +- core/services/ocr2/delegate_test.go | 11 +- .../smoke/forwarders_ocr2_test.go | 6 +- 10 files changed, 244 insertions(+), 9 deletions(-) create mode 100644 .changeset/hungry-carpets-flow.md diff --git a/.changeset/hungry-carpets-flow.md b/.changeset/hungry-carpets-flow.md new file mode 100644 index 00000000000..19835b99c17 --- /dev/null +++ b/.changeset/hungry-carpets-flow.md @@ -0,0 +1,5 @@ +--- +"chainlink": minor +--- + +Added a mechanism to validate forwarders for OCR2 and fallback to EOA if necessary #added diff --git a/common/txmgr/mocks/tx_manager.go b/common/txmgr/mocks/tx_manager.go index 935e7313817..a3e8c489314 100644 --- a/common/txmgr/mocks/tx_manager.go +++ b/common/txmgr/mocks/tx_manager.go @@ -301,6 +301,34 @@ func (_m *TxManager[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) GetFor return r0, r1 } +// GetForwarderForEOAOCR2Feeds provides a mock function with given fields: eoa, ocr2AggregatorID +func (_m *TxManager[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) GetForwarderForEOAOCR2Feeds(eoa ADDR, ocr2AggregatorID ADDR) (ADDR, error) { + ret := _m.Called(eoa, ocr2AggregatorID) + + if len(ret) == 0 { + panic("no return value specified for GetForwarderForEOAOCR2Feeds") + } + + var r0 ADDR + var r1 error + if rf, ok := ret.Get(0).(func(ADDR, ADDR) (ADDR, error)); ok { + return rf(eoa, ocr2AggregatorID) + } + if rf, ok := ret.Get(0).(func(ADDR, ADDR) ADDR); ok { + r0 = rf(eoa, ocr2AggregatorID) + } else { + r0 = ret.Get(0).(ADDR) + } + + if rf, ok := ret.Get(1).(func(ADDR, ADDR) error); ok { + r1 = rf(eoa, ocr2AggregatorID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // HealthReport provides a mock function with given fields: func (_m *TxManager[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) HealthReport() map[string]error { ret := _m.Called() diff --git a/common/txmgr/txmgr.go b/common/txmgr/txmgr.go index 4d4eabe5c40..1c8b59a55cc 100644 --- a/common/txmgr/txmgr.go +++ b/common/txmgr/txmgr.go @@ -47,6 +47,7 @@ type TxManager[ Trigger(addr ADDR) CreateTransaction(ctx context.Context, txRequest txmgrtypes.TxRequest[ADDR, TX_HASH]) (etx txmgrtypes.Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE], err error) GetForwarderForEOA(eoa ADDR) (forwarder ADDR, err error) + GetForwarderForEOAOCR2Feeds(eoa, ocr2AggregatorID ADDR) (forwarder ADDR, err error) RegisterResumeCallback(fn ResumeCallback) SendNativeToken(ctx context.Context, chainID CHAIN_ID, from, to ADDR, value big.Int, gasLimit uint64) (etx txmgrtypes.Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE], err error) Reset(addr ADDR, abandon bool) error @@ -553,6 +554,15 @@ func (b *Txm[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) GetForward return } +// GetForwarderForEOAOCR2Feeds calls forwarderMgr to get a proper forwarder for a given EOA and checks if its set as a transmitter on the OCR2Aggregator contract. +func (b *Txm[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) GetForwarderForEOAOCR2Feeds(eoa, ocr2Aggregator ADDR) (forwarder ADDR, err error) { + if !b.txConfig.ForwardersEnabled() { + return forwarder, fmt.Errorf("forwarding is not enabled, to enable set Transactions.ForwardersEnabled =true") + } + forwarder, err = b.fwdMgr.ForwarderForOCR2Feeds(eoa, ocr2Aggregator) + return +} + func (b *Txm[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) checkEnabled(ctx context.Context, addr ADDR) error { if err := b.keyStore.CheckEnabled(ctx, addr, b.chainID); err != nil { return fmt.Errorf("cannot send transaction from %s on chain ID %s: %w", addr, b.chainID.String(), err) @@ -649,6 +659,10 @@ func (n *NullTxManager[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) Cre func (n *NullTxManager[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) GetForwarderForEOA(addr ADDR) (fwdr ADDR, err error) { return fwdr, err } +func (n *NullTxManager[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) GetForwarderForEOAOCR2Feeds(_, _ ADDR) (fwdr ADDR, err error) { + return fwdr, err +} + func (n *NullTxManager[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) Reset(addr ADDR, abandon bool) error { return nil } diff --git a/common/txmgr/types/forwarder_manager.go b/common/txmgr/types/forwarder_manager.go index 4d70b730004..3e51ffb1524 100644 --- a/common/txmgr/types/forwarder_manager.go +++ b/common/txmgr/types/forwarder_manager.go @@ -9,6 +9,7 @@ import ( type ForwarderManager[ADDR types.Hashable] interface { services.Service ForwarderFor(addr ADDR) (forwarder ADDR, err error) + ForwarderForOCR2Feeds(eoa, ocr2Aggregator ADDR) (forwarder ADDR, err error) // Converts payload to be forwarder-friendly ConvertPayload(dest ADDR, origPayload []byte) ([]byte, error) } diff --git a/common/txmgr/types/mocks/forwarder_manager.go b/common/txmgr/types/mocks/forwarder_manager.go index fe40e7bb5e2..1021e776e9d 100644 --- a/common/txmgr/types/mocks/forwarder_manager.go +++ b/common/txmgr/types/mocks/forwarder_manager.go @@ -91,6 +91,34 @@ func (_m *ForwarderManager[ADDR]) ForwarderFor(addr ADDR) (ADDR, error) { return r0, r1 } +// ForwarderForOCR2Feeds provides a mock function with given fields: eoa, ocr2Aggregator +func (_m *ForwarderManager[ADDR]) ForwarderForOCR2Feeds(eoa ADDR, ocr2Aggregator ADDR) (ADDR, error) { + ret := _m.Called(eoa, ocr2Aggregator) + + if len(ret) == 0 { + panic("no return value specified for ForwarderForOCR2Feeds") + } + + var r0 ADDR + var r1 error + if rf, ok := ret.Get(0).(func(ADDR, ADDR) (ADDR, error)); ok { + return rf(eoa, ocr2Aggregator) + } + if rf, ok := ret.Get(0).(func(ADDR, ADDR) ADDR); ok { + r0 = rf(eoa, ocr2Aggregator) + } else { + r0 = ret.Get(0).(ADDR) + } + + if rf, ok := ret.Get(1).(func(ADDR, ADDR) error); ok { + r1 = rf(eoa, ocr2Aggregator) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // HealthReport provides a mock function with given fields: func (_m *ForwarderManager[ADDR]) HealthReport() map[string]error { ret := _m.Called() diff --git a/core/chains/evm/forwarders/forwarder_manager.go b/core/chains/evm/forwarders/forwarder_manager.go index 7a7a274127f..15e3534e8cb 100644 --- a/core/chains/evm/forwarders/forwarder_manager.go +++ b/core/chains/evm/forwarders/forwarder_manager.go @@ -2,6 +2,7 @@ package forwarders import ( "context" + "slices" "sync" "time" @@ -9,6 +10,7 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" pkgerrors "github.com/pkg/errors" + "github.com/smartcontractkit/libocr/gethwrappers2/ocr2aggregator" "github.com/smartcontractkit/chainlink-common/pkg/logger" "github.com/smartcontractkit/chainlink-common/pkg/services" @@ -131,6 +133,42 @@ func (f *FwdMgr) ForwarderFor(addr common.Address) (forwarder common.Address, er return common.Address{}, pkgerrors.Errorf("Cannot find forwarder for given EOA") } +func (f *FwdMgr) ForwarderForOCR2Feeds(eoa, ocr2Aggregator common.Address) (forwarder common.Address, err error) { + fwdrs, err := f.ORM.FindForwardersByChain(f.ctx, big.Big(*f.evmClient.ConfiguredChainID())) + if err != nil { + return common.Address{}, err + } + + offchainAggregator, err := ocr2aggregator.NewOCR2Aggregator(ocr2Aggregator, f.evmClient) + if err != nil { + return common.Address{}, err + } + + transmitters, err := offchainAggregator.GetTransmitters(&bind.CallOpts{Context: f.ctx}) + if err != nil { + return common.Address{}, pkgerrors.Errorf("failed to get ocr2 aggregator transmitters: %s", err.Error()) + } + + for _, fwdr := range fwdrs { + if !slices.Contains(transmitters, fwdr.Address) { + f.logger.Criticalw("Forwarder is not set as a transmitter", "forwarder", fwdr.Address, "ocr2Aggregator", ocr2Aggregator, "err", err) + continue + } + + eoas, err := f.getContractSenders(fwdr.Address) + if err != nil { + f.logger.Errorw("Failed to get forwarder senders", "forwarder", fwdr.Address, "err", err) + continue + } + for _, addr := range eoas { + if addr == eoa { + return fwdr.Address, nil + } + } + } + return common.Address{}, pkgerrors.Errorf("Cannot find forwarder for given EOA") +} + func (f *FwdMgr) ConvertPayload(dest common.Address, origPayload []byte) ([]byte, error) { databytes, err := f.getForwardedPayload(dest, origPayload) if err != nil { diff --git a/core/chains/evm/forwarders/forwarder_manager_test.go b/core/chains/evm/forwarders/forwarder_manager_test.go index 3a515e7ab39..993efacac4a 100644 --- a/core/chains/evm/forwarders/forwarder_manager_test.go +++ b/core/chains/evm/forwarders/forwarder_manager_test.go @@ -2,19 +2,23 @@ package forwarders_test import ( "math/big" + "slices" "testing" "time" - "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" - + "github.com/ethereum/go-ethereum/accounts/abi/bind" "github.com/ethereum/go-ethereum/accounts/abi/bind/backends" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/smartcontractkit/libocr/gethwrappers2/testocr2aggregator" + "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" "github.com/smartcontractkit/chainlink-common/pkg/utils" + "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/testhelpers" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/client" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/forwarders" @@ -150,3 +154,105 @@ func TestFwdMgr_AccountUnauthorizedToForward_SkipsForwarding(t *testing.T) { err = fwdMgr.Close() require.NoError(t, err) } + +func TestFwdMgr_InvalidForwarderForOCR2FeedsStates(t *testing.T) { + lggr := logger.Test(t) + db := pgtest.NewSqlxDB(t) + ctx := testutils.Context(t) + cfg := configtest.NewTestGeneralConfig(t) + evmcfg := evmtest.NewChainScopedConfig(t, cfg) + owner := testutils.MustNewSimTransactor(t) + ec := backends.NewSimulatedBackend(map[common.Address]core.GenesisAccount{ + owner.From: { + Balance: big.NewInt(0).Mul(big.NewInt(10), big.NewInt(1e18)), + }, + }, 10e6) + t.Cleanup(func() { ec.Close() }) + linkAddr := common.HexToAddress("0x01BE23585060835E02B77ef475b0Cc51aA1e0709") + operatorAddr, _, _, err := operator_wrapper.DeployOperator(owner, ec, linkAddr, owner.From) + require.NoError(t, err) + + forwarderAddr, _, forwarder, err := authorized_forwarder.DeployAuthorizedForwarder(owner, ec, linkAddr, owner.From, operatorAddr, []byte{}) + require.NoError(t, err) + ec.Commit() + + accessAddress, _, _, err := testocr2aggregator.DeploySimpleWriteAccessController(owner, ec) + require.NoError(t, err, "failed to deploy test access controller contract") + ocr2Address, _, ocr2, err := testocr2aggregator.DeployOCR2Aggregator( + owner, + ec, + linkAddr, + big.NewInt(0), + big.NewInt(10), + accessAddress, + accessAddress, + 9, + "TEST", + ) + require.NoError(t, err, "failed to deploy ocr2 test aggregator") + ec.Commit() + + evmClient := client.NewSimulatedBackendClient(t, ec, testutils.FixtureChainID) + lpOpts := logpoller.Opts{ + PollPeriod: 100 * time.Millisecond, + FinalityDepth: 2, + BackfillBatchSize: 3, + RpcBatchSize: 2, + KeepFinalizedBlocksDepth: 1000, + } + lp := logpoller.NewLogPoller(logpoller.NewORM(testutils.FixtureChainID, db, lggr), evmClient, lggr, lpOpts) + fwdMgr := forwarders.NewFwdMgr(db, evmClient, lp, lggr, evmcfg.EVM()) + fwdMgr.ORM = forwarders.NewORM(db) + + _, err = fwdMgr.ORM.CreateForwarder(ctx, forwarderAddr, ubig.Big(*testutils.FixtureChainID)) + require.NoError(t, err) + lst, err := fwdMgr.ORM.FindForwardersByChain(ctx, ubig.Big(*testutils.FixtureChainID)) + require.NoError(t, err) + require.Equal(t, len(lst), 1) + require.Equal(t, lst[0].Address, forwarderAddr) + + fwdMgr = forwarders.NewFwdMgr(db, evmClient, lp, lggr, evmcfg.EVM()) + require.NoError(t, fwdMgr.Start(testutils.Context(t))) + // cannot find forwarder because it isn't authorized nor added as a transmitter + addr, err := fwdMgr.ForwarderForOCR2Feeds(owner.From, ocr2Address) + require.ErrorContains(t, err, "Cannot find forwarder for given EOA") + require.True(t, utils.IsZero(addr)) + + _, err = forwarder.SetAuthorizedSenders(owner, []common.Address{owner.From}) + require.NoError(t, err) + ec.Commit() + + authorizedSenders, err := forwarder.GetAuthorizedSenders(&bind.CallOpts{Context: ctx}) + require.NoError(t, err) + require.Equal(t, owner.From, authorizedSenders[0]) + + // cannot find forwarder because it isn't added as a transmitter + addr, err = fwdMgr.ForwarderForOCR2Feeds(owner.From, ocr2Address) + require.ErrorContains(t, err, "Cannot find forwarder for given EOA") + require.True(t, utils.IsZero(addr)) + + onchainConfig, err := testhelpers.GenerateDefaultOCR2OnchainConfig(big.NewInt(0), big.NewInt(10)) + require.NoError(t, err) + + _, err = ocr2.SetConfig(owner, + []common.Address{testutils.NewAddress(), testutils.NewAddress(), testutils.NewAddress(), testutils.NewAddress()}, + []common.Address{forwarderAddr, testutils.NewAddress(), testutils.NewAddress(), testutils.NewAddress()}, + 1, + onchainConfig, + 0, + []byte{}) + require.NoError(t, err) + ec.Commit() + + transmitters, err := ocr2.GetTransmitters(&bind.CallOpts{Context: ctx}) + require.NoError(t, err) + require.True(t, slices.Contains(transmitters, forwarderAddr)) + + // create new fwd to have an empty cache that has to fetch authorized forwarders from log poller + fwdMgr = forwarders.NewFwdMgr(db, evmClient, lp, lggr, evmcfg.EVM()) + require.NoError(t, fwdMgr.Start(testutils.Context(t))) + addr, err = fwdMgr.ForwarderForOCR2Feeds(owner.From, ocr2Address) + require.NoError(t, err, "forwarder should be valid and found because it is both authorized and set as a transmitter") + require.Equal(t, forwarderAddr, addr) + require.NoError(t, fwdMgr.Close()) +} diff --git a/core/services/ocr2/delegate.go b/core/services/ocr2/delegate.go index 3ebf9a8fefd..5d149d140f1 100644 --- a/core/services/ocr2/delegate.go +++ b/core/services/ocr2/delegate.go @@ -495,14 +495,22 @@ func GetEVMEffectiveTransmitterID(jb *job.Job, chain legacyevm.Chain, lggr logge if chain == nil { return "", fmt.Errorf("job forwarding requires non-nil chain") } - effectiveTransmitterID, err := chain.TxManager().GetForwarderForEOA(common.HexToAddress(spec.TransmitterID.String)) + + var err error + var effectiveTransmitterID common.Address + // Median forwarders need special handling because of OCR2Aggregator transmitters whitelist. + if spec.PluginType == types.Median { + effectiveTransmitterID, err = chain.TxManager().GetForwarderForEOAOCR2Feeds(common.HexToAddress(spec.TransmitterID.String), common.HexToAddress(spec.ContractID)) + } else { + effectiveTransmitterID, err = chain.TxManager().GetForwarderForEOA(common.HexToAddress(spec.TransmitterID.String)) + } + if err == nil { return effectiveTransmitterID.String(), nil } else if !spec.TransmitterID.Valid { return "", errors.New("failed to get forwarder address and transmitterID is not set") } lggr.Warnw("Skipping forwarding for job, will fallback to default behavior", "job", jb.Name, "err", err) - // this shouldn't happen unless behaviour above was changed } return spec.TransmitterID.String, nil diff --git a/core/services/ocr2/delegate_test.go b/core/services/ocr2/delegate_test.go index bae1f5f3e78..8f204f57091 100644 --- a/core/services/ocr2/delegate_test.go +++ b/core/services/ocr2/delegate_test.go @@ -67,10 +67,17 @@ func TestGetEVMEffectiveTransmitterID(t *testing.T) { jb.OCR2OracleSpec.RelayConfig["sendingKeys"] = tc.sendingKeys jb.ForwardingAllowed = tc.forwardingEnabled + args := []interface{}{tc.getForwarderForEOAArg} + getForwarderMethodName := "GetForwarderForEOA" + if tc.pluginType == types.Median { + getForwarderMethodName = "GetForwarderForEOAOCR2Feeds" + args = append(args, common.HexToAddress(jb.OCR2OracleSpec.ContractID)) + } + if tc.forwardingEnabled && tc.getForwarderForEOAErr { - txManager.Mock.On("GetForwarderForEOA", tc.getForwarderForEOAArg).Return(common.HexToAddress("0x0"), errors.New("random error")).Once() + txManager.Mock.On(getForwarderMethodName, args...).Return(common.HexToAddress("0x0"), errors.New("random error")).Once() } else if tc.forwardingEnabled { - txManager.Mock.On("GetForwarderForEOA", tc.getForwarderForEOAArg).Return(common.HexToAddress(tc.expectedTransmitterID), nil).Once() + txManager.Mock.On(getForwarderMethodName, args...).Return(common.HexToAddress(tc.expectedTransmitterID), nil).Once() } } diff --git a/integration-tests/smoke/forwarders_ocr2_test.go b/integration-tests/smoke/forwarders_ocr2_test.go index ee86e8cc4b6..036236fdf0f 100644 --- a/integration-tests/smoke/forwarders_ocr2_test.go +++ b/integration-tests/smoke/forwarders_ocr2_test.go @@ -92,9 +92,6 @@ func TestForwarderOCR2Basic(t *testing.T) { ocrInstances, err := actions_seth.DeployOCRv2Contracts(l, sethClient, 1, common.HexToAddress(lt.Address()), transmitters, ocrOffchainOptions) require.NoError(t, err, "Error deploying OCRv2 contracts with forwarders") - err = actions.CreateOCRv2JobsLocal(ocrInstances, bootstrapNode, workerNodes, env.MockAdapter, "ocr2", 5, uint64(sethClient.ChainID), true, false) - require.NoError(t, err, "Error creating OCRv2 jobs with forwarders") - ocrv2Config, err := actions.BuildMedianOCR2ConfigLocal(workerNodes, ocrOffchainOptions) require.NoError(t, err, "Error building OCRv2 config") ocrv2Config.Transmitters = authorizedForwarders @@ -102,6 +99,9 @@ func TestForwarderOCR2Basic(t *testing.T) { err = actions_seth.ConfigureOCRv2AggregatorContracts(ocrv2Config, ocrInstances) require.NoError(t, err, "Error configuring OCRv2 aggregator contracts") + err = actions.CreateOCRv2JobsLocal(ocrInstances, bootstrapNode, workerNodes, env.MockAdapter, "ocr2", 5, uint64(sethClient.ChainID), true, false) + require.NoError(t, err, "Error creating OCRv2 jobs with forwarders") + err = actions_seth.WatchNewOCRRound(l, sethClient, 1, contracts.V2OffChainAgrregatorToOffChainAggregatorWithRounds(ocrInstances), time.Duration(10*time.Minute)) require.NoError(t, err, "error watching for new OCRv2 round")