diff --git a/core/services/ocr2/plugins/ccip/commit_reporting_plugin.go b/core/services/ocr2/plugins/ccip/commit_reporting_plugin.go index a0345ec284..9de8a133a4 100644 --- a/core/services/ocr2/plugins/ccip/commit_reporting_plugin.go +++ b/core/services/ocr2/plugins/ccip/commit_reporting_plugin.go @@ -164,7 +164,6 @@ func (rf *CommitReportingPluginFactory) NewReportingPlugin(config types.Reportin rf.config.destLP, rf.config.offRamp, rf.destPriceRegReader, - rf.config.destClient, int64(rf.config.commitStore.OffchainConfig().DestFinalityDepth), ), gasPriceEstimator: rf.config.commitStore.GasPriceEstimator(), diff --git a/core/services/ocr2/plugins/ccip/commit_reporting_plugin_test.go b/core/services/ocr2/plugins/ccip/commit_reporting_plugin_test.go index 4517d39979..9653c1e237 100644 --- a/core/services/ocr2/plugins/ccip/commit_reporting_plugin_test.go +++ b/core/services/ocr2/plugins/ccip/commit_reporting_plugin_test.go @@ -302,8 +302,8 @@ func TestCommitReportingPlugin_Report(t *testing.T) { aos := make([]types.AttributedObservation, 0, len(tc.observations)) for _, o := range tc.observations { - obs, err := o.Marshal() - assert.NoError(t, err) + obs, err2 := o.Marshal() + assert.NoError(t, err2) aos = append(aos, types.AttributedObservation{Observation: obs}) } diff --git a/core/services/ocr2/plugins/ccip/execution_reporting_plugin.go b/core/services/ocr2/plugins/ccip/execution_reporting_plugin.go index 833c4014f1..1dc72992c0 100644 --- a/core/services/ocr2/plugins/ccip/execution_reporting_plugin.go +++ b/core/services/ocr2/plugins/ccip/execution_reporting_plugin.go @@ -21,7 +21,6 @@ import ( evmclient "github.com/smartcontractkit/chainlink/v2/core/chains/evm/client" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/gas" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/logpoller" - "github.com/smartcontractkit/chainlink/v2/core/gethwrappers/ccip/generated/custom_token_pool" "github.com/smartcontractkit/chainlink/v2/core/gethwrappers/ccip/generated/evm_2_evm_offramp" "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/internal" @@ -66,19 +65,18 @@ type ExecutionPluginStaticConfig struct { type ExecutionReportingPlugin struct { config ExecutionPluginStaticConfig - F int - lggr logger.Logger - inflightReports *inflightExecReportsContainer - snoozedRoots cache.SnoozedRoots - destPriceRegistry ccipdata.PriceRegistryReader - destWrappedNative common.Address - onchainConfig ccipdata.ExecOnchainConfig - offchainConfig ccipdata.ExecOffchainConfig - cachedSourceFeeTokens cache.AutoSync[[]common.Address] - cachedDestTokens cache.AutoSync[cache.CachedTokens] - cachedTokenPools cache.AutoSync[map[common.Address]common.Address] - customTokenPoolFactory func(ctx context.Context, poolAddress common.Address, bind bind.ContractBackend) (custom_token_pool.CustomTokenPoolInterface, error) - gasPriceEstimator prices.GasPriceEstimatorExec + F int + lggr logger.Logger + inflightReports *inflightExecReportsContainer + snoozedRoots cache.SnoozedRoots + destPriceRegistry ccipdata.PriceRegistryReader + destWrappedNative common.Address + onchainConfig ccipdata.ExecOnchainConfig + offchainConfig ccipdata.ExecOffchainConfig + cachedSourceFeeTokens cache.AutoSync[[]common.Address] + cachedDestTokens cache.AutoSync[cache.CachedTokens] + cachedTokenPools cache.AutoSync[map[common.Address]common.Address] + gasPriceEstimator prices.GasPriceEstimatorExec } type ExecutionReportingPluginFactory struct { @@ -157,10 +155,7 @@ func (rf *ExecutionReportingPluginFactory) NewReportingPlugin(config types.Repor cachedDestTokens: cachedDestTokens, cachedSourceFeeTokens: cachedSourceFeeTokens, cachedTokenPools: cachedTokenPools, - customTokenPoolFactory: func(ctx context.Context, poolAddress common.Address, contractBackend bind.ContractBackend) (custom_token_pool.CustomTokenPoolInterface, error) { - return custom_token_pool.NewCustomTokenPool(poolAddress, contractBackend) - }, - gasPriceEstimator: rf.config.offRampReader.GasPriceEstimator(), + gasPriceEstimator: rf.config.offRampReader.GasPriceEstimator(), }, types.ReportingPluginInfo{ Name: "CCIPExecution", // Setting this to false saves on calldata since OffRamp doesn't require agreement between NOPs @@ -358,49 +353,66 @@ func (r *ExecutionReportingPlugin) getExecutableObservations(ctx context.Context return []ObservedMessage{}, nil } -// destPoolRateLimits returns a map that consists of the rate limits of each destination tokens of the provided reports. +// destPoolRateLimits returns a map that consists of the rate limits of each destination token of the provided reports. // If a token is missing from the returned map it either means that token was not found or token pool is disabled for this token. func (r *ExecutionReportingPlugin) destPoolRateLimits(ctx context.Context, commitReports []commitReportWithSendRequests, sourceToDestToken map[common.Address]common.Address) (map[common.Address]*big.Int, error) { - dstTokens := make(map[common.Address]struct{}) // todo: replace with a set or uniqueSlice data structure + tokenPools, err := r.cachedTokenPools.Get(ctx) + if err != nil { + return nil, fmt.Errorf("get cached token pools: %w", err) + } + + dstTokenToPool := make(map[common.Address]common.Address) + dstPoolToToken := make(map[common.Address]common.Address) + dstPools := make([]common.Address, 0) + for _, msg := range commitReports { for _, req := range msg.sendRequestsWithMeta { for _, tk := range req.TokenAmounts { - if dstToken, exists := sourceToDestToken[tk.Token]; exists { - dstTokens[dstToken] = struct{}{} + dstToken, exists := sourceToDestToken[tk.Token] + if !exists { + r.lggr.Warnw("token not found on destination chain", "sourceToken", tk) continue } - r.lggr.Warnw("token not found on destination chain", "sourceToken", tk) + + // another message with the same token exists in the report + // we skip it since we don't want to query for the rate limit twice + if _, seen := dstTokenToPool[dstToken]; seen { + continue + } + + poolAddress, exists := tokenPools[dstToken] + if !exists { + return nil, fmt.Errorf("pool for token '%s' does not exist", dstToken) + } + + if tokenAddr, seen := dstPoolToToken[poolAddress]; seen { + return nil, fmt.Errorf("pool is already seen for token %s", tokenAddr) + } + + dstTokenToPool[dstToken] = poolAddress + dstPoolToToken[poolAddress] = dstToken + dstPools = append(dstPools, poolAddress) } } } - tokenPools, err := r.cachedTokenPools.Get(ctx) + rateLimits, err := r.config.offRampReader.GetTokenPoolsRateLimits(ctx, dstPools) if err != nil { - return nil, fmt.Errorf("get cached token pools: %w", err) + return nil, fmt.Errorf("fetch pool rate limits: %w", err) } - res := make(map[common.Address]*big.Int, len(dstTokens)) - - for dstToken := range dstTokens { - poolAddress, exists := tokenPools[dstToken] - if !exists { - return nil, fmt.Errorf("pool for token '%s' does not exist", dstToken) - } - - tokenPool, err := r.customTokenPoolFactory(ctx, poolAddress, r.config.destClient) - if err != nil { - return nil, fmt.Errorf("new custom dest token pool %s: %w", poolAddress, err) - } - - offRampAddr := r.config.offRampReader.Address() - rateLimiterState, err := tokenPool.CurrentOffRampRateLimiterState(&bind.CallOpts{Context: ctx}, offRampAddr) - if err != nil { - return nil, fmt.Errorf("get rate off ramp rate limiter state: %w", err) + res := make(map[common.Address]*big.Int, len(dstTokenToPool)) + for i, rateLimit := range rateLimits { + // if the rate limit is disabled for this token pool then we omit it from the result + if !rateLimit.IsEnabled { + continue } - if rateLimiterState.IsEnabled { - res[dstToken] = rateLimiterState.Tokens + tokenAddr, exists := dstPoolToToken[dstPools[i]] + if !exists { + return nil, fmt.Errorf("pool to token mapping does not contain %s", dstPools[i]) } + res[tokenAddr] = rateLimit.Tokens } return res, nil diff --git a/core/services/ocr2/plugins/ccip/execution_reporting_plugin_test.go b/core/services/ocr2/plugins/ccip/execution_reporting_plugin_test.go index 8f70521c9c..de03a8f0aa 100644 --- a/core/services/ocr2/plugins/ccip/execution_reporting_plugin_test.go +++ b/core/services/ocr2/plugins/ccip/execution_reporting_plugin_test.go @@ -12,7 +12,6 @@ import ( "time" "github.com/cometbft/cometbft/libs/rand" - "github.com/ethereum/go-ethereum/accounts/abi/bind" "github.com/ethereum/go-ethereum/common" "github.com/pkg/errors" "github.com/smartcontractkit/libocr/commontypes" @@ -23,7 +22,6 @@ import ( "github.com/stretchr/testify/require" lpMocks "github.com/smartcontractkit/chainlink/v2/core/chains/evm/logpoller/mocks" - "github.com/smartcontractkit/chainlink/v2/core/gethwrappers/ccip/generated/custom_token_pool" "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/internal" "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/internal/cache" @@ -128,6 +126,8 @@ func TestExecutionReportingPlugin_Observation(t *testing.T) { mockOffRampReader.On("CurrentRateLimiterState", mock.Anything).Return(tc.rateLimiterState, nil).Maybe() mockOffRampReader.On("Address").Return(offRamp.Address()).Maybe() mockOffRampReader.On("GetSenderNonce", mock.Anything, mock.Anything).Return(offRamp.GetSenderNonce(nil, utils.RandomAddress())).Maybe() + mockOffRampReader.On("GetTokenPoolsRateLimits", ctx, []common.Address{}). + Return([]ccipdata.TokenBucketRateLimit{}, nil).Maybe() p.config.offRampReader = mockOffRampReader mockOnRampReader := ccipdata.NewMockOnRampReader(t) @@ -765,11 +765,14 @@ func TestExecutionReportingPlugin_destPoolRateLimits(t *testing.T) { tk2pool := utils.RandomAddress() testCases := []struct { - name string - tokenAmounts []internal.TokenAmount - sourceToDestToken map[common.Address]common.Address - destPools map[common.Address]common.Address - poolRateLimits map[common.Address]custom_token_pool.RateLimiterTokenBucket + name string + tokenAmounts []internal.TokenAmount + // the order of the following fields: sourceTokens, destTokens and poolRateLimits + // should follow the order of the tokenAmounts + sourceTokens []common.Address + destTokens []common.Address + destPools []common.Address + poolRateLimits []ccipdata.TokenBucketRateLimit destPoolsCacheErr error expRateLimits map[common.Address]*big.Int @@ -783,17 +786,12 @@ func TestExecutionReportingPlugin_destPoolRateLimits(t *testing.T) { {Token: tk1}, {Token: tk1}, }, - sourceToDestToken: map[common.Address]common.Address{ - tk1: tk1dest, - tk2: tk2dest, - }, - destPools: map[common.Address]common.Address{ - tk1dest: tk1pool, - tk2dest: tk2pool, - }, - poolRateLimits: map[common.Address]custom_token_pool.RateLimiterTokenBucket{ - tk1pool: {Tokens: big.NewInt(1000), IsEnabled: true}, - tk2pool: {Tokens: big.NewInt(2000), IsEnabled: true}, + sourceTokens: []common.Address{tk1, tk2}, + destTokens: []common.Address{tk1dest, tk2dest}, + destPools: []common.Address{tk1pool, tk2pool}, + poolRateLimits: []ccipdata.TokenBucketRateLimit{ + {Tokens: big.NewInt(1000), IsEnabled: true}, + {Tokens: big.NewInt(2000), IsEnabled: true}, }, expRateLimits: map[common.Address]*big.Int{ tk1dest: big.NewInt(1000), @@ -802,19 +800,16 @@ func TestExecutionReportingPlugin_destPoolRateLimits(t *testing.T) { expErr: false, }, { - name: "token missing from source to dest mapping", + name: "missing from source to dest mapping should not return error", tokenAmounts: []internal.TokenAmount{ {Token: tk1}, - {Token: tk2}, // <-- missing form sourceToDestToken - }, - sourceToDestToken: map[common.Address]common.Address{ - tk1: tk1dest, - }, - destPools: map[common.Address]common.Address{ - tk1dest: tk1pool, + {Token: tk2}, // <- missing }, - poolRateLimits: map[common.Address]custom_token_pool.RateLimiterTokenBucket{ - tk1pool: {Tokens: big.NewInt(1000), IsEnabled: true}, + sourceTokens: []common.Address{tk1}, + destTokens: []common.Address{tk1dest}, + destPools: []common.Address{tk1pool}, + poolRateLimits: []ccipdata.TokenBucketRateLimit{ + {Tokens: big.NewInt(1000), IsEnabled: true}, }, expRateLimits: map[common.Address]*big.Int{ tk1dest: big.NewInt(1000), @@ -827,17 +822,12 @@ func TestExecutionReportingPlugin_destPoolRateLimits(t *testing.T) { {Token: tk1}, {Token: tk2}, }, - sourceToDestToken: map[common.Address]common.Address{ - tk1: tk1dest, - tk2: tk2dest, - }, - destPools: map[common.Address]common.Address{ - tk1dest: tk1pool, - tk2dest: tk2pool, - }, - poolRateLimits: map[common.Address]custom_token_pool.RateLimiterTokenBucket{ - tk1pool: {Tokens: big.NewInt(1000), IsEnabled: true}, - tk2pool: {Tokens: big.NewInt(2000), IsEnabled: false}, // <--- pool disabled + sourceTokens: []common.Address{tk1, tk2}, + destTokens: []common.Address{tk1dest, tk2dest}, + destPools: []common.Address{tk1pool, tk2pool}, + poolRateLimits: []ccipdata.TokenBucketRateLimit{ + {Tokens: big.NewInt(1000), IsEnabled: true}, + {Tokens: big.NewInt(2000), IsEnabled: false}, }, expRateLimits: map[common.Address]*big.Int{ tk1dest: big.NewInt(1000), @@ -845,18 +835,37 @@ func TestExecutionReportingPlugin_destPoolRateLimits(t *testing.T) { expErr: false, }, { - name: "dest pool cache error", - tokenAmounts: []internal.TokenAmount{{Token: tk1}}, - sourceToDestToken: map[common.Address]common.Address{tk1: tk1dest}, - destPoolsCacheErr: errors.New("some random error"), + name: "dest pool cache error", + tokenAmounts: []internal.TokenAmount{ + {Token: tk1}, + }, + sourceTokens: []common.Address{tk1}, + destTokens: []common.Address{tk1dest}, + destPools: []common.Address{tk1pool}, + poolRateLimits: []ccipdata.TokenBucketRateLimit{ + {Tokens: big.NewInt(1000), IsEnabled: true}, + }, + expRateLimits: map[common.Address]*big.Int{ + tk1dest: big.NewInt(1000), + }, + destPoolsCacheErr: errors.New("some err"), expErr: true, }, { - name: "pool for token not found", - tokenAmounts: []internal.TokenAmount{{Token: tk1}}, - sourceToDestToken: map[common.Address]common.Address{tk1: tk1dest}, - destPools: map[common.Address]common.Address{}, - expErr: true, + name: "pool for token not found", + tokenAmounts: []internal.TokenAmount{ + {Token: tk1}, {Token: tk2}, {Token: tk1}, {Token: tk2}, + }, + sourceTokens: []common.Address{tk1, tk2}, + destTokens: []common.Address{tk1dest, tk2dest}, + destPools: []common.Address{tk1pool}, // <-- pool2 not found + poolRateLimits: []ccipdata.TokenBucketRateLimit{ + {Tokens: big.NewInt(1000), IsEnabled: true}, + }, + expRateLimits: map[common.Address]*big.Int{ + tk1dest: big.NewInt(1000), + }, + expErr: true, }, } @@ -864,26 +873,31 @@ func TestExecutionReportingPlugin_destPoolRateLimits(t *testing.T) { lggr := logger.TestLogger(t) for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { + sourceToDestMapping := make(map[common.Address]common.Address) + for i, srcTk := range tc.sourceTokens { + sourceToDestMapping[srcTk] = tc.destTokens[i] + } + + poolsMapping := make(map[common.Address]common.Address) + for i, poolAddr := range tc.destPools { + poolsMapping[tc.destTokens[i]] = poolAddr + } + p := &ExecutionReportingPlugin{} p.lggr = lggr tokenPoolsCache := cache.NewMockAutoSync[map[common.Address]common.Address](t) - tokenPoolsCache.On("Get", ctx).Return(tc.destPools, tc.destPoolsCacheErr).Maybe() + tokenPoolsCache.On("Get", ctx).Return(poolsMapping, tc.destPoolsCacheErr).Maybe() p.cachedTokenPools = tokenPoolsCache - offRamp, offRampAddr := testhelpers.NewFakeOffRamp(t) - offRamp.SetTokenPools(tc.destPools) - + offRampAddr := utils.RandomAddress() mockOffRampReader := ccipdata.NewMockOffRampReader(t) mockOffRampReader.On("Address").Return(offRampAddr, nil).Maybe() + mockOffRampReader.On("GetTokenPoolsRateLimits", ctx, tc.destPools). + Return(tc.poolRateLimits, nil). + Maybe() p.config.offRampReader = mockOffRampReader - p.customTokenPoolFactory = func(ctx context.Context, poolAddress common.Address, _ bind.ContractBackend) (custom_token_pool.CustomTokenPoolInterface, error) { - mp := &mockPool{} - mp.On("CurrentOffRampRateLimiterState", mock.Anything, offRampAddr).Return(tc.poolRateLimits[poolAddress], nil) - return mp, nil - } - rateLimits, err := p.destPoolRateLimits(ctx, []commitReportWithSendRequests{ { sendRequestsWithMeta: []internal.EVM2EVMOnRampCCIPSendRequestedWithMeta{ @@ -894,7 +908,8 @@ func TestExecutionReportingPlugin_destPoolRateLimits(t *testing.T) { }, }, }, - }, tc.sourceToDestToken) + }, sourceToDestMapping) + if tc.expErr { assert.Error(t, err) return @@ -1711,13 +1726,3 @@ func generateExecutionReport(t *testing.T, numMsgs, tokensPerMsg, bytesPerMsg in ProofFlagBits: big.NewInt(rand.Int64()), } } - -type mockPool struct { - custom_token_pool.CustomTokenPoolInterface - mock.Mock -} - -func (mp *mockPool) CurrentOffRampRateLimiterState(opts *bind.CallOpts, offRamp common.Address) (custom_token_pool.RateLimiterTokenBucket, error) { - args := mp.Called(opts, offRamp) - return args.Get(0).(custom_token_pool.RateLimiterTokenBucket), args.Error(1) -} diff --git a/core/services/ocr2/plugins/ccip/internal/cache/tokens.go b/core/services/ocr2/plugins/ccip/internal/cache/tokens.go index 9217209c46..15025a67f0 100644 --- a/core/services/ocr2/plugins/ccip/internal/cache/tokens.go +++ b/core/services/ocr2/plugins/ccip/internal/cache/tokens.go @@ -5,13 +5,10 @@ import ( "fmt" "sync" - "github.com/ethereum/go-ethereum/accounts/abi/bind" "github.com/ethereum/go-ethereum/common" "golang.org/x/exp/slices" - evmclient "github.com/smartcontractkit/chainlink/v2/core/chains/evm/client" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/logpoller" - "github.com/smartcontractkit/chainlink/v2/core/gethwrappers/generated/link_token_interface" "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/internal/ccipdata" ) @@ -66,7 +63,6 @@ func NewTokenToDecimals( lp logpoller.LogPoller, offRamp ccipdata.OffRampReader, priceRegistryReader ccipdata.PriceRegistryReader, - client evmclient.Client, optimisticConfirmations int64, ) *CachedChain[map[common.Address]uint8] { return &CachedChain[map[common.Address]uint8]{ @@ -81,9 +77,6 @@ func NewTokenToDecimals( lggr: lggr, priceRegistryReader: priceRegistryReader, offRamp: offRamp, - tokenFactory: func(token common.Address) (link_token_interface.LinkTokenInterface, error) { - return link_token_interface.NewLinkToken(token, client) - }, }, } } @@ -173,7 +166,6 @@ type tokenToDecimals struct { lggr logger.Logger offRamp ccipdata.OffRampReader priceRegistryReader ccipdata.PriceRegistryReader - tokenFactory func(address common.Address) (link_token_interface.LinkTokenInterface, error) tokenDecimals sync.Map } @@ -190,25 +182,29 @@ func (t *tokenToDecimals) CallOrigin(ctx context.Context) (map[common.Address]ui } mapping := make(map[common.Address]uint8, len(destTokens)) + unknownDecimalsTokens := make([]common.Address, 0, len(destTokens)) + for _, token := range destTokens { if decimals, exists := t.getCachedDecimals(token); exists { mapping[token] = decimals continue } + unknownDecimalsTokens = append(unknownDecimalsTokens, token) + } - tokenContract, err := t.tokenFactory(token) - if err != nil { - return nil, err - } - - decimals, err := tokenContract.Decimals(&bind.CallOpts{Context: ctx}) - if err != nil { - return nil, fmt.Errorf("get token %s decimals: %w", token, err) - } + if len(unknownDecimalsTokens) == 0 { + return mapping, nil + } - t.setCachedDecimals(token, decimals) - mapping[token] = decimals + decimals, err := t.priceRegistryReader.GetTokensDecimals(ctx, unknownDecimalsTokens) + if err != nil { + return nil, fmt.Errorf("get tokens decimals: %w", err) } + for i := range decimals { + t.setCachedDecimals(unknownDecimalsTokens[i], decimals[i]) + mapping[unknownDecimalsTokens[i]] = decimals[i] + } + return mapping, nil } diff --git a/core/services/ocr2/plugins/ccip/internal/cache/tokens_test.go b/core/services/ocr2/plugins/ccip/internal/cache/tokens_test.go index b2a44926f8..2a91e68c7b 100644 --- a/core/services/ocr2/plugins/ccip/internal/cache/tokens_test.go +++ b/core/services/ocr2/plugins/ccip/internal/cache/tokens_test.go @@ -11,9 +11,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" + "golang.org/x/exp/slices" - mock_contracts "github.com/smartcontractkit/chainlink/v2/core/gethwrappers/ccip/mocks" - "github.com/smartcontractkit/chainlink/v2/core/gethwrappers/generated/link_token_interface" "github.com/smartcontractkit/chainlink/v2/core/internal/testutils" "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/internal/ccipdata" @@ -21,7 +20,7 @@ import ( ) func Test_tokenToDecimals(t *testing.T) { - tokenPriceMappings := map[common.Address]uint8{ + tokenDecimalsMapping := map[common.Address]uint8{ common.HexToAddress("0xA"): 10, common.HexToAddress("0xB"): 5, common.HexToAddress("0xC"): 2, @@ -67,7 +66,7 @@ func Test_tokenToDecimals(t *testing.T) { }, }, { - name: "missing tokens are skipped", + name: "error on invalid token", destTokens: []common.Address{}, feeTokens: []common.Address{common.HexToAddress("0xD")}, want: map[common.Address]uint8{}, @@ -80,14 +79,39 @@ func Test_tokenToDecimals(t *testing.T) { offRampReader := ccipdata.NewMockOffRampReader(t) offRampReader.On("GetDestinationTokens", mock.Anything).Return(tt.destTokens, nil) + decimalsQueryTokens := make([]common.Address, 0) + tokenDecimals := make([]uint8, 0) + var queryErr error + for i := range tt.destTokens { + decimals, exists := tokenDecimalsMapping[tt.destTokens[i]] + if !exists { + queryErr = fmt.Errorf("decimals not found") + } + tokenDecimals = append(tokenDecimals, decimals) + decimalsQueryTokens = append(decimalsQueryTokens, tt.destTokens[i]) + } + for i := range tt.feeTokens { + if slices.Contains(decimalsQueryTokens, tt.feeTokens[i]) { + continue + } + decimals, exists := tokenDecimalsMapping[tt.feeTokens[i]] + if !exists { + queryErr = fmt.Errorf("decimals not found") + } + tokenDecimals = append(tokenDecimals, decimals) + decimalsQueryTokens = append(decimalsQueryTokens, tt.feeTokens[i]) + } + priceRegistryReader := ccipdata.NewMockPriceRegistryReader(t) priceRegistryReader.On("GetFeeTokens", mock.Anything).Return(tt.feeTokens, nil) + if len(decimalsQueryTokens) > 0 { + priceRegistryReader.On("GetTokensDecimals", mock.Anything, decimalsQueryTokens).Return(tokenDecimals, queryErr).Once() + } tokenToDecimal := &tokenToDecimals{ lggr: logger.TestLogger(t), offRamp: offRampReader, priceRegistryReader: priceRegistryReader, - tokenFactory: createTokenFactory(tokenPriceMappings), } got, err := tokenToDecimal.CallOrigin(testutils.Context(t)) @@ -99,11 +123,7 @@ func Test_tokenToDecimals(t *testing.T) { require.NoError(t, err) assert.Equal(t, tt.want, got) - // we set token factory to always return an error - // we don't expect it to be used again, decimals should be in cache. - tokenToDecimal.tokenFactory = func(address common.Address) (link_token_interface.LinkTokenInterface, error) { - return nil, fmt.Errorf("some error") - } + // we don't expect rpc call to be made, decimals should be in cache. got, err = tokenToDecimal.CallOrigin(testutils.Context(t)) require.NoError(t, err) assert.Equal(t, tt.want, got) @@ -217,16 +237,3 @@ func Test_cachedDecimals(t *testing.T) { assert.Equal(t, uint8(123), decimals) assert.True(t, exists) } - -func createTokenFactory(decimalMapping map[common.Address]uint8) func(address common.Address) (link_token_interface.LinkTokenInterface, error) { - return func(address common.Address) (link_token_interface.LinkTokenInterface, error) { - linkToken := &mock_contracts.LinkTokenInterface{} - if decimals, found := decimalMapping[address]; found { - // Make sure each token is fetched only once - linkToken.On("Decimals", mock.Anything).Return(decimals, nil) - } else { - linkToken.On("Decimals", mock.Anything).Return(uint8(0), errors.New("Error")) - } - return linkToken, nil - } -} diff --git a/core/services/ocr2/plugins/ccip/internal/ccipdata/logpoller.go b/core/services/ocr2/plugins/ccip/internal/ccipdata/logpoller.go index afca3f98c2..256bc0171d 100644 --- a/core/services/ocr2/plugins/ccip/internal/ccipdata/logpoller.go +++ b/core/services/ocr2/plugins/ccip/internal/ccipdata/logpoller.go @@ -9,7 +9,6 @@ import ( evmclient "github.com/smartcontractkit/chainlink/v2/core/chains/evm/client" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/logpoller" - "github.com/smartcontractkit/chainlink/v2/core/gethwrappers/ccip/generated/commit_store" "github.com/smartcontractkit/chainlink/v2/core/gethwrappers/ccip/generated/evm_2_evm_offramp" "github.com/smartcontractkit/chainlink/v2/core/gethwrappers/ccip/generated/evm_2_evm_onramp" "github.com/smartcontractkit/chainlink/v2/core/gethwrappers/ccip/generated/price_registry" @@ -109,21 +108,6 @@ func (c *LogPollerReader) loadOffRamp(addr common.Address) (*evm_2_evm_offramp.E return offRamp, nil } -func (c *LogPollerReader) loadCommitStore(addr common.Address) (*commit_store.CommitStoreFilterer, error) { - commitStore, exists := loadCachedDependency[*commit_store.CommitStoreFilterer](&c.dependencyCache, addr) - if exists { - return commitStore, nil - } - - commitStore, err := commit_store.NewCommitStoreFilterer(addr, c.client) - if err != nil { - return nil, err - } - - c.dependencyCache.Store(addr, commitStore) - return commitStore, nil -} - func loadCachedDependency[T any](cache *sync.Map, addr common.Address) (T, bool) { var empty T diff --git a/core/services/ocr2/plugins/ccip/internal/ccipdata/offramp_reader.go b/core/services/ocr2/plugins/ccip/internal/ccipdata/offramp_reader.go index 9322ebdf1e..f3747fdcb8 100644 --- a/core/services/ocr2/plugins/ccip/internal/ccipdata/offramp_reader.go +++ b/core/services/ocr2/plugins/ccip/internal/ccipdata/offramp_reader.go @@ -94,6 +94,14 @@ type ExecReport struct { ProofFlagBits *big.Int } +type TokenBucketRateLimit struct { + Tokens *big.Int + LastUpdated uint32 + IsEnabled bool + Capacity *big.Int + Rate *big.Int +} + //go:generate mockery --quiet --name OffRampReader --output . --filename offramp_reader_mock.go --inpackage --case=underscore type OffRampReader interface { Closer @@ -107,6 +115,7 @@ type OffRampReader interface { // GetDestinationTokensFromSourceTokens will return an 1:1 mapping of the provided source tokens to dest tokens. // Note that if you provide the same token twice you will get an error, each token should be provided once. GetDestinationTokensFromSourceTokens(ctx context.Context, tokenAddresses []common.Address) ([]common.Address, error) + GetTokenPoolsRateLimits(ctx context.Context, poolAddresses []common.Address) ([]TokenBucketRateLimit, error) GetSupportedTokens(ctx context.Context) ([]common.Address, error) Address() common.Address // TODO Needed for caching, maybe caching should move behind the readers? diff --git a/core/services/ocr2/plugins/ccip/internal/ccipdata/offramp_reader_mock.go b/core/services/ocr2/plugins/ccip/internal/ccipdata/offramp_reader_mock.go index d09d97b5f0..e13dbf7084 100644 --- a/core/services/ocr2/plugins/ccip/internal/ccipdata/offramp_reader_mock.go +++ b/core/services/ocr2/plugins/ccip/internal/ccipdata/offramp_reader_mock.go @@ -361,6 +361,32 @@ func (_m *MockOffRampReader) GetSupportedTokens(ctx context.Context) ([]common.A return r0, r1 } +// GetTokenPoolsRateLimits provides a mock function with given fields: ctx, poolAddresses +func (_m *MockOffRampReader) GetTokenPoolsRateLimits(ctx context.Context, poolAddresses []common.Address) ([]TokenBucketRateLimit, error) { + ret := _m.Called(ctx, poolAddresses) + + var r0 []TokenBucketRateLimit + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, []common.Address) ([]TokenBucketRateLimit, error)); ok { + return rf(ctx, poolAddresses) + } + if rf, ok := ret.Get(0).(func(context.Context, []common.Address) []TokenBucketRateLimit); ok { + r0 = rf(ctx, poolAddresses) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]TokenBucketRateLimit) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, []common.Address) error); ok { + r1 = rf(ctx, poolAddresses) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // OffchainConfig provides a mock function with given fields: func (_m *MockOffRampReader) OffchainConfig() ExecOffchainConfig { ret := _m.Called() diff --git a/core/services/ocr2/plugins/ccip/internal/ccipdata/offramp_reader_test.go b/core/services/ocr2/plugins/ccip/internal/ccipdata/offramp_reader_test.go index 3e95f5b9f6..f5adb87969 100644 --- a/core/services/ocr2/plugins/ccip/internal/ccipdata/offramp_reader_test.go +++ b/core/services/ocr2/plugins/ccip/internal/ccipdata/offramp_reader_test.go @@ -383,4 +383,8 @@ func testOffRampReader(t *testing.T, th offRampReaderTH) { destTokens, err := th.reader.GetDestinationTokensFromSourceTokens(ctx, tokens) require.NoError(t, err) require.Empty(t, destTokens) + + rateLimits, err := th.reader.GetTokenPoolsRateLimits(ctx, []common.Address{}) + require.NoError(t, err) + require.Empty(t, rateLimits) } diff --git a/core/services/ocr2/plugins/ccip/internal/ccipdata/offramp_reader_v1_0_0_unit_test.go b/core/services/ocr2/plugins/ccip/internal/ccipdata/offramp_reader_v1_0_0_unit_test.go index 630f389c52..ea816e9902 100644 --- a/core/services/ocr2/plugins/ccip/internal/ccipdata/offramp_reader_v1_0_0_unit_test.go +++ b/core/services/ocr2/plugins/ccip/internal/ccipdata/offramp_reader_v1_0_0_unit_test.go @@ -61,9 +61,17 @@ func TestOffRampGetDestinationTokensFromSourceTokens(t *testing.T) { expErr: false, }, { - name: "unexpected output type", + name: "different compatible type", outputChangeFn: func(outputs []rpclib.DataAndErr) []rpclib.DataAndErr { - outputs[0].Outputs = []any{utils.RandomAddress().String()} + outputs[0].Outputs = []any{outputs[0].Outputs[0].(common.Address).String()} + return outputs + }, + expErr: false, + }, + { + name: "different incompatible type", + outputChangeFn: func(outputs []rpclib.DataAndErr) []rpclib.DataAndErr { + outputs[0].Outputs = []any{outputs[0].Outputs[0].(common.Address).Bytes()} return outputs }, expErr: true, diff --git a/core/services/ocr2/plugins/ccip/internal/ccipdata/offramp_v1_0_0.go b/core/services/ocr2/plugins/ccip/internal/ccipdata/offramp_v1_0_0.go index 6ce6335081..4aeb96d226 100644 --- a/core/services/ocr2/plugins/ccip/internal/ccipdata/offramp_v1_0_0.go +++ b/core/services/ocr2/plugins/ccip/internal/ccipdata/offramp_v1_0_0.go @@ -16,6 +16,7 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/chains/evm/client" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/gas" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/logpoller" + "github.com/smartcontractkit/chainlink/v2/core/gethwrappers/ccip/generated/custom_token_pool" "github.com/smartcontractkit/chainlink/v2/core/gethwrappers/ccip/generated/evm_2_evm_offramp" "github.com/smartcontractkit/chainlink/v2/core/gethwrappers/ccip/generated/evm_2_evm_offramp_1_0_0" "github.com/smartcontractkit/chainlink/v2/core/gethwrappers/ccip/generated/router" @@ -37,6 +38,7 @@ const ( var ( abiOffRampV1_0_0 = abihelpers.MustParseABI(evm_2_evm_offramp_1_0_0.EVM2EVMOffRampABI) + abiCustomTokenPool = abihelpers.MustParseABI(custom_token_pool.CustomTokenPoolABI) _ OffRampReader = &OffRampV1_0_0{} ExecutionStateChangedEventV1_0_0 = abihelpers.MustGetEventID("ExecutionStateChanged", abiOffRampV1_0_0) ExecutionStateChangedSeqNrIndexV1_0_0 = 1 @@ -149,29 +151,57 @@ func (o *OffRampV1_0_0) GetDestinationTokensFromSourceTokens(ctx context.Context return nil, fmt.Errorf("batch call limit: %w", err) } + destTokens, err := rpclib.ParseOutputs[common.Address](results, func(d rpclib.DataAndErr) (common.Address, error) { + return rpclib.ParseOutput[common.Address](d, 0) + }) + if err != nil { + return nil, fmt.Errorf("parse outputs: %w", err) + } + seenDestTokens := make(map[common.Address]struct{}) - destTokens := make([]common.Address, 0, len(tokenAddresses)) - for _, res := range results { - if res.Err != nil { - return nil, fmt.Errorf("rpc sub-call: %w", res.Err) + for _, destToken := range destTokens { + if _, exists := seenDestTokens[destToken]; exists { + return nil, fmt.Errorf("offRamp misconfig, destination token %s already exists", destToken) } + seenDestTokens[destToken] = struct{}{} + } - destTokenAddress, err := rpclib.ParseOutput[common.Address](res, 0) - if err != nil { - return nil, err - } - destTokens = append(destTokens, destTokenAddress) + return destTokens, nil +} - if _, exists := seenDestTokens[destTokenAddress]; exists { - return nil, fmt.Errorf("offRamp misconfig, destination token %s already exists", destTokenAddress) - } - seenDestTokens[destTokenAddress] = struct{}{} +func (o *OffRampV1_0_0) GetTokenPoolsRateLimits(ctx context.Context, poolAddresses []common.Address) ([]TokenBucketRateLimit, error) { + if len(poolAddresses) == 0 { + return nil, nil } - if len(destTokens) != len(tokenAddresses) { - return nil, fmt.Errorf("got %d tokens while %d were expected", len(destTokens), len(tokenAddresses)) + evmCalls := make([]rpclib.EvmCall, 0, len(poolAddresses)) + for _, poolAddress := range poolAddresses { + evmCalls = append(evmCalls, rpclib.NewEvmCall( + abiCustomTokenPool, + "currentOffRampRateLimiterState", + poolAddress, + o.addr, + )) } - return destTokens, nil + + latestBlock, err := o.lp.LatestBlock(pg.WithParentCtx(ctx)) + if err != nil { + return nil, fmt.Errorf("get latest block: %w", err) + } + + results, err := o.evmBatchCaller.BatchCall(ctx, uint64(latestBlock), evmCalls) + if err != nil { + return nil, fmt.Errorf("batch call limit: %w", err) + } + + rateLimits, err := rpclib.ParseOutputs[TokenBucketRateLimit](results, func(d rpclib.DataAndErr) (TokenBucketRateLimit, error) { + return rpclib.ParseOutput[TokenBucketRateLimit](d, 0) + }) + if err != nil { + return nil, fmt.Errorf("parse outputs: %w", err) + } + + return rateLimits, nil } func (o *OffRampV1_0_0) GetSupportedTokens(ctx context.Context) ([]common.Address, error) { diff --git a/core/services/ocr2/plugins/ccip/internal/ccipdata/price_registry_reader.go b/core/services/ocr2/plugins/ccip/internal/ccipdata/price_registry_reader.go index 4b67772d6f..638802bea6 100644 --- a/core/services/ocr2/plugins/ccip/internal/ccipdata/price_registry_reader.go +++ b/core/services/ocr2/plugins/ccip/internal/ccipdata/price_registry_reader.go @@ -56,6 +56,7 @@ type PriceRegistryReader interface { FeeTokenEvents() []common.Hash GetFeeTokens(ctx context.Context) ([]common.Address, error) GetTokenPrices(ctx context.Context, wantedTokens []common.Address) ([]TokenPriceUpdate, error) + GetTokensDecimals(ctx context.Context, tokenAddresses []common.Address) ([]uint8, error) } // NewPriceRegistryReader determines the appropriate version of the price registry and returns a reader for it. diff --git a/core/services/ocr2/plugins/ccip/internal/ccipdata/price_registry_reader_mock.go b/core/services/ocr2/plugins/ccip/internal/ccipdata/price_registry_reader_mock.go index 2ebf23e211..50b0489725 100644 --- a/core/services/ocr2/plugins/ccip/internal/ccipdata/price_registry_reader_mock.go +++ b/core/services/ocr2/plugins/ccip/internal/ccipdata/price_registry_reader_mock.go @@ -175,6 +175,32 @@ func (_m *MockPriceRegistryReader) GetTokenPrices(ctx context.Context, wantedTok return r0, r1 } +// GetTokensDecimals provides a mock function with given fields: ctx, tokenAddresses +func (_m *MockPriceRegistryReader) GetTokensDecimals(ctx context.Context, tokenAddresses []common.Address) ([]uint8, error) { + ret := _m.Called(ctx, tokenAddresses) + + var r0 []uint8 + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, []common.Address) ([]uint8, error)); ok { + return rf(ctx, tokenAddresses) + } + if rf, ok := ret.Get(0).(func(context.Context, []common.Address) []uint8); ok { + r0 = rf(ctx, tokenAddresses) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]uint8) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, []common.Address) error); ok { + r1 = rf(ctx, tokenAddresses) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + type mockConstructorTestingTNewMockPriceRegistryReader interface { mock.TestingT Cleanup(func()) diff --git a/core/services/ocr2/plugins/ccip/internal/ccipdata/price_registry_v1_0_0.go b/core/services/ocr2/plugins/ccip/internal/ccipdata/price_registry_v1_0_0.go index 398a3ffc2e..2689be0dfe 100644 --- a/core/services/ocr2/plugins/ccip/internal/ccipdata/price_registry_v1_0_0.go +++ b/core/services/ocr2/plugins/ccip/internal/ccipdata/price_registry_v1_0_0.go @@ -2,6 +2,7 @@ package ccipdata import ( "context" + "fmt" "math/big" "testing" "time" @@ -14,14 +15,17 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/chains/evm/client" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/logpoller" "github.com/smartcontractkit/chainlink/v2/core/gethwrappers/ccip/generated/price_registry_1_0_0" + "github.com/smartcontractkit/chainlink/v2/core/gethwrappers/shared/generated/erc20" "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/abihelpers" "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/internal/logpollerutil" + "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/internal/rpclib" "github.com/smartcontractkit/chainlink/v2/core/services/pg" ) var ( - _ PriceRegistryReader = &PriceRegistryV1_0_0{} + abiERC20 = abihelpers.MustParseABI(erc20.ERC20ABI) + _ PriceRegistryReader = &PriceRegistryV1_0_0{} // Exposed only for backwards compatibility with tests. UsdPerUnitGasUpdatedV1_0_0 = abihelpers.MustGetEventID("UsdPerUnitGasUpdated", abihelpers.MustParseABI(price_registry_1_0_0.PriceRegistryABI)) ) @@ -30,6 +34,7 @@ type PriceRegistryV1_0_0 struct { priceRegistry price_registry_1_0_0.PriceRegistryInterface address common.Address lp logpoller.LogPoller + evmBatchCaller rpclib.EvmBatchCaller lggr logger.Logger filters []logpoller.Filter tokenUpdated common.Hash @@ -136,6 +141,36 @@ func (p *PriceRegistryV1_0_0) GetGasPriceUpdatesCreatedAfter(ctx context.Context ) } +func (p *PriceRegistryV1_0_0) GetTokensDecimals(ctx context.Context, tokenAddresses []common.Address) ([]uint8, error) { + if len(tokenAddresses) == 0 { + return nil, nil + } + + evmCalls := make([]rpclib.EvmCall, 0, len(tokenAddresses)) + for _, tokenAddress := range tokenAddresses { + evmCalls = append(evmCalls, rpclib.NewEvmCall(abiERC20, "decimals", tokenAddress)) + } + + latestBlock, err := p.lp.LatestBlock(pg.WithParentCtx(ctx)) + if err != nil { + return nil, fmt.Errorf("get latest block: %w", err) + } + + results, err := p.evmBatchCaller.BatchCall(ctx, uint64(latestBlock), evmCalls) + if err != nil { + return nil, fmt.Errorf("batch call limit: %w", err) + } + + decimals, err := rpclib.ParseOutputs[uint8](results, func(d rpclib.DataAndErr) (uint8, error) { + return rpclib.ParseOutput[uint8](d, 0) + }) + if err != nil { + return nil, fmt.Errorf("parse outputs: %w", err) + } + + return decimals, nil +} + func NewPriceRegistryV1_0_0(lggr logger.Logger, priceRegistryAddr common.Address, lp logpoller.LogPoller, ec client.Client) (*PriceRegistryV1_0_0, error) { priceRegistry, err := price_registry_1_0_0.NewPriceRegistry(priceRegistryAddr, ec) if err != nil { @@ -166,9 +201,15 @@ func NewPriceRegistryV1_0_0(lggr logger.Logger, priceRegistryAddr common.Address return nil, err } return &PriceRegistryV1_0_0{ - priceRegistry: priceRegistry, - address: priceRegistryAddr, - lp: lp, + priceRegistry: priceRegistry, + address: priceRegistryAddr, + lp: lp, + evmBatchCaller: rpclib.NewDynamicLimitedBatchCaller( + lggr, + ec, + rpclib.DefaultRpcBatchSizeLimit, + rpclib.DefaultRpcBatchBackOffMultiplier, + ), lggr: lggr, gasUpdated: UsdPerUnitGasUpdatedV1_0_0, tokenUpdated: usdPerTokenUpdated, diff --git a/core/services/ocr2/plugins/ccip/internal/rpclib/evm.go b/core/services/ocr2/plugins/ccip/internal/rpclib/evm.go index 07ea313fe0..050815816c 100644 --- a/core/services/ocr2/plugins/ccip/internal/rpclib/evm.go +++ b/core/services/ocr2/plugins/ccip/internal/rpclib/evm.go @@ -2,8 +2,10 @@ package rpclib import ( "context" + "encoding/json" "fmt" "math/big" + "reflect" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/hexutil" @@ -203,20 +205,44 @@ type DataAndErr struct { Err error } +func ParseOutputs[T any](results []DataAndErr, parseFunc func(d DataAndErr) (T, error)) ([]T, error) { + parsed := make([]T, 0, len(results)) + + for _, res := range results { + v, err := parseFunc(res) + if err != nil { + return nil, fmt.Errorf("parse contract output: %w", err) + } + parsed = append(parsed, v) + } + + return parsed, nil +} + func ParseOutput[T any](dataAndErr DataAndErr, idx int) (T, error) { - var empty T + var parsed T if dataAndErr.Err != nil { - return empty, dataAndErr.Err + return parsed, fmt.Errorf("rpc call error: %w", dataAndErr.Err) } if idx < 0 || idx >= len(dataAndErr.Outputs) { - return empty, fmt.Errorf("idx %d is out of bounds for %d outputs", idx, len(dataAndErr.Outputs)) + return parsed, fmt.Errorf("idx %d is out of bounds for %d outputs", idx, len(dataAndErr.Outputs)) } res, is := dataAndErr.Outputs[idx].(T) if !is { - return empty, fmt.Errorf("the result (%T) is not an address", dataAndErr.Outputs[idx]) + // some rpc types are not strictly defined + // for that reason we try to manually map the fields using json encoding + b, err := json.Marshal(dataAndErr.Outputs[idx]) + if err == nil { + var empty T + if err := json.Unmarshal(b, &parsed); err == nil && !reflect.DeepEqual(parsed, empty) { + return parsed, nil + } + } + + return parsed, fmt.Errorf("the result type is: %T, expected: %T", dataAndErr.Outputs[idx], parsed) } return res, nil