diff --git a/arbitrator/prover/src/merkle.rs b/arbitrator/prover/src/merkle.rs index 5a3cbadbc7..b37e0140ba 100644 --- a/arbitrator/prover/src/merkle.rs +++ b/arbitrator/prover/src/merkle.rs @@ -149,6 +149,12 @@ impl MerkleType { } } +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +struct Layers { + data: Vec>, + dirt: Vec>, +} + /// A Merkle tree with a fixed number of layers /// /// https://en.wikipedia.org/wiki/Merkle_tree @@ -168,10 +174,8 @@ impl MerkleType { pub struct Merkle { ty: MerkleType, #[serde(with = "arc_mutex_sedre")] - layers: Arc>>>, + layers: Arc>, min_depth: usize, - #[serde(with = "arc_mutex_sedre")] - dirty_layers: Arc>>>, } fn hash_node(ty: MerkleType, a: impl AsRef<[u8]>, b: impl AsRef<[u8]>) -> Bytes32 { @@ -247,42 +251,43 @@ impl Merkle { layers.push(new_layer); layer_i += 1; } - let dirty_layers = Arc::new(Mutex::new(dirty_indices)); + let layers = Arc::new(Mutex::new(Layers { + data: layers, + dirt: dirty_indices, + })); Merkle { ty, - layers: Arc::new(Mutex::new(layers)), + layers, min_depth, - dirty_layers, } } fn rehash(&self) { - let dirty_layers = &mut self.dirty_layers.lock().unwrap(); - if dirty_layers.is_empty() || dirty_layers[0].is_empty() { + let layers = &mut self.layers.lock().unwrap(); + if layers.dirt.is_empty() || layers.dirt[0].is_empty() { return; } - let layers = &mut self.layers.lock().unwrap(); - for layer_i in 1..layers.len() { + for layer_i in 1..layers.data.len() { let dirty_i = layer_i - 1; - let dirt = dirty_layers[dirty_i].clone(); + let dirt = layers.dirt[dirty_i].clone(); for idx in dirt.iter().sorted() { let left_child_idx = idx << 1; let right_child_idx = left_child_idx + 1; - let left = layers[layer_i - 1][left_child_idx]; - let right = layers[layer_i - 1] + let left = layers.data[layer_i - 1][left_child_idx]; + let right = layers.data[layer_i - 1] .get(right_child_idx) .unwrap_or(empty_hash_at(self.ty, layer_i - 1)); let new_hash = hash_node(self.ty, left, right); - if *idx < layers[layer_i].len() { - layers[layer_i][*idx] = new_hash; + if *idx < layers.data[layer_i].len() { + layers.data[layer_i][*idx] = new_hash; } else { - layers[layer_i].push(new_hash); + layers.data[layer_i].push(new_hash); } - if layer_i < layers.len() - 1 { - dirty_layers[dirty_i + 1].insert(idx >> 1); + if layer_i < layers.data.len() - 1 { + layers.dirt[dirty_i + 1].insert(idx >> 1); } } - dirty_layers[dirty_i].clear(); + layers.dirt[dirty_i].clear(); } } @@ -290,7 +295,7 @@ impl Merkle { #[cfg(feature = "counters")] ROOT_COUNTERS[&self.ty].fetch_add(1, Ordering::Relaxed); self.rehash(); - if let Some(layer) = self.layers.lock().unwrap().last() { + if let Some(layer) = self.layers.lock().unwrap().data.last() { assert_eq!(layer.len(), 1); layer[0] } else { @@ -302,26 +307,27 @@ impl Merkle { #[inline] fn capacity(&self) -> usize { let layers = self.layers.lock().unwrap(); - if layers.is_empty() { + if layers.data.is_empty() { return 0; } let base: usize = 2; - base.pow((layers.len() - 1).try_into().unwrap()) + base.pow((layers.data.len() - 1).try_into().unwrap()) } // Returns the number of leaves in the tree. pub fn len(&self) -> usize { - self.layers.lock().unwrap()[0].len() + self.layers.lock().unwrap().data[0].len() } pub fn is_empty(&self) -> bool { let layers = self.layers.lock().unwrap(); - layers.is_empty() || layers[0].is_empty() + layers.data.is_empty() || layers.data[0].is_empty() } #[must_use] pub fn prove(&self, idx: usize) -> Option> { - if self.layers.lock().unwrap().is_empty() || idx >= self.layers.lock().unwrap()[0].len() { + let layers = self.layers.lock().unwrap(); + if layers.data.is_empty() || idx >= layers.data[0].len() { return None; } Some(self.prove_any(idx)) @@ -332,9 +338,9 @@ impl Merkle { pub fn prove_any(&self, mut idx: usize) -> Vec { self.rehash(); let layers = self.layers.lock().unwrap(); - let mut proof = vec![u8::try_from(layers.len() - 1).unwrap()]; - for (layer_i, layer) in layers.iter().enumerate() { - if layer_i == layers.len() - 1 { + let mut proof = vec![u8::try_from(layers.data.len() - 1).unwrap()]; + for (layer_i, layer) in layers.data.iter().enumerate() { + if layer_i == layers.data.len() - 1 { break; } let counterpart = idx ^ 1; @@ -352,7 +358,7 @@ impl Merkle { /// Adds a new leaf to the merkle /// Currently O(n) in the number of leaves (could be log(n)) pub fn push_leaf(&mut self, leaf: Bytes32) { - let mut leaves = self.layers.lock().unwrap().swap_remove(0); + let mut leaves = self.layers.lock().unwrap().data.swap_remove(0); leaves.push(leaf); *self = Self::new_advanced(self.ty, leaves, self.min_depth); } @@ -360,7 +366,7 @@ impl Merkle { /// Removes the rightmost leaf from the merkle /// Currently O(n) in the number of leaves (could be log(n)) pub fn pop_leaf(&mut self) { - let mut leaves = self.layers.lock().unwrap().swap_remove(0); + let mut leaves = self.layers.lock().unwrap().data.swap_remove(0); leaves.pop(); *self = Self::new_advanced(self.ty, leaves, self.min_depth); } @@ -371,11 +377,11 @@ impl Merkle { #[cfg(feature = "counters")] SET_COUNTERS[&self.ty].fetch_add(1, Ordering::Relaxed); let mut layers = self.layers.lock().unwrap(); - if layers[0][idx] == hash { + if layers.data[0][idx] == hash { return; } - layers[0][idx] = hash; - self.dirty_layers.lock().unwrap()[0].insert(idx >> 1); + layers.data[0][idx] = hash; + layers.dirt[0].insert(idx >> 1); } /// Resizes the number of leaves the tree can hold. @@ -389,15 +395,15 @@ impl Merkle { } let mut layers = self.layers.lock().unwrap(); let mut layer_size = new_len; - for (layer_i, layer) in layers.iter_mut().enumerate() { + for (layer_i, layer) in layers.data.iter_mut().enumerate() { layer.resize(layer_size, *empty_hash_at(self.ty, layer_i)); layer_size = max(layer_size >> 1, 1); } - let start = layers[0].len(); + let start = layers.data[0].len(); for i in start..new_len { - self.dirty_layers.lock().unwrap()[0].insert(i); + layers.dirt[0].insert(i); } - Ok(layers[0].len()) + Ok(layers.data[0].len()) } }