diff --git a/Cargo.lock b/Cargo.lock index a44088f901a9..bdfbd5680f4f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -9634,7 +9634,6 @@ dependencies = [ "alloy-rlp", "arbitrary", "assert_matches", - "auto_impl", "criterion", "itertools 0.13.0", "pretty_assertions", diff --git a/crates/engine/tree/src/tree/root.rs b/crates/engine/tree/src/tree/root.rs index 345f296458e4..7254cc882a7e 100644 --- a/crates/engine/tree/src/tree/root.rs +++ b/crates/engine/tree/src/tree/root.rs @@ -20,13 +20,13 @@ use reth_trie_sparse::{ }; use revm_primitives::{keccak256, EvmState, B256}; use std::{ - self, collections::BTreeMap, ops::Deref, sync::{ mpsc::{self, channel, Receiver, Sender}, Arc, }, + thread::{self}, time::{Duration, Instant}, }; use tracing::{debug, error, trace}; @@ -70,7 +70,7 @@ pub struct StateRootConfig { /// Messages used internally by the state root task #[derive(Debug)] #[allow(dead_code)] -pub enum StateRootMessage { +pub enum StateRootMessage { /// New state update from transaction execution StateUpdate(EvmState), /// Proof calculation completed for a specific state update @@ -85,7 +85,7 @@ pub enum StateRootMessage { /// State root calculation completed RootCalculated { /// The updated sparse trie - trie: BoxSparseStateTrie, + trie: Box>, /// Time taken to calculate the root elapsed: Duration, }, @@ -161,24 +161,24 @@ impl ProofSequencer { /// A wrapper for the sender that signals completion when dropped #[allow(dead_code)] -pub(crate) struct StateHookSender(Sender); +pub(crate) struct StateHookSender(Sender>); #[allow(dead_code)] -impl StateHookSender { - pub(crate) const fn new(inner: Sender) -> Self { +impl StateHookSender { + pub(crate) const fn new(inner: Sender>) -> Self { Self(inner) } } -impl Deref for StateHookSender { - type Target = Sender; +impl Deref for StateHookSender { + type Target = Sender>; fn deref(&self) -> &Self::Target { &self.0 } } -impl Drop for StateHookSender { +impl Drop for StateHookSender { fn drop(&mut self) { // Send completion signal when the sender is dropped let _ = self.0.send(StateRootMessage::FinishedStateUpdates); @@ -217,15 +217,6 @@ fn evm_state_to_hashed_post_state(update: EvmState) -> HashedPostState { hashed_state } -type BoxBlindedProviderFactory = Box< - dyn BlindedProviderFactory< - AccountNodeProvider = Box + Send + Sync>, - StorageNodeProvider = Box + Send + Sync>, - > + Send - + Sync, ->; -type BoxSparseStateTrie = Box>; - /// Standalone task that receives a transaction state stream and updates relevant /// data structures to calculate state root. /// @@ -235,24 +226,24 @@ type BoxSparseStateTrie = Box>; /// to the tree. /// Then it updates relevant leaves according to the result of the transaction. #[derive(Debug)] -pub struct StateRootTask { +pub struct StateRootTask { /// Task configuration. config: StateRootConfig, /// Receiver for state root related messages. - rx: Receiver, + rx: Receiver>, /// Sender for state root related messages. - tx: Sender, + tx: Sender>, /// Proof targets that have been already fetched. fetched_proof_targets: MultiProofTargets, /// Proof sequencing handler. proof_sequencer: ProofSequencer, /// The sparse trie used for the state root calculation. If [`None`], then update is in /// progress. - sparse_trie: Option, + sparse_trie: Option>>, } #[allow(dead_code)] -impl StateRootTask +impl<'env, Factory, ABP, SBP, BPF> StateRootTask where Factory: DatabaseProviderFactory + StateCommitmentProvider @@ -260,12 +251,15 @@ where + Send + Sync + 'static, + ABP: BlindedProvider + Send + Sync + 'env, + SBP: BlindedProvider + Send + Sync + 'env, + BPF: BlindedProviderFactory + + Send + + Sync + + 'env, { /// Creates a new state root task with the unified message channel - pub fn new( - config: StateRootConfig, - blinded_provider: BoxBlindedProviderFactory, - ) -> Self { + pub fn new(config: StateRootConfig, blinded_provider: BPF) -> Self { let (tx, rx) = channel(); Self { @@ -279,13 +273,14 @@ where } /// Spawns the state root task and returns a handle to await its result. - pub fn spawn(self) -> StateRootHandle { + pub fn spawn<'scope>(self, scope: &'scope thread::Scope<'scope, 'env>) -> StateRootHandle { let (tx, rx) = mpsc::sync_channel(1); std::thread::Builder::new() .name("State Root Task".to_string()) - .spawn(move || { + .spawn_scoped(scope, move || { debug!(target: "engine::tree", "Starting state root task"); - let result = self.run(); + + let result = rayon::scope(|scope| self.run(scope)); let _ = tx.send(result); }) .expect("failed to spawn state root thread"); @@ -308,12 +303,13 @@ where /// /// Returns proof targets derived from the state update. fn on_state_update( + scope: &rayon::Scope<'env>, view: ConsistentDbView, input: Arc, update: EvmState, fetched_proof_targets: &mut MultiProofTargets, proof_sequence_number: u64, - state_root_message_sender: Sender, + state_root_message_sender: Sender>, ) { let hashed_state_update = evm_state_to_hashed_post_state(update); @@ -323,7 +319,7 @@ where } // Dispatch proof gathering for this state update - rayon::spawn(move || { + scope.spawn(move |_| { let provider = match view.provider_ro() { Ok(provider) => provider, Err(error) => { @@ -376,7 +372,12 @@ where } /// Spawns root calculation with the current state and proofs. - fn spawn_root_calculation(&mut self, state: HashedPostState, multiproof: MultiProof) { + fn spawn_root_calculation( + &mut self, + scope: &rayon::Scope<'env>, + state: HashedPostState, + multiproof: MultiProof, + ) { let Some(trie) = self.sparse_trie.take() else { return }; trace!( @@ -390,7 +391,7 @@ where let targets = get_proof_targets(&state, &HashMap::default()); let tx = self.tx.clone(); - rayon::spawn(move || { + scope.spawn(move |_| { let result = update_sparse_trie(trie, multiproof, targets, state); match result { Ok((trie, elapsed)) => { @@ -408,7 +409,7 @@ where }); } - fn run(mut self) -> StateRootResult { + fn run(mut self, scope: &rayon::Scope<'env>) -> StateRootResult { let mut current_state_update = HashedPostState::default(); let mut current_multiproof = MultiProof::default(); let mut updates_received = 0; @@ -428,6 +429,7 @@ where "Received new state update" ); Self::on_state_update( + scope, self.config.consistent_view.clone(), self.config.input.clone(), update, @@ -457,7 +459,11 @@ where current_multiproof.extend(combined_proof); current_state_update.extend(combined_state_update); } else { - self.spawn_root_calculation(combined_state_update, combined_proof); + self.spawn_root_calculation( + scope, + combined_state_update, + combined_proof, + ); } } } @@ -495,6 +501,7 @@ where "Spawning subsequent root calculation" ); self.spawn_root_calculation( + scope, std::mem::take(&mut current_state_update), std::mem::take(&mut current_multiproof), ); @@ -569,12 +576,16 @@ fn get_proof_targets( /// Updates the sparse trie with the given proofs and state, and returns the updated trie and the /// time it took. -fn update_sparse_trie( - mut trie: BoxSparseStateTrie, +fn update_sparse_trie< + ABP: BlindedProvider + Send + Sync, + SBP: BlindedProvider + Send + Sync, + BPF: BlindedProviderFactory + Send + Sync, +>( + mut trie: Box>, multiproof: MultiProof, targets: MultiProofTargets, state: HashedPostState, -) -> SparseStateTrieResult<(BoxSparseStateTrie, Duration)> { +) -> SparseStateTrieResult<(Box>, Duration)> { trace!(target: "engine::root::sparse", "Updating sparse trie"); let started_at = Instant::now(); @@ -775,18 +786,19 @@ mod tests { &state_sorted, ), Arc::new(config.input.prefix_sets.clone()), - ) - .boxed(); - let task = StateRootTask::new(config, blinded_provider_factory); - let mut state_hook = task.state_hook(); - let handle = task.spawn(); - - for update in state_updates { - state_hook.on_state(&update); - } - drop(state_hook); + ); + let (root_from_task, _) = std::thread::scope(|std_scope| { + let task = StateRootTask::new(config, blinded_provider_factory); + let mut state_hook = task.state_hook(); + let handle = task.spawn(std_scope); - let (root_from_task, _) = handle.wait_for_result().expect("task failed"); + for update in state_updates { + state_hook.on_state(&update); + } + drop(state_hook); + + handle.wait_for_result().expect("task failed") + }); let root_from_base = state_root(accumulated_state); assert_eq!( diff --git a/crates/trie/sparse/Cargo.toml b/crates/trie/sparse/Cargo.toml index 863bb9e63482..205451ef72a8 100644 --- a/crates/trie/sparse/Cargo.toml +++ b/crates/trie/sparse/Cargo.toml @@ -23,7 +23,6 @@ alloy-primitives.workspace = true alloy-rlp.workspace = true # misc -auto_impl.workspace = true smallvec = { workspace = true, features = ["const_new"] } thiserror.workspace = true diff --git a/crates/trie/sparse/src/blinded.rs b/crates/trie/sparse/src/blinded.rs index 097d7fdd8bb7..22471cf99ffd 100644 --- a/crates/trie/sparse/src/blinded.rs +++ b/crates/trie/sparse/src/blinded.rs @@ -5,7 +5,6 @@ use reth_execution_errors::SparseTrieError; use reth_trie_common::Nibbles; /// Factory for instantiating blinded node providers. -#[auto_impl::auto_impl(Box)] pub trait BlindedProviderFactory { /// Type capable of fetching blinded account nodes. type AccountNodeProvider: BlindedProvider; @@ -20,7 +19,6 @@ pub trait BlindedProviderFactory { } /// Trie node provider for retrieving blinded nodes. -#[auto_impl::auto_impl(Box)] pub trait BlindedProvider { /// The error type for the provider. type Error: Into; diff --git a/crates/trie/trie/src/proof/blinded.rs b/crates/trie/trie/src/proof/blinded.rs index 8535b232698b..55f8bdfbc48c 100644 --- a/crates/trie/trie/src/proof/blinded.rs +++ b/crates/trie/trie/src/proof/blinded.rs @@ -29,10 +29,6 @@ impl ProofBlindedProviderFactory { ) -> Self { Self { trie_cursor_factory, hashed_cursor_factory, prefix_sets } } - - pub const fn boxed(self) -> BoxProofBlindedProviderFactory { - BoxProofBlindedProviderFactory(self) - } } impl BlindedProviderFactory for ProofBlindedProviderFactory @@ -61,27 +57,6 @@ where } } -/// Boxed version of [`ProofBlindedProviderFactory`]. -#[derive(Debug)] -pub struct BoxProofBlindedProviderFactory(ProofBlindedProviderFactory); - -impl BlindedProviderFactory for BoxProofBlindedProviderFactory -where - T: TrieCursorFactory + Clone + Send + Sync, - H: HashedCursorFactory + Clone + Send + Sync, -{ - type AccountNodeProvider = Box>; - type StorageNodeProvider = Box>; - - fn account_node_provider(&self) -> Self::AccountNodeProvider { - Box::new(self.0.account_node_provider()) - } - - fn storage_node_provider(&self, account: B256) -> Self::StorageNodeProvider { - Box::new(self.0.storage_node_provider(account)) - } -} - /// Blinded provider for retrieving account trie nodes by path. #[derive(Debug)] pub struct ProofBlindedAccountProvider {