Skip to content

Commit

Permalink
Audit fixes - Part 1 (#421)
Browse files Browse the repository at this point in the history
`P2` `H3.1` `H3.2` `H3.3`
  • Loading branch information
dimkouv authored Jan 8, 2025
1 parent 17b5e61 commit e46f3a7
Show file tree
Hide file tree
Showing 12 changed files with 82 additions and 378 deletions.
1 change: 0 additions & 1 deletion .mockery.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ packages:
github.com/smartcontractkit/chainlink-ccip/internal/reader:
interfaces:
HomeChain:
RMNRemote:
CCIP:
github.com/smartcontractkit/chainlink-ccip/internal/plugincommon:
interfaces:
Expand Down
39 changes: 22 additions & 17 deletions commit/merkleroot/observation.go
Original file line number Diff line number Diff line change
Expand Up @@ -380,17 +380,12 @@ func (o observerImpl) ObserveOffRampNextSeqNums(ctx context.Context) []plugintyp
return nil
}

if len(offRampNextSeqNums) != len(sourceChains) {
o.lggr.Errorf("call to NextSeqNum returned unexpected number of seq nums, got %d, expected %d",
len(offRampNextSeqNums), len(sourceChains))
return nil
}

result := make([]plugintypes.SeqNumChain, len(sourceChains))
for i := range sourceChains {
result[i] = plugintypes.SeqNumChain{ChainSel: sourceChains[i], SeqNum: offRampNextSeqNums[i]}
result := make([]plugintypes.SeqNumChain, 0, len(sourceChains))
for chainSelector, seqNum := range offRampNextSeqNums {
result = append(result, plugintypes.NewSeqNumChain(chainSelector, seqNum))
}

sort.Slice(result, func(i, j int) bool { return result[i].ChainSel < result[j].ChainSel })
return result
}

