-
Notifications
You must be signed in to change notification settings - Fork 54
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
(WIP) Add a draft spec for RMN OffChain Blessing
- Loading branch information
Showing
1 changed file
with
357 additions
and
0 deletions.
There are no files selected for viewing
357 changes: 357 additions & 0 deletions
357
core/services/ocr3/plugins/ccip/spec/commit_plugin_rmn_ocb_draft.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,357 @@ | ||
# TODO: doc | ||
|
||
from typing import List, Dict, Optional | ||
from collections import defaultdict | ||
|
||
RmnNodeId = str | ||
ChainSelector = int | ||
|
||
MAX_INTERVAL_LENGTH = 256 | ||
|
||
|
||
class Interval: | ||
# TODO: doc, inclusive, exclusive ranges | ||
def __init__(self, min: int, max: int): | ||
# TODO: invariant check min <= max | ||
self.min = min | ||
self.max = max | ||
|
||
def is_empty(self) -> bool: | ||
return self.min == self.max | ||
|
||
|
||
class RmnSig: | ||
def __init__(self, rmn_node_id: RmnNodeId, sig: bytes): | ||
self.rmn_node_id = rmn_node_id | ||
self.sig = sig | ||
|
||
|
||
class SignedInterval: | ||
def __init__(self, interval: Interval, root: bytes, sigs: List[RmnSig]): | ||
self.interval = interval | ||
self.root = root | ||
self.sigs = sigs | ||
|
||
|
||
class CcipMessage: | ||
def __init__(self, seq_num: int): | ||
self.seq_num = seq_num | ||
pass | ||
|
||
|
||
class CommitQuery: | ||
def __init__( | ||
self, | ||
rmn_max_seq_nums: Dict[ChainSelector, int], | ||
signed_intervals: Dict[ChainSelector, SignedInterval] | ||
): | ||
self.rmn_max_seq_nums = rmn_max_seq_nums | ||
self.signed_intervals = signed_intervals | ||
|
||
|
||
class CommitObservation: | ||
def __init__( | ||
self, | ||
max_seq_nums_on_dest_chain: Dict[ChainSelector, int], | ||
max_seq_nums_from_source_chains: Dict[ChainSelector, int], # TODO: try without | ||
messages: Dict[ChainSelector, List[CcipMessage]] | ||
): | ||
self.max_seq_nums_on_dest_chain = max_seq_nums_on_dest_chain | ||
self.max_seq_nums_from_source_chains = max_seq_nums_from_source_chains | ||
self.messages = messages | ||
|
||
|
||
class CommitOutcome: | ||
def __init__( | ||
self, | ||
next_intervals: Dict[ChainSelector, Interval], | ||
signed_intervals: Dict[ChainSelector, SignedInterval] | ||
): | ||
self.next_intervals = next_intervals | ||
self.signed_intervals = signed_intervals | ||
return | ||
|
||
|
||
class RmnNode: | ||
def __init__(self, node_id: RmnNodeId, ip_address: bytes, pub_key: bytes, supported_chains: List[ChainSelector]): | ||
self.node_id = node_id | ||
self.ip_address = ip_address | ||
self.pub_key = pub_key | ||
self.supported_chains = supported_chains | ||
|
||
|
||
class RmnClientConfig: | ||
def __init__(self, rmn_nodes: List[RmnNode]): | ||
self.rmn_nodes = rmn_nodes | ||
|
||
|
||
class RmnClient: | ||
def __init__(self, rmn_client_config: RmnClientConfig): | ||
self.rmn_client_config = rmn_client_config | ||
|
||
# TODO: doc | ||
def request_max_seq_nums_from_single_node( | ||
self, | ||
rmn_node_id: RmnNodeId, | ||
chains: List[ChainSelector] | ||
) -> Dict[ChainSelector, int]: | ||
pass | ||
|
||
# TODO: doc | ||
def request_signed_intervals_from_single_node( | ||
self, | ||
rmn_node_id: RmnNodeId, | ||
intervals: Dict[ChainSelector, Interval] | ||
) -> Dict[ChainSelector, SignedInterval]: | ||
pass | ||
|
||
# TODO: doc | ||
def request_max_seq_nums( | ||
self, | ||
chains: List[ChainSelector] | ||
) -> Dict[ChainSelector, int]: | ||
pass | ||
|
||
# TODO: doc | ||
def request_signed_intervals( | ||
self, | ||
intervals: Dict[ChainSelector, Interval] | ||
) -> Dict[ChainSelector, SignedInterval]: | ||
pass | ||
|
||
|
||
class ChainReader: | ||
def __init__(self): | ||
pass | ||
|
||
|
||
class OffRamp: | ||
def __init__(self): | ||
pass | ||
|
||
# TODO: doc | ||
def get_max_seq_nums_on_dest_chain(self) -> Dict[ChainSelector, int]: | ||
pass | ||
|
||
|
||
class CommitPlugin: | ||
def __init__( | ||
self, | ||
rmn_client: RmnClient, | ||
all_source_chains: List[ChainSelector], | ||
dest_chain: ChainSelector, | ||
chain_readers: Dict[ChainSelector, ChainReader], | ||
off_ramp: OffRamp, | ||
f: int | ||
): | ||
self.rmn_client = rmn_client | ||
self.all_source_chains = all_source_chains | ||
self.dest_chain = dest_chain | ||
self.chain_readers = chain_readers | ||
self.off_ramp = off_ramp | ||
self.f = f | ||
return | ||
|
||
# TODO: doc | ||
def can_read_from_dest_chain(self) -> bool: | ||
return self.dest_chain in self.chain_readers | ||
|
||
# TODO: doc | ||
def get_ccip_messages_from_source_chains( | ||
self, | ||
intervals: Dict[ChainSelector, Interval] | ||
) -> Dict[ChainSelector, List[CcipMessage]]: | ||
pass | ||
|
||
# TODO: doc | ||
def query(self, previous_outcome: CommitOutcome) -> CommitQuery: | ||
max_seq_nums = self.rmn_client.request_max_seq_nums(self.all_source_chains) | ||
signed_intervals = self.rmn_client.request_signed_intervals(previous_outcome.next_intervals) | ||
return CommitQuery(max_seq_nums, signed_intervals) | ||
|
||
# TODO: doc | ||
def observation(self, previous_outcome: CommitOutcome) -> CommitObservation: | ||
# Get persisted min seq nums from dest chain | ||
max_seq_nums_on_dest_chain = {} | ||
if self.can_read_from_dest_chain(): | ||
max_seq_nums_on_dest_chain = self.off_ramp.get_max_seq_nums_on_dest_chain() | ||
|
||
# Get max seq nums from all chains that can be read from | ||
max_seq_nums_from_source_chains = {} | ||
for (chain_selector, chain_reader) in self.chain_readers.items(): | ||
max_seq_nums_from_source_chains[chain_selector] = chain_reader.get_max_seq_num(self.dest_chain) | ||
|
||
# Get messages in previous_outcome.next_interval | ||
messages = self.get_ccip_messages_from_source_chains(previous_outcome.next_intervals) | ||
|
||
return CommitObservation(max_seq_nums_on_dest_chain, max_seq_nums_from_source_chains, messages) | ||
|
||
# TODO: doc | ||
def aggregate_observations( | ||
self, | ||
observations: List[CommitObservation] | ||
) -> CommitObservation: | ||
pass | ||
|
||
# TODO: doc | ||
def get_consensus_max_seq_nums_on_dest_chain( | ||
self, | ||
observations: List[CommitObservation] | ||
) -> Dict[ChainSelector, int]: | ||
counts = defaultdict(int) | ||
for observation in observations: | ||
# Convert the dictionary to a frozenset of items so it can be used as a key | ||
if len(observation.max_seq_nums_on_dest_chain) > 0: | ||
frozen_dict = frozenset(observation.max_seq_nums_on_dest_chain.items()) | ||
counts[frozen_dict] += 1 | ||
|
||
# Consensus on the onchain state is reached if there is only one max_seq_nums_on_dest_chain dict that is | ||
# observed by more than f nodes | ||
# TODO: more doc | ||
candidates = [] | ||
for (candidate, count) in counts.items(): | ||
if count > self.f: | ||
candidates.append(candidate) | ||
|
||
if len(candidates) == 1: | ||
return dict(candidates[0]) | ||
else: | ||
return {} | ||
|
||
# TODO: doc | ||
# the interval mins in previous outcome should be one more than what's onchain | ||
def chains_with_unexpected_onchain_state( | ||
self, | ||
max_seq_nums_on_dest_chain: Dict[ChainSelector, int], | ||
previous_outcome: CommitOutcome, | ||
) -> List[ChainSelector]: | ||
pass | ||
|
||
# TODO: doc | ||
def compute_merkle_root( | ||
self, | ||
chain_selector: ChainSelector, | ||
current_interval: Interval, | ||
observations: List[CommitObservation] | ||
) -> bytes: | ||
pass | ||
|
||
# TODO: doc | ||
def compute_merkle_root2( | ||
self, | ||
messages: List[CcipMessage] | ||
) -> bytes: | ||
pass | ||
|
||
# TODO: doc | ||
def verify_rmn_sigs(self, sigs: List[RmnSig]) -> bool: | ||
pass | ||
|
||
# TODO: doc | ||
def get_rmn_threshold_for_chain(self, chain_selector: ChainSelector) -> int: | ||
pass | ||
|
||
# TODO: doc | ||
def get_verified_signed_interval( | ||
self, | ||
chain_selector: ChainSelector, | ||
current_interval: Interval, | ||
max_seq_num_on_dest_chain: int, | ||
signed_intervals: Dict[ChainSelector, SignedInterval], | ||
messages: List[CcipMessage] | ||
) -> Optional[SignedInterval]: | ||
if current_interval.min != max_seq_num_on_dest_chain + 1: | ||
return None | ||
|
||
merkle_root = self.compute_merkle_root2(messages) | ||
|
||
if len(merkle_root) != 32: | ||
return None | ||
|
||
rmn_threshold = self.get_rmn_threshold_for_chain(chain_selector) | ||
if rmn_threshold > 0: | ||
if chain_selector not in signed_intervals: | ||
return None | ||
else: | ||
signed_interval = signed_intervals[chain_selector] | ||
if (current_interval != signed_interval.interval or | ||
merkle_root != signed_interval.root or | ||
not self.verify_rmn_sigs(signed_interval.sigs)): | ||
return None | ||
else: | ||
return signed_interval | ||
else: | ||
return SignedInterval(current_interval, merkle_root, []) | ||
|
||
# TODO: doc | ||
def get_messages_consensus( | ||
self, | ||
previous_outcome: CommitOutcome, | ||
observations: List[CommitObservation] | ||
) -> Dict[ChainSelector, List[CcipMessage]]: | ||
# build Dict[ChainSelector, Dict[int, Dict[CcipMessage, int]]] | ||
# map of chains to maps of seq nums to maps of CcipMessage to occurrence | ||
pass | ||
|
||
# TODO: doc | ||
def build_next_interval(self, current_interval: Interval, rmn_max_seq_num: Optional[int]) -> Interval: | ||
interval_min = current_interval.max | ||
if rmn_max_seq_num is None: | ||
pass | ||
else: | ||
interval_max = max(current_interval.max, rmn_max_seq_num) | ||
if interval_max - interval_min > MAX_INTERVAL_LENGTH: | ||
interval_max = interval_min + MAX_INTERVAL_LENGTH | ||
return Interval(interval_min, interval_max) | ||
|
||
# TODO: doc | ||
def rebuild_current_interval( | ||
self, | ||
current_interval: Interval, | ||
max_seq_num_on_dest_chain: int, | ||
rmn_max_seq_num: int | ||
) -> Interval: | ||
pass | ||
|
||
# TODO: doc | ||
def outcome( | ||
self, | ||
previous_outcome: CommitOutcome, | ||
query: CommitQuery, | ||
observations: List[CommitObservation] | ||
) -> CommitOutcome: | ||
max_seq_nums_on_dest_chain = self.get_consensus_max_seq_nums_on_dest_chain(observations) | ||
if len(max_seq_nums_on_dest_chain) == 0: | ||
# TODO: doc | ||
return CommitOutcome(previous_outcome.next_intervals, {}) | ||
|
||
messages = self.get_messages_consensus(previous_outcome, observations) | ||
|
||
next_intervals = {} | ||
signed_intervals = {} | ||
for chain_selector in self.all_source_chains: | ||
# handle key-missing errors | ||
current_interval = previous_outcome.next_intervals[chain_selector] | ||
max_seq_num_on_dest_chain = max_seq_nums_on_dest_chain[chain_selector] | ||
|
||
signed_interval = self.get_verified_signed_interval( | ||
chain_selector, | ||
current_interval, | ||
max_seq_num_on_dest_chain, | ||
query.signed_intervals, | ||
messages[chain_selector] | ||
) | ||
|
||
rmn_max_seq_num = query.rmn_max_seq_nums.get(chain_selector) | ||
|
||
# TODO: doc | ||
if signed_interval is None: | ||
# self.rebuild_current_interval() | ||
next_interval = self.rebuild_current_interval(current_interval, max_seq_num_on_dest_chain, rmn_max_seq_num) | ||
pass | ||
else: | ||
next_interval = self.build_next_interval(current_interval, rmn_max_seq_num) | ||
next_intervals[chain_selector] = next_interval | ||
signed_intervals[chain_selector] = signed_interval | ||
|
||
return CommitOutcome(next_intervals, signed_intervals) |