Skip to content

Commit

Permalink
Revert "box dyn"
Browse files Browse the repository at this point in the history
This reverts commit f0487b9.
  • Loading branch information
shekhirin committed Dec 11, 2024
1 parent f0487b9 commit 93ad324
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 79 deletions.
1 change: 0 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

112 changes: 62 additions & 50 deletions crates/engine/tree/src/tree/root.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@ use reth_trie_sparse::{
};
use revm_primitives::{keccak256, EvmState, B256};
use std::{
self,
collections::BTreeMap,
ops::Deref,
sync::{
mpsc::{self, channel, Receiver, Sender},
Arc,
},
thread::{self},
time::{Duration, Instant},
};
use tracing::{debug, error, trace};
Expand Down Expand Up @@ -70,7 +70,7 @@ pub struct StateRootConfig<Factory> {
/// Messages used internally by the state root task
#[derive(Debug)]
#[allow(dead_code)]
pub enum StateRootMessage {
pub enum StateRootMessage<BPF: BlindedProviderFactory> {
/// New state update from transaction execution
StateUpdate(EvmState),
/// Proof calculation completed for a specific state update
Expand All @@ -85,7 +85,7 @@ pub enum StateRootMessage {
/// State root calculation completed
RootCalculated {
/// The updated sparse trie
trie: BoxSparseStateTrie,
trie: Box<SparseStateTrie<BPF>>,
/// Time taken to calculate the root
elapsed: Duration,
},
Expand Down Expand Up @@ -161,24 +161,24 @@ impl ProofSequencer {

/// A wrapper for the sender that signals completion when dropped
#[allow(dead_code)]
pub(crate) struct StateHookSender(Sender<StateRootMessage>);
pub(crate) struct StateHookSender<BPF: BlindedProviderFactory>(Sender<StateRootMessage<BPF>>);

#[allow(dead_code)]
impl StateHookSender {
pub(crate) const fn new(inner: Sender<StateRootMessage>) -> Self {
impl<BPF: BlindedProviderFactory> StateHookSender<BPF> {
pub(crate) const fn new(inner: Sender<StateRootMessage<BPF>>) -> Self {
Self(inner)
}
}

impl Deref for StateHookSender {
type Target = Sender<StateRootMessage>;
impl<BPF: BlindedProviderFactory> Deref for StateHookSender<BPF> {
type Target = Sender<StateRootMessage<BPF>>;

fn deref(&self) -> &Self::Target {
&self.0
}
}

impl Drop for StateHookSender {
impl<BPF: BlindedProviderFactory> Drop for StateHookSender<BPF> {
fn drop(&mut self) {
// Send completion signal when the sender is dropped
let _ = self.0.send(StateRootMessage::FinishedStateUpdates);
Expand Down Expand Up @@ -217,15 +217,6 @@ fn evm_state_to_hashed_post_state(update: EvmState) -> HashedPostState {
hashed_state
}

type BoxBlindedProviderFactory = Box<
dyn BlindedProviderFactory<
AccountNodeProvider = Box<dyn BlindedProvider<Error = SparseTrieError> + Send + Sync>,
StorageNodeProvider = Box<dyn BlindedProvider<Error = SparseTrieError> + Send + Sync>,
> + Send
+ Sync,
>;
type BoxSparseStateTrie = Box<SparseStateTrie<BoxBlindedProviderFactory>>;

/// Standalone task that receives a transaction state stream and updates relevant
/// data structures to calculate state root.
///
Expand All @@ -235,37 +226,40 @@ type BoxSparseStateTrie = Box<SparseStateTrie<BoxBlindedProviderFactory>>;
/// to the tree.
/// Then it updates relevant leaves according to the result of the transaction.
#[derive(Debug)]
pub struct StateRootTask<Factory> {
pub struct StateRootTask<Factory, BPF: BlindedProviderFactory> {
/// Task configuration.
config: StateRootConfig<Factory>,
/// Receiver for state root related messages.
rx: Receiver<StateRootMessage>,
rx: Receiver<StateRootMessage<BPF>>,
/// Sender for state root related messages.
tx: Sender<StateRootMessage>,
tx: Sender<StateRootMessage<BPF>>,
/// Proof targets that have been already fetched.
fetched_proof_targets: MultiProofTargets,
/// Proof sequencing handler.
proof_sequencer: ProofSequencer,
/// The sparse trie used for the state root calculation. If [`None`], then update is in
/// progress.
sparse_trie: Option<BoxSparseStateTrie>,
sparse_trie: Option<Box<SparseStateTrie<BPF>>>,
}

#[allow(dead_code)]
impl<Factory> StateRootTask<Factory>
impl<'env, Factory, ABP, SBP, BPF> StateRootTask<Factory, BPF>
where
Factory: DatabaseProviderFactory<Provider: BlockReader>
+ StateCommitmentProvider
+ Clone
+ Send
+ Sync
+ 'static,
ABP: BlindedProvider<Error = SparseTrieError> + Send + Sync + 'env,
SBP: BlindedProvider<Error = SparseTrieError> + Send + Sync + 'env,
BPF: BlindedProviderFactory<AccountNodeProvider = ABP, StorageNodeProvider = SBP>
+ Send
+ Sync
+ 'env,
{
/// Creates a new state root task with the unified message channel
pub fn new(
config: StateRootConfig<Factory>,
blinded_provider: BoxBlindedProviderFactory,
) -> Self {
pub fn new(config: StateRootConfig<Factory>, blinded_provider: BPF) -> Self {
let (tx, rx) = channel();

Self {
Expand All @@ -279,13 +273,14 @@ where
}

/// Spawns the state root task and returns a handle to await its result.
pub fn spawn(self) -> StateRootHandle {
pub fn spawn<'scope>(self, scope: &'scope thread::Scope<'scope, 'env>) -> StateRootHandle {
let (tx, rx) = mpsc::sync_channel(1);
std::thread::Builder::new()
.name("State Root Task".to_string())
.spawn(move || {
.spawn_scoped(scope, move || {
debug!(target: "engine::tree", "Starting state root task");
let result = self.run();

let result = rayon::scope(|scope| self.run(scope));
let _ = tx.send(result);
})
.expect("failed to spawn state root thread");
Expand All @@ -308,12 +303,13 @@ where
///
/// Returns proof targets derived from the state update.
fn on_state_update(
scope: &rayon::Scope<'env>,
view: ConsistentDbView<Factory>,
input: Arc<TrieInput>,
update: EvmState,
fetched_proof_targets: &mut MultiProofTargets,
proof_sequence_number: u64,
state_root_message_sender: Sender<StateRootMessage>,
state_root_message_sender: Sender<StateRootMessage<BPF>>,
) {
let hashed_state_update = evm_state_to_hashed_post_state(update);

Expand All @@ -323,7 +319,7 @@ where
}

// Dispatch proof gathering for this state update
rayon::spawn(move || {
scope.spawn(move |_| {
let provider = match view.provider_ro() {
Ok(provider) => provider,
Err(error) => {
Expand Down Expand Up @@ -376,7 +372,12 @@ where
}

/// Spawns root calculation with the current state and proofs.
fn spawn_root_calculation(&mut self, state: HashedPostState, multiproof: MultiProof) {
fn spawn_root_calculation(
&mut self,
scope: &rayon::Scope<'env>,
state: HashedPostState,
multiproof: MultiProof,
) {
let Some(trie) = self.sparse_trie.take() else { return };

trace!(
Expand All @@ -390,7 +391,7 @@ where
let targets = get_proof_targets(&state, &HashMap::default());

let tx = self.tx.clone();
rayon::spawn(move || {
scope.spawn(move |_| {
let result = update_sparse_trie(trie, multiproof, targets, state);
match result {
Ok((trie, elapsed)) => {
Expand All @@ -408,7 +409,7 @@ where
});
}

fn run(mut self) -> StateRootResult {
fn run(mut self, scope: &rayon::Scope<'env>) -> StateRootResult {
let mut current_state_update = HashedPostState::default();
let mut current_multiproof = MultiProof::default();
let mut updates_received = 0;
Expand All @@ -428,6 +429,7 @@ where
"Received new state update"
);
Self::on_state_update(
scope,
self.config.consistent_view.clone(),
self.config.input.clone(),
update,
Expand Down Expand Up @@ -457,7 +459,11 @@ where
current_multiproof.extend(combined_proof);
current_state_update.extend(combined_state_update);
} else {
self.spawn_root_calculation(combined_state_update, combined_proof);
self.spawn_root_calculation(
scope,
combined_state_update,
combined_proof,
);
}
}
}
Expand Down Expand Up @@ -495,6 +501,7 @@ where
"Spawning subsequent root calculation"
);
self.spawn_root_calculation(
scope,
std::mem::take(&mut current_state_update),
std::mem::take(&mut current_multiproof),
);
Expand Down Expand Up @@ -569,12 +576,16 @@ fn get_proof_targets(

/// Updates the sparse trie with the given proofs and state, and returns the updated trie and the
/// time it took.
fn update_sparse_trie(
mut trie: BoxSparseStateTrie,
fn update_sparse_trie<
ABP: BlindedProvider<Error = SparseTrieError> + Send + Sync,
SBP: BlindedProvider<Error = SparseTrieError> + Send + Sync,
BPF: BlindedProviderFactory<AccountNodeProvider = ABP, StorageNodeProvider = SBP> + Send + Sync,
>(
mut trie: Box<SparseStateTrie<BPF>>,
multiproof: MultiProof,
targets: MultiProofTargets,
state: HashedPostState,
) -> SparseStateTrieResult<(BoxSparseStateTrie, Duration)> {
) -> SparseStateTrieResult<(Box<SparseStateTrie<BPF>>, Duration)> {
trace!(target: "engine::root::sparse", "Updating sparse trie");
let started_at = Instant::now();

Expand Down Expand Up @@ -775,18 +786,19 @@ mod tests {
&state_sorted,
),
Arc::new(config.input.prefix_sets.clone()),
)
.boxed();
let task = StateRootTask::new(config, blinded_provider_factory);
let mut state_hook = task.state_hook();
let handle = task.spawn();

for update in state_updates {
state_hook.on_state(&update);
}
drop(state_hook);
);
let (root_from_task, _) = std::thread::scope(|std_scope| {
let task = StateRootTask::new(config, blinded_provider_factory);
let mut state_hook = task.state_hook();
let handle = task.spawn(std_scope);

let (root_from_task, _) = handle.wait_for_result().expect("task failed");
for update in state_updates {
state_hook.on_state(&update);
}
drop(state_hook);

handle.wait_for_result().expect("task failed")
});
let root_from_base = state_root(accumulated_state);

assert_eq!(
Expand Down
1 change: 0 additions & 1 deletion crates/trie/sparse/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ alloy-primitives.workspace = true
alloy-rlp.workspace = true

# misc
auto_impl.workspace = true
smallvec = { workspace = true, features = ["const_new"] }
thiserror.workspace = true

Expand Down
2 changes: 0 additions & 2 deletions crates/trie/sparse/src/blinded.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ use reth_execution_errors::SparseTrieError;
use reth_trie_common::Nibbles;

/// Factory for instantiating blinded node providers.
#[auto_impl::auto_impl(Box)]
pub trait BlindedProviderFactory {
/// Type capable of fetching blinded account nodes.
type AccountNodeProvider: BlindedProvider;
Expand All @@ -20,7 +19,6 @@ pub trait BlindedProviderFactory {
}

/// Trie node provider for retrieving blinded nodes.
#[auto_impl::auto_impl(Box)]
pub trait BlindedProvider {
/// The error type for the provider.
type Error: Into<SparseTrieError>;
Expand Down
25 changes: 0 additions & 25 deletions crates/trie/trie/src/proof/blinded.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,6 @@ impl<T, H> ProofBlindedProviderFactory<T, H> {
) -> Self {
Self { trie_cursor_factory, hashed_cursor_factory, prefix_sets }
}

pub const fn boxed(self) -> BoxProofBlindedProviderFactory<T, H> {
BoxProofBlindedProviderFactory(self)
}
}

impl<T, H> BlindedProviderFactory for ProofBlindedProviderFactory<T, H>
Expand Down Expand Up @@ -61,27 +57,6 @@ where
}
}

/// Boxed version of [`ProofBlindedProviderFactory`].
#[derive(Debug)]
pub struct BoxProofBlindedProviderFactory<T, H>(ProofBlindedProviderFactory<T, H>);

impl<T, H> BlindedProviderFactory for BoxProofBlindedProviderFactory<T, H>
where
T: TrieCursorFactory + Clone + Send + Sync,
H: HashedCursorFactory + Clone + Send + Sync,
{
type AccountNodeProvider = Box<ProofBlindedAccountProvider<T, H>>;
type StorageNodeProvider = Box<ProofBlindedStorageProvider<T, H>>;

fn account_node_provider(&self) -> Self::AccountNodeProvider {
Box::new(self.0.account_node_provider())
}

fn storage_node_provider(&self, account: B256) -> Self::StorageNodeProvider {
Box::new(self.0.storage_node_provider(account))
}
}

/// Blinded provider for retrieving account trie nodes by path.
#[derive(Debug)]
pub struct ProofBlindedAccountProvider<T, H> {
Expand Down

0 comments on commit 93ad324

Please sign in to comment.