Skip to content

Commit

Permalink
ocr3 - Only destination chain readers can observe seq nums. (#1020)
Browse files Browse the repository at this point in the history
## Motivation

Only destination chain readers can observe sequence numbers otherwise
the majority will observe outdated ones which will lead to no progress
between rounds.

## Solution

Include sequence numbers only if they were synced with the on chain
state.
  • Loading branch information
dimkouv authored Jun 14, 2024
1 parent 50e10d0 commit 574f164
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 50 deletions.
13 changes: 9 additions & 4 deletions core/services/ocr3/plugins/ccip/commit/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ func (p *Plugin) Query(_ context.Context, _ ocr3types.OutcomeContext) (types.Que
// We discover the token prices only for the tokens that are used to pay for ccip fees.
// The fee tokens are configured in the plugin config.
func (p *Plugin) Observation(ctx context.Context, outctx ocr3types.OutcomeContext, _ types.Query) (types.Observation, error) {
maxSeqNumsPerChain, err := observeMaxSeqNums(
maxSeqNumsPerChain, seqNumsInSync, err := observeMaxSeqNums(
ctx,
p.lggr,
p.ccipReader,
Expand Down Expand Up @@ -152,8 +152,13 @@ func (p *Plugin) Observation(ctx context.Context, outctx ocr3types.OutcomeContex
for _, msg := range newMsgs {
msgBaseDetails = append(msgBaseDetails, msg.CCIPMsgBaseDetails)
}
return cciptypes.NewCommitPluginObservation(msgBaseDetails, gasPrices, tokenPrices, maxSeqNumsPerChain, p.cfg).Encode()

if !seqNumsInSync {
// If the node was not able to sync the max sequence numbers we don't want to transmit
// the potentially outdated ones. We expect that a sufficient number of nodes will be able to observe them.
maxSeqNumsPerChain = nil
}
return cciptypes.NewCommitPluginObservation(msgBaseDetails, gasPrices, tokenPrices, maxSeqNumsPerChain, p.cfg).Encode()
}

func (p *Plugin) ValidateObservation(_ ocr3types.OutcomeContext, _ types.Query, ao types.AttributedObservation) error {
Expand All @@ -166,7 +171,7 @@ func (p *Plugin) ValidateObservation(_ ocr3types.OutcomeContext, _ types.Query,
return fmt.Errorf("validate sequence numbers: %w", err)
}

if err := validateObserverReadingEligibility(ao.Observer, obs.NewMsgs, p.cfg.ObserverInfo); err != nil {
if err := validateObserverReadingEligibility(ao.Observer, obs.NewMsgs, obs.MaxSeqNums, p.cfg.ObserverInfo); err != nil {
return fmt.Errorf("validate observer %d reading eligibility: %w", ao.Observer, err)
}

Expand Down Expand Up @@ -299,7 +304,7 @@ func (p *Plugin) ShouldTransmitAcceptedReport(ctx context.Context, u uint64, r o
"gasPriceUpdates", len(decodedReport.PriceUpdates.GasPriceUpdates),
)

// todo: if report is stale -> do not transmit
// todo: if report is stale -> do not transmit (check the spec for the exact condition)
return true, nil
}

Expand Down
6 changes: 5 additions & 1 deletion core/services/ocr3/plugins/ccip/commit/plugin_e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,11 @@ func setupEmptyOutcome(ctx context.Context, t *testing.T, lggr logger.Logger) []
FChain: map[cciptypes.ChainSelector]int{
chainC: 1,
},
ObserverInfo: map[commontypes.OracleID]cciptypes.ObserverInfo{},
ObserverInfo: map[commontypes.OracleID]cciptypes.ObserverInfo{
1: {Writer: false, Reads: []cciptypes.ChainSelector{}},
2: {Writer: false, Reads: []cciptypes.ChainSelector{}},
3: {Writer: false, Reads: []cciptypes.ChainSelector{}},
},
PricedTokens: []types.Account{tokenX},
TokenPricesObserver: false,
NewMsgScanBatchSize: 256,
Expand Down
31 changes: 22 additions & 9 deletions core/services/ocr3/plugins/ccip/commit/plugin_functions.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,16 @@ func observeMaxSeqNums(
readableChains mapset.Set[cciptypes.ChainSelector],
destChain cciptypes.ChainSelector,
knownSourceChains []cciptypes.ChainSelector,
) ([]cciptypes.SeqNumChain, error) {
) ([]cciptypes.SeqNumChain, bool, error) {
seqNumsInSync := false

// If there is a previous outcome, start with the sequence numbers of it.
seqNumPerChain := make(map[cciptypes.ChainSelector]cciptypes.SeqNum)
if previousOutcomeBytes != nil {
lggr.Debugw("observing based on previous outcome")
prevOutcome, err := cciptypes.DecodeCommitPluginOutcome(previousOutcomeBytes)
if err != nil {
return nil, fmt.Errorf("decode commit plugin previous outcome: %w", err)
return nil, false, fmt.Errorf("decode commit plugin previous outcome: %w", err)
}
lggr.Debugw("previous outcome decoded", "outcome", prevOutcome.String())

Expand All @@ -52,7 +54,7 @@ func observeMaxSeqNums(
lggr.Debugw("reading sequence numbers from destination")
onChainSeqNums, err := ccipReader.NextSeqNum(ctx, knownSourceChains)
if err != nil {
return nil, fmt.Errorf("get next seq nums: %w", err)
return nil, false, fmt.Errorf("get next seq nums: %w", err)
}
lggr.Debugw("discovered sequence numbers from destination", "onChainSeqNums", onChainSeqNums)

Expand All @@ -63,6 +65,7 @@ func observeMaxSeqNums(
lggr.Debugw("updated sequence number", "chain", ch, "seqNum", onChainSeqNums[i])
}
}
seqNumsInSync = true
}

maxChainSeqNums := make([]cciptypes.SeqNumChain, 0)
Expand All @@ -71,7 +74,7 @@ func observeMaxSeqNums(
}

sort.Slice(maxChainSeqNums, func(i, j int) bool { return maxChainSeqNums[i].ChainSel < maxChainSeqNums[j].ChainSel })
return maxChainSeqNums, nil
return maxChainSeqNums, seqNumsInSync, nil
}

// observeNewMsgs finds the new messages for each supported chain based on the provided max sequence numbers.
Expand Down Expand Up @@ -550,6 +553,12 @@ func pluginConfigConsensus(
// validateObservedSequenceNumbers checks if the sequence numbers of the provided messages are unique for each chain and
// that they match the observed max sequence numbers.
func validateObservedSequenceNumbers(msgs []cciptypes.CCIPMsgBaseDetails, maxSeqNums []cciptypes.SeqNumChain) error {
// If the observer did not include sequence numbers it means that it's not a destination chain reader.
// In that case we cannot do any msg sequence number validations.
if len(maxSeqNums) == 0 {
return nil
}

// MaxSeqNums must be unique for each chain.
maxSeqNumsMap := make(map[cciptypes.ChainSelector]cciptypes.SeqNum)
for _, maxSeqNum := range maxSeqNums {
Expand Down Expand Up @@ -590,19 +599,23 @@ func validateObservedSequenceNumbers(msgs []cciptypes.CCIPMsgBaseDetails, maxSeq
func validateObserverReadingEligibility(
observer commontypes.OracleID,
msgs []cciptypes.CCIPMsgBaseDetails,
seqNums []cciptypes.SeqNumChain,
observerCfg map[commontypes.OracleID]cciptypes.ObserverInfo,
) error {
if len(msgs) == 0 {
return nil
}

observerInfo, exists := observerCfg[observer]
if !exists {
return fmt.Errorf("observer not found in config")
}

observerReadChains := mapset.NewSet(observerInfo.Reads...)

if len(seqNums) > 0 && !observerInfo.Writer {
return fmt.Errorf("observer must be a writer if it observes sequence numbers")
}

if len(msgs) == 0 {
return nil
}

for _, msg := range msgs {
// Observer must be able to read the chain that the message is coming from.
if !observerReadChains.Contains(msg.SourceChain) {
Expand Down
101 changes: 70 additions & 31 deletions core/services/ocr3/plugins/ccip/commit/plugin_functions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,14 @@ import (

func Test_observeMaxSeqNumsPerChain(t *testing.T) {
testCases := []struct {
name string
prevOutcome cciptypes.CommitPluginOutcome
onChainSeqNums map[cciptypes.ChainSelector]cciptypes.SeqNum
readChains []cciptypes.ChainSelector
destChain cciptypes.ChainSelector
expErr bool
expMaxSeqNums []cciptypes.SeqNumChain
name string
prevOutcome cciptypes.CommitPluginOutcome
onChainSeqNums map[cciptypes.ChainSelector]cciptypes.SeqNum
readChains []cciptypes.ChainSelector
destChain cciptypes.ChainSelector
expErr bool
expSeqNumsInSync bool
expMaxSeqNums []cciptypes.SeqNumChain
}{
{
name: "report on chain seq num when no previous outcome and can read dest",
Expand All @@ -38,9 +39,10 @@ func Test_observeMaxSeqNumsPerChain(t *testing.T) {
1: 10,
2: 20,
},
readChains: []cciptypes.ChainSelector{1, 2, 3},
destChain: 3,
expErr: false,
readChains: []cciptypes.ChainSelector{1, 2, 3},
destChain: 3,
expErr: false,
expSeqNumsInSync: true,
expMaxSeqNums: []cciptypes.SeqNumChain{
{ChainSel: 1, SeqNum: 10},
{ChainSel: 2, SeqNum: 20},
Expand All @@ -53,10 +55,11 @@ func Test_observeMaxSeqNumsPerChain(t *testing.T) {
1: 10,
2: 20,
},
readChains: []cciptypes.ChainSelector{1, 2},
destChain: 3,
expErr: false,
expMaxSeqNums: []cciptypes.SeqNumChain{},
readChains: []cciptypes.ChainSelector{1, 2},
destChain: 3,
expErr: false,
expSeqNumsInSync: false,
expMaxSeqNums: []cciptypes.SeqNumChain{},
},
{
name: "report previous outcome seq nums and override when on chain is higher if can read dest",
Expand All @@ -70,14 +73,36 @@ func Test_observeMaxSeqNumsPerChain(t *testing.T) {
1: 10,
2: 20,
},
readChains: []cciptypes.ChainSelector{1, 2, 3},
destChain: 3,
expErr: false,
readChains: []cciptypes.ChainSelector{1, 2, 3},
destChain: 3,
expErr: false,
expSeqNumsInSync: true,
expMaxSeqNums: []cciptypes.SeqNumChain{
{ChainSel: 1, SeqNum: 11},
{ChainSel: 2, SeqNum: 20},
},
},
{
name: "report previous outcome seq nums and mark as non synced if cannot read dest",
prevOutcome: cciptypes.CommitPluginOutcome{
MaxSeqNums: []cciptypes.SeqNumChain{
{ChainSel: 1, SeqNum: 11}, // for chain 1 previous outcome is higher than on-chain state
{ChainSel: 2, SeqNum: 19}, // for chain 2 previous outcome is behind on-chain state
},
},
onChainSeqNums: map[cciptypes.ChainSelector]cciptypes.SeqNum{
1: 10,
2: 20,
},
readChains: []cciptypes.ChainSelector{1, 2},
destChain: 3,
expErr: false,
expSeqNumsInSync: false,
expMaxSeqNums: []cciptypes.SeqNumChain{
{ChainSel: 1, SeqNum: 11},
{ChainSel: 2, SeqNum: 19},
},
},
}

for _, tc := range testCases {
Expand All @@ -104,7 +129,7 @@ func Test_observeMaxSeqNumsPerChain(t *testing.T) {
}
mockReader.On("NextSeqNum", ctx, knownSourceChains).Return(onChainSeqNums, nil)

seqNums, err := observeMaxSeqNums(
seqNums, synced, err := observeMaxSeqNums(
ctx,
lggr,
mockReader,
Expand All @@ -120,6 +145,7 @@ func Test_observeMaxSeqNumsPerChain(t *testing.T) {
}
assert.NoError(t, err)
assert.Equal(t, tc.expMaxSeqNums, seqNums)
assert.Equal(t, tc.expSeqNumsInSync, synced)
})
}
}
Expand Down Expand Up @@ -466,6 +492,7 @@ func Test_validateObserverReadingEligibility(t *testing.T) {
name string
observer commontypes.OracleID
msgs []cciptypes.CCIPMsgBaseDetails
seqNums []cciptypes.SeqNumChain
observerInfo map[commontypes.OracleID]cciptypes.ObserverInfo
expErr bool
}{
Expand All @@ -484,16 +511,26 @@ func Test_validateObserverReadingEligibility(t *testing.T) {
expErr: false,
},
{
name: "observer cannot read one chain",
name: "observer is a writer so can observe seq nums",
observer: commontypes.OracleID(10),
msgs: []cciptypes.CCIPMsgBaseDetails{
{ID: cciptypes.Bytes32{1}, SourceChain: 1, SeqNum: 12},
{ID: cciptypes.Bytes32{3}, SourceChain: 2, SeqNum: 12},
{ID: cciptypes.Bytes32{1}, SourceChain: 3, SeqNum: 12},
{ID: cciptypes.Bytes32{2}, SourceChain: 3, SeqNum: 12},
msgs: []cciptypes.CCIPMsgBaseDetails{},
seqNums: []cciptypes.SeqNumChain{
{ChainSel: 1, SeqNum: 12},
},
observerInfo: map[commontypes.OracleID]cciptypes.ObserverInfo{
10: {Reads: []cciptypes.ChainSelector{1, 3}},
10: {Reads: []cciptypes.ChainSelector{1, 3}, Writer: true},
},
expErr: false,
},
{
name: "observer is not a writer so cannot observe seq nums",
observer: commontypes.OracleID(10),
msgs: []cciptypes.CCIPMsgBaseDetails{},
seqNums: []cciptypes.SeqNumChain{
{ChainSel: 1, SeqNum: 12},
},
observerInfo: map[commontypes.OracleID]cciptypes.ObserverInfo{
10: {Reads: []cciptypes.ChainSelector{1, 3}, Writer: false},
},
expErr: true,
},
Expand All @@ -512,17 +549,19 @@ func Test_validateObserverReadingEligibility(t *testing.T) {
expErr: true,
},
{
name: "no msgs",
observer: commontypes.OracleID(10),
msgs: []cciptypes.CCIPMsgBaseDetails{},
observerInfo: map[commontypes.OracleID]cciptypes.ObserverInfo{},
expErr: false,
name: "no msgs",
observer: commontypes.OracleID(10),
msgs: []cciptypes.CCIPMsgBaseDetails{},
observerInfo: map[commontypes.OracleID]cciptypes.ObserverInfo{
10: {Reads: []cciptypes.ChainSelector{1, 3}},
},
expErr: false,
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := validateObserverReadingEligibility(tc.observer, tc.msgs, tc.observerInfo)
err := validateObserverReadingEligibility(tc.observer, tc.msgs, tc.seqNums, tc.observerInfo)
if tc.expErr {
assert.Error(t, err)
return
Expand Down
19 changes: 14 additions & 5 deletions core/services/ocr3/plugins/ccip/spec/commit_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,19 +53,28 @@ def observation(self, previous_outcome):
# Observe fChain for each chain. {chain: f_chain}
f_chain = self.cfg["f_chain"]

if not self.can_read_dest():
# If node is not able to read updated sequence numbers from the destination,
# it should not observe the outdated ones that are coming from the previous outcome.
observed_seq_nums = {}

return (observed_seq_nums, new_msgs, token_prices, gas_prices, f_chain)


def validate_observation(self, attributed_observation):
observation = attributed_observation.observation
observer = attributed_observation.observer

if "seq_nums" in observation:
assert observer.can_read_dest()

observer_supported_chains = self.cfg["observer_info"][observer]["supported_chains"]
for (chain, msgs) in observation["new_msgs"].items():
assert(chain in observer_supported_chains)

for msg in msgs:
assert(msg.seq_num > observation["observed_seq_nums"][msg.source_chain])
if "seq_nums" in observation:
for msg in msgs:
assert(msg.seq_num > observation["observed_seq_nums"][msg.source_chain])

assert(len(msgs) == len(set([msg.seq_num for msg in msgs])))
assert(len(msgs) == len(set([msg.id for msg in msgs])))
Expand All @@ -82,8 +91,8 @@ def outcome(self, observations):

trees = {} # { chain: (root, min_seq_num, max_seq_num) }
for (chain, msgs) in all_msgs:
# filter out msgs with seq nums not matching consensus seq nums
msgs = [msg for msg in msgs if msg.seq_num >= observed_seq_nums[chain]]
# keep only msgs with seq nums greater than the consensus max commited seq nums
msgs = [msg for msg in msgs if msg.seq_num > seq_nums[chain]]

msgs_by_seq_num = msgs.group_by_seq_num() # { 423: [0x1, 0x1, 0x2] }
# 2 nodes say that msg id is 0x1 and 1 node says it's 0x2
Expand Down Expand Up @@ -123,7 +132,7 @@ def should_transmit(self, report):

on_chain_seq_nums = self.offRamp.get_sequence_numbers()
for (chain, tree) in report.trees():
if on_chain_seq_nums[chain] >= tree.min_seq_num:
if not (on_chain_seq_nums[chain]+1 == tree.min_seq_num):
return False

return True
Expand Down

0 comments on commit 574f164

Please sign in to comment.