diff --git a/crates/trie/parallel/src/proof.rs b/crates/trie/parallel/src/proof.rs index cabd9c7e0669..491ecf6f3baf 100644 --- a/crates/trie/parallel/src/proof.rs +++ b/crates/trie/parallel/src/proof.rs @@ -5,11 +5,12 @@ use alloy_primitives::{ }; use alloy_rlp::{BufMut, Encodable}; use itertools::Itertools; +use rayon::iter::{ParallelBridge, ParallelIterator}; use reth_db::DatabaseError; use reth_execution_errors::StorageRootError; use reth_provider::{ providers::ConsistentDbView, BlockReader, DBProvider, DatabaseProviderFactory, ProviderError, - StateCommitmentProvider, + ProviderResult, StateCommitmentProvider, }; use reth_trie::{ hashed_cursor::{HashedCursorFactory, HashedPostStateCursorFactory}, @@ -25,7 +26,7 @@ use reth_trie::{ use reth_trie_common::proof::ProofRetainer; use reth_trie_db::{DatabaseHashedCursorFactory, DatabaseTrieCursorFactory}; use std::sync::Arc; -use tracing::{debug, error}; +use tracing::debug; #[cfg(feature = "metrics")] use crate::metrics::ParallelStateRootMetrics; @@ -112,36 +113,31 @@ where prefix_sets.account_prefix_set.iter().map(|nibbles| B256::from_slice(&nibbles.pack())), prefix_sets.storage_prefix_sets.clone(), ); + let storage_root_targets_len = storage_root_targets.len(); // Pre-calculate storage roots for accounts which were changed. - tracker.set_precomputed_storage_roots(storage_root_targets.len() as u64); + tracker.set_precomputed_storage_roots(storage_root_targets_len as u64); debug!(target: "trie::parallel_state_root", len = storage_root_targets.len(), "pre-generating storage proofs"); - let mut storage_proofs = - B256HashMap::with_capacity_and_hasher(storage_root_targets.len(), Default::default()); - for (hashed_address, prefix_set) in - storage_root_targets.into_iter().sorted_unstable_by_key(|(address, _)| *address) - { - let view = self.view.clone(); - let target_slots = targets.get(&hashed_address).cloned().unwrap_or_default(); - - let trie_nodes_sorted = self.nodes_sorted.clone(); - let hashed_state_sorted = self.state_sorted.clone(); - - let (tx, rx) = std::sync::mpsc::sync_channel(1); + let mut storage_proofs = storage_root_targets + .into_iter() + .sorted_unstable_by_key(|(address, _)| *address) + .par_bridge() + .map_init( + || (self.view.clone(), self.nodes_sorted.clone(), self.state_sorted.clone()), + |(view, trie_nodes_sorted, hashed_state_sorted), (hashed_address, prefix_set)| { + let target_slots = targets.get(&hashed_address).cloned().unwrap_or_default(); - rayon::spawn_fifo(move || { - let result = (|| -> Result<_, ParallelStateRootError> { let provider_ro = view.provider_ro()?; let trie_cursor_factory = InMemoryTrieCursorFactory::new( DatabaseTrieCursorFactory::new(provider_ro.tx_ref()), - &trie_nodes_sorted, + trie_nodes_sorted, ); let hashed_cursor_factory = HashedPostStateCursorFactory::new( DatabaseHashedCursorFactory::new(provider_ro.tx_ref()), - &hashed_state_sorted, + hashed_state_sorted, ); - StorageProof::new_hashed( + let result = StorageProof::new_hashed( trie_cursor_factory, hashed_cursor_factory, hashed_address, @@ -149,14 +145,36 @@ where .with_prefix_set_mut(PrefixSetMut::from(prefix_set.iter().cloned())) .with_branch_node_hash_masks(self.collect_branch_node_hash_masks) .storage_multiproof(target_slots) - .map_err(|e| ParallelStateRootError::Other(e.to_string())) - })(); - if let Err(err) = tx.send(result) { - error!(target: "trie::parallel", ?hashed_address, err_content = ?err.0, "Failed to send proof result"); - } - }); - storage_proofs.insert(hashed_address, rx); - } + .map_err(|e| ParallelStateRootError::Other(e.to_string())); + + ProviderResult::Ok((hashed_address, result)) + }, + ) + .try_fold(B256HashMap::default, |mut acc, result| { + let (hashed_address, result) = result?; + + acc.insert(hashed_address, result); + ProviderResult::Ok(acc) + }) + .reduce( + || { + Ok(B256HashMap::with_capacity_and_hasher( + storage_root_targets_len, + Default::default(), + )) + }, + |m1, m2| { + let mut m1 = m1?; + let m2 = m2?; + m1.extend(m2); + Ok(m1) + }, + ) + .map_err(|err| { + ParallelStateRootError::StorageRoot(StorageRootError::Database( + DatabaseError::Other(format!("{err:?}")), + )) + })?; let provider_ro = self.view.provider_ro()?; let trie_cursor_factory = InMemoryTrieCursorFactory::new( @@ -199,13 +217,7 @@ where } TrieElement::Leaf(hashed_address, account) => { let storage_multiproof = match storage_proofs.remove(&hashed_address) { - Some(rx) => rx.recv().map_err(|_| { - ParallelStateRootError::StorageRoot(StorageRootError::Database( - DatabaseError::Other(format!( - "channel closed for {hashed_address}" - )), - )) - })??, + Some(result) => result?, // Since we do not store all intermediate nodes in the database, there might // be a possibility of re-adding a non-modified leaf to the hash builder. None => {