Skip to content

Commit

Permalink
Use sync locks for stake table state (#2404)
Browse files Browse the repository at this point in the history
* Use sync locks for stake table state

* fixup
  • Loading branch information
sveitser authored Dec 16, 2024
1 parent 6482b18 commit 5c32108
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 26 deletions.
13 changes: 7 additions & 6 deletions types/src/v0/impls/l1.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@ use lru::LruCache;
use serde::{de::DeserializeOwned, Serialize};
use std::{
cmp::{min, Ordering},
collections::BTreeMap,
fmt::Debug,
num::NonZeroUsize,
sync::Arc,
sync::{self, Arc},
time::Duration,
};
use tokio::{
Expand Down Expand Up @@ -343,6 +344,7 @@ impl L1Client {
provider: Arc::new(provider),
events_max_block_range: opt.l1_events_max_block_range,
state: Arc::new(Mutex::new(L1State::new(opt.l1_blocks_cache_size))),
stake_table_state: sync::Arc::new(sync::RwLock::new(BTreeMap::new())),
sender,
receiver: receiver.deactivate(),
update_task: Default::default(),
Expand Down Expand Up @@ -382,7 +384,7 @@ impl L1Client {
epoch: EpochNumber,
) {
let retry_delay = self.retry_delay;
let state = self.state.clone();
let state = self.stake_table_state.clone();

let span = tracing::warn_span!("L1 client memberships update");

Expand All @@ -397,8 +399,8 @@ impl L1Client {
);
}
Ok(stake_tables) => {
let mut state = state.lock().await;
let _ = state.stake_tables.insert(epoch, stake_tables);
let mut state = state.write().unwrap();
let _ = state.insert(epoch, stake_tables);
}
}

Expand Down Expand Up @@ -563,7 +565,7 @@ impl L1Client {
}

pub fn stake_table(&self, epoch: &EpochNumber) -> StakeTables {
if let Some(stake_tables) = self.state.blocking_lock().stake_tables.get(epoch) {
if let Some(stake_tables) = self.stake_table_state.read().unwrap().get(epoch) {
stake_tables.clone()
} else {
// It would be nice if we could update l1_cache of stake
Expand Down Expand Up @@ -852,7 +854,6 @@ impl L1State {
Self {
snapshot: Default::default(),
finalized: LruCache::new(cache_size),
stake_tables: Default::default(),
}
}

Expand Down
48 changes: 30 additions & 18 deletions types/src/v0/impls/stake_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ use super::{
L1Client, NodeState, PubKey, SeqTypes,
};

use async_lock::RwLock;
use contract_bindings::permissioned_stake_table::StakersUpdatedFilter;
use ethers::types::U256;
use hotshot::types::SignatureKey as _;
Expand All @@ -19,6 +18,7 @@ use hotshot_types::{
PeerConfig,
};
use itertools::Itertools;
use std::sync::RwLock;
use std::{
cmp::max,
collections::{BTreeMap, BTreeSet, HashMap},
Expand Down Expand Up @@ -174,7 +174,7 @@ impl EpochCommittees {
.filter(|entry| entry.stake() > U256::zero())
.collect();

let mut state = self.state.write_blocking();
let mut state = self.state.write().unwrap();

let committee = Committee {
eligible_leaders,
Expand Down Expand Up @@ -313,7 +313,7 @@ impl Membership<SeqTypes> for EpochCommittees {

/// Get the stake table for the current view
fn stake_table(&self, epoch: Epoch) -> Vec<StakeTableEntry<PubKey>> {
if let Some(st) = self.state.read_blocking().get(&epoch) {
if let Some(st) = self.state.read().unwrap().get(&epoch) {
st.indexed_stake_table.clone().into_values().collect()
} else {
self.update_stake_table(epoch, self.l1_client.stake_table(&epoch))
Expand All @@ -324,7 +324,7 @@ impl Membership<SeqTypes> for EpochCommittees {
}
/// Get the stake table for the current view
fn da_stake_table(&self, epoch: Epoch) -> Vec<StakeTableEntry<PubKey>> {
if let Some(sc) = self.state.read_blocking().get(&epoch) {
if let Some(sc) = self.state.read().unwrap().get(&epoch) {
sc.indexed_da_members.clone().into_values().collect()
} else {
self.update_stake_table(epoch, self.l1_client.stake_table(&epoch))
Expand All @@ -340,7 +340,7 @@ impl Membership<SeqTypes> for EpochCommittees {
_view_number: <SeqTypes as NodeType>::View,
epoch: Epoch,
) -> BTreeSet<PubKey> {
if let Some(sc) = self.state.read_blocking().get(&epoch) {
if let Some(sc) = self.state.read().unwrap().get(&epoch) {
sc.indexed_stake_table.clone().into_keys().collect()
} else {
self.update_stake_table(epoch, self.l1_client.stake_table(&epoch))
Expand All @@ -356,7 +356,7 @@ impl Membership<SeqTypes> for EpochCommittees {
_view_number: <SeqTypes as NodeType>::View,
epoch: Epoch,
) -> BTreeSet<PubKey> {
if let Some(sc) = self.state.read_blocking().get(&epoch) {
if let Some(sc) = self.state.read().unwrap().get(&epoch) {
sc.indexed_da_members.clone().into_keys().collect()
} else {
self.update_stake_table(epoch, self.l1_client.stake_table(&epoch))
Expand All @@ -373,7 +373,8 @@ impl Membership<SeqTypes> for EpochCommittees {
epoch: Epoch,
) -> BTreeSet<PubKey> {
self.state
.read_blocking()
.read()
.unwrap()
.get(&epoch)
.unwrap()
.eligible_leaders
Expand All @@ -386,7 +387,8 @@ impl Membership<SeqTypes> for EpochCommittees {
fn stake(&self, pub_key: &PubKey, epoch: Epoch) -> Option<StakeTableEntry<PubKey>> {
// Only return the stake if it is above zero
self.state
.read_blocking()
.read()
.unwrap()
.get(&epoch)
.and_then(|h| h.indexed_stake_table.get(pub_key).cloned())
}
Expand All @@ -395,15 +397,17 @@ impl Membership<SeqTypes> for EpochCommittees {
fn da_stake(&self, pub_key: &PubKey, epoch: Epoch) -> Option<StakeTableEntry<PubKey>> {
// Only return the stake if it is above zero
self.state
.read_blocking()
.read()
.unwrap()
.get(&epoch)
.and_then(|h| h.indexed_da_members.get(pub_key).cloned())
}

/// Check if a node has stake in the committee
fn has_stake(&self, pub_key: &PubKey, epoch: Epoch) -> bool {
self.state
.read_blocking()
.read()
.unwrap()
.get(&epoch)
.and_then(|h| h.indexed_stake_table.get(pub_key))
.map_or(false, |x| x.stake() > U256::zero())
Expand All @@ -412,7 +416,8 @@ impl Membership<SeqTypes> for EpochCommittees {
/// Check if a node has stake in the committee
fn has_da_stake(&self, pub_key: &PubKey, epoch: Epoch) -> bool {
self.state
.read_blocking()
.read()
.unwrap()
.get(&epoch)
.and_then(|h| h.indexed_da_members.get(pub_key))
.map_or(false, |x| x.stake() > U256::zero())
Expand All @@ -426,7 +431,8 @@ impl Membership<SeqTypes> for EpochCommittees {
) -> Result<PubKey, Self::Error> {
let leaders = self
.state
.read_blocking()
.read()
.unwrap()
.get(&epoch)
.ok_or(LeaderLookupError)?
.eligible_leaders
Expand All @@ -440,7 +446,8 @@ impl Membership<SeqTypes> for EpochCommittees {
/// Get the total number of nodes in the committee
fn total_nodes(&self, epoch: Epoch) -> usize {
self.state
.read_blocking()
.read()
.unwrap()
.get(&epoch)
.map(|sc| sc.indexed_stake_table.len())
.unwrap_or_default()
Expand All @@ -449,7 +456,8 @@ impl Membership<SeqTypes> for EpochCommittees {
/// Get the total number of DA nodes in the committee
fn da_total_nodes(&self, epoch: Epoch) -> usize {
self.state
.read_blocking()
.read()
.unwrap()
.get(&epoch)
.map(|sc: &Committee| sc.indexed_da_members.len())
.unwrap_or_default()
Expand All @@ -459,7 +467,8 @@ impl Membership<SeqTypes> for EpochCommittees {
fn success_threshold(&self, epoch: Epoch) -> NonZeroU64 {
let quorum = self
.state
.read_blocking()
.read()
.unwrap()
.get(&epoch)
.unwrap()
.indexed_stake_table
Expand All @@ -471,7 +480,8 @@ impl Membership<SeqTypes> for EpochCommittees {
fn da_success_threshold(&self, epoch: Epoch) -> NonZeroU64 {
let da = self
.state
.read_blocking()
.read()
.unwrap()
.get(&epoch)
.unwrap()
.indexed_da_members
Expand All @@ -483,7 +493,8 @@ impl Membership<SeqTypes> for EpochCommittees {
fn failure_threshold(&self, epoch: Epoch) -> NonZeroU64 {
let quorum = self
.state
.read_blocking()
.read()
.unwrap()
.get(&epoch)
.unwrap()
.indexed_stake_table
Expand All @@ -496,7 +507,8 @@ impl Membership<SeqTypes> for EpochCommittees {
fn upgrade_threshold(&self, epoch: Epoch) -> NonZeroU64 {
let quorum = self
.state
.read_blocking()
.read()
.unwrap()
.get(&epoch)
.unwrap()
.indexed_stake_table
Expand Down
11 changes: 9 additions & 2 deletions types/src/v0/v0_1/l1.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@ use hotshot_types::{
};
use lru::LruCache;
use serde::{Deserialize, Serialize};
use std::{collections::BTreeMap, num::NonZeroUsize, sync::Arc, time::Duration};
use std::{
collections::BTreeMap,
num::NonZeroUsize,
sync::{self, Arc},
time::Duration,
};
use tokio::{
sync::{Mutex, RwLock},
task::JoinHandle,
Expand Down Expand Up @@ -111,6 +116,9 @@ pub struct L1Client {
pub(crate) events_max_block_range: u64,
/// Shared state updated by an asynchronous task which polls the L1.
pub(crate) state: Arc<Mutex<L1State>>,
/// TODO: We need to be able to take out sync locks on this part of the
/// state. until the trait definition of Membership is updated in HotShot.
pub(crate) stake_table_state: sync::Arc<sync::RwLock<BTreeMap<EpochNumber, StakeTables>>>,
/// Channel used by the async update task to send events to clients.
pub(crate) sender: Sender<L1Event>,
/// Receiver for events from the async update task.
Expand Down Expand Up @@ -140,7 +148,6 @@ pub(crate) enum RpcClient {
pub(crate) struct L1State {
pub(crate) snapshot: L1Snapshot,
pub(crate) finalized: LruCache<u64, L1BlockInfo>,
pub(crate) stake_tables: BTreeMap<EpochNumber, StakeTables>,
}

#[derive(Clone, Debug)]
Expand Down

0 comments on commit 5c32108

Please sign in to comment.