Skip to content

Commit

Permalink
Merge branch 'main' into CCIP-4618-use-encoded-sizes-hash
Browse files Browse the repository at this point in the history
  • Loading branch information
asoliman92 authored Dec 18, 2024
2 parents ae2ce38 + ec2d169 commit 872096e
Show file tree
Hide file tree
Showing 9 changed files with 317 additions and 105 deletions.
2 changes: 1 addition & 1 deletion commit/report.go
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ func (p *Plugin) decodeReport(ctx context.Context, report []byte) (cciptypes.Com
}

func (p *Plugin) isStaleReport(seqNr, latestSeqNr uint64, decodedReport cciptypes.CommitPluginReport) bool {
if seqNr < latestSeqNr && len(decodedReport.MerkleRoots) == 0 {
if seqNr <= latestSeqNr && len(decodedReport.MerkleRoots) == 0 {
p.lggr.Infow("skipping stale report", "seqNr", seqNr, "latestSeqNr", latestSeqNr)
return true
}
Expand Down
53 changes: 53 additions & 0 deletions commit/report_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package commit

import (
rand2 "math/rand"
"testing"

"github.com/smartcontractkit/libocr/commontypes"
Expand Down Expand Up @@ -199,3 +200,55 @@ func TestPluginReports_InvalidOutcome(t *testing.T) {
_, err := p.Reports(tests.Context(t), 0, []byte("invalid json"))
require.Error(t, err)
}

func Test_Plugin_isStaleReport(t *testing.T) {
testCases := []struct {
name string
onChainSeqNum uint64
reportSeqNum uint64
lenMerkleRoots int
shouldBeStale bool
}{
{
name: "report is not stale when merkle roots exist no matter the seq nums",
onChainSeqNum: rand2.Uint64(),
reportSeqNum: rand2.Uint64(),
lenMerkleRoots: 1,
shouldBeStale: false,
},
{
name: "report is stale when onChainSeqNum is equal to report seq num",
onChainSeqNum: 33,
reportSeqNum: 33,
lenMerkleRoots: 0,
shouldBeStale: true,
},
{
name: "report is stale when onChainSeqNum is greater than report seq num",
onChainSeqNum: 34,
reportSeqNum: 33,
lenMerkleRoots: 0,
shouldBeStale: true,
},
{
name: "report is not stale when onChainSeqNum is less than report seq num",
onChainSeqNum: 32,
reportSeqNum: 33,
lenMerkleRoots: 0,
shouldBeStale: false,
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
p := Plugin{
lggr: logger.Test(t),
}
report := ccipocr3.CommitPluginReport{
MerkleRoots: make([]ccipocr3.MerkleRootChain, tc.lenMerkleRoots),
}
stale := p.isStaleReport(tc.reportSeqNum, tc.onChainSeqNum, report)
require.Equal(t, tc.shouldBeStale, stale)
})
}
}
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ require (
github.com/prometheus/client_golang v1.20.0
github.com/prometheus/client_model v0.6.1
github.com/smartcontractkit/chain-selectors v1.0.34
github.com/smartcontractkit/chainlink-common v0.3.1-0.20241127162636-07aa781ee1f4
github.com/smartcontractkit/chainlink-common v0.3.1-0.20241212163958-6a43e61b9d49
github.com/smartcontractkit/libocr v0.0.0-20241007185508-adbe57025f12
github.com/stretchr/testify v1.9.0
go.uber.org/zap v1.27.0
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp
github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME=
github.com/smartcontractkit/chain-selectors v1.0.34 h1:MJ17OGu8+jjl426pcKrJkCf3fePb3eCreuAnUA3RBj4=
github.com/smartcontractkit/chain-selectors v1.0.34/go.mod h1:xsKM0aN3YGcQKTPRPDDtPx2l4mlTN1Djmg0VVXV40b8=
github.com/smartcontractkit/chainlink-common v0.3.1-0.20241127162636-07aa781ee1f4 h1:atCZ1jol7a+tdtgU/wNqXgliBun5H7BjGBicGL8Tj6o=
github.com/smartcontractkit/chainlink-common v0.3.1-0.20241127162636-07aa781ee1f4/go.mod h1:bQktEJf7sJ0U3SmIcXvbGUox7SmXcnSEZ4kUbT8R5Nk=
github.com/smartcontractkit/chainlink-common v0.3.1-0.20241212163958-6a43e61b9d49 h1:ZA92CTX9JtEArrxgZw7PNctVxFS+/DmSXumkwf1WiMY=
github.com/smartcontractkit/chainlink-common v0.3.1-0.20241212163958-6a43e61b9d49/go.mod h1:bQktEJf7sJ0U3SmIcXvbGUox7SmXcnSEZ4kUbT8R5Nk=
github.com/smartcontractkit/libocr v0.0.0-20241007185508-adbe57025f12 h1:NzZGjaqez21I3DU7objl3xExTH4fxYvzTqar8DC6360=
github.com/smartcontractkit/libocr v0.0.0-20241007185508-adbe57025f12/go.mod h1:fb1ZDVXACvu4frX3APHZaEBp0xi1DIm34DcA0CwTsZM=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
Expand Down
4 changes: 4 additions & 0 deletions internal/plugincommon/transmitters.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package plugincommon

import (
"fmt"
"sort"
"time"

"github.com/smartcontractkit/libocr/commontypes"
Expand Down Expand Up @@ -33,6 +34,9 @@ func GetTransmissionSchedule(
}
}

// transmissionSchedule must be deterministic
sort.Slice(transmitters, func(i, j int) bool { return transmitters[i] < transmitters[j] })

transmissionDelays := make([]time.Duration, len(transmitters))

for i := range transmissionDelays {
Expand Down
9 changes: 9 additions & 0 deletions internal/plugincommon/transmitters_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,15 @@ func TestGetTransmissionSchedule(t *testing.T) {
expectedError: true,
chainSupportReturnsError: true,
},
{
name: "determinism check",
allTheOracles: []commontypes.OracleID{3, 1, 2}, // <------ not ordered
oraclesSupportingDest: []commontypes.OracleID{1, 3},
transmissionDelayMultiplier: 5 * time.Second,
expectedTransmitters: []commontypes.OracleID{1, 3},
expectedTransmissionDelays: []time.Duration{5 * time.Second, 10 * time.Second},
expectedError: false,
},
}

for _, tc := range testCases {
Expand Down
178 changes: 108 additions & 70 deletions pkg/reader/price_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ import (
"fmt"
"math/big"

"golang.org/x/sync/errgroup"

"github.com/smartcontractkit/chainlink-common/pkg/logger"
commontypes "github.com/smartcontractkit/chainlink-common/pkg/types"
"github.com/smartcontractkit/chainlink-common/pkg/types/query/primitives"
Expand Down Expand Up @@ -71,6 +69,12 @@ type LatestRoundData struct {
AnsweredInRound *big.Int
}

// ContractTokenMap maps contracts to their token indices
type ContractTokenMap map[commontypes.BoundContract][]int

// Number of batch operations performed (getLatestRoundData and getDecimals)
const priceReaderOperationCount = 2

func (pr *priceReader) GetFeeQuoterTokenUpdates(
ctx context.Context,
tokens []ccipocr3.UnknownEncodedAddress,
Expand Down Expand Up @@ -142,99 +146,133 @@ func (pr *priceReader) GetFeeQuoterTokenUpdates(
return updateMap, nil
}

// GetFeedPricesUSD gets USD prices for multiple tokens using batch requests
func (pr *priceReader) GetFeedPricesUSD(
ctx context.Context, tokens []ccipocr3.UnknownEncodedAddress,
ctx context.Context,
tokens []ccipocr3.UnknownEncodedAddress,
) ([]*big.Int, error) {
prices := make([]*big.Int, len(tokens))
if pr.feedChainReader() == nil {
pr.lggr.Debug("node does not support feed chain")
return prices, nil
}
eg := new(errgroup.Group)
for idx, token := range tokens {
eg.Go(func() error {
boundContract := commontypes.BoundContract{
Address: string(pr.tokenInfo[token].AggregatorAddress),
Name: consts.ContractNamePriceAggregator,
}
rawTokenPrice, err := pr.getRawTokenPriceE18Normalized(ctx, token, boundContract, pr.feedChainReader())
if err != nil {
return fmt.Errorf("token price for %s: %w", token, err)
}
tokenInfo, ok := pr.tokenInfo[token]
if !ok {
return fmt.Errorf("get tokenInfo for %s: %w", token, err)
}

prices[idx] = calculateUsdPer1e18TokenAmount(rawTokenPrice, tokenInfo.Decimals)
return nil
})
// Create batch request grouped by contract
batchRequest, contractTokenMap, err := pr.prepareBatchRequest(tokens)
if err != nil {
return nil, fmt.Errorf("prepare batch request: %w", err)
}

if err := eg.Wait(); err != nil {
return nil, fmt.Errorf("failed to get all token prices successfully: %w", err)
// Execute batch request
results, err := pr.feedChainReader().BatchGetLatestValues(ctx, batchRequest)
if err != nil {
return nil, fmt.Errorf("batch request failed: %w", err)
}

for _, price := range prices {
if price == nil {
return nil, fmt.Errorf("failed to get all token prices successfully, some prices are nil")
// Process results by contract
for boundContract, tokenIndices := range contractTokenMap {
contractResults, ok := results[boundContract]
if !ok || len(contractResults) != priceReaderOperationCount {
return nil, fmt.Errorf("invalid results for contract %s", boundContract.Address)
}

// Get price data
priceResult, err := contractResults[0].GetResult()
if err != nil {
return nil, fmt.Errorf("get price for contract %s: %w", boundContract.Address, err)
}
latestRoundData, ok := priceResult.(*LatestRoundData)
if !ok {
return nil, fmt.Errorf("invalid price data type for contract %s", boundContract.Address)
}

// Get decimals
decimalResult, err := contractResults[1].GetResult()
if err != nil {
return nil, fmt.Errorf("get decimals for contract %s: %w", boundContract.Address, err)
}
decimals, ok := decimalResult.(*uint8)
if !ok {
return nil, fmt.Errorf("invalid decimals data type for contract %s", boundContract.Address)
}

// Normalize price for this contract
normalizedContractPrice := pr.normalizePrice(latestRoundData.Answer, *decimals)

// Apply the normalized price to all tokens using this contract
for _, tokenIdx := range tokenIndices {
token := tokens[tokenIdx]
tokenInfo := pr.tokenInfo[token]
prices[tokenIdx] = calculateUsdPer1e18TokenAmount(normalizedContractPrice, tokenInfo.Decimals)
}
}

// Verify no nil prices
if err := pr.validatePrices(prices); err != nil {
return nil, err
}

return prices, nil
}

func (pr *priceReader) getFeedDecimals(
ctx context.Context,
token ccipocr3.UnknownEncodedAddress,
boundContract commontypes.BoundContract,
feedChainReader contractreader.ContractReaderFacade,
) (uint8, error) {
var decimals uint8
if err :=
feedChainReader.GetLatestValue(
ctx,
boundContract.ReadIdentifier(consts.MethodNameGetDecimals),
primitives.Unconfirmed,
nil,
&decimals,
); err != nil {
return 0, fmt.Errorf("decimals call failed for token %s: %w", token, err)
// prepareBatchRequest creates a batch request grouped by contract and returns the mapping of contracts to token indices
func (pr *priceReader) prepareBatchRequest(
tokens []ccipocr3.UnknownEncodedAddress,
) (commontypes.BatchGetLatestValuesRequest, ContractTokenMap, error) {
batchRequest := make(commontypes.BatchGetLatestValuesRequest)
contractTokenMap := make(ContractTokenMap)

for i, token := range tokens {
tokenInfo, ok := pr.tokenInfo[token]
if !ok {
return nil, nil, fmt.Errorf("get tokenInfo for %s: missing token info", token)
}

boundContract := commontypes.BoundContract{
Address: string(tokenInfo.AggregatorAddress),
Name: consts.ContractNamePriceAggregator,
}

// Initialize contract batch if it doesn't exist
if _, exists := batchRequest[boundContract]; !exists {
batchRequest[boundContract] = make(commontypes.ContractBatch, priceReaderOperationCount)
batchRequest[boundContract][0] = commontypes.BatchRead{
ReadName: consts.MethodNameGetLatestRoundData,
Params: nil,
ReturnVal: &LatestRoundData{},
}
batchRequest[boundContract][1] = commontypes.BatchRead{
ReadName: consts.MethodNameGetDecimals,
Params: nil,
ReturnVal: new(uint8),
}
}

// Track which tokens use this contract
contractTokenMap[boundContract] = append(contractTokenMap[boundContract], i)
}

return decimals, nil
return batchRequest, contractTokenMap, nil
}

func (pr *priceReader) getRawTokenPriceE18Normalized(
ctx context.Context,
token ccipocr3.UnknownEncodedAddress,
boundContract commontypes.BoundContract,
feedChainReader contractreader.ContractReaderFacade,
) (*big.Int, error) {
var latestRoundData LatestRoundData
identifier := boundContract.ReadIdentifier(consts.MethodNameGetLatestRoundData)
if err :=
feedChainReader.GetLatestValue(
ctx,
identifier,
primitives.Unconfirmed,
nil,
&latestRoundData,
); err != nil {
return nil, fmt.Errorf("latestRoundData call failed for token %s: %w", token, err)
func (pr *priceReader) normalizePrice(price *big.Int, decimals uint8) *big.Int {
answer := new(big.Int).Set(price)
if decimals < 18 {
return answer.Mul(answer, big.NewInt(0).Exp(big.NewInt(10), big.NewInt(18-int64(decimals)), nil))
}

decimals, err1 := pr.getFeedDecimals(ctx, token, boundContract, feedChainReader)
if err1 != nil {
return nil, fmt.Errorf("failed to get decimals for token %s: %w", token, err1)
if decimals > 18 {
return answer.Div(answer, big.NewInt(0).Exp(big.NewInt(10), big.NewInt(int64(decimals)-18), nil))
}
answer := latestRoundData.Answer
if decimals < 18 {
answer.Mul(answer, big.NewInt(0).Exp(big.NewInt(10), big.NewInt(18-int64(decimals)), nil))
} else if decimals > 18 {
answer.Div(answer, big.NewInt(0).Exp(big.NewInt(10), big.NewInt(int64(decimals)-18), nil))
return answer
}

func (pr *priceReader) validatePrices(prices []*big.Int) error {
for _, price := range prices {
if price == nil {
return fmt.Errorf("failed to get all token prices successfully, some prices are nil")
}
}
return answer, nil
return nil
}

func (pr *priceReader) feedChainReader() contractreader.ContractReaderFacade {
Expand Down
Loading

0 comments on commit 872096e

Please sign in to comment.