diff --git a/arbitrator/prover/benches/merkle_bench.rs b/arbitrator/prover/benches/merkle_bench.rs index 98ac09a20e..20ac4eed2c 100644 --- a/arbitrator/prover/benches/merkle_bench.rs +++ b/arbitrator/prover/benches/merkle_bench.rs @@ -1,9 +1,9 @@ use arbutil::Bytes32; use criterion::{criterion_group, criterion_main, Criterion}; -use prover::merkle::{Merkle, MerkleType}; +use prover::merkle::{DirtyMerkle, MerkleType}; use rand::Rng; -fn resize_and_set_leaves(merkle: Merkle, rng: &mut rand::rngs::ThreadRng) { +fn resize_and_set_leaves(merkle: DirtyMerkle, rng: &mut rand::rngs::ThreadRng) { for _ in 0..100 { merkle.resize(merkle.len() + 5).expect("resize failed"); for _ in 0..(merkle.len() / 10) { @@ -27,7 +27,7 @@ fn merkle_benchmark(c: &mut Criterion) { // Perform many calls to set leaves to new values c.bench_function("resize_set_leaves_and_root", |b| { b.iter(|| { - let merkle = Merkle::new_advanced(MerkleType::Memory, leaves.clone(), 20); + let merkle = DirtyMerkle::new_advanced(MerkleType::Memory, leaves.clone(), 20); resize_and_set_leaves(merkle.clone(), &mut rng); }) }); @@ -42,7 +42,7 @@ fn merkle_construction(c: &mut Criterion) { c.bench_function("merkle_construction", |b| { b.iter(|| { - let merkle = Merkle::new_advanced(MerkleType::Memory, leaves.clone(), 21); + let merkle = DirtyMerkle::new_advanced(MerkleType::Memory, leaves.clone(), 21); merkle.root(); }) }); diff --git a/arbitrator/prover/src/machine.rs b/arbitrator/prover/src/machine.rs index 6b0e1df3e1..dfa0be81a4 100644 --- a/arbitrator/prover/src/machine.rs +++ b/arbitrator/prover/src/machine.rs @@ -9,7 +9,7 @@ use crate::{ }, host, memory::Memory, - merkle::{Merkle, MerkleType}, + merkle::{CleanMerkle, DirtyMerkle, MerkleType}, programs::{config::CompileConfig, meter::MeteredMachine, ModuleMod, StylusData}, reinterpret::{ReinterpretAsSigned, ReinterpretAsUnsigned}, utils::{file_bytes, CBytes, RemoteTableType}, @@ -33,7 +33,7 @@ use serde_with::serde_as; use sha3::Keccak256; use smallvec::SmallVec; use std::{ - borrow::Cow, + borrow::{Borrow, Cow}, convert::{TryFrom, TryInto}, fmt::{self, Display}, fs::File, @@ -77,7 +77,7 @@ pub struct Function { pub code: Vec, pub ty: FunctionType, #[serde(skip)] - code_merkle: Merkle, + code_merkle: CleanMerkle, pub local_types: Vec, } @@ -101,7 +101,7 @@ impl Function { insts.push(Instruction { opcode: Opcode::InitFrame, argument_data: 0, - proving_argument_data: Some(Merkle::new(MerkleType::Value, empty_local_hashes).root()), + proving_argument_data: Some(CleanMerkle::new(MerkleType::Value, empty_local_hashes).root()), }); // Fill in parameters for i in (0..func_ty.inputs.len()).rev() { @@ -139,7 +139,7 @@ impl Function { let mut func = Function { code, ty, - code_merkle: Merkle::default(), // TODO: make an option + code_merkle: Default::default(), // TODO: make an option local_types, }; func.set_code_merkle(); @@ -159,7 +159,7 @@ impl Function { #[cfg(not(feature = "rayon"))] let code_hashes = (0..chunks).map(crunch).collect(); - self.code_merkle = Merkle::new(MerkleType::Instruction, code_hashes); + self.code_merkle = CleanMerkle::new(MerkleType::Instruction, code_hashes); } fn serialize_body_for_proof(&self, pc: ProgramCounter) -> Vec { @@ -190,7 +190,7 @@ impl StackFrame { h.update("Stack frame:"); h.update(self.return_ref.hash()); h.update( - Merkle::new( + CleanMerkle::new( MerkleType::Value, self.locals.iter().map(|v| v.hash()).collect(), ) @@ -205,7 +205,7 @@ impl StackFrame { let mut data = Vec::new(); data.extend(self.return_ref.serialize_for_proof()); data.extend( - Merkle::new( + CleanMerkle::new( MerkleType::Value, self.locals.iter().map(|v| v.hash()).collect(), ) @@ -249,7 +249,7 @@ pub(crate) struct Table { pub ty: TableType, pub elems: Vec, #[serde(skip)] - elems_merkle: Merkle, + elems_merkle: CleanMerkle, } impl Table { @@ -289,10 +289,10 @@ pub struct Module { pub(crate) memory: Memory, pub(crate) tables: Vec, #[serde(skip)] - pub(crate) tables_merkle: Merkle, + pub(crate) tables_merkle: CleanMerkle, pub(crate) funcs: Arc>, #[serde(skip)] - pub(crate) funcs_merkle: Arc, + pub(crate) funcs_merkle: Arc, pub(crate) types: Arc>, pub(crate) internals_offset: u32, pub(crate) names: Arc, @@ -511,7 +511,7 @@ impl Module { tables.push(Table { elems: vec![TableElement::default(); usize::try_from(table.initial).unwrap()], ty: *table, - elems_merkle: Merkle::default(), + elems_merkle: Default::default(), }); } @@ -566,9 +566,9 @@ impl Module { Ok(Module { memory, globals: bin.globals.clone(), - tables_merkle: Merkle::new(MerkleType::Table, tables_hashes?), + tables_merkle: CleanMerkle::new(MerkleType::Table, tables_hashes?), tables, - funcs_merkle: Arc::new(Merkle::new( + funcs_merkle: Arc::new(CleanMerkle::new( MerkleType::Function, code.iter().map(|f| f.hash()).collect(), )), @@ -615,7 +615,7 @@ impl Module { let mut h = Keccak256::new(); h.update("Module:"); h.update( - Merkle::new( + CleanMerkle::new( MerkleType::Value, self.globals.iter().map(|v| v.hash()).collect(), ) @@ -629,11 +629,11 @@ impl Module { h.finalize().into() } - fn serialize_for_proof(&self, mem_merkle: &Merkle) -> Vec { + fn serialize_for_proof>>>(&self, mem_merkle: &CleanMerkle) -> Vec { let mut data = Vec::new(); data.extend( - Merkle::new( + CleanMerkle::new( MerkleType::Value, self.globals.iter().map(|v| v.hash()).collect(), ) @@ -682,9 +682,9 @@ pub struct ModuleSerdeAll { globals: Vec, memory: Memory, tables: Vec
, - tables_merkle: Merkle, + tables_merkle: CleanMerkle, funcs: Vec, - funcs_merkle: Arc, + funcs_merkle: Arc, types: Arc>, internals_offset: u32, names: Arc, @@ -747,7 +747,7 @@ impl From<&Module> for ModuleSerdeAll { pub struct FunctionSerdeAll { code: Vec, ty: FunctionType, - code_merkle: Merkle, + code_merkle: CleanMerkle, local_types: Vec, } @@ -957,7 +957,7 @@ pub struct Machine { internal_stack: Vec, frame_stacks: Vec>, modules: Vec, - modules_merkle: Option, + modules_merkle: Option, global_state: GlobalState, pc: ProgramCounter, stdio_output: Vec, @@ -1472,8 +1472,8 @@ impl Machine { globals: Vec::new(), memory: Memory::default(), tables: Vec::new(), - tables_merkle: Merkle::default(), - funcs_merkle: Arc::new(Merkle::new( + tables_merkle: CleanMerkle::default(), + funcs_merkle: Arc::new(CleanMerkle::new( MerkleType::Function, entrypoint_funcs.iter().map(Function::hash).collect(), )), @@ -1498,14 +1498,14 @@ impl Machine { // Merkleize things if requested for module in &mut modules { for table in module.tables.iter_mut() { - table.elems_merkle = Merkle::new( + table.elems_merkle = CleanMerkle::new( MerkleType::TableElement, table.elems.iter().map(TableElement::hash).collect(), ); } let tables_hashes: Result<_, _> = module.tables.iter().map(Table::hash).collect(); - module.tables_merkle = Merkle::new(MerkleType::Table, tables_hashes?); + module.tables_merkle = CleanMerkle::new(MerkleType::Table, tables_hashes?); if always_merkleize { module.memory.cache_merkle_tree(); @@ -1513,7 +1513,7 @@ impl Machine { } let mut modules_merkle = None; if always_merkleize { - modules_merkle = Some(Merkle::new( + modules_merkle = Some(DirtyMerkle::new( MerkleType::Module, modules.iter().map(Module::hash).collect(), )); @@ -1562,18 +1562,18 @@ impl Machine { for module in modules.iter_mut() { for table in module.tables.iter_mut() { - table.elems_merkle = Merkle::new( + table.elems_merkle = CleanMerkle::new( MerkleType::TableElement, table.elems.iter().map(TableElement::hash).collect(), ); } let tables: Result<_> = module.tables.iter().map(Table::hash).collect(); - module.tables_merkle = Merkle::new(MerkleType::Table, tables?); + module.tables_merkle = CleanMerkle::new(MerkleType::Table, tables?); let funcs = Arc::get_mut(&mut module.funcs).expect("Multiple copies of module funcs"); funcs.iter_mut().for_each(Function::set_code_merkle); - module.funcs_merkle = Arc::new(Merkle::new( + module.funcs_merkle = Arc::new(CleanMerkle::new( MerkleType::Function, module.funcs.iter().map(Function::hash).collect(), )); @@ -1581,9 +1581,9 @@ impl Machine { module.memory.cache_merkle_tree(); } } - let mut modules_merkle: Option = None; + let mut modules_merkle: Option = None; if always_merkleize { - modules_merkle = Some(Merkle::new( + modules_merkle = Some(DirtyMerkle::new( MerkleType::Module, modules.iter().map(Module::hash).collect(), )); @@ -1692,7 +1692,7 @@ impl Machine { for module in &mut self.modules { module.memory.cache_merkle_tree(); } - self.modules_merkle = Some(Merkle::new( + self.modules_merkle = Some(DirtyMerkle::new( MerkleType::Module, self.modules.iter().map(Module::hash).collect(), )); @@ -2695,14 +2695,14 @@ impl Machine { self.status } - fn get_modules_merkle(&self) -> Cow { + fn get_modules_merkle(&self) -> CleanMerkle>>> { if let Some(merkle) = &self.modules_merkle { - Cow::Borrowed(merkle) + merkle.clean().to_cow() } else { - Cow::Owned(Merkle::new( + CleanMerkle::new( MerkleType::Module, self.modules.iter().map(Module::hash).collect(), - )) + ).to_cow() } } @@ -2914,14 +2914,14 @@ impl Machine { let idx = arg as usize; out!(locals[idx].serialize_for_proof()); let merkle = - Merkle::new(MerkleType::Value, locals.iter().map(|v| v.hash()).collect()); + CleanMerkle::new(MerkleType::Value, locals.iter().map(|v| v.hash()).collect()); out!(merkle.prove(idx).expect("Out of bounds local access")); } GlobalGet | GlobalSet => { let idx = arg as usize; out!(module.globals[idx].serialize_for_proof()); let globals_merkle = module.globals.iter().map(|v| v.hash()).collect(); - let merkle = Merkle::new(MerkleType::Value, globals_merkle); + let merkle = CleanMerkle::new(MerkleType::Value, globals_merkle); out!(merkle.prove(idx).expect("Out of bounds global access")); } MemoryLoad { .. } | MemoryStore { .. } => { @@ -2958,9 +2958,9 @@ impl Machine { copy.modules[self.pc.module()] .memory .merkelize() - .into_owned() + .to_owned() } else { - mem_merkle.into_owned() + mem_merkle.to_owned() }; out!(second_mem_merkle.prove(next_leaf_idx).unwrap_or_default()); } diff --git a/arbitrator/prover/src/memory.rs b/arbitrator/prover/src/memory.rs index bba8e4124f..13781c5099 100644 --- a/arbitrator/prover/src/memory.rs +++ b/arbitrator/prover/src/memory.rs @@ -2,7 +2,7 @@ // For license information, see https://github.com/nitro/blob/master/LICENSE use crate::{ - merkle::{Merkle, MerkleType}, + merkle::{CleanMerkle, DirtyMerkle, MerkleType}, value::{ArbValueType, Value}, }; use arbutil::Bytes32; @@ -44,14 +44,22 @@ impl TryFrom<&wasmparser::MemoryType> for MemoryType { } } -#[derive(PartialEq, Eq, Clone, Debug, Default, Serialize, Deserialize)] +#[derive(Clone, Debug, Default, Serialize, Deserialize)] pub struct Memory { buffer: Vec, #[serde(skip)] - pub merkle: Option, + pub merkle: Option, pub max_size: u64, } +impl PartialEq for Memory { + fn eq(&self, other: &Self) -> bool { + self.buffer == other.buffer && self.max_size == other.max_size + } +} + +impl Eq for Memory {} + fn hash_leaf(bytes: [u8; Memory::LEAF_SIZE]) -> Bytes32 { let mut h = Keccak256::new(); h.update("Memory leaf:"); @@ -98,10 +106,7 @@ impl Memory { self.buffer.len() as u64 } - pub fn merkelize(&self) -> Cow<'_, Merkle> { - if let Some(m) = &self.merkle { - return Cow::Borrowed(m); - } + fn make_dirty_merkle(&self) -> DirtyMerkle { // Round the size up to 8 byte long leaves, then round up to the next power of two number of leaves let leaves = round_up_to_power_of_two(div_round_up(self.buffer.len(), Self::LEAF_SIZE)); @@ -119,11 +124,18 @@ impl Memory { }) .collect(); let size = leaf_hashes.len(); - let m = Merkle::new_advanced(MerkleType::Memory, leaf_hashes, Self::MEMORY_LAYERS); + let mut m = DirtyMerkle::new_advanced(MerkleType::Memory, leaf_hashes, Self::MEMORY_LAYERS); if size < leaves { m.resize(leaves).expect("Couldn't resize merkle tree"); } - Cow::Owned(m) + m + } + + pub fn merkelize(&self) -> CleanMerkle>>> { + if let Some(merkle) = &self.merkle { + return merkle.clean().to_cow(); + } + self.make_dirty_merkle().into_clean().to_cow() } pub fn get_leaf_data(&self, leaf_idx: usize) -> [u8; Self::LEAF_SIZE] { @@ -230,7 +242,7 @@ impl Memory { let buf = value.to_le_bytes(); self.buffer[idx..end_idx].copy_from_slice(&buf[..bytes.into()]); - if let Some(merkle) = self.merkle.take() { + if let Some(mut merkle) = self.merkle.take() { let start_leaf = idx / Self::LEAF_SIZE; merkle.set(start_leaf, hash_leaf(self.get_leaf_data(start_leaf))); let end_leaf = (end_idx - 1) / Self::LEAF_SIZE; @@ -258,12 +270,11 @@ impl Memory { let end_idx = end_idx as usize; self.buffer[idx..end_idx].copy_from_slice(value); - if let Some(merkle) = self.merkle.take() { + if let Some(mut merkle) = self.merkle.take() { let start_leaf = idx / Self::LEAF_SIZE; merkle.set(start_leaf, hash_leaf(self.get_leaf_data(start_leaf))); // No need for second merkle assert!(value.len() <= Self::LEAF_SIZE); - self.merkle = Some(merkle); } true @@ -302,12 +313,14 @@ impl Memory { } pub fn cache_merkle_tree(&mut self) { - self.merkle = Some(self.merkelize().into_owned()); + if self.merkle.is_none() { + self.merkle = Some(self.make_dirty_merkle()); + } } pub fn resize(&mut self, new_size: usize) { self.buffer.resize(new_size, 0); - if let Some(merkle) = self.merkle.take() { + if let Some(mut merkle) = self.merkle.take() { merkle .resize(new_size) .expect("Couldn't resize merkle tree"); @@ -330,7 +343,7 @@ mod test { 86u8, 177, 192, 60, 217, 123, 221, 153, 118, 79, 229, 122, 210, 48, 187, 104, 40, 84, 112, 63, 137, 86, 54, 2, 56, 118, 72, 158, 242, 225, 65, 80, ]); - let memory = Memory::new(65536, 1); + let mut memory = Memory::new(65536, 1); assert_eq!(memory.hash(), module_memory_hash); } diff --git a/arbitrator/prover/src/merkle.rs b/arbitrator/prover/src/merkle.rs index 5a3cbadbc7..4a61b9b153 100644 --- a/arbitrator/prover/src/merkle.rs +++ b/arbitrator/prover/src/merkle.rs @@ -9,9 +9,15 @@ use enum_iterator::Sequence; #[cfg(feature = "counters")] use enum_iterator::all; use itertools::Itertools; +use parking_lot::Once; +use std::borrow::Borrow; +use std::borrow::Cow; +use std::cell::UnsafeCell; use std::cmp::max; +use std::ops::Deref; +use std::sync::atomic::AtomicBool; #[cfg(feature = "counters")] use std::sync::atomic::AtomicUsize; @@ -30,8 +36,8 @@ use sha3::Keccak256; use std::{ collections::HashSet, convert::{TryFrom, TryInto}, - sync::{Arc, Mutex}, }; +use parking_lot::Mutex; #[cfg(feature = "rayon")] use rayon::prelude::*; @@ -149,6 +155,42 @@ impl MerkleType { } } +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +struct DirtyMerkleData { + layers: Vec>, + dirty_layers: Vec>, +} + +impl DirtyMerkleData { + fn rehash(&mut self, ty: MerkleType) { + if self.dirty_layers.is_empty() || self.dirty_layers[0].is_empty() { + return; + } + for layer_i in 1..self.layers.len() { + let dirty_i = layer_i - 1; + let dirt = self.dirty_layers[dirty_i].clone(); + for idx in dirt.iter().sorted() { + let left_child_idx = idx << 1; + let right_child_idx = left_child_idx + 1; + let left = self.layers[layer_i - 1][left_child_idx]; + let right = self.layers[layer_i - 1] + .get(right_child_idx) + .unwrap_or(empty_hash_at(ty, layer_i - 1)); + let new_hash = hash_node(ty, left, right); + if *idx < self.layers[layer_i].len() { + self.layers[layer_i][*idx] = new_hash; + } else { + self.layers[layer_i].push(new_hash); + } + if layer_i < self.layers.len() - 1 { + self.dirty_layers[dirty_i + 1].insert(idx >> 1); + } + } + self.dirty_layers[dirty_i].clear(); + } + } +} + /// A Merkle tree with a fixed number of layers /// /// https://en.wikipedia.org/wiki/Merkle_tree @@ -164,14 +206,40 @@ impl MerkleType { /// and passing a minimum depth. /// /// This structure does not contain the data itself, only the hashes. -#[derive(Debug, Clone, Default, Serialize, Deserialize)] -pub struct Merkle { +#[derive(Debug, Default)] +pub struct DirtyMerkle { ty: MerkleType, - #[serde(with = "arc_mutex_sedre")] - layers: Arc>>>, min_depth: usize, - #[serde(with = "arc_mutex_sedre")] - dirty_layers: Arc>>>, + clean: Once, + data: UnsafeCell, +} + +fn done_once() -> Once { + let once = Once::new(); + once.call_once(|| {}); + once +} + +impl Clone for DirtyMerkle { + fn clone(&self) -> Self { + self.rehash(); + // SAFETY: It's safe to read data with an immutable reference after a rehash + let data = unsafe { + (*self.data.get()).clone() + }; + DirtyMerkle { + ty: self.ty, + min_depth: self.min_depth, + clean: done_once(), + data: UnsafeCell::new(data), + } + } +} + +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +pub struct CleanMerkle>> = Vec>> { + ty: MerkleType, + layers: L, } fn hash_node(ty: MerkleType, a: impl AsRef<[u8]>, b: impl AsRef<[u8]>) -> Bytes32 { @@ -216,143 +284,156 @@ fn new_layer(ty: MerkleType, layer: &Vec, empty_hash: &'static Bytes32) new_layer } -impl Merkle { - /// Creates a new Merkle tree with the given type and leaf hashes. - /// The tree is built up to the minimum depth necessary to hold all the - /// leaves. - pub fn new(ty: MerkleType, hashes: Vec) -> Merkle { +impl CleanMerkle>> { + pub fn new(ty: MerkleType, hashes: Vec) -> CleanMerkle { Self::new_advanced(ty, hashes, 0) } - /// Creates a new Merkle tree with the given type, leaf hashes, a hash to - /// use for representing empty leaves, and a minimum depth. - pub fn new_advanced(ty: MerkleType, hashes: Vec, min_depth: usize) -> Merkle { - #[cfg(feature = "counters")] - NEW_COUNTERS[&ty].fetch_add(1, Ordering::Relaxed); + pub fn new_advanced(ty: MerkleType, hashes: Vec, min_depth: usize) -> CleanMerkle { if hashes.is_empty() && min_depth == 0 { - return Merkle::default(); + return CleanMerkle::default(); } let mut depth = (hashes.len() as f64).log2().ceil() as usize; depth = depth.max(min_depth); let mut layers: Vec> = Vec::with_capacity(depth); layers.push(hashes); - let mut dirty_indices: Vec> = Vec::with_capacity(depth); let mut layer_i = 0usize; while layers.last().unwrap().len() > 1 || layers.len() < min_depth { let layer = layers.last().unwrap(); let empty_hash = empty_hash_at(ty, layer_i); let new_layer = new_layer(ty, layer, empty_hash); - dirty_indices.push(HashSet::with_capacity(new_layer.len())); layers.push(new_layer); layer_i += 1; } - let dirty_layers = Arc::new(Mutex::new(dirty_indices)); - Merkle { + CleanMerkle { ty, layers } + } + + pub fn to_cow(self) -> CleanMerkle>>> { + CleanMerkle { + ty: self.ty, + layers: Cow::Owned(self.layers), + } + } +} + +impl<'a> CleanMerkle<&'a Vec>> { + pub fn to_cow(self) -> CleanMerkle>>> { + CleanMerkle { + ty: self.ty, + layers: Cow::Borrowed(self.layers), + } + } +} + +impl<'a> CleanMerkle>>> { + pub fn to_owned(self) -> CleanMerkle>> { + CleanMerkle { + ty: self.ty, + layers: self.layers.into_owned(), + } + } +} + +impl DirtyMerkle { + /// Creates a new Merkle tree with the given type and leaf hashes. + /// The tree is built up to the minimum depth necessary to hold all the + /// leaves. + pub fn new(ty: MerkleType, hashes: Vec) -> DirtyMerkle { + Self::new_advanced(ty, hashes, 0) + } + + /// Creates a new Merkle tree with the given type, leaf hashes, a hash to + /// use for representing empty leaves, and a minimum depth. + pub fn new_advanced(ty: MerkleType, hashes: Vec, min_depth: usize) -> DirtyMerkle { + #[cfg(feature = "counters")] + NEW_COUNTERS[&ty].fetch_add(1, Ordering::Relaxed); + let clean = CleanMerkle::new_advanced(ty, hashes, min_depth); + let dirty_layers = clean + .layers + .iter() + .map(|layer| HashSet::with_capacity(layer.len())) + .collect(); + DirtyMerkle { ty, - layers: Arc::new(Mutex::new(layers)), min_depth, - dirty_layers, + data: UnsafeCell::new(DirtyMerkleData { + layers: clean.layers, + dirty_layers, + }), + clean: done_once(), } } fn rehash(&self) { - let dirty_layers = &mut self.dirty_layers.lock().unwrap(); - if dirty_layers.is_empty() || dirty_layers[0].is_empty() { - return; + self.clean.call_once(|| { + // SAFETY: We have an immutable reference and are in a Once, so nothing else can mutate this + let data = unsafe { &mut *self.data.get() }; + data.rehash(self.ty); + }) + } + + pub fn clean(&self) -> CleanMerkle<&'_ Vec>> { + self.rehash(); + // SAFETY: It's safe to read data with an immutable reference after a rehash + let data = unsafe { &*self.data.get() }; + CleanMerkle { + ty: self.ty, + layers: data.layers.borrow(), } - let layers = &mut self.layers.lock().unwrap(); - for layer_i in 1..layers.len() { - let dirty_i = layer_i - 1; - let dirt = dirty_layers[dirty_i].clone(); - for idx in dirt.iter().sorted() { - let left_child_idx = idx << 1; - let right_child_idx = left_child_idx + 1; - let left = layers[layer_i - 1][left_child_idx]; - let right = layers[layer_i - 1] - .get(right_child_idx) - .unwrap_or(empty_hash_at(self.ty, layer_i - 1)); - let new_hash = hash_node(self.ty, left, right); - if *idx < layers[layer_i].len() { - layers[layer_i][*idx] = new_hash; - } else { - layers[layer_i].push(new_hash); - } - if layer_i < layers.len() - 1 { - dirty_layers[dirty_i + 1].insert(idx >> 1); - } - } - dirty_layers[dirty_i].clear(); + } + + pub fn into_clean(self) -> CleanMerkle { + self.rehash(); + CleanMerkle { + ty: self.ty, + layers: self.data.into_inner().layers, } } pub fn root(&self) -> Bytes32 { #[cfg(feature = "counters")] ROOT_COUNTERS[&self.ty].fetch_add(1, Ordering::Relaxed); - self.rehash(); - if let Some(layer) = self.layers.lock().unwrap().last() { - assert_eq!(layer.len(), 1); - layer[0] - } else { - Bytes32::default() - } + self.clean().root() } // Returns the total number of leaves the tree can hold. #[inline] fn capacity(&self) -> usize { - let layers = self.layers.lock().unwrap(); - if layers.is_empty() { + let clean = self.clean(); + if clean.layers.is_empty() { return 0; } let base: usize = 2; - base.pow((layers.len() - 1).try_into().unwrap()) + base.pow((clean.layers.len() - 1).try_into().unwrap()) } // Returns the number of leaves in the tree. pub fn len(&self) -> usize { - self.layers.lock().unwrap()[0].len() + let clean = self.clean(); + clean.layers[0].len() } pub fn is_empty(&self) -> bool { - let layers = self.layers.lock().unwrap(); - layers.is_empty() || layers[0].is_empty() + let clean = self.clean(); + clean.layers.is_empty() || clean.layers[0].is_empty() } #[must_use] - pub fn prove(&self, idx: usize) -> Option> { - if self.layers.lock().unwrap().is_empty() || idx >= self.layers.lock().unwrap()[0].len() { - return None; - } - Some(self.prove_any(idx)) + pub fn prove(&mut self, idx: usize) -> Option> { + self.clean().prove(idx) } /// creates a merkle proof regardless of if the leaf has content #[must_use] - pub fn prove_any(&self, mut idx: usize) -> Vec { - self.rehash(); - let layers = self.layers.lock().unwrap(); - let mut proof = vec![u8::try_from(layers.len() - 1).unwrap()]; - for (layer_i, layer) in layers.iter().enumerate() { - if layer_i == layers.len() - 1 { - break; - } - let counterpart = idx ^ 1; - proof.extend( - layer - .get(counterpart) - .cloned() - .unwrap_or_else(|| *empty_hash_at(self.ty, layer_i)), - ); - idx >>= 1; - } - proof + pub fn prove_any(&mut self, idx: usize) -> Vec { + self.clean().prove_any(idx) } /// Adds a new leaf to the merkle /// Currently O(n) in the number of leaves (could be log(n)) pub fn push_leaf(&mut self, leaf: Bytes32) { - let mut leaves = self.layers.lock().unwrap().swap_remove(0); + let mut leaves = self.data.get_mut().layers.swap_remove(0); leaves.push(leaf); *self = Self::new_advanced(self.ty, leaves, self.min_depth); } @@ -360,58 +441,101 @@ impl Merkle { /// Removes the rightmost leaf from the merkle /// Currently O(n) in the number of leaves (could be log(n)) pub fn pop_leaf(&mut self) { - let mut leaves = self.layers.lock().unwrap().swap_remove(0); + let mut leaves = self.data.get_mut().layers.swap_remove(0); leaves.pop(); *self = Self::new_advanced(self.ty, leaves, self.min_depth); } // Sets the leaf at the given index to the given hash. // Panics if the index is out of bounds (since the structure doesn't grow). - pub fn set(&self, idx: usize, hash: Bytes32) { + pub fn set(&mut self, idx: usize, hash: Bytes32) { #[cfg(feature = "counters")] SET_COUNTERS[&self.ty].fetch_add(1, Ordering::Relaxed); - let mut layers = self.layers.lock().unwrap(); - if layers[0][idx] == hash { + // This dirties the merkle tree + self.clean = Once::new(); + let data = self.data.get_mut(); + if data.layers[0][idx] == hash { return; } - layers[0][idx] = hash; - self.dirty_layers.lock().unwrap()[0].insert(idx >> 1); + data.layers[0][idx] = hash; + data.dirty_layers[0].insert(idx >> 1); } /// Resizes the number of leaves the tree can hold. /// /// The extra space is filled with empty hashes. - pub fn resize(&self, new_len: usize) -> Result { + pub fn resize(&mut self, new_len: usize) -> Result { if new_len > self.capacity() { return Err( "Cannot resize to a length greater than the capacity of the tree.".to_owned(), ); } - let mut layers = self.layers.lock().unwrap(); let mut layer_size = new_len; - for (layer_i, layer) in layers.iter_mut().enumerate() { + // This dirties the merkle tree + self.clean = Once::new(); + let data = self.data.get_mut(); + for (layer_i, layer) in data.layers.iter_mut().enumerate() { layer.resize(layer_size, *empty_hash_at(self.ty, layer_i)); layer_size = max(layer_size >> 1, 1); } - let start = layers[0].len(); + let start = data.layers[0].len(); for i in start..new_len { - self.dirty_layers.lock().unwrap()[0].insert(i); + data.dirty_layers[0].insert(i); + } + Ok(data.layers[0].len()) + } +} + +impl>>> CleanMerkle { + pub fn root(&self) -> Bytes32 { + let layers = self.layers.borrow(); + if let Some(layer) = layers.last() { + assert_eq!(layer.len(), 1); + layer[0] + } else { + Bytes32::default() + } + } + + pub fn prove_any(&self, mut idx: usize) -> Vec { + let layers = self.layers.borrow(); + let mut proof = vec![u8::try_from(layers.len() - 1).unwrap()]; + for (layer_i, layer) in layers.iter().enumerate() { + if layer_i == layers.len() - 1 { + break; + } + let counterpart = idx ^ 1; + proof.extend( + layer + .get(counterpart) + .cloned() + .unwrap_or_else(|| *empty_hash_at(self.ty, layer_i)), + ); + idx >>= 1; + } + proof + } + + pub fn prove(&self, idx: usize) -> Option> { + let layers = self.layers.borrow(); + if layers.is_empty() || idx >= layers[0].len() { + return None; } - Ok(layers[0].len()) + Some(self.prove_any(idx)) } } -impl PartialEq for Merkle { +impl>>> PartialEq for CleanMerkle { fn eq(&self, other: &Self) -> bool { self.root() == other.root() } } -impl Eq for Merkle {} +impl>>> Eq for CleanMerkle {} -pub mod arc_mutex_sedre { +pub mod mutex_sedre { pub fn serialize( - data: &std::sync::Arc>, + data: &std::sync::Mutex, serializer: S, ) -> Result where @@ -423,14 +547,14 @@ pub mod arc_mutex_sedre { pub fn deserialize<'de, D, T>( deserializer: D, - ) -> Result>, D::Error> + ) -> Result, D::Error> where D: serde::Deserializer<'de>, T: serde::Deserialize<'de>, { - Ok(std::sync::Arc::new(std::sync::Mutex::new(T::deserialize( + Ok(std::sync::Mutex::new(T::deserialize( deserializer, - )?))) + )?)) } } @@ -472,7 +596,7 @@ fn resize_works() { ), ), ); - let merkle = Merkle::new(MerkleType::Value, hashes.clone()); + let mut merkle = DirtyMerkle::new(MerkleType::Value, hashes.clone()); assert_eq!(merkle.capacity(), 8); assert_eq!(merkle.root(), expected); @@ -516,9 +640,9 @@ fn resize_works() { #[test] fn correct_capacity() { - let merkle = Merkle::new(MerkleType::Value, vec![Bytes32::from([1; 32])]); + let merkle = DirtyMerkle::new(MerkleType::Value, vec![Bytes32::from([1; 32])]); assert_eq!(merkle.capacity(), 1); - let merkle = Merkle::new_advanced(MerkleType::Memory, vec![Bytes32::from([1; 32])], 11); + let merkle = DirtyMerkle::new_advanced(MerkleType::Memory, vec![Bytes32::from([1; 32])], 11); assert_eq!(merkle.capacity(), 1024); } @@ -546,18 +670,16 @@ fn emit_memory_zerohashes() { #[test] fn serialization_roundtrip() { - let merkle = Merkle::new_advanced(MerkleType::Value, vec![Bytes32::from([1; 32])], 4); - merkle.resize(4).expect("resize failed"); - merkle.set(3, Bytes32::from([2; 32])); + let mut merkle = CleanMerkle::new_advanced(MerkleType::Value, vec![Bytes32::from([1; 32])], 4); let serialized = bincode::serialize(&merkle).unwrap(); - let deserialized: Merkle = bincode::deserialize(&serialized).unwrap(); + let mut deserialized: CleanMerkle = bincode::deserialize(&serialized).unwrap(); assert_eq!(merkle, deserialized); } #[test] #[should_panic(expected = "index out of bounds")] fn set_with_bad_index_panics() { - let merkle = Merkle::new( + let mut merkle = DirtyMerkle::new( MerkleType::Value, vec![Bytes32::default(), Bytes32::default()], );