Skip to content

Commit

Permalink
feat(engine): proof fetching on state update for StateRootTask
Browse files Browse the repository at this point in the history
  • Loading branch information
fgimenez committed Nov 11, 2024
1 parent d0f48d5 commit 3f60a20
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 33 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crates/engine/tree/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ metrics.workspace = true
reth-metrics = { workspace = true, features = ["common"] }

# misc
rayon.workspace = true
tracing.workspace = true

# optional deps for test-utils
Expand Down
104 changes: 77 additions & 27 deletions crates/engine/tree/src/tree/root.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
//! State root task related functionality.
use reth_provider::providers::ConsistentDbView;
use reth_trie::{updates::TrieUpdates, TrieInput};
use reth_trie_parallel::root::ParallelStateRootError;
use revm_primitives::{EvmState, B256};
use std::sync::{
mpsc::{self, Receiver, RecvError},
Arc,
use reth_provider::{providers::ConsistentDbView, BlockReader, DatabaseProviderFactory};
use reth_trie::{updates::TrieUpdates, HashedPostState, HashedStorage, MultiProof, TrieInput};
use reth_trie_parallel::{proof::ParallelProof, root::ParallelStateRootError};
use revm_primitives::{keccak256, EvmState, HashMap, HashSet, B256};
use std::{
collections::VecDeque,
sync::{
mpsc::{self, Receiver, RecvError},
Arc,
},
};
use tracing::debug;

Expand Down Expand Up @@ -74,19 +77,20 @@ pub(crate) struct StateRootTask<Factory> {
state_stream: StdReceiverStream,
/// Task configuration.
config: StateRootConfig<Factory>,
/// Current state.
state: HashedPostState,
/// Channels to retrieve proof calculation results from.
pending_proofs: VecDeque<Receiver<Result<MultiProof, ParallelStateRootError>>>,
}

#[allow(dead_code)]
impl<Factory> StateRootTask<Factory>
where
Factory: Send + 'static,
Factory: DatabaseProviderFactory<Provider: BlockReader> + Clone + Send + Sync + 'static,
{
/// Creates a new `StateRootTask`.
pub(crate) const fn new(
config: StateRootConfig<Factory>,
state_stream: StdReceiverStream,
) -> Self {
Self { config, state_stream }
pub(crate) fn new(config: StateRootConfig<Factory>, state_stream: StdReceiverStream) -> Self {
Self { config, state_stream, state: Default::default(), pending_proofs: Default::default() }
}

/// Spawns the state root task and returns a handle to await its result.
Expand All @@ -106,23 +110,69 @@ where

/// Handles state updates.
fn on_state_update(
_view: &reth_provider::providers::ConsistentDbView<impl Send + 'static>,
_input: &std::sync::Arc<reth_trie::TrieInput>,
_state: EvmState,
view: ConsistentDbView<Factory>,
input: Arc<TrieInput>,
update: EvmState,
state: &mut HashedPostState,
pending_proofs: &mut VecDeque<Receiver<Result<MultiProof, ParallelStateRootError>>>,
) {
// Default implementation of state update handling
// TODO: calculate hashed state update and dispatch proof gathering for it.
let mut hashed_state_update = HashedPostState::default();
for (address, account) in update {
if account.is_touched() {
let hashed_address = keccak256(address);

let destroyed = account.is_selfdestructed();
hashed_state_update.accounts.insert(
hashed_address,
if destroyed || account.is_empty() { None } else { Some(account.info.into()) },
);

if destroyed || !account.storage.is_empty() {
let storage = account.storage.into_iter().filter_map(|(slot, value)| {
(!destroyed && value.is_changed())
.then(|| (keccak256(B256::from(slot)), value.present_value))
});
hashed_state_update
.storages
.insert(hashed_address, HashedStorage::from_iter(destroyed, storage));
}
}
}

// Dispatch proof gathering for this state update
let targets = hashed_state_update
.accounts
.keys()
.filter(|hashed_address| {
!state.accounts.contains_key(*hashed_address) &&
!state.storages.contains_key(*hashed_address)
})
.map(|hashed_address| (*hashed_address, HashSet::default()))
.chain(hashed_state_update.storages.iter().map(|(hashed_address, storage)| {
(*hashed_address, storage.storage.keys().copied().collect())
}))
.collect::<HashMap<_, _>>();

let (tx, rx) = mpsc::sync_channel(1);
rayon::spawn(move || {
let result = ParallelProof::new(view, input).multiproof(targets);
let _ = tx.send(result);
});

pending_proofs.push_back(rx);

state.extend(hashed_state_update);
}
}

#[allow(dead_code)]
impl<Factory> StateRootTask<Factory>
where
Factory: Send + 'static,
{
fn run(self) -> StateRootResult {
while let Ok(state) = self.state_stream.recv() {
Self::on_state_update(&self.config.consistent_view, &self.config.input, state);
fn run(mut self) -> StateRootResult {
while let Ok(update) = self.state_stream.recv() {
Self::on_state_update(
self.config.consistent_view.clone(),
self.config.input.clone(),
update,
&mut self.state,
&mut self.pending_proofs,
);
}

// TODO:
Expand Down
15 changes: 9 additions & 6 deletions crates/trie/parallel/src/proof.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
use crate::{root::ParallelStateRootError, stats::ParallelTrieTracker, StorageRootTargets};
use alloy_primitives::{map::HashSet, B256};
use alloy_primitives::{
map::{HashMap, HashSet},
B256,
};
use alloy_rlp::{BufMut, Encodable};
use itertools::Itertools;
use reth_db::DatabaseError;
Expand All @@ -18,7 +21,7 @@ use reth_trie::{
};
use reth_trie_common::proof::ProofRetainer;
use reth_trie_db::{DatabaseHashedCursorFactory, DatabaseTrieCursorFactory};
use std::{collections::HashMap, sync::Arc};
use std::sync::Arc;
use tracing::debug;

#[cfg(feature = "metrics")]
Expand All @@ -30,15 +33,15 @@ pub struct ParallelProof<Factory> {
/// Consistent view of the database.
view: ConsistentDbView<Factory>,
/// Trie input.
input: TrieInput,
input: Arc<TrieInput>,
/// Parallel state root metrics.
#[cfg(feature = "metrics")]
metrics: ParallelStateRootMetrics,
}

impl<Factory> ParallelProof<Factory> {
/// Create new state proof generator.
pub fn new(view: ConsistentDbView<Factory>, input: TrieInput) -> Self {
pub fn new(view: ConsistentDbView<Factory>, input: Arc<TrieInput>) -> Self {
Self {
view,
input,
Expand All @@ -59,8 +62,8 @@ where
) -> Result<MultiProof, ParallelStateRootError> {
let mut tracker = ParallelTrieTracker::default();

let trie_nodes_sorted = Arc::new(self.input.nodes.into_sorted());
let hashed_state_sorted = Arc::new(self.input.state.into_sorted());
let trie_nodes_sorted = self.input.nodes.clone().into_sorted();
let hashed_state_sorted = self.input.state.clone().into_sorted();

// Extend prefix sets with targets
let mut prefix_sets = self.input.prefix_sets.clone();
Expand Down

0 comments on commit 3f60a20

Please sign in to comment.