Expand All @@ -413,23 +408,30 @@ func (o observerImpl) ObserveLatestOnRampSeqNums(
sourceChains := mapset.NewSet(allSourceChains...).Intersect(supportedChains).ToSlice()
sort.Slice(sourceChains, func(i, j int) bool { return sourceChains[i] < sourceChains[j] })

latestOnRampSeqNums := make([]plugintypes.SeqNumChain, len(sourceChains))
mu := &sync.Mutex{}
latestOnRampSeqNums := make([]plugintypes.SeqNumChain, 0, len(sourceChains))
eg := &errgroup.Group{}

for i, sourceChain := range sourceChains {
for _, sourceChain := range sourceChains {
eg.Go(func() error {
nextOnRampSeqNum, err := o.ccipReader.GetExpectedNextSequenceNumber(ctx, sourceChain, destChain)
if err != nil {
return fmt.Errorf("failed to get expected next sequence number for source chain %d: %w", sourceChain, err)
o.lggr.Errorf("failed to get expected next seq num for source chain %d: %s", sourceChain, err)
return nil
}

if nextOnRampSeqNum == 0 {
return fmt.Errorf("expected next sequence number for source chain %d is 0", sourceChain)
o.lggr.Errorf("unexpected next seq num for source chain %d, it is 0", sourceChain)
return nil
}

latestOnRampSeqNums[i] = plugintypes.SeqNumChain{
ChainSel: sourceChain,
SeqNum: nextOnRampSeqNum - 1, // Latest is the next one minus one.
}
mu.Lock()
latestOnRampSeqNums = append(
latestOnRampSeqNums,
plugintypes.NewSeqNumChain(sourceChain, nextOnRampSeqNum-1), // Latest is the next one minus one.
)
mu.Unlock()

return nil
})
}
Expand All @@ -439,6 +441,9 @@ func (o observerImpl) ObserveLatestOnRampSeqNums(
return nil
}

sort.Slice(latestOnRampSeqNums, func(i, j int) bool {
return latestOnRampSeqNums[i].ChainSel < latestOnRampSeqNums[j].ChainSel
})
return latestOnRampSeqNums
}

Expand Down
18 changes: 13 additions & 5 deletions commit/merkleroot/observation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"golang.org/x/exp/maps"

"github.com/smartcontractkit/libocr/commontypes"
"github.com/smartcontractkit/libocr/ragep2p/types"
Expand Down Expand Up @@ -176,7 +177,7 @@ func TestObservation(t *testing.T) {
func Test_ObserveOffRampNextSeqNums(t *testing.T) {
const nodeID commontypes.OracleID = 1
knownSourceChains := []cciptypes.ChainSelector{4, 7, 19}
nextSeqNums := []cciptypes.SeqNum{345, 608, 7713}
nextSeqNums := map[cciptypes.ChainSelector]cciptypes.SeqNum{4: 345, 7: 608, 19: 7713}

testCases := []struct {
name string
Expand Down Expand Up @@ -233,20 +234,27 @@ func Test_ObserveOffRampNextSeqNums(t *testing.T) {
expResult: nil,
},
{
name: "nil is returned when nextSeqNums returns incorrect number of seq nums",
name: "nextSeqNums returns incorrect number of seq nums, other chains should be processed correctly",
getDeps: func(t *testing.T) (*common_mock.MockChainSupport, *reader_mock.MockCCIPReader) {
chainSupport := common_mock.NewMockChainSupport(t)
chainSupport.EXPECT().SupportsDestChain(nodeID).Return(true, nil)
chainSupport.EXPECT().DestChain().Return(1)
chainSupport.EXPECT().KnownSourceChainsSlice().Return(knownSourceChains, nil)
ccipReader := reader_mock.NewMockCCIPReader(t)
// return a smaller slice, should trigger validation condition
ccipReader.EXPECT().NextSeqNum(mock.Anything, knownSourceChains).Return(nextSeqNums[1:], nil)

nextSeqNumsCp := maps.Clone(nextSeqNums)
delete(nextSeqNumsCp, cciptypes.ChainSelector(4))

ccipReader.EXPECT().NextSeqNum(mock.Anything, knownSourceChains).Return(nextSeqNumsCp, nil)
ccipReader.EXPECT().GetRmnCurseInfo(mock.Anything, mock.Anything, mock.Anything).
Return(&reader.CurseInfo{}, nil)
return chainSupport, ccipReader
},
expResult: nil,
expResult: []plugintypes.SeqNumChain{
plugintypes.NewSeqNumChain(7, 608),
plugintypes.NewSeqNumChain(19, 7713),
},
},
{
name: "dest chain is cursed sequence numbers not observed",
Expand Down Expand Up @@ -286,7 +294,7 @@ func Test_ObserveOffRampNextSeqNums(t *testing.T) {
knownSourceChains := []cciptypes.ChainSelector{4, 7, 19}
cursedSourceChains := map[cciptypes.ChainSelector]bool{7: true, 4: false}
knownSourceChainsExcludingCursed := []cciptypes.ChainSelector{4, 19}
nextSeqNumsExcludingCursed := []cciptypes.SeqNum{345, 7713}
nextSeqNumsExcludingCursed := map[cciptypes.ChainSelector]cciptypes.SeqNum{4: 345, 19: 7713}

chainSupport := common_mock.NewMockChainSupport(t)
chainSupport.EXPECT().SupportsDestChain(nodeID).Return(true, nil)
Expand Down
14 changes: 5 additions & 9 deletions commit/merkleroot/validate_observation.go
Original file line number Diff line number Diff line change
Expand Up @@ -209,16 +209,12 @@ func ValidateMerkleRootsState(
return fmt.Errorf("get next sequence numbers: %w", err)
}

if len(offRampExpNextSeqNums) != len(chainSlice) {
return fmt.Errorf("critical reader error: seq nums length mismatch")
}

for i, offRampExpNextSeqNum := range offRampExpNextSeqNums {
chain := chainSlice[i]

newNextOnRampSeqNum, ok := newNextOnRampSeqNums[chain]
for chain, newNextOnRampSeqNum := range newNextOnRampSeqNums {
offRampExpNextSeqNum, ok := offRampExpNextSeqNums[chain]
if !ok {
return fmt.Errorf("critical unexpected error: newOnRampSeqNum not found")
// Due to some chain being disabled while the sequence numbers were already observed.
// Report should not be considered valid in that case.
return fmt.Errorf("offRamp expected next sequence number for chain %d was not found", chain)
}

if newNextOnRampSeqNum != offRampExpNextSeqNum {
Expand Down
17 changes: 9 additions & 8 deletions commit/merkleroot/validate_observation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ func Test_validateMerkleRootsState(t *testing.T) {
testCases := []struct {
name string
onRampNextSeqNum []plugintypes.SeqNumChain
offRampExpNextSeqNum []cciptypes.SeqNum
offRampExpNextSeqNum map[cciptypes.ChainSelector]cciptypes.SeqNum
readerErr error
expErr bool
}{
Expand All @@ -355,7 +355,7 @@ func Test_validateMerkleRootsState(t *testing.T) {
plugintypes.NewSeqNumChain(10, 100),
plugintypes.NewSeqNumChain(20, 200),
},
offRampExpNextSeqNum: []cciptypes.SeqNum{100, 200},
offRampExpNextSeqNum: map[cciptypes.ChainSelector]cciptypes.SeqNum{10: 100, 20: 200},
expErr: false,
},
{
Expand All @@ -364,7 +364,8 @@ func Test_validateMerkleRootsState(t *testing.T) {
plugintypes.NewSeqNumChain(10, 100),
plugintypes.NewSeqNumChain(20, 200),
},
offRampExpNextSeqNum: []cciptypes.SeqNum{100, 201}, // <- 200 is already on chain
// <- 200 is already on chain
offRampExpNextSeqNum: map[cciptypes.ChainSelector]cciptypes.SeqNum{10: 100, 20: 201},
expErr: true,
},
{
Expand All @@ -373,25 +374,25 @@ func Test_validateMerkleRootsState(t *testing.T) {
plugintypes.NewSeqNumChain(10, 101), // <- onchain 99 but we submit 101 instead of 100
plugintypes.NewSeqNumChain(20, 200),
},
offRampExpNextSeqNum: []cciptypes.SeqNum{100, 200},
offRampExpNextSeqNum: map[cciptypes.ChainSelector]cciptypes.SeqNum{10: 100, 20: 200},
expErr: true,
},
{
name: "reader returned wrong number of seq nums",
name: "reader returned wrong number of seq nums, should be ok",
onRampNextSeqNum: []plugintypes.SeqNumChain{
plugintypes.NewSeqNumChain(10, 100),
plugintypes.NewSeqNumChain(20, 200),
},
offRampExpNextSeqNum: []cciptypes.SeqNum{100, 200, 300},
expErr: true,
offRampExpNextSeqNum: map[cciptypes.ChainSelector]cciptypes.SeqNum{10: 100, 20: 200, 30: 300},
expErr: false,
},
{
name: "reader error",
onRampNextSeqNum: []plugintypes.SeqNumChain{
plugintypes.NewSeqNumChain(10, 100),
plugintypes.NewSeqNumChain(20, 200),
},
offRampExpNextSeqNum: []cciptypes.SeqNum{100, 200},
offRampExpNextSeqNum: map[cciptypes.ChainSelector]cciptypes.SeqNum{10: 100, 20: 200},
readerErr: fmt.Errorf("reader error"),
expErr: true,
},
Expand Down
23 changes: 13 additions & 10 deletions commit/plugin_e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ func TestPlugin_E2E_AllNodesAgree_MerkleRoots(t *testing.T) {
expTransmittedReports []ccipocr3.CommitPluginReport

offRampNextSeqNumDefaultOverrideKeys []ccipocr3.ChainSelector
offRampNextSeqNumDefaultOverrideValues []ccipocr3.SeqNum
offRampNextSeqNumDefaultOverrideValues map[ccipocr3.ChainSelector]ccipocr3.SeqNum

enableDiscovery bool
}{
Expand Down Expand Up @@ -201,10 +201,13 @@ func TestPlugin_E2E_AllNodesAgree_MerkleRoots(t *testing.T) {
},
},
{
name: "report generated in previous outcome, transmitted with success",
prevOutcome: outcomeReportGenerated,
offRampNextSeqNumDefaultOverrideKeys: []ccipocr3.ChainSelector{sourceChain1, sourceChain2},
offRampNextSeqNumDefaultOverrideValues: []ccipocr3.SeqNum{11, 20},
name: "report generated in previous outcome, transmitted with success",
prevOutcome: outcomeReportGenerated,
offRampNextSeqNumDefaultOverrideKeys: []ccipocr3.ChainSelector{sourceChain1, sourceChain2},
offRampNextSeqNumDefaultOverrideValues: map[ccipocr3.ChainSelector]ccipocr3.SeqNum{
sourceChain1: 11,
sourceChain2: 20,
},
expOutcome: committypes.Outcome{
MerkleRootOutcome: merkleroot.Outcome{
OutcomeType: merkleroot.ReportTransmitted,
Expand Down Expand Up @@ -690,7 +693,7 @@ func prepareCcipReaderMock(

if mockEmptySeqNrs {
ccipReader.EXPECT().NextSeqNum(ctx, mock.Anything).Unset()
ccipReader.EXPECT().NextSeqNum(ctx, mock.Anything).Return([]ccipocr3.SeqNum{}, nil).
ccipReader.EXPECT().NextSeqNum(ctx, mock.Anything).Return(map[ccipocr3.ChainSelector]ccipocr3.SeqNum{}, nil).
Maybe()
}

Expand Down Expand Up @@ -787,12 +790,12 @@ func setupNode(params SetupNodeParams) nodeSetup {
}
sort.Slice(sourceChains, func(i, j int) bool { return sourceChains[i] < sourceChains[j] })

offRampNextSeqNums := make([]ccipocr3.SeqNum, 0)
offRampNextSeqNums := make(map[ccipocr3.ChainSelector]ccipocr3.SeqNum, 0)
chainsWithNewMsgs := make([]ccipocr3.ChainSelector, 0)
for _, sourceChain := range sourceChains {
offRampNextSeqNum, ok := params.offRampNextSeqNum[sourceChain]
assert.True(params.t, ok)
offRampNextSeqNums = append(offRampNextSeqNums, offRampNextSeqNum)
offRampNextSeqNums[sourceChain] = offRampNextSeqNum

newMsgs := make([]ccipocr3.Message, 0)
numNewMsgs := (params.onRampLastSeqNum[sourceChain] - offRampNextSeqNum) + 1
Expand All @@ -815,9 +818,9 @@ func setupNode(params SetupNodeParams) nodeSetup {
}
}

seqNumsOfChainsWithNewMsgs := make([]ccipocr3.SeqNum, 0)
seqNumsOfChainsWithNewMsgs := map[ccipocr3.ChainSelector]ccipocr3.SeqNum{}
for _, chainSel := range chainsWithNewMsgs {
seqNumsOfChainsWithNewMsgs = append(seqNumsOfChainsWithNewMsgs, params.offRampNextSeqNum[chainSel])
seqNumsOfChainsWithNewMsgs[chainSel] = params.offRampNextSeqNum[chainSel]
}
if len(chainsWithNewMsgs) > 0 {
ccipReader.EXPECT().NextSeqNum(params.ctx, chainsWithNewMsgs).Return(seqNumsOfChainsWithNewMsgs, nil).Maybe()
Expand Down
2 changes: 1 addition & 1 deletion internal/mocks/inmem/ccipreader_inmem.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ func (r InMemoryCCIPReader) MsgsBetweenSeqNums(

func (r InMemoryCCIPReader) NextSeqNum(
ctx context.Context, chains []cciptypes.ChainSelector,
) (seqNum []cciptypes.SeqNum, err error) {
) (seqNum map[cciptypes.ChainSelector]cciptypes.SeqNum, err error) {
panic("implement me")
}

Expand Down
47 changes: 0 additions & 47 deletions internal/reader/rmn_remote.go

This file was deleted.

Loading

0 comments on commit e46f3a7

Please sign in to comment.