diff --git a/go.mod b/go.mod index 7e2a6454..a297202a 100644 --- a/go.mod +++ b/go.mod @@ -8,7 +8,7 @@ require ( github.com/prometheus/client_golang v1.20.0 github.com/prometheus/client_model v0.6.1 github.com/smartcontractkit/chain-selectors v1.0.34 - github.com/smartcontractkit/chainlink-common v0.3.1-0.20241127162636-07aa781ee1f4 + github.com/smartcontractkit/chainlink-common v0.3.1-0.20241212163958-6a43e61b9d49 github.com/smartcontractkit/libocr v0.0.0-20241007185508-adbe57025f12 github.com/stretchr/testify v1.9.0 go.uber.org/zap v1.27.0 diff --git a/go.sum b/go.sum index 84713ae1..4fc29736 100644 --- a/go.sum +++ b/go.sum @@ -71,8 +71,8 @@ github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME= github.com/smartcontractkit/chain-selectors v1.0.34 h1:MJ17OGu8+jjl426pcKrJkCf3fePb3eCreuAnUA3RBj4= github.com/smartcontractkit/chain-selectors v1.0.34/go.mod h1:xsKM0aN3YGcQKTPRPDDtPx2l4mlTN1Djmg0VVXV40b8= -github.com/smartcontractkit/chainlink-common v0.3.1-0.20241127162636-07aa781ee1f4 h1:atCZ1jol7a+tdtgU/wNqXgliBun5H7BjGBicGL8Tj6o= -github.com/smartcontractkit/chainlink-common v0.3.1-0.20241127162636-07aa781ee1f4/go.mod h1:bQktEJf7sJ0U3SmIcXvbGUox7SmXcnSEZ4kUbT8R5Nk= +github.com/smartcontractkit/chainlink-common v0.3.1-0.20241212163958-6a43e61b9d49 h1:ZA92CTX9JtEArrxgZw7PNctVxFS+/DmSXumkwf1WiMY= +github.com/smartcontractkit/chainlink-common v0.3.1-0.20241212163958-6a43e61b9d49/go.mod h1:bQktEJf7sJ0U3SmIcXvbGUox7SmXcnSEZ4kUbT8R5Nk= github.com/smartcontractkit/libocr v0.0.0-20241007185508-adbe57025f12 h1:NzZGjaqez21I3DU7objl3xExTH4fxYvzTqar8DC6360= github.com/smartcontractkit/libocr v0.0.0-20241007185508-adbe57025f12/go.mod h1:fb1ZDVXACvu4frX3APHZaEBp0xi1DIm34DcA0CwTsZM= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= diff --git a/pkg/reader/price_reader.go b/pkg/reader/price_reader.go index a14f0311..e9359d31 100644 --- a/pkg/reader/price_reader.go +++ b/pkg/reader/price_reader.go @@ -5,8 +5,6 @@ import ( "fmt" "math/big" - "golang.org/x/sync/errgroup" - "github.com/smartcontractkit/chainlink-common/pkg/logger" commontypes "github.com/smartcontractkit/chainlink-common/pkg/types" "github.com/smartcontractkit/chainlink-common/pkg/types/query/primitives" @@ -71,6 +69,12 @@ type LatestRoundData struct { AnsweredInRound *big.Int } +// ContractTokenMap maps contracts to their token indices +type ContractTokenMap map[commontypes.BoundContract][]int + +// Number of batch operations performed (getLatestRoundData and getDecimals) +const priceReaderOperationCount = 2 + func (pr *priceReader) GetFeeQuoterTokenUpdates( ctx context.Context, tokens []ccipocr3.UnknownEncodedAddress, @@ -142,99 +146,133 @@ func (pr *priceReader) GetFeeQuoterTokenUpdates( return updateMap, nil } +// GetFeedPricesUSD gets USD prices for multiple tokens using batch requests func (pr *priceReader) GetFeedPricesUSD( - ctx context.Context, tokens []ccipocr3.UnknownEncodedAddress, + ctx context.Context, + tokens []ccipocr3.UnknownEncodedAddress, ) ([]*big.Int, error) { prices := make([]*big.Int, len(tokens)) if pr.feedChainReader() == nil { pr.lggr.Debug("node does not support feed chain") return prices, nil } - eg := new(errgroup.Group) - for idx, token := range tokens { - eg.Go(func() error { - boundContract := commontypes.BoundContract{ - Address: string(pr.tokenInfo[token].AggregatorAddress), - Name: consts.ContractNamePriceAggregator, - } - rawTokenPrice, err := pr.getRawTokenPriceE18Normalized(ctx, token, boundContract, pr.feedChainReader()) - if err != nil { - return fmt.Errorf("token price for %s: %w", token, err) - } - tokenInfo, ok := pr.tokenInfo[token] - if !ok { - return fmt.Errorf("get tokenInfo for %s: %w", token, err) - } - prices[idx] = calculateUsdPer1e18TokenAmount(rawTokenPrice, tokenInfo.Decimals) - return nil - }) + // Create batch request grouped by contract + batchRequest, contractTokenMap, err := pr.prepareBatchRequest(tokens) + if err != nil { + return nil, fmt.Errorf("prepare batch request: %w", err) } - if err := eg.Wait(); err != nil { - return nil, fmt.Errorf("failed to get all token prices successfully: %w", err) + // Execute batch request + results, err := pr.feedChainReader().BatchGetLatestValues(ctx, batchRequest) + if err != nil { + return nil, fmt.Errorf("batch request failed: %w", err) } - for _, price := range prices { - if price == nil { - return nil, fmt.Errorf("failed to get all token prices successfully, some prices are nil") + // Process results by contract + for boundContract, tokenIndices := range contractTokenMap { + contractResults, ok := results[boundContract] + if !ok || len(contractResults) != priceReaderOperationCount { + return nil, fmt.Errorf("invalid results for contract %s", boundContract.Address) } + + // Get price data + priceResult, err := contractResults[0].GetResult() + if err != nil { + return nil, fmt.Errorf("get price for contract %s: %w", boundContract.Address, err) + } + latestRoundData, ok := priceResult.(*LatestRoundData) + if !ok { + return nil, fmt.Errorf("invalid price data type for contract %s", boundContract.Address) + } + + // Get decimals + decimalResult, err := contractResults[1].GetResult() + if err != nil { + return nil, fmt.Errorf("get decimals for contract %s: %w", boundContract.Address, err) + } + decimals, ok := decimalResult.(*uint8) + if !ok { + return nil, fmt.Errorf("invalid decimals data type for contract %s", boundContract.Address) + } + + // Normalize price for this contract + normalizedContractPrice := pr.normalizePrice(latestRoundData.Answer, *decimals) + + // Apply the normalized price to all tokens using this contract + for _, tokenIdx := range tokenIndices { + token := tokens[tokenIdx] + tokenInfo := pr.tokenInfo[token] + prices[tokenIdx] = calculateUsdPer1e18TokenAmount(normalizedContractPrice, tokenInfo.Decimals) + } + } + + // Verify no nil prices + if err := pr.validatePrices(prices); err != nil { + return nil, err } return prices, nil } -func (pr *priceReader) getFeedDecimals( - ctx context.Context, - token ccipocr3.UnknownEncodedAddress, - boundContract commontypes.BoundContract, - feedChainReader contractreader.ContractReaderFacade, -) (uint8, error) { - var decimals uint8 - if err := - feedChainReader.GetLatestValue( - ctx, - boundContract.ReadIdentifier(consts.MethodNameGetDecimals), - primitives.Unconfirmed, - nil, - &decimals, - ); err != nil { - return 0, fmt.Errorf("decimals call failed for token %s: %w", token, err) +// prepareBatchRequest creates a batch request grouped by contract and returns the mapping of contracts to token indices +func (pr *priceReader) prepareBatchRequest( + tokens []ccipocr3.UnknownEncodedAddress, +) (commontypes.BatchGetLatestValuesRequest, ContractTokenMap, error) { + batchRequest := make(commontypes.BatchGetLatestValuesRequest) + contractTokenMap := make(ContractTokenMap) + + for i, token := range tokens { + tokenInfo, ok := pr.tokenInfo[token] + if !ok { + return nil, nil, fmt.Errorf("get tokenInfo for %s: missing token info", token) + } + + boundContract := commontypes.BoundContract{ + Address: string(tokenInfo.AggregatorAddress), + Name: consts.ContractNamePriceAggregator, + } + + // Initialize contract batch if it doesn't exist + if _, exists := batchRequest[boundContract]; !exists { + batchRequest[boundContract] = make(commontypes.ContractBatch, priceReaderOperationCount) + batchRequest[boundContract][0] = commontypes.BatchRead{ + ReadName: consts.MethodNameGetLatestRoundData, + Params: nil, + ReturnVal: &LatestRoundData{}, + } + batchRequest[boundContract][1] = commontypes.BatchRead{ + ReadName: consts.MethodNameGetDecimals, + Params: nil, + ReturnVal: new(uint8), + } + } + + // Track which tokens use this contract + contractTokenMap[boundContract] = append(contractTokenMap[boundContract], i) } - return decimals, nil + return batchRequest, contractTokenMap, nil } -func (pr *priceReader) getRawTokenPriceE18Normalized( - ctx context.Context, - token ccipocr3.UnknownEncodedAddress, - boundContract commontypes.BoundContract, - feedChainReader contractreader.ContractReaderFacade, -) (*big.Int, error) { - var latestRoundData LatestRoundData - identifier := boundContract.ReadIdentifier(consts.MethodNameGetLatestRoundData) - if err := - feedChainReader.GetLatestValue( - ctx, - identifier, - primitives.Unconfirmed, - nil, - &latestRoundData, - ); err != nil { - return nil, fmt.Errorf("latestRoundData call failed for token %s: %w", token, err) +func (pr *priceReader) normalizePrice(price *big.Int, decimals uint8) *big.Int { + answer := new(big.Int).Set(price) + if decimals < 18 { + return answer.Mul(answer, big.NewInt(0).Exp(big.NewInt(10), big.NewInt(18-int64(decimals)), nil)) } - - decimals, err1 := pr.getFeedDecimals(ctx, token, boundContract, feedChainReader) - if err1 != nil { - return nil, fmt.Errorf("failed to get decimals for token %s: %w", token, err1) + if decimals > 18 { + return answer.Div(answer, big.NewInt(0).Exp(big.NewInt(10), big.NewInt(int64(decimals)-18), nil)) } - answer := latestRoundData.Answer - if decimals < 18 { - answer.Mul(answer, big.NewInt(0).Exp(big.NewInt(10), big.NewInt(18-int64(decimals)), nil)) - } else if decimals > 18 { - answer.Div(answer, big.NewInt(0).Exp(big.NewInt(10), big.NewInt(int64(decimals)-18), nil)) + return answer +} + +func (pr *priceReader) validatePrices(prices []*big.Int) error { + for _, price := range prices { + if price == nil { + return fmt.Errorf("failed to get all token prices successfully, some prices are nil") + } } - return answer, nil + return nil } func (pr *priceReader) feedChainReader() contractreader.ContractReaderFacade { diff --git a/pkg/reader/price_reader_test.go b/pkg/reader/price_reader_test.go index 0e35edaa..2c4bfb98 100644 --- a/pkg/reader/price_reader_test.go +++ b/pkg/reader/price_reader_test.go @@ -12,7 +12,6 @@ import ( "github.com/stretchr/testify/mock" commontypes "github.com/smartcontractkit/chainlink-common/pkg/types" - "github.com/smartcontractkit/chainlink-common/pkg/types/query/primitives" readermock "github.com/smartcontractkit/chainlink-ccip/mocks/pkg/contractreader" "github.com/smartcontractkit/chainlink-ccip/pkg/consts" @@ -27,6 +26,9 @@ const ( EthAddr = cciptypes.UnknownEncodedAddress("0xe100000000000000000000000000000000000000") EthAggregatorAddr = cciptypes.UnknownEncodedAddress("0xe200000000000000000000000000000000000000") + + BtcAddr = cciptypes.UnknownEncodedAddress("0xb100000000000000000000000000000000000000") + BtcAgregatorAddr = cciptypes.UnknownEncodedAddress("0xb200000000000000000000000000000000000000") ) var ( @@ -44,6 +46,11 @@ var ( DeviationPPB: cciptypes.NewBigInt(big.NewInt(1e5)), Decimals: Decimals18, } + BtcInfo = pluginconfig.TokenInfo{ + AggregatorAddress: BtcAgregatorAddr, + DeviationPPB: cciptypes.NewBigInt(big.NewInt(1e5)), + Decimals: Decimals18, + } ) func TestOnchainTokenPricesReader_GetTokenPricesUSD(t *testing.T) { @@ -87,6 +94,46 @@ func TestOnchainTokenPricesReader_GetTokenPricesUSD(t *testing.T) { want: nil, wantErr: true, }, + { + name: "Empty input tokens list", + tokenInfo: map[cciptypes.UnknownEncodedAddress]pluginconfig.TokenInfo{ + ArbAddr: ArbInfo, + }, + inputTokens: []cciptypes.UnknownEncodedAddress{}, + mockPrices: map[cciptypes.UnknownEncodedAddress]*big.Int{}, + want: []*big.Int{}, + }, + { + name: "Repeated token in input", + tokenInfo: map[cciptypes.UnknownEncodedAddress]pluginconfig.TokenInfo{ + ArbAddr: ArbInfo, + }, + inputTokens: []cciptypes.UnknownEncodedAddress{ArbAddr, ArbAddr}, + mockPrices: map[cciptypes.UnknownEncodedAddress]*big.Int{ArbAddr: ArbPrice}, + want: []*big.Int{ArbPrice, ArbPrice}, + }, + { + name: "Zero price should succeed", + tokenInfo: map[cciptypes.UnknownEncodedAddress]pluginconfig.TokenInfo{ + ArbAddr: ArbInfo, + }, + inputTokens: []cciptypes.UnknownEncodedAddress{ArbAddr}, + mockPrices: map[cciptypes.UnknownEncodedAddress]*big.Int{ArbAddr: big.NewInt(0)}, + want: []*big.Int{big.NewInt(0)}, + }, + { + name: "Multiple error accounts", + tokenInfo: map[cciptypes.UnknownEncodedAddress]pluginconfig.TokenInfo{ + ArbAddr: ArbInfo, + EthAddr: EthInfo, + BtcAddr: BtcInfo, + }, + inputTokens: []cciptypes.UnknownEncodedAddress{ArbAddr, EthAddr, BtcAddr}, + mockPrices: map[cciptypes.UnknownEncodedAddress]*big.Int{ArbAddr: ArbPrice}, + errorAccounts: []cciptypes.UnknownEncodedAddress{EthAddr, BtcAddr}, + want: nil, + wantErr: true, + }, } for _, tc := range testCases { @@ -161,6 +208,11 @@ func createMockReader( ) *readermock.MockContractReaderFacade { reader := readermock.NewMockContractReaderFacade(t) + // Create the expected batch request and results + expectedRequest := make(commontypes.BatchGetLatestValuesRequest) + expectedResults := make(commontypes.BatchGetLatestValuesResult) + + // Handle successful cases for token, price := range mockPrices { info := tokenInfo[token] boundContract := commontypes.BoundContract{ @@ -168,43 +220,84 @@ func createMockReader( Name: consts.ContractNamePriceAggregator, } - identifier := boundContract.ReadIdentifier(consts.MethodNameGetLatestRoundData) - reader.On("GetLatestValue", - mock.Anything, - identifier, - primitives.Unconfirmed, - nil, - mock.Anything).Run( - func(args mock.Arguments) { - arg := args.Get(4).(*LatestRoundData) - arg.Answer = big.NewInt(price.Int64()) - }).Return(nil).Once() - - reader.On("GetLatestValue", - mock.Anything, - boundContract.ReadIdentifier(consts.MethodNameGetDecimals), - primitives.Unconfirmed, - nil, - mock.Anything).Run( - func(args mock.Arguments) { - arg := args.Get(4).(*uint8) - *arg = info.Decimals - }).Return(nil) + // Add to expected request + if _, exists := expectedRequest[boundContract]; !exists { + expectedRequest[boundContract] = make(commontypes.ContractBatch, 0) + } + expectedRequest[boundContract] = append(expectedRequest[boundContract], + commontypes.BatchRead{ + ReadName: consts.MethodNameGetLatestRoundData, + Params: nil, + ReturnVal: &LatestRoundData{}, + }, + commontypes.BatchRead{ + ReadName: consts.MethodNameGetDecimals, + Params: nil, + ReturnVal: new(uint8), + }, + ) + + // Create results + results := make(commontypes.ContractBatchResults, 2) + // Price result + priceResult := commontypes.BatchReadResult{ReadName: consts.MethodNameGetLatestRoundData} + priceResult.SetResult(&LatestRoundData{Answer: big.NewInt(price.Int64())}, nil) + results[0] = priceResult + + // Decimals result + decimalsResult := commontypes.BatchReadResult{ReadName: consts.MethodNameGetDecimals} + decimals := info.Decimals + decimalsResult.SetResult(&decimals, nil) + results[1] = decimalsResult + + expectedResults[boundContract] = results } + // Handle error cases for _, account := range errorAccounts { info := tokenInfo[account] boundContract := commontypes.BoundContract{ Address: string(info.AggregatorAddress), Name: consts.ContractNamePriceAggregator, } - reader.On("GetLatestValue", - mock.Anything, - boundContract.ReadIdentifier(consts.MethodNameGetLatestRoundData), - primitives.Unconfirmed, - nil, - mock.Anything).Return(fmt.Errorf("error")).Once() + + results := make(commontypes.ContractBatchResults, 2) + // Price result with error + priceResult := commontypes.BatchReadResult{ReadName: consts.MethodNameGetLatestRoundData} + priceResult.SetResult(nil, fmt.Errorf("error")) + results[0] = priceResult + + // Decimals result + decimalsResult := commontypes.BatchReadResult{ReadName: consts.MethodNameGetDecimals} + decimalsResult.SetResult(nil, nil) + results[1] = decimalsResult + + expectedResults[boundContract] = results } + // Set up the mock expectation for BatchGetLatestValues + reader.On("BatchGetLatestValues", + mock.Anything, + mock.MatchedBy(func(req commontypes.BatchGetLatestValuesRequest) bool { + // Validate request structure + for boundContract, batch := range req { + // Verify contract has exactly two reads (price and decimals) + if len(batch) != 2 { + return false + } + // Verify read names + if batch[0].ReadName != consts.MethodNameGetLatestRoundData || + batch[1].ReadName != consts.MethodNameGetDecimals { + return false + } + // Verify contract exists in our expected results + if _, exists := expectedResults[boundContract]; !exists { + return false + } + } + return true + }), + ).Return(expectedResults, nil).Once() + return reader }