diff --git a/crates/engine/tree/src/tree/mod.rs b/crates/engine/tree/src/tree/mod.rs
index 234a96a47d07..f41330724ecf 100644
--- a/crates/engine/tree/src/tree/mod.rs
+++ b/crates/engine/tree/src/tree/mod.rs
@@ -31,7 +31,7 @@ use reth_engine_primitives::{
EngineValidator, ForkchoiceStateTracker, OnForkChoiceUpdated,
};
use reth_errors::{ConsensusError, ProviderResult};
-use reth_evm::execute::BlockExecutorProvider;
+use reth_evm::{execute::BlockExecutorProvider, system_calls::OnStateHook};
use reth_payload_builder::PayloadBuilderHandle;
use reth_payload_builder_primitives::PayloadBuilder;
use reth_payload_primitives::PayloadBuilderAttributes;
@@ -41,15 +41,24 @@ use reth_primitives::{
};
use reth_primitives_traits::Block;
use reth_provider::{
- providers::ConsistentDbView, BlockReader, DatabaseProviderFactory, ExecutionOutcome,
- HashedPostStateProvider, ProviderError, StateCommitmentProvider, StateProviderBox,
- StateProviderFactory, StateReader, StateRootProvider, TransactionVariant,
+ providers::ConsistentDbView, BlockReader, DBProvider, DatabaseProviderFactory,
+ ExecutionOutcome, HashedPostStateProvider, ProviderError, StateCommitmentProvider,
+ StateProviderBox, StateProviderFactory, StateReader, StateRootProvider, TransactionVariant,
};
use reth_revm::database::StateProviderDatabase;
use reth_stages_api::ControlFlow;
-use reth_trie::{updates::TrieUpdates, HashedPostState, TrieInput};
+use reth_trie::{
+ hashed_cursor::HashedPostStateCursorFactory,
+ prefix_set::TriePrefixSetsMut,
+ proof::ProofBlindedProviderFactory,
+ trie_cursor::InMemoryTrieCursorFactory,
+ updates::{TrieUpdates, TrieUpdatesSorted},
+ HashedPostState, HashedPostStateSorted, TrieInput,
+};
+use reth_trie_db::{DatabaseHashedCursorFactory, DatabaseTrieCursorFactory};
use reth_trie_parallel::root::{ParallelStateRoot, ParallelStateRootError};
use revm_primitives::EvmState;
+use root::{StateRootConfig, StateRootTask};
use std::{
cmp::Ordering,
collections::{btree_map, hash_map, BTreeMap, VecDeque},
@@ -463,6 +472,15 @@ pub enum TreeAction {
},
}
+/// Context used to keep alive the required values when returning a state hook
+/// from a scoped thread.
+struct StateHookContext
{
+ provider_ro: P,
+ nodes_sorted: TrieUpdatesSorted,
+ state_sorted: HashedPostStateSorted,
+ prefix_sets: Arc,
+}
+
/// The engine API tree handler implementation.
///
/// This type is responsible for processing engine API requests, maintaining the canonical state and
@@ -2224,63 +2242,143 @@ where
let exec_time = Instant::now();
- // TODO: create StateRootTask with the receiving end of a channel and
- // pass the sending end of the channel to the state hook.
- let noop_state_hook = |_state: &EvmState| {};
- let output = self.metrics.executor.execute_metered(
- executor,
- (&block, U256::MAX).into(),
- Box::new(noop_state_hook),
- )?;
+ let persistence_not_in_progress = !self.persistence_state.in_progress();
- trace!(target: "engine::tree", elapsed=?exec_time.elapsed(), ?block_number, "Executed block");
+ let state_root_result = match std::thread::scope(|scope| {
+ let (state_root_handle, state_hook) = if persistence_not_in_progress {
+ let consistent_view = ConsistentDbView::new_with_latest_tip(self.provider.clone())?;
- if let Err(err) = self.consensus.validate_block_post_execution(
- &block,
- PostExecutionInput::new(&output.receipts, &output.requests),
- ) {
- // call post-block hook
- self.invalid_block_hook.on_invalid_block(
- &parent_block,
- &block.seal_slow(),
- &output,
- None,
- );
- return Err(err.into())
- }
-
- let hashed_state = self.provider.hashed_post_state(&output.state);
-
- trace!(target: "engine::tree", block=?sealed_block.num_hash(), "Calculating block state root");
- let root_time = Instant::now();
- let mut state_root_result = None;
-
- // TODO: switch to calculate state root using `StateRootTask`.
-
- // We attempt to compute state root in parallel if we are currently not persisting anything
- // to database. This is safe, because the database state cannot change until we
- // finish parallel computation. It is important that nothing is being persisted as
- // we are computing in parallel, because we initialize a different database transaction
- // per thread and it might end up with a different view of the database.
- let persistence_in_progress = self.persistence_state.in_progress();
- if !persistence_in_progress {
- state_root_result = match self
- .compute_state_root_parallel(block.header().parent_hash(), &hashed_state)
- {
- Ok((state_root, trie_output)) => Some((state_root, trie_output)),
- Err(ParallelStateRootError::Provider(ProviderError::ConsistentView(error))) => {
- debug!(target: "engine", %error, "Parallel state root computation failed consistency check, falling back");
- None
- }
- Err(error) => return Err(InsertBlockErrorKindTwo::Other(Box::new(error))),
+ let input = Arc::new(
+ self.compute_trie_input(consistent_view.clone(), block.header().parent_hash())
+ .unwrap(),
+ );
+ let state_root_config = StateRootConfig {
+ consistent_view: consistent_view.clone(),
+ input: input.clone(),
+ };
+
+ let provider_ro = consistent_view.provider_ro()?;
+ let nodes_sorted = input.nodes.clone().into_sorted();
+ let state_sorted = input.state.clone().into_sorted();
+ let prefix_sets = Arc::new(input.prefix_sets.clone());
+
+ // context will hold the values that need to be kept alive
+ let context =
+ StateHookContext { provider_ro, nodes_sorted, state_sorted, prefix_sets };
+
+ // it is ok to leak here because we are in a scoped thread, the
+ // memory will be freed when the thread completes
+ let context = Box::leak(Box::new(context));
+
+ let blinded_provider_factory = ProofBlindedProviderFactory::new(
+ InMemoryTrieCursorFactory::new(
+ DatabaseTrieCursorFactory::new(context.provider_ro.tx_ref()),
+ &context.nodes_sorted,
+ ),
+ HashedPostStateCursorFactory::new(
+ DatabaseHashedCursorFactory::new(context.provider_ro.tx_ref()),
+ &context.state_sorted,
+ ),
+ context.prefix_sets.clone(),
+ );
+
+ let state_root_task =
+ StateRootTask::new(state_root_config, blinded_provider_factory);
+ let state_hook = state_root_task.state_hook();
+ (Some(state_root_task.spawn(scope)), Box::new(state_hook) as Box)
+ } else {
+ (None, Box::new(|_state: &EvmState| {}) as Box)
};
- }
- let (state_root, trie_output) = if let Some(result) = state_root_result {
- result
- } else {
- debug!(target: "engine::tree", block=?sealed_block.num_hash(), persistence_in_progress, "Failed to compute state root in parallel");
- state_provider.state_root_with_updates(hashed_state.clone())?
+ let output = self.metrics.executor.execute_metered(
+ executor,
+ (&block, U256::MAX).into(),
+ state_hook,
+ )?;
+
+ trace!(target: "engine::tree", elapsed=?exec_time.elapsed(), ?block_number, "Executed block");
+
+ if let Err(err) = self.consensus.validate_block_post_execution(
+ &block,
+ PostExecutionInput::new(&output.receipts, &output.requests),
+ ) {
+ // call post-block hook
+ self.invalid_block_hook.on_invalid_block(
+ &parent_block,
+ &block.clone().seal_slow(),
+ &output,
+ None,
+ );
+ return Err(err.into())
+ }
+
+ let hashed_state = self.provider.hashed_post_state(&output.state);
+
+ trace!(target: "engine::tree", block=?sealed_block.num_hash(), "Calculating block state root");
+ let root_time = Instant::now();
+
+ // We attempt to compute state root in parallel if we are currently not persisting
+ // anything to database. This is safe, because the database state cannot
+ // change until we finish parallel computation. It is important that nothing
+ // is being persisted as we are computing in parallel, because we initialize
+ // a different database transaction per thread and it might end up with a
+ // different view of the database.
+ if persistence_not_in_progress {
+ if let Some(state_root_handle) = state_root_handle {
+ match state_root_handle.wait_for_result() {
+ Ok((task_state_root, _task_trie_updates)) => {
+ info!(
+ target: "engine::tree",
+ block = ?sealed_block.num_hash(),
+ ?task_state_root,
+ "State root task finished"
+ );
+ }
+ Err(error) => {
+ info!(target: "engine::tree", ?error, "Failed to wait for state root task
+ result");
+ }
+ }
+ }
+
+ match self.compute_state_root_parallel(block.header().parent_hash(), &hashed_state)
+ {
+ Ok(result) => {
+ info!(
+ target: "engine::tree",
+ block = ?sealed_block.num_hash(),
+ regular_state_root = ?result.0,
+ "Regular root task finished"
+ );
+ Ok((Some((result.0, result.1)), hashed_state, output, root_time))
+ }
+ Err(ParallelStateRootError::Provider(ProviderError::ConsistentView(error))) => {
+ debug!(target: "engine", %error, "Parallel state root computation failed consistency check, falling back");
+ Ok((None, hashed_state, output, root_time))
+ }
+ Err(error) => Err(InsertBlockErrorKindTwo::Other(Box::new(error))),
+ }
+ } else {
+ Ok((None, hashed_state, output, root_time))
+ }
+ }) {
+ Ok((Some(res), hashed_state, output, root_time)) => {
+ (Some(res), hashed_state, output, root_time)
+ }
+ Ok((None, hashed_state, output, root_time)) => (None, hashed_state, output, root_time),
+ Err(e) => return Err(e),
+ };
+
+ let (state_root, trie_output, hashed_state, output, root_time) = match state_root_result {
+ (Some(res), hashed_state, output, root_time) => {
+ (res.0, res.1, hashed_state, output, root_time)
+ }
+ (None, hashed_state, output, root_time) => {
+ debug!(target: "engine::tree", block=?sealed_block.num_hash(), ?persistence_not_in_progress, "Failed to compute state root in parallel");
+ let (root, updates) =
+ state_provider.state_root_with_updates(hashed_state.clone())?;
+ (root, updates, hashed_state, output, root_time)
+ }
};
if state_root != block.header().state_root() {
@@ -2331,23 +2429,11 @@ where
Ok(InsertPayloadOk2::Inserted(BlockStatus2::Valid))
}
- /// Compute state root for the given hashed post state in parallel.
- ///
- /// # Returns
- ///
- /// Returns `Ok(_)` if computed successfully.
- /// Returns `Err(_)` if error was encountered during computation.
- /// `Err(ProviderError::ConsistentView(_))` can be safely ignored and fallback computation
- /// should be used instead.
- fn compute_state_root_parallel(
+ fn compute_trie_input(
&self,
+ consistent_view: ConsistentDbView,
parent_hash: B256,
- hashed_state: &HashedPostState,
- ) -> Result<(B256, TrieUpdates), ParallelStateRootError> {
- // TODO: when we switch to calculate state root using `StateRootTask` this
- // method can be still useful to calculate the required `TrieInput` to
- // create the task.
- let consistent_view = ConsistentDbView::new_with_latest_tip(self.provider.clone())?;
+ ) -> Result {
let mut input = TrieInput::default();
if let Some((historical, blocks)) = self.state.tree_state.blocks_by_hash(parent_hash) {
@@ -2367,6 +2453,25 @@ where
input.append(revert_state);
}
+ Ok(input)
+ }
+
+ /// Compute state root for the given hashed post state in parallel.
+ ///
+ /// # Returns
+ ///
+ /// Returns `Ok(_)` if computed successfully.
+ /// Returns `Err(_)` if error was encountered during computation.
+ /// `Err(ProviderError::ConsistentView(_))` can be safely ignored and fallback computation
+ /// should be used instead.
+ fn compute_state_root_parallel(
+ &self,
+ parent_hash: B256,
+ hashed_state: &HashedPostState,
+ ) -> Result<(B256, TrieUpdates), ParallelStateRootError> {
+ let consistent_view = ConsistentDbView::new_with_latest_tip(self.provider.clone())?;
+
+ let mut input = self.compute_trie_input(consistent_view.clone(), parent_hash)?;
// Extend with block we are validating root for.
input.append_ref(hashed_state);
@@ -2648,7 +2753,7 @@ mod tests {
use reth_primitives::{Block, BlockExt, EthPrimitives};
use reth_provider::test_utils::MockEthProvider;
use reth_rpc_types_compat::engine::{block_to_payload_v1, payload::block_to_payload_v3};
- use reth_trie::updates::TrieUpdates;
+ use reth_trie::{updates::TrieUpdates, HashedPostState};
use std::{
str::FromStr,
sync::mpsc::{channel, Sender},
diff --git a/crates/engine/tree/src/tree/root.rs b/crates/engine/tree/src/tree/root.rs
index f9d16e0fe400..230c3dccbdbc 100644
--- a/crates/engine/tree/src/tree/root.rs
+++ b/crates/engine/tree/src/tree/root.rs
@@ -46,7 +46,6 @@ pub struct StateRootHandle {
rx: mpsc::Receiver,
}
-#[allow(dead_code)]
impl StateRootHandle {
/// Creates a new handle from a receiver.
pub(crate) const fn new(rx: mpsc::Receiver) -> Self {