Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(engine): wire StateRootTask in EngineApiTreeHandler #12639

Draft
wants to merge 12 commits into
base: main
Choose a base branch
from
253 changes: 179 additions & 74 deletions crates/engine/tree/src/tree/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ use reth_engine_primitives::{
EngineValidator, ForkchoiceStateTracker, OnForkChoiceUpdated,
};
use reth_errors::{ConsensusError, ProviderResult};
use reth_evm::execute::BlockExecutorProvider;
use reth_evm::{execute::BlockExecutorProvider, system_calls::OnStateHook};
use reth_payload_builder::PayloadBuilderHandle;
use reth_payload_builder_primitives::PayloadBuilder;
use reth_payload_primitives::PayloadBuilderAttributes;
Expand All @@ -41,15 +41,24 @@ use reth_primitives::{
};
use reth_primitives_traits::Block;
use reth_provider::{
providers::ConsistentDbView, BlockReader, DatabaseProviderFactory, ExecutionOutcome,
HashedPostStateProvider, ProviderError, StateCommitmentProvider, StateProviderBox,
StateProviderFactory, StateReader, StateRootProvider, TransactionVariant,
providers::ConsistentDbView, BlockReader, DBProvider, DatabaseProviderFactory,
ExecutionOutcome, HashedPostStateProvider, ProviderError, StateCommitmentProvider,
StateProviderBox, StateProviderFactory, StateReader, StateRootProvider, TransactionVariant,
};
use reth_revm::database::StateProviderDatabase;
use reth_stages_api::ControlFlow;
use reth_trie::{updates::TrieUpdates, HashedPostState, TrieInput};
use reth_trie::{
hashed_cursor::HashedPostStateCursorFactory,
prefix_set::TriePrefixSetsMut,
proof::ProofBlindedProviderFactory,
trie_cursor::InMemoryTrieCursorFactory,
updates::{TrieUpdates, TrieUpdatesSorted},
HashedPostState, HashedPostStateSorted, TrieInput,
};
use reth_trie_db::{DatabaseHashedCursorFactory, DatabaseTrieCursorFactory};
use reth_trie_parallel::root::{ParallelStateRoot, ParallelStateRootError};
use revm_primitives::EvmState;
use root::{StateRootConfig, StateRootTask};
use std::{
cmp::Ordering,
collections::{btree_map, hash_map, BTreeMap, VecDeque},
Expand Down Expand Up @@ -463,6 +472,15 @@ pub enum TreeAction {
},
}

/// Context used to keep alive the required values when returning a state hook
/// from a scoped thread.
struct StateHookContext<P> {
provider_ro: P,
nodes_sorted: TrieUpdatesSorted,
state_sorted: HashedPostStateSorted,
prefix_sets: Arc<TriePrefixSetsMut>,
}

/// The engine API tree handler implementation.
///
/// This type is responsible for processing engine API requests, maintaining the canonical state and
Expand Down Expand Up @@ -2224,63 +2242,143 @@ where

let exec_time = Instant::now();

// TODO: create StateRootTask with the receiving end of a channel and
// pass the sending end of the channel to the state hook.
let noop_state_hook = |_state: &EvmState| {};
let output = self.metrics.executor.execute_metered(
executor,
(&block, U256::MAX).into(),
Box::new(noop_state_hook),
)?;
let persistence_not_in_progress = !self.persistence_state.in_progress();

trace!(target: "engine::tree", elapsed=?exec_time.elapsed(), ?block_number, "Executed block");
let state_root_result = match std::thread::scope(|scope| {
let (state_root_handle, state_hook) = if persistence_not_in_progress {
let consistent_view = ConsistentDbView::new_with_latest_tip(self.provider.clone())?;

if let Err(err) = self.consensus.validate_block_post_execution(
&block,
PostExecutionInput::new(&output.receipts, &output.requests),
) {
// call post-block hook
self.invalid_block_hook.on_invalid_block(
&parent_block,
&block.seal_slow(),
&output,
None,
);
return Err(err.into())
}

let hashed_state = self.provider.hashed_post_state(&output.state);

trace!(target: "engine::tree", block=?sealed_block.num_hash(), "Calculating block state root");
let root_time = Instant::now();
let mut state_root_result = None;

// TODO: switch to calculate state root using `StateRootTask`.

// We attempt to compute state root in parallel if we are currently not persisting anything
// to database. This is safe, because the database state cannot change until we
// finish parallel computation. It is important that nothing is being persisted as
// we are computing in parallel, because we initialize a different database transaction
// per thread and it might end up with a different view of the database.
let persistence_in_progress = self.persistence_state.in_progress();
if !persistence_in_progress {
state_root_result = match self
.compute_state_root_parallel(block.header().parent_hash(), &hashed_state)
{
Ok((state_root, trie_output)) => Some((state_root, trie_output)),
Err(ParallelStateRootError::Provider(ProviderError::ConsistentView(error))) => {
debug!(target: "engine", %error, "Parallel state root computation failed consistency check, falling back");
None
}
Err(error) => return Err(InsertBlockErrorKindTwo::Other(Box::new(error))),
let input = Arc::new(
self.compute_trie_input(consistent_view.clone(), block.header().parent_hash())
.unwrap(),
);
let state_root_config = StateRootConfig {
consistent_view: consistent_view.clone(),
input: input.clone(),
};

let provider_ro = consistent_view.provider_ro()?;
let nodes_sorted = input.nodes.clone().into_sorted();
let state_sorted = input.state.clone().into_sorted();
let prefix_sets = Arc::new(input.prefix_sets.clone());

// context will hold the values that need to be kept alive
let context =
StateHookContext { provider_ro, nodes_sorted, state_sorted, prefix_sets };

// it is ok to leak here because we are in a scoped thread, the
// memory will be freed when the thread completes
let context = Box::leak(Box::new(context));

let blinded_provider_factory = ProofBlindedProviderFactory::new(
InMemoryTrieCursorFactory::new(
DatabaseTrieCursorFactory::new(context.provider_ro.tx_ref()),
&context.nodes_sorted,
),
HashedPostStateCursorFactory::new(
DatabaseHashedCursorFactory::new(context.provider_ro.tx_ref()),
&context.state_sorted,
),
context.prefix_sets.clone(),
);

let state_root_task =
StateRootTask::new(state_root_config, blinded_provider_factory);
let state_hook = state_root_task.state_hook();
(Some(state_root_task.spawn(scope)), Box::new(state_hook) as Box<dyn OnStateHook>)
} else {
(None, Box::new(|_state: &EvmState| {}) as Box<dyn OnStateHook>)
};
}

let (state_root, trie_output) = if let Some(result) = state_root_result {
result
} else {
debug!(target: "engine::tree", block=?sealed_block.num_hash(), persistence_in_progress, "Failed to compute state root in parallel");
state_provider.state_root_with_updates(hashed_state.clone())?
let output = self.metrics.executor.execute_metered(
executor,
(&block, U256::MAX).into(),
state_hook,
)?;

trace!(target: "engine::tree", elapsed=?exec_time.elapsed(), ?block_number, "Executed block");

if let Err(err) = self.consensus.validate_block_post_execution(
&block,
PostExecutionInput::new(&output.receipts, &output.requests),
) {
// call post-block hook
self.invalid_block_hook.on_invalid_block(
&parent_block,
&block.clone().seal_slow(),
&output,
None,
);
return Err(err.into())
}

let hashed_state = self.provider.hashed_post_state(&output.state);

trace!(target: "engine::tree", block=?sealed_block.num_hash(), "Calculating block state root");
let root_time = Instant::now();

// We attempt to compute state root in parallel if we are currently not persisting
// anything to database. This is safe, because the database state cannot
// change until we finish parallel computation. It is important that nothing
// is being persisted as we are computing in parallel, because we initialize
// a different database transaction per thread and it might end up with a
// different view of the database.
if persistence_not_in_progress {
if let Some(state_root_handle) = state_root_handle {
match state_root_handle.wait_for_result() {
Ok((task_state_root, _task_trie_updates)) => {
info!(
target: "engine::tree",
block = ?sealed_block.num_hash(),
?task_state_root,
"State root task finished"
);
}
Err(error) => {
info!(target: "engine::tree", ?error, "Failed to wait for state root task
result");
}
}
}

match self.compute_state_root_parallel(block.header().parent_hash(), &hashed_state)
{
Ok(result) => {
info!(
target: "engine::tree",
block = ?sealed_block.num_hash(),
regular_state_root = ?result.0,
"Regular root task finished"
);
Ok((Some((result.0, result.1)), hashed_state, output, root_time))
}
Err(ParallelStateRootError::Provider(ProviderError::ConsistentView(error))) => {
debug!(target: "engine", %error, "Parallel state root computation failed consistency check, falling back");
Ok((None, hashed_state, output, root_time))
}
Err(error) => Err(InsertBlockErrorKindTwo::Other(Box::new(error))),
}
} else {
Ok((None, hashed_state, output, root_time))
}
}) {
Ok((Some(res), hashed_state, output, root_time)) => {
(Some(res), hashed_state, output, root_time)
}
Ok((None, hashed_state, output, root_time)) => (None, hashed_state, output, root_time),
Err(e) => return Err(e),
};

let (state_root, trie_output, hashed_state, output, root_time) = match state_root_result {
(Some(res), hashed_state, output, root_time) => {
(res.0, res.1, hashed_state, output, root_time)
}
(None, hashed_state, output, root_time) => {
debug!(target: "engine::tree", block=?sealed_block.num_hash(), ?persistence_not_in_progress, "Failed to compute state root in parallel");
let (root, updates) =
state_provider.state_root_with_updates(hashed_state.clone())?;
(root, updates, hashed_state, output, root_time)
}
};

if state_root != block.header().state_root() {
Expand Down Expand Up @@ -2331,23 +2429,11 @@ where
Ok(InsertPayloadOk2::Inserted(BlockStatus2::Valid))
}

/// Compute state root for the given hashed post state in parallel.
///
/// # Returns
///
/// Returns `Ok(_)` if computed successfully.
/// Returns `Err(_)` if error was encountered during computation.
/// `Err(ProviderError::ConsistentView(_))` can be safely ignored and fallback computation
/// should be used instead.
fn compute_state_root_parallel(
fn compute_trie_input(
&self,
consistent_view: ConsistentDbView<P>,
parent_hash: B256,
hashed_state: &HashedPostState,
) -> Result<(B256, TrieUpdates), ParallelStateRootError> {
// TODO: when we switch to calculate state root using `StateRootTask` this
// method can be still useful to calculate the required `TrieInput` to
// create the task.
let consistent_view = ConsistentDbView::new_with_latest_tip(self.provider.clone())?;
) -> Result<TrieInput, ParallelStateRootError> {
let mut input = TrieInput::default();

if let Some((historical, blocks)) = self.state.tree_state.blocks_by_hash(parent_hash) {
Expand All @@ -2367,6 +2453,25 @@ where
input.append(revert_state);
}

Ok(input)
}

/// Compute state root for the given hashed post state in parallel.
///
/// # Returns
///
/// Returns `Ok(_)` if computed successfully.
/// Returns `Err(_)` if error was encountered during computation.
/// `Err(ProviderError::ConsistentView(_))` can be safely ignored and fallback computation
/// should be used instead.
fn compute_state_root_parallel(
&self,
parent_hash: B256,
hashed_state: &HashedPostState,
) -> Result<(B256, TrieUpdates), ParallelStateRootError> {
let consistent_view = ConsistentDbView::new_with_latest_tip(self.provider.clone())?;

let mut input = self.compute_trie_input(consistent_view.clone(), parent_hash)?;
// Extend with block we are validating root for.
input.append_ref(hashed_state);

Expand Down Expand Up @@ -2648,7 +2753,7 @@ mod tests {
use reth_primitives::{Block, BlockExt, EthPrimitives};
use reth_provider::test_utils::MockEthProvider;
use reth_rpc_types_compat::engine::{block_to_payload_v1, payload::block_to_payload_v3};
use reth_trie::updates::TrieUpdates;
use reth_trie::{updates::TrieUpdates, HashedPostState};
use std::{
str::FromStr,
sync::mpsc::{channel, Sender},
Expand Down
1 change: 0 additions & 1 deletion crates/engine/tree/src/tree/root.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ pub struct StateRootHandle {
rx: mpsc::Receiver<StateRootResult>,
}

#[allow(dead_code)]
impl StateRootHandle {
/// Creates a new handle from a receiver.
pub(crate) const fn new(rx: mpsc::Receiver<StateRootResult>) -> Self {
Expand Down
Loading