Skip to content

Commit

Permalink
(WIP) Add a draft spec for RMN OffChain Blessing
Browse files Browse the repository at this point in the history
  • Loading branch information
rstout committed Jun 18, 2024
1 parent c651218 commit 0b67563
Showing 1 changed file with 357 additions and 0 deletions.
357 changes: 357 additions & 0 deletions core/services/ocr3/plugins/ccip/spec/commit_plugin_rmn_ocb_draft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,357 @@
# TODO: doc

from typing import List, Dict, Optional
from collections import defaultdict

RmnNodeId = str
ChainSelector = int

MAX_INTERVAL_LENGTH = 256


class Interval:
# TODO: doc, inclusive, exclusive ranges
def __init__(self, min: int, max: int):
# TODO: invariant check min <= max
self.min = min
self.max = max

def is_empty(self) -> bool:
return self.min == self.max


class RmnSig:
def __init__(self, rmn_node_id: RmnNodeId, sig: bytes):
self.rmn_node_id = rmn_node_id
self.sig = sig


class SignedInterval:
def __init__(self, interval: Interval, root: bytes, sigs: List[RmnSig]):
self.interval = interval
self.root = root
self.sigs = sigs


class CcipMessage:
def __init__(self, seq_num: int):
self.seq_num = seq_num
pass


class CommitQuery:
def __init__(
self,
rmn_max_seq_nums: Dict[ChainSelector, int],
signed_intervals: Dict[ChainSelector, SignedInterval]
):
self.rmn_max_seq_nums = rmn_max_seq_nums
self.signed_intervals = signed_intervals


class CommitObservation:
def __init__(
self,
max_seq_nums_on_dest_chain: Dict[ChainSelector, int],
max_seq_nums_from_source_chains: Dict[ChainSelector, int], # TODO: try without
messages: Dict[ChainSelector, List[CcipMessage]]
):
self.max_seq_nums_on_dest_chain = max_seq_nums_on_dest_chain
self.max_seq_nums_from_source_chains = max_seq_nums_from_source_chains
self.messages = messages


class CommitOutcome:
def __init__(
self,
next_intervals: Dict[ChainSelector, Interval],
signed_intervals: Dict[ChainSelector, SignedInterval]
):
self.next_intervals = next_intervals
self.signed_intervals = signed_intervals
return


class RmnNode:
def __init__(self, node_id: RmnNodeId, ip_address: bytes, pub_key: bytes, supported_chains: List[ChainSelector]):
self.node_id = node_id
self.ip_address = ip_address
self.pub_key = pub_key
self.supported_chains = supported_chains


class RmnClientConfig:
def __init__(self, rmn_nodes: List[RmnNode]):
self.rmn_nodes = rmn_nodes


class RmnClient:
def __init__(self, rmn_client_config: RmnClientConfig):
self.rmn_client_config = rmn_client_config

# TODO: doc
def request_max_seq_nums_from_single_node(
self,
rmn_node_id: RmnNodeId,
chains: List[ChainSelector]
) -> Dict[ChainSelector, int]:
pass

# TODO: doc
def request_signed_intervals_from_single_node(
self,
rmn_node_id: RmnNodeId,
intervals: Dict[ChainSelector, Interval]
) -> Dict[ChainSelector, SignedInterval]:
pass

# TODO: doc
def request_max_seq_nums(
self,
chains: List[ChainSelector]
) -> Dict[ChainSelector, int]:
pass

# TODO: doc
def request_signed_intervals(
self,
intervals: Dict[ChainSelector, Interval]
) -> Dict[ChainSelector, SignedInterval]:
pass


class ChainReader:
def __init__(self):
pass


class OffRamp:
def __init__(self):
pass

# TODO: doc
def get_max_seq_nums_on_dest_chain(self) -> Dict[ChainSelector, int]:
pass


class CommitPlugin:
def __init__(
self,
rmn_client: RmnClient,
all_source_chains: List[ChainSelector],
dest_chain: ChainSelector,
chain_readers: Dict[ChainSelector, ChainReader],
off_ramp: OffRamp,
f: int
):
self.rmn_client = rmn_client
self.all_source_chains = all_source_chains
self.dest_chain = dest_chain
self.chain_readers = chain_readers
self.off_ramp = off_ramp
self.f = f
return

# TODO: doc
def can_read_from_dest_chain(self) -> bool:
return self.dest_chain in self.chain_readers

# TODO: doc
def get_ccip_messages_from_source_chains(
self,
intervals: Dict[ChainSelector, Interval]
) -> Dict[ChainSelector, List[CcipMessage]]:
pass

# TODO: doc
def query(self, previous_outcome: CommitOutcome) -> CommitQuery:
max_seq_nums = self.rmn_client.request_max_seq_nums(self.all_source_chains)
signed_intervals = self.rmn_client.request_signed_intervals(previous_outcome.next_intervals)
return CommitQuery(max_seq_nums, signed_intervals)

# TODO: doc
def observation(self, previous_outcome: CommitOutcome) -> CommitObservation:
# Get persisted min seq nums from dest chain
max_seq_nums_on_dest_chain = {}
if self.can_read_from_dest_chain():
max_seq_nums_on_dest_chain = self.off_ramp.get_max_seq_nums_on_dest_chain()

# Get max seq nums from all chains that can be read from
max_seq_nums_from_source_chains = {}
for (chain_selector, chain_reader) in self.chain_readers.items():
max_seq_nums_from_source_chains[chain_selector] = chain_reader.get_max_seq_num(self.dest_chain)

# Get messages in previous_outcome.next_interval
messages = self.get_ccip_messages_from_source_chains(previous_outcome.next_intervals)

return CommitObservation(max_seq_nums_on_dest_chain, max_seq_nums_from_source_chains, messages)

# TODO: doc
def aggregate_observations(
self,
observations: List[CommitObservation]
) -> CommitObservation:
pass

# TODO: doc
def get_consensus_max_seq_nums_on_dest_chain(
self,
observations: List[CommitObservation]
) -> Dict[ChainSelector, int]:
counts = defaultdict(int)
for observation in observations:
# Convert the dictionary to a frozenset of items so it can be used as a key
if len(observation.max_seq_nums_on_dest_chain) > 0:
frozen_dict = frozenset(observation.max_seq_nums_on_dest_chain.items())
counts[frozen_dict] += 1

# Consensus on the onchain state is reached if there is only one max_seq_nums_on_dest_chain dict that is
# observed by more than f nodes
# TODO: more doc
candidates = []
for (candidate, count) in counts.items():
if count > self.f:
candidates.append(candidate)

if len(candidates) == 1:
return dict(candidates[0])
else:
return {}

# TODO: doc
# the interval mins in previous outcome should be one more than what's onchain
def chains_with_unexpected_onchain_state(
self,
max_seq_nums_on_dest_chain: Dict[ChainSelector, int],
previous_outcome: CommitOutcome,
) -> List[ChainSelector]:
pass

# TODO: doc
def compute_merkle_root(
self,
chain_selector: ChainSelector,
current_interval: Interval,
observations: List[CommitObservation]
) -> bytes:
pass

# TODO: doc
def compute_merkle_root2(
self,
messages: List[CcipMessage]
) -> bytes:
pass

# TODO: doc
def verify_rmn_sigs(self, sigs: List[RmnSig]) -> bool:
pass

# TODO: doc
def get_rmn_threshold_for_chain(self, chain_selector: ChainSelector) -> int:
pass

# TODO: doc
def get_verified_signed_interval(
self,
chain_selector: ChainSelector,
current_interval: Interval,
max_seq_num_on_dest_chain: int,
signed_intervals: Dict[ChainSelector, SignedInterval],
messages: List[CcipMessage]
) -> Optional[SignedInterval]:
if current_interval.min != max_seq_num_on_dest_chain + 1:
return None

merkle_root = self.compute_merkle_root2(messages)

if len(merkle_root) != 32:
return None

rmn_threshold = self.get_rmn_threshold_for_chain(chain_selector)
if rmn_threshold > 0:
if chain_selector not in signed_intervals:
return None
else:
signed_interval = signed_intervals[chain_selector]
if (current_interval != signed_interval.interval or
merkle_root != signed_interval.root or
not self.verify_rmn_sigs(signed_interval.sigs)):
return None
else:
return signed_interval
else:
return SignedInterval(current_interval, merkle_root, [])

# TODO: doc
def get_messages_consensus(
self,
previous_outcome: CommitOutcome,
observations: List[CommitObservation]
) -> Dict[ChainSelector, List[CcipMessage]]:
# build Dict[ChainSelector, Dict[int, Dict[CcipMessage, int]]]
# map of chains to maps of seq nums to maps of CcipMessage to occurrence
pass

# TODO: doc
def build_next_interval(self, current_interval: Interval, rmn_max_seq_num: Optional[int]) -> Interval:
interval_min = current_interval.max
if rmn_max_seq_num is None:
pass
else:
interval_max = max(current_interval.max, rmn_max_seq_num)
if interval_max - interval_min > MAX_INTERVAL_LENGTH:
interval_max = interval_min + MAX_INTERVAL_LENGTH
return Interval(interval_min, interval_max)

# TODO: doc
def rebuild_current_interval(
self,
current_interval: Interval,
max_seq_num_on_dest_chain: int,
rmn_max_seq_num: int
) -> Interval:
pass

# TODO: doc
def outcome(
self,
previous_outcome: CommitOutcome,
query: CommitQuery,
observations: List[CommitObservation]
) -> CommitOutcome:
max_seq_nums_on_dest_chain = self.get_consensus_max_seq_nums_on_dest_chain(observations)
if len(max_seq_nums_on_dest_chain) == 0:
# TODO: doc
return CommitOutcome(previous_outcome.next_intervals, {})

messages = self.get_messages_consensus(previous_outcome, observations)

next_intervals = {}
signed_intervals = {}
for chain_selector in self.all_source_chains:
# handle key-missing errors
current_interval = previous_outcome.next_intervals[chain_selector]
max_seq_num_on_dest_chain = max_seq_nums_on_dest_chain[chain_selector]

signed_interval = self.get_verified_signed_interval(
chain_selector,
current_interval,
max_seq_num_on_dest_chain,
query.signed_intervals,
messages[chain_selector]
)

rmn_max_seq_num = query.rmn_max_seq_nums.get(chain_selector)

# TODO: doc
if signed_interval is None:
# self.rebuild_current_interval()
next_interval = self.rebuild_current_interval(current_interval, max_seq_num_on_dest_chain, rmn_max_seq_num)
pass
else:
next_interval = self.build_next_interval(current_interval, rmn_max_seq_num)
next_intervals[chain_selector] = next_interval
signed_intervals[chain_selector] = signed_interval

return CommitOutcome(next_intervals, signed_intervals)

0 comments on commit 0b67563

Please sign in to comment.