Skip to content

Commit

Permalink
Merge branch 'main' into solana/specify-writable-pool-accounts
Browse files Browse the repository at this point in the history
  • Loading branch information
aalu1418 committed Jan 9, 2025
2 parents b381e85 + 9cc1076 commit 4eba5bd
Show file tree
Hide file tree
Showing 8 changed files with 126 additions and 67 deletions.
20 changes: 13 additions & 7 deletions commit/plugin_e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -291,8 +291,10 @@ func TestPlugin_E2E_AllNodesAgree_TokenPrices(t *testing.T) {
m.EXPECT().
// tokens need to be ordered, plugin checks all tokens from commit offchain config
GetFeedPricesUSD(params.ctx, []ccipocr3.UnknownEncodedAddress{arbAddr, ethAddr}).
Return([]*big.Int{arbPrice, ethPrice}, nil).
Maybe()
Return(map[ccipocr3.UnknownEncodedAddress]*big.Int{
arbAddr: arbPrice,
ethAddr: ethPrice,
}, nil).Maybe()

m.EXPECT().
GetFeeQuoterTokenUpdates(params.ctx, mock.Anything, mock.Anything).
Expand Down Expand Up @@ -331,7 +333,10 @@ func TestPlugin_E2E_AllNodesAgree_TokenPrices(t *testing.T) {
m.EXPECT().
// tokens need to be ordered, plugin checks all tokens from commit offchain config
GetFeedPricesUSD(params.ctx, []ccipocr3.UnknownEncodedAddress{arbAddr, ethAddr}).
Return([]*big.Int{arbPrice, ethPrice}, nil).
Return(map[ccipocr3.UnknownEncodedAddress]*big.Int{
arbAddr: arbPrice,
ethAddr: ethPrice,
}, nil).
Maybe()

// Arb is fresh, will not be updated
Expand All @@ -357,8 +362,10 @@ func TestPlugin_E2E_AllNodesAgree_TokenPrices(t *testing.T) {
m.EXPECT().
// tokens need to be ordered, plugin checks all tokens from commit offchain config
GetFeedPricesUSD(params.ctx, []ccipocr3.UnknownEncodedAddress{arbAddr, ethAddr}).
Return([]*big.Int{arbPrice, ethPrice}, nil).
Maybe()
Return(map[ccipocr3.UnknownEncodedAddress]*big.Int{
arbAddr: arbPrice,
ethAddr: ethPrice,
}, nil).Maybe()

m.EXPECT().
GetFeeQuoterTokenUpdates(params.ctx, mock.Anything, mock.Anything).
Expand Down Expand Up @@ -713,8 +720,7 @@ func preparePriceReaderMock(ctx context.Context, priceReader *readerpkg_mock.Moc

priceReader.EXPECT().
GetFeedPricesUSD(ctx, mock.Anything).
Return([]*big.Int{}, nil).
Maybe()
Return(map[ccipocr3.UnknownEncodedAddress]*big.Int{}, nil).Maybe()
}

type nodeSetup struct {
Expand Down
12 changes: 4 additions & 8 deletions commit/tokenprice/observation.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,15 +80,11 @@ func (p *processor) ObserveFeedTokenPrices(ctx context.Context) []cciptypes.Toke
return []cciptypes.TokenPrice{}
}

// If we couldn't fetch all prices log and return only the ones we could fetch
if len(tokenPrices) != len(tokensToQuery) {
p.lggr.Errorw("token prices length mismatch", "got", tokenPrices, "want", tokensToQuery)
return []cciptypes.TokenPrice{}
}

tokenPricesUSD := make([]cciptypes.TokenPrice, 0, len(tokenPrices))
for i, token := range tokensToQuery {
tokenPricesUSD = append(tokenPricesUSD, cciptypes.NewTokenPrice(token, tokenPrices[i]))
for _, token := range tokensToQuery {
if tokenPrices[token] != nil {
tokenPricesUSD = append(tokenPricesUSD, cciptypes.NewTokenPrice(token, tokenPrices[token]))
}
}

return tokenPricesUSD
Expand Down
4 changes: 3 additions & 1 deletion commit/tokenprice/observation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ func Test_Observation(t *testing.T) {

tokenPriceReader := readerpkg_mock.NewMockPriceReader(t)
tokenPriceReader.EXPECT().GetFeedPricesUSD(mock.Anything, []cciptypes.UnknownEncodedAddress{tokenA, tokenB}).
Return([]*big.Int{bi100, bi200}, nil)
Return(map[cciptypes.UnknownEncodedAddress]*big.Int{
tokenA: bi100,
tokenB: bi200}, nil)

tokenPriceReader.EXPECT().GetFeeQuoterTokenUpdates(mock.Anything, mock.Anything, mock.Anything).Return(
map[cciptypes.UnknownEncodedAddress]plugintypes.TimestampedBig{
Expand Down
14 changes: 7 additions & 7 deletions mocks/pkg/reader/price_reader.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

34 changes: 28 additions & 6 deletions pkg/contractreader/extended.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
"sync"
"time"

"github.com/smartcontractkit/chainlink-ccip/pkg/consts"

"github.com/smartcontractkit/chainlink-common/pkg/services"
"github.com/smartcontractkit/chainlink-common/pkg/types"
clcommontypes "github.com/smartcontractkit/chainlink-common/pkg/types"
Expand Down Expand Up @@ -95,7 +97,9 @@ type ExtendedBoundContract struct {
type extendedContractReader struct {
reader ContractReaderFacade
contractBindingsByName map[string][]ExtendedBoundContract
mu *sync.RWMutex
// contract names that allow multiple bindings
multiBindAllowed map[string]bool
mu *sync.RWMutex
}

func NewExtendedContractReader(baseContractReader ContractReaderFacade) Extended {
Expand All @@ -106,7 +110,10 @@ func NewExtendedContractReader(baseContractReader ContractReaderFacade) Extended
return &extendedContractReader{
reader: baseContractReader,
contractBindingsByName: make(map[string][]ExtendedBoundContract),
mu: &sync.RWMutex{},
// so far this is the only contract that allows multiple bindings
// if more contracts are added, this should be moved to a config
multiBindAllowed: map[string]bool{consts.ContractNamePriceAggregator: true},
mu: &sync.RWMutex{},
}
}

Expand Down Expand Up @@ -261,10 +268,25 @@ func (e *extendedContractReader) Bind(ctx context.Context, allBindings []types.B
e.mu.Lock()
defer e.mu.Unlock()
for _, binding := range validBindings {
e.contractBindingsByName[binding.Name] = append(e.contractBindingsByName[binding.Name], ExtendedBoundContract{
BoundAt: time.Now(),
Binding: binding,
})
if e.multiBindAllowed[binding.Name] {
e.contractBindingsByName[binding.Name] = append(e.contractBindingsByName[binding.Name], ExtendedBoundContract{
BoundAt: time.Now(),
Binding: binding,
})
} else {
if len(e.contractBindingsByName[binding.Name]) > 0 {
// Unbind the previous binding
err := e.reader.Unbind(ctx, []types.BoundContract{e.contractBindingsByName[binding.Name][0].Binding})
if err != nil {
return fmt.Errorf("failed to unbind previous binding: %w", err)
}
}
// Override the previous binding
e.contractBindingsByName[binding.Name] = []ExtendedBoundContract{{
BoundAt: time.Now(),
Binding: binding,
}}
}
}

return nil
Expand Down
38 changes: 38 additions & 0 deletions pkg/contractreader/extended_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"fmt"
"testing"

"github.com/smartcontractkit/chainlink-ccip/pkg/consts"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
Expand All @@ -26,6 +28,42 @@ func TestExtendedContractReader(t *testing.T) {
bindings := extCr.GetBindings(contractName)
assert.Len(t, bindings, 0)

cr.On("Bind", context.Background(),
[]types.BoundContract{{Name: contractName, Address: "0x123"}}).Return(nil)
cr.On("Unbind", context.Background(),
[]types.BoundContract{{Name: contractName, Address: "0x123"}}).Return(nil)
cr.On("Bind", context.Background(),
[]types.BoundContract{{Name: contractName, Address: "0x124"}}).Return(nil)
cr.On("Bind", context.Background(),
[]types.BoundContract{{Name: contractName, Address: "0x125"}}).Return(fmt.Errorf("some err"))

err := extCr.Bind(context.Background(), []types.BoundContract{{Name: contractName, Address: "0x123"}})
assert.NoError(t, err)

// ignored since 0x123 already exists
err = extCr.Bind(context.Background(), []types.BoundContract{{Name: contractName, Address: "0x123"}})
assert.NoError(t, err)

err = extCr.Bind(context.Background(), []types.BoundContract{{Name: contractName, Address: "0x124"}})
assert.NoError(t, err)

// Bind fails
err = extCr.Bind(context.Background(), []types.BoundContract{{Name: contractName, Address: "0x125"}})
assert.Error(t, err)

bindings = extCr.GetBindings(contractName)
assert.Len(t, bindings, 1)
assert.Equal(t, "0x124", bindings[0].Binding.Address)
}

func TestExtendedContractReader_AllowMultiBindingForAggregator(t *testing.T) {
const contractName = consts.ContractNamePriceAggregator
cr := chainreadermocks.NewMockContractReaderFacade(t)
extCr := contractreader.NewExtendedContractReader(cr)

bindings := extCr.GetBindings(contractName)
assert.Len(t, bindings, 0)

cr.On("Bind", context.Background(),
[]types.BoundContract{{Name: contractName, Address: "0x123"}}).Return(nil)
cr.On("Bind", context.Background(),
Expand Down
48 changes: 21 additions & 27 deletions pkg/reader/price_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ type PriceReader interface {
// 1 ETH = 2,000 USD per full token, each full token is 1e18 units -> 2000 * 1e18 * 1e18 / 1e18 = 2_000e18
// 1 LINK = 5.00 USD per full token, each full token is 1e18 units -> 5 * 1e18 * 1e18 / 1e18 = 5e18
// The order of the returned prices corresponds to the order of the provided tokens.
GetFeedPricesUSD(ctx context.Context, tokens []ccipocr3.UnknownEncodedAddress) ([]*big.Int, error)
GetFeedPricesUSD(ctx context.Context,
tokens []ccipocr3.UnknownEncodedAddress) (map[ccipocr3.UnknownEncodedAddress]*big.Int, error)

// GetFeeQuoterTokenUpdates returns the latest token prices from the FeeQuoter on the specified chain
GetFeeQuoterTokenUpdates(
Expand Down Expand Up @@ -70,7 +71,7 @@ type LatestRoundData struct {
}

// ContractTokenMap maps contracts to their token indices
type ContractTokenMap map[commontypes.BoundContract][]int
type ContractTokenMap map[commontypes.BoundContract][]ccipocr3.UnknownEncodedAddress

// Number of batch operations performed (getLatestRoundData and getDecimals)
const priceReaderOperationCount = 2
Expand Down Expand Up @@ -150,8 +151,8 @@ func (pr *priceReader) GetFeeQuoterTokenUpdates(
func (pr *priceReader) GetFeedPricesUSD(
ctx context.Context,
tokens []ccipocr3.UnknownEncodedAddress,
) ([]*big.Int, error) {
prices := make([]*big.Int, len(tokens))
) (map[ccipocr3.UnknownEncodedAddress]*big.Int, error) {
prices := make(map[ccipocr3.UnknownEncodedAddress]*big.Int, len(tokens))
if pr.feedChainReader() == nil {
pr.lggr.Debug("node does not support feed chain")
return prices, nil
Expand All @@ -170,40 +171,42 @@ func (pr *priceReader) GetFeedPricesUSD(
}

// Process results by contract
for boundContract, tokenIndices := range contractTokenMap {
for boundContract, tokens := range contractTokenMap {
contractResults, ok := results[boundContract]
if !ok || len(contractResults) != priceReaderOperationCount {
return nil, fmt.Errorf("invalid results for contract %s", boundContract.Address)
pr.lggr.Errorf("invalid results for contract %s", boundContract.Address)
continue
}

// Get price data
latestRoundData, err := pr.getPriceData(contractResults[0], boundContract)
if err != nil {
return nil, err
pr.lggr.Errorw("calling getPriceData", err)
continue
}

// Get decimals
decimals, err := pr.getDecimals(contractResults[1], boundContract)
if err != nil {
return nil, err
pr.lggr.Errorw("calling getPriceData", err)
continue
}

// Normalize price for this contract
normalizedContractPrice := pr.normalizePrice(latestRoundData.Answer, *decimals)

// Apply the normalized price to all tokens using this contract
for _, tokenIdx := range tokenIndices {
token := tokens[tokenIdx]
for _, token := range tokens {
tokenInfo := pr.tokenInfo[token]
prices[tokenIdx] = calculateUsdPer1e18TokenAmount(normalizedContractPrice, tokenInfo.Decimals)
price := calculateUsdPer1e18TokenAmount(normalizedContractPrice, tokenInfo.Decimals)
if price == nil {
pr.lggr.Errorw("failed to calculate price", "token", token)
continue
}
prices[token] = price
}
}

// Verify no nil prices
if err := pr.validatePrices(prices); err != nil {
return nil, err
}

return prices, nil
}

Expand Down Expand Up @@ -250,7 +253,7 @@ func (pr *priceReader) prepareBatchRequest(
batchRequest := make(commontypes.BatchGetLatestValuesRequest)
contractTokenMap := make(ContractTokenMap)

for i, token := range tokens {
for _, token := range tokens {
tokenInfo, ok := pr.tokenInfo[token]
if !ok {
return nil, nil, fmt.Errorf("get tokenInfo for %s: missing token info", token)
Expand All @@ -277,7 +280,7 @@ func (pr *priceReader) prepareBatchRequest(
}

// Track which tokens use this contract
contractTokenMap[boundContract] = append(contractTokenMap[boundContract], i)
contractTokenMap[boundContract] = append(contractTokenMap[boundContract], token)
}

return batchRequest, contractTokenMap, nil
Expand All @@ -294,15 +297,6 @@ func (pr *priceReader) normalizePrice(price *big.Int, decimals uint8) *big.Int {
return answer
}

func (pr *priceReader) validatePrices(prices []*big.Int) error {
for _, price := range prices {
if price == nil {
return fmt.Errorf("failed to get all token prices successfully, some prices are nil")
}
}
return nil
}

func (pr *priceReader) feedChainReader() contractreader.ContractReaderFacade {
return pr.chainReaders[pr.feedChain]
}
Expand Down
Loading

0 comments on commit 4eba5bd

Please sign in to comment.