From 68935c8677a8e25dd713d0bae846dd4a1369e926 Mon Sep 17 00:00:00 2001 From: Kris Nuttycombe Date: Wed, 26 Jun 2024 15:33:00 -0600 Subject: [PATCH] shardtree: Reduce dependence upon `LocatedTree::max_position` --- shardtree/src/lib.rs | 24 +++++--------- shardtree/src/prunable.rs | 70 ++++++++++++++++++++++++++++++++++++--- shardtree/src/tree.rs | 30 +++++++++++------ 3 files changed, 92 insertions(+), 32 deletions(-) diff --git a/shardtree/src/lib.rs b/shardtree/src/lib.rs index 011f37b..6749c65 100644 --- a/shardtree/src/lib.rs +++ b/shardtree/src/lib.rs @@ -235,17 +235,15 @@ impl< let (append_result, position, checkpoint_id) = if let Some(subtree) = self.store.last_shard().map_err(ShardTreeError::Storage)? { - match subtree.max_position() { - // If the subtree is full, then construct a successor tree. - Some(pos) if pos == subtree.root_addr.max_position() => { - let addr = subtree.root_addr; - if subtree.root_addr.index() < Self::max_subtree_index() { - LocatedTree::empty(addr.next_at_level()).append(value, retention)? - } else { - return Err(InsertionError::TreeFull.into()); - } + if subtree.root().is_full() { + let addr = subtree.root_addr; + if subtree.root_addr.index() < Self::max_subtree_index() { + LocatedTree::empty(addr.next_at_level()).append(value, retention)? + } else { + return Err(InsertionError::TreeFull.into()); } - _ => subtree.append(value, retention)?, + } else { + subtree.append(value, retention)? } } else { let root_addr = Address::from_parts(Self::subtree_level(), 0); @@ -573,13 +571,7 @@ impl< .map_err(ShardTreeError::Storage)?; if let Some(to_clear) = to_clear { - let pre_clearing_max_position = to_clear.max_position(); let cleared = to_clear.clear_flags(positions); - - // Clearing flags should not modify the max position of leaves represented - // in the shard. - assert!(cleared.max_position() == pre_clearing_max_position); - self.store .put_shard(cleared) .map_err(ShardTreeError::Storage)?; diff --git a/shardtree/src/prunable.rs b/shardtree/src/prunable.rs index 9130a7b..834b5c4 100644 --- a/shardtree/src/prunable.rs +++ b/shardtree/src/prunable.rs @@ -362,10 +362,11 @@ impl LocatedPrunableTree { /// If the tree contains any [`Node::Nil`] nodes that are to the left of filled nodes in the /// tree, this will return an error containing the addresses of those nodes. pub fn right_filled_root(&self) -> Result> { - self.root_hash( - self.max_position() - .map_or_else(|| self.root_addr.position_range_start(), |pos| pos + 1), - ) + let truncate_at = self + .max_position() + .map_or_else(|| self.root_addr.position_range_start(), |pos| pos + 1); + + self.root_hash(truncate_at) } /// Returns the positions of marked leaves in the tree. @@ -949,6 +950,33 @@ impl LocatedPrunableTree { root: go(&to_clear, self.root_addr, &self.root), } } + + #[cfg(test)] + pub(crate) fn flag_positions(&self) -> BTreeMap { + fn go( + root: &PrunableTree, + root_addr: Address, + acc: &mut BTreeMap, + ) { + match &root.0 { + Node::Parent { left, right, .. } => { + let (l_addr, r_addr) = root_addr + .children() + .expect("A parent node cannot appear at level 0"); + go(&left, l_addr, acc); + go(&right, r_addr, acc); + } + Node::Leaf { value } if value.1 != RetentionFlags::EPHEMERAL => { + acc.insert(root_addr.max_position(), value.1); + } + _ => (), + } + } + + let mut result = BTreeMap::new(); + go(&self.root, self.root_addr, &mut result); + result + } } // We need an applicative functor for Result for this function so that we can correctly @@ -971,13 +999,15 @@ fn accumulate_result_with( #[cfg(test)] mod tests { - use std::collections::BTreeSet; + use std::collections::{BTreeMap, BTreeSet}; use incrementalmerkletree::{Address, Level, Position}; + use proptest::proptest; use super::{LocatedPrunableTree, PrunableTree, RetentionFlags}; use crate::{ error::{InsertionError, QueryError}, + testing::{arb_char_str, arb_prunable_tree}, tree::{ tests::{leaf, nil, parent}, LocatedTree, @@ -1197,4 +1227,34 @@ mod tests { )])) ); } + + proptest! { + #[test] + fn clear_flags( + root in arb_prunable_tree(arb_char_str(), 8, 2^6) + ) { + let root_addr = Address::from_parts(Level::from(7), 0); + let tree = LocatedTree::from_parts(root_addr, root); + + let (to_clear, to_retain) = tree.flag_positions().into_iter().enumerate().fold( + (BTreeMap::new(), BTreeMap::new()), + |(mut to_clear, mut to_retain), (i, (pos, flags))| { + if i % 2 == 0 { + to_clear.insert(pos, flags); + } else { + to_retain.insert(pos, flags); + } + (to_clear, to_retain) + } + ); + + let pre_clearing_max_position = tree.max_position(); + let cleared = tree.clear_flags(to_clear); + + // Clearing flags should not modify the max position of leaves represented + // in the shard. + assert!(cleared.max_position() == pre_clearing_max_position); + assert_eq!(to_retain, cleared.flag_positions()); + } + } } diff --git a/shardtree/src/tree.rs b/shardtree/src/tree.rs index e92fd6c..1cfe2be 100644 --- a/shardtree/src/tree.rs +++ b/shardtree/src/tree.rs @@ -148,6 +148,15 @@ impl Tree { matches!(&self.0, Node::Leaf { .. }) } + /// Returns `true` if no additional nodes can be appended to this tree. + pub fn is_full(&self) -> bool { + match &self.0 { + Node::Nil => false, + Node::Leaf { .. } | Node::Pruned => true, + Node::Parent { right, .. } => right.is_full(), + } + } + /// Returns a vector of the addresses of [`Node::Nil`] and [`Node::Pruned`] subtree roots /// within this tree. /// @@ -260,19 +269,18 @@ impl LocatedTree { /// Note that no actual leaf value may exist at this position, as it may have previously been /// pruned. pub fn max_position(&self) -> Option { - Self::max_position_internal(self.root_addr, &self.root) - } - - pub(crate) fn max_position_internal(addr: Address, root: &Tree) -> Option { - match &root.0 { - Node::Nil => None, - Node::Leaf { .. } | Node::Pruned => Some(addr.position_range_end() - 1), - Node::Parent { left, right, .. } => { - let (l_addr, r_addr) = addr.children().unwrap(); - Self::max_position_internal(r_addr, right.as_ref()) - .or_else(|| Self::max_position_internal(l_addr, left.as_ref())) + fn go(addr: Address, root: &Tree) -> Option { + match &root.0 { + Node::Nil => None, + Node::Leaf { .. } | Node::Pruned => Some(addr.position_range_end() - 1), + Node::Parent { left, right, .. } => { + let (l_addr, r_addr) = addr.children().unwrap(); + go(r_addr, right.as_ref()).or_else(|| go(l_addr, left.as_ref())) + } } } + + go(self.root_addr, &self.root) } /// Returns the value at the specified position, if any.