From a77bd4c577163c9cc8c47d0cc96b65266b245de3 Mon Sep 17 00:00:00 2001 From: Dr Maxim Orlovsky Date: Tue, 20 Aug 2024 18:16:54 +0200 Subject: [PATCH] mpc: clearly distinguish factored width from the tree width --- commit_verify/src/mpc/block.rs | 54 +++++++++++++++++++--------------- commit_verify/src/mpc/tree.rs | 32 ++++++++++++-------- 2 files changed, 51 insertions(+), 35 deletions(-) diff --git a/commit_verify/src/mpc/block.rs b/commit_verify/src/mpc/block.rs index 6a7cdfae..2f293a52 100644 --- a/commit_verify/src/mpc/block.rs +++ b/commit_verify/src/mpc/block.rs @@ -206,7 +206,7 @@ impl From<&MerkleTree> for MerkleBlock { fn from(tree: &MerkleTree) -> Self { let map = &tree.map; - let iter = (0..tree.width()).map(|pos| { + let iter = (0..tree.width_limit()).map(|pos| { map.get(&pos) .map(|(protocol_id, message)| TreeNode::CommitmentLeaf { protocol_id: *protocol_id, @@ -242,23 +242,23 @@ impl MerkleBlock { ) -> Result { let path = proof.as_path(); let mut pos = proof.pos; - let mut width = proof.width(); + let mut width_limit = proof.width_limit(); - let expected = protocol_id_pos(protocol_id, proof.cofactor, width); + let expected = protocol_id_pos(protocol_id, proof.cofactor, proof.depth()); if expected != pos { return Err(InvalidProof { protocol_id, expected, actual: pos, - width, + width: width_limit, }); } let mut dir = Vec::with_capacity(path.len()); let mut rev = Vec::with_capacity(path.len()); for (depth, hash) in path.iter().enumerate() { - let list = if pos >= width / 2 { - pos -= width / 2; + let list = if pos >= width_limit / 2 { + pos -= width_limit / 2; &mut dir } else { &mut rev @@ -267,7 +267,7 @@ impl MerkleBlock { depth: u5::with(depth as u8) + 1, hash: *hash, }); - width /= 2; + width_limit /= 2; } let mut cross_section = Vec::with_capacity(path.len() + 1); @@ -364,9 +364,9 @@ impl MerkleBlock { hash: hash1, }, Some(TreeNode::ConcealedNode { - depth: depth2, - hash: hash2, - }), + depth: depth2, + hash: hash2, + }), ) if depth1 == depth2 => { let depth = depth1 - 1; let height = self.depth.to_u8() as u32 - depth.to_u8() as u32; @@ -375,7 +375,7 @@ impl MerkleBlock { offset += 2u32.pow(self.depth.to_u8() as u32 - depth1.to_u8() as u32); } else { self.cross_section[pos] = - TreeNode::with(hash1, hash2, depth, self.width()); + TreeNode::with(hash1, hash2, depth, self.width_limit()); self.cross_section .remove(pos + 1) .expect("we allow 0 elements"); @@ -419,7 +419,7 @@ impl MerkleBlock { if count == prev_count { break; } - debug_assert_eq!(offset, self.width()); + debug_assert_eq!(offset, self.width_limit()); } Ok(count) @@ -533,7 +533,7 @@ impl MerkleBlock { .map(|n| self.depth.to_u8() - n.depth_or(self.depth).to_u8()) .map(|height| 2u32.pow(height as u32)) .sum::(), - self.width(), + self.width_limit(), "LNPBP-4 merge-reveal procedure is broken; please report the below data to the LNP/BP \ Standards Association Original block: {orig:#?} @@ -591,11 +591,15 @@ Changed commitment id: {}", /// Computes position for a given `protocol_id` within the tree leaves. pub fn protocol_id_pos(&self, protocol_id: ProtocolId) -> u32 { - protocol_id_pos(protocol_id, self.cofactor, self.width()) + protocol_id_pos(protocol_id, self.cofactor, self.depth) } - /// Computes the width of the merkle tree. - pub fn width(&self) -> u32 { 2u32.pow(self.depth.to_u8() as u32) } + /// Computes the maximum possible width of the merkle tree. + pub fn width_limit(&self) -> u32 { 2u32.pow(self.depth.to_u8() as u32) } + + /// Computes the factored width of the merkle tree according to the formula + /// `2 ^ depth - cofactor`. + pub fn factored_width(&self) -> u32 { self.width_limit() - self.cofactor as u32 } /// Constructs [`MessageMap`] for revealed protocols and messages. pub fn to_known_message_map(&self) -> MessageMap { @@ -611,7 +615,7 @@ Changed commitment id: {}", } => Some((protocol_id, message)), }), ) - .expect("same collection size") + .expect("same collection size") } } @@ -626,9 +630,9 @@ impl Conceal for MerkleBlock { .expect("broken internal MerkleBlock structure"); debug_assert_eq!(concealed.cross_section.len(), 1); let Some(TreeNode::ConcealedNode { - depth: u5::ZERO, - hash, - }) = concealed.cross_section.first() + depth: u5::ZERO, + hash, + }) = concealed.cross_section.first() else { panic!("broken MerkleBlock conceal procedure") }; @@ -670,10 +674,14 @@ impl Proof for MerkleProof { impl MerkleProof { /// Computes the depth of the merkle tree. - pub fn depth(&self) -> u8 { self.path.len() as u8 } + pub fn depth(&self) -> u5 { u5::with(self.path.len() as u8) } + + /// Computes the maximum width of the merkle tree. + pub fn width_limit(&self) -> u32 { 2u32.pow(self.depth().to_u8() as u32) } - /// Computes the width of the merkle tree. - pub fn width(&self) -> u32 { 2u32.pow(self.depth() as u32) } + /// Computes the factored width of the merkle tree according to the formula + /// `2 ^ depth - cofactor`. + pub fn factored_width(&self) -> u32 { self.width_limit() - self.cofactor as u32 } /// Converts the proof into inner merkle path representation pub fn into_path(self) -> Confined, 0, 32> { self.path } diff --git a/commit_verify/src/mpc/tree.rs b/commit_verify/src/mpc/tree.rs index b559f02d..26c6d267 100644 --- a/commit_verify/src/mpc/tree.rs +++ b/commit_verify/src/mpc/tree.rs @@ -66,14 +66,14 @@ impl Proof for MerkleTree { impl MerkleTree { pub fn root(&self) -> MerkleHash { - let iter = (0..self.width()).map(|pos| { + let iter = (0..self.width_limit()).map(|pos| { self.map .get(&pos) .map(|(protocol, msg)| Leaf::inhabited(*protocol, *msg)) .unwrap_or_else(|| Leaf::entropy(self.entropy, pos)) }); let leaves = LargeVec::try_from_iter(iter).expect("tree width has u32-bound size"); - debug_assert_eq!(leaves.len_u32(), self.width()); + debug_assert_eq!(leaves.len_u32(), self.width_limit()); MerkleHash::merklize(&leaves) } } @@ -146,12 +146,12 @@ mod commit { let mut depth = source.min_depth; let mut prev_width = 1u32; loop { - let width = 2u32.pow(depth.to_u8() as u32); - if width as usize >= msg_count { + let width_limit = 2u32.pow(depth.to_u8() as u32); + if width_limit as usize >= msg_count { for cofactor in 0..=(prev_width.min(COFACTOR_ATTEMPTS as u32) as u16) { map.clear(); if source.messages.iter().all(|(protocol, message)| { - let pos = protocol_id_pos(*protocol, cofactor, width); + let pos = protocol_id_pos(*protocol, cofactor, depth); map.insert(pos, (*protocol, *message)).is_none() }) { return Ok(MerkleTree { @@ -165,7 +165,7 @@ mod commit { } } - prev_width = width; + prev_width = width_limit; depth = depth .checked_add(1) .ok_or(Error::CantFitInMaxSlots(msg_count))?; @@ -174,7 +174,8 @@ mod commit { } } -pub(super) fn protocol_id_pos(protocol_id: ProtocolId, cofactor: u16, width: u32) -> u32 { +pub(super) fn protocol_id_pos(protocol_id: ProtocolId, cofactor: u16, depth: u5) -> u32 { + let width = 2u32.pow(depth.to_u8() as u32); debug_assert_ne!(width, 0); let rem = u256::from_le_bytes((*protocol_id).into_inner()) % u256::from(width.saturating_sub(cofactor as u32).max(1) as u64); @@ -184,14 +185,21 @@ pub(super) fn protocol_id_pos(protocol_id: ProtocolId, cofactor: u16, width: u32 impl MerkleTree { /// Computes position for a given `protocol_id` within the tree leaves. pub fn protocol_id_pos(&self, protocol_id: ProtocolId) -> u32 { - protocol_id_pos(protocol_id, self.cofactor, self.width()) + protocol_id_pos(protocol_id, self.cofactor, self.depth) } - /// Computes the width of the merkle tree. - pub fn width(&self) -> u32 { 2u32.pow(self.depth.to_u8() as u32) } + /// Computes the maximum possible width of the merkle tree, equal to `2 ^ + /// depth`. + pub fn width_limit(&self) -> u32 { 2u32.pow(self.depth.to_u8() as u32) } + + /// Computes the factored width of the merkle tree, equal to `2 ^ depth - + /// cofactor`. + pub fn factored_width(&self) -> u32 { self.width_limit() - self.cofactor as u32 } pub fn depth(&self) -> u5 { self.depth } + pub fn cofactor(&self) -> u16 { self.cofactor } + pub fn entropy(&self) -> u64 { self.entropy } } @@ -307,7 +315,7 @@ mod test { length {} bytes.\nTakes {} msecs to generate", tree.depth, tree.cofactor, - tree.width(), + tree.factored_width(), counter.unconfine().count, elapsed_gen.as_millis(), ); @@ -323,7 +331,7 @@ mod test { let msgs = make_random_messages(9); let tree = make_random_tree(&msgs); assert!(tree.depth() > u5::with(3)); - assert!(tree.width() > 9); + assert!(tree.factored_width() > 9); let mut set = BTreeSet::::new(); for (pid, msg) in msgs { let pos = tree.protocol_id_pos(pid);