diff --git a/Cargo.lock b/Cargo.lock index b8d68cbee6d3d..1148866a60d5c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7203,6 +7203,7 @@ dependencies = [ "crossbeam-channel", "futures", "metrics", + "rayon", "reth-beacon-consensus", "reth-blockchain-tree", "reth-blockchain-tree-api", diff --git a/crates/engine/tree/Cargo.toml b/crates/engine/tree/Cargo.toml index 0429d46c5c194..2aca962e5d1ed 100644 --- a/crates/engine/tree/Cargo.toml +++ b/crates/engine/tree/Cargo.toml @@ -51,6 +51,7 @@ metrics.workspace = true reth-metrics = { workspace = true, features = ["common"] } # misc +rayon.workspace = true tracing.workspace = true # optional deps for test-utils diff --git a/crates/engine/tree/src/tree/root.rs b/crates/engine/tree/src/tree/root.rs index 45cf5a7803106..75d4dc62d2f68 100644 --- a/crates/engine/tree/src/tree/root.rs +++ b/crates/engine/tree/src/tree/root.rs @@ -1,12 +1,15 @@ //! State root task related functionality. -use reth_provider::providers::ConsistentDbView; -use reth_trie::{updates::TrieUpdates, TrieInput}; -use reth_trie_parallel::root::ParallelStateRootError; -use revm_primitives::{EvmState, B256}; -use std::sync::{ - mpsc::{self, Receiver, RecvError}, - Arc, +use reth_provider::{providers::ConsistentDbView, BlockReader, DatabaseProviderFactory}; +use reth_trie::{updates::TrieUpdates, HashedPostState, HashedStorage, MultiProof, TrieInput}; +use reth_trie_parallel::{proof::ParallelProof, root::ParallelStateRootError}; +use revm_primitives::{keccak256, EvmState, HashMap, HashSet, B256}; +use std::{ + collections::VecDeque, + sync::{ + mpsc::{self, Receiver, RecvError}, + Arc, + }, }; use tracing::debug; @@ -74,19 +77,20 @@ pub(crate) struct StateRootTask { state_stream: StdReceiverStream, /// Task configuration. config: StateRootConfig, + /// Current state. + state: HashedPostState, + /// Channels to retrieve proof calculation results from. + pending_proofs: VecDeque>>, } #[allow(dead_code)] impl StateRootTask where - Factory: Send + 'static, + Factory: DatabaseProviderFactory + Clone + Send + Sync + 'static, { /// Creates a new `StateRootTask`. - pub(crate) const fn new( - config: StateRootConfig, - state_stream: StdReceiverStream, - ) -> Self { - Self { config, state_stream } + pub(crate) fn new(config: StateRootConfig, state_stream: StdReceiverStream) -> Self { + Self { config, state_stream, state: Default::default(), pending_proofs: Default::default() } } /// Spawns the state root task and returns a handle to await its result. @@ -106,23 +110,69 @@ where /// Handles state updates. fn on_state_update( - _view: &reth_provider::providers::ConsistentDbView, - _input: &std::sync::Arc, - _state: EvmState, + view: ConsistentDbView, + input: Arc, + update: EvmState, + state: &mut HashedPostState, + pending_proofs: &mut VecDeque>>, ) { - // Default implementation of state update handling - // TODO: calculate hashed state update and dispatch proof gathering for it. + let mut hashed_state_update = HashedPostState::default(); + for (address, account) in update { + if account.is_touched() { + let hashed_address = keccak256(address); + + let destroyed = account.is_selfdestructed(); + hashed_state_update.accounts.insert( + hashed_address, + if destroyed || account.is_empty() { None } else { Some(account.info.into()) }, + ); + + if destroyed || !account.storage.is_empty() { + let storage = account.storage.into_iter().filter_map(|(slot, value)| { + (!destroyed && value.is_changed()) + .then(|| (keccak256(B256::from(slot)), value.present_value)) + }); + hashed_state_update + .storages + .insert(hashed_address, HashedStorage::from_iter(destroyed, storage)); + } + } + } + + // Dispatch proof gathering for this state update + let targets = hashed_state_update + .accounts + .keys() + .filter(|hashed_address| { + !state.accounts.contains_key(*hashed_address) && + !state.storages.contains_key(*hashed_address) + }) + .map(|hashed_address| (*hashed_address, HashSet::default())) + .chain(hashed_state_update.storages.iter().map(|(hashed_address, storage)| { + (*hashed_address, storage.storage.keys().copied().collect()) + })) + .collect::>(); + + let (tx, rx) = mpsc::sync_channel(1); + rayon::spawn(move || { + let result = ParallelProof::new(view, input).multiproof(targets); + let _ = tx.send(result); + }); + + pending_proofs.push_back(rx); + + state.extend(hashed_state_update); } -} -#[allow(dead_code)] -impl StateRootTask -where - Factory: Send + 'static, -{ - fn run(self) -> StateRootResult { - while let Ok(state) = self.state_stream.recv() { - Self::on_state_update(&self.config.consistent_view, &self.config.input, state); + fn run(mut self) -> StateRootResult { + while let Ok(update) = self.state_stream.recv() { + Self::on_state_update( + self.config.consistent_view.clone(), + self.config.input.clone(), + update, + &mut self.state, + &mut self.pending_proofs, + ); } // TODO: diff --git a/crates/trie/parallel/src/proof.rs b/crates/trie/parallel/src/proof.rs index 4cb99b50d0c85..52538e3e30ffe 100644 --- a/crates/trie/parallel/src/proof.rs +++ b/crates/trie/parallel/src/proof.rs @@ -1,5 +1,8 @@ use crate::{root::ParallelStateRootError, stats::ParallelTrieTracker, StorageRootTargets}; -use alloy_primitives::{map::HashSet, B256}; +use alloy_primitives::{ + map::{HashMap, HashSet}, + B256, +}; use alloy_rlp::{BufMut, Encodable}; use itertools::Itertools; use reth_db::DatabaseError; @@ -18,7 +21,7 @@ use reth_trie::{ }; use reth_trie_common::proof::ProofRetainer; use reth_trie_db::{DatabaseHashedCursorFactory, DatabaseTrieCursorFactory}; -use std::{collections::HashMap, sync::Arc}; +use std::sync::Arc; use tracing::debug; #[cfg(feature = "metrics")] @@ -30,7 +33,7 @@ pub struct ParallelProof { /// Consistent view of the database. view: ConsistentDbView, /// Trie input. - input: TrieInput, + input: Arc, /// Parallel state root metrics. #[cfg(feature = "metrics")] metrics: ParallelStateRootMetrics, @@ -38,7 +41,7 @@ pub struct ParallelProof { impl ParallelProof { /// Create new state proof generator. - pub fn new(view: ConsistentDbView, input: TrieInput) -> Self { + pub fn new(view: ConsistentDbView, input: Arc) -> Self { Self { view, input, @@ -59,8 +62,8 @@ where ) -> Result { let mut tracker = ParallelTrieTracker::default(); - let trie_nodes_sorted = Arc::new(self.input.nodes.into_sorted()); - let hashed_state_sorted = Arc::new(self.input.state.into_sorted()); + let trie_nodes_sorted = self.input.nodes.clone().into_sorted(); + let hashed_state_sorted = self.input.state.clone().into_sorted(); // Extend prefix sets with targets let mut prefix_sets = self.input.prefix_sets.clone();