From d8c832a01589ef8e3267432bfa9326ea62185ff7 Mon Sep 17 00:00:00 2001 From: Pepper Lebeck-Jobe Date: Tue, 21 May 2024 09:18:32 +0200 Subject: [PATCH] Move the layer data and dirty indices to a struct. This allows us to lock a single Mutex to gain access to both parts of the data structure. This helps us to prevent deadlocks caused by accessing the separate locks in different order from various call sites. --- arbitrator/prover/src/merkle.rs | 80 ++++++++++++++++++--------------- 1 file changed, 43 insertions(+), 37 deletions(-) diff --git a/arbitrator/prover/src/merkle.rs b/arbitrator/prover/src/merkle.rs index 5a3cbadbc7..b37e0140ba 100644 --- a/arbitrator/prover/src/merkle.rs +++ b/arbitrator/prover/src/merkle.rs @@ -149,6 +149,12 @@ impl MerkleType { } } +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +struct Layers { + data: Vec>, + dirt: Vec>, +} + /// A Merkle tree with a fixed number of layers /// /// https://en.wikipedia.org/wiki/Merkle_tree @@ -168,10 +174,8 @@ impl MerkleType { pub struct Merkle { ty: MerkleType, #[serde(with = "arc_mutex_sedre")] - layers: Arc>>>, + layers: Arc>, min_depth: usize, - #[serde(with = "arc_mutex_sedre")] - dirty_layers: Arc>>>, } fn hash_node(ty: MerkleType, a: impl AsRef<[u8]>, b: impl AsRef<[u8]>) -> Bytes32 { @@ -247,42 +251,43 @@ impl Merkle { layers.push(new_layer); layer_i += 1; } - let dirty_layers = Arc::new(Mutex::new(dirty_indices)); + let layers = Arc::new(Mutex::new(Layers { + data: layers, + dirt: dirty_indices, + })); Merkle { ty, - layers: Arc::new(Mutex::new(layers)), + layers, min_depth, - dirty_layers, } } fn rehash(&self) { - let dirty_layers = &mut self.dirty_layers.lock().unwrap(); - if dirty_layers.is_empty() || dirty_layers[0].is_empty() { + let layers = &mut self.layers.lock().unwrap(); + if layers.dirt.is_empty() || layers.dirt[0].is_empty() { return; } - let layers = &mut self.layers.lock().unwrap(); - for layer_i in 1..layers.len() { + for layer_i in 1..layers.data.len() { let dirty_i = layer_i - 1; - let dirt = dirty_layers[dirty_i].clone(); + let dirt = layers.dirt[dirty_i].clone(); for idx in dirt.iter().sorted() { let left_child_idx = idx << 1; let right_child_idx = left_child_idx + 1; - let left = layers[layer_i - 1][left_child_idx]; - let right = layers[layer_i - 1] + let left = layers.data[layer_i - 1][left_child_idx]; + let right = layers.data[layer_i - 1] .get(right_child_idx) .unwrap_or(empty_hash_at(self.ty, layer_i - 1)); let new_hash = hash_node(self.ty, left, right); - if *idx < layers[layer_i].len() { - layers[layer_i][*idx] = new_hash; + if *idx < layers.data[layer_i].len() { + layers.data[layer_i][*idx] = new_hash; } else { - layers[layer_i].push(new_hash); + layers.data[layer_i].push(new_hash); } - if layer_i < layers.len() - 1 { - dirty_layers[dirty_i + 1].insert(idx >> 1); + if layer_i < layers.data.len() - 1 { + layers.dirt[dirty_i + 1].insert(idx >> 1); } } - dirty_layers[dirty_i].clear(); + layers.dirt[dirty_i].clear(); } } @@ -290,7 +295,7 @@ impl Merkle { #[cfg(feature = "counters")] ROOT_COUNTERS[&self.ty].fetch_add(1, Ordering::Relaxed); self.rehash(); - if let Some(layer) = self.layers.lock().unwrap().last() { + if let Some(layer) = self.layers.lock().unwrap().data.last() { assert_eq!(layer.len(), 1); layer[0] } else { @@ -302,26 +307,27 @@ impl Merkle { #[inline] fn capacity(&self) -> usize { let layers = self.layers.lock().unwrap(); - if layers.is_empty() { + if layers.data.is_empty() { return 0; } let base: usize = 2; - base.pow((layers.len() - 1).try_into().unwrap()) + base.pow((layers.data.len() - 1).try_into().unwrap()) } // Returns the number of leaves in the tree. pub fn len(&self) -> usize { - self.layers.lock().unwrap()[0].len() + self.layers.lock().unwrap().data[0].len() } pub fn is_empty(&self) -> bool { let layers = self.layers.lock().unwrap(); - layers.is_empty() || layers[0].is_empty() + layers.data.is_empty() || layers.data[0].is_empty() } #[must_use] pub fn prove(&self, idx: usize) -> Option> { - if self.layers.lock().unwrap().is_empty() || idx >= self.layers.lock().unwrap()[0].len() { + let layers = self.layers.lock().unwrap(); + if layers.data.is_empty() || idx >= layers.data[0].len() { return None; } Some(self.prove_any(idx)) @@ -332,9 +338,9 @@ impl Merkle { pub fn prove_any(&self, mut idx: usize) -> Vec { self.rehash(); let layers = self.layers.lock().unwrap(); - let mut proof = vec![u8::try_from(layers.len() - 1).unwrap()]; - for (layer_i, layer) in layers.iter().enumerate() { - if layer_i == layers.len() - 1 { + let mut proof = vec![u8::try_from(layers.data.len() - 1).unwrap()]; + for (layer_i, layer) in layers.data.iter().enumerate() { + if layer_i == layers.data.len() - 1 { break; } let counterpart = idx ^ 1; @@ -352,7 +358,7 @@ impl Merkle { /// Adds a new leaf to the merkle /// Currently O(n) in the number of leaves (could be log(n)) pub fn push_leaf(&mut self, leaf: Bytes32) { - let mut leaves = self.layers.lock().unwrap().swap_remove(0); + let mut leaves = self.layers.lock().unwrap().data.swap_remove(0); leaves.push(leaf); *self = Self::new_advanced(self.ty, leaves, self.min_depth); } @@ -360,7 +366,7 @@ impl Merkle { /// Removes the rightmost leaf from the merkle /// Currently O(n) in the number of leaves (could be log(n)) pub fn pop_leaf(&mut self) { - let mut leaves = self.layers.lock().unwrap().swap_remove(0); + let mut leaves = self.layers.lock().unwrap().data.swap_remove(0); leaves.pop(); *self = Self::new_advanced(self.ty, leaves, self.min_depth); } @@ -371,11 +377,11 @@ impl Merkle { #[cfg(feature = "counters")] SET_COUNTERS[&self.ty].fetch_add(1, Ordering::Relaxed); let mut layers = self.layers.lock().unwrap(); - if layers[0][idx] == hash { + if layers.data[0][idx] == hash { return; } - layers[0][idx] = hash; - self.dirty_layers.lock().unwrap()[0].insert(idx >> 1); + layers.data[0][idx] = hash; + layers.dirt[0].insert(idx >> 1); } /// Resizes the number of leaves the tree can hold. @@ -389,15 +395,15 @@ impl Merkle { } let mut layers = self.layers.lock().unwrap(); let mut layer_size = new_len; - for (layer_i, layer) in layers.iter_mut().enumerate() { + for (layer_i, layer) in layers.data.iter_mut().enumerate() { layer.resize(layer_size, *empty_hash_at(self.ty, layer_i)); layer_size = max(layer_size >> 1, 1); } - let start = layers[0].len(); + let start = layers.data[0].len(); for i in start..new_len { - self.dirty_layers.lock().unwrap()[0].insert(i); + layers.dirt[0].insert(i); } - Ok(layers[0].len()) + Ok(layers.data[0].len()) } }