Skip to content

Commit

Permalink
refactor: get rid of query pool locks
Browse files Browse the repository at this point in the history
  • Loading branch information
carver committed Sep 21, 2024
1 parent a45e272 commit bf3d383
Showing 1 changed file with 42 additions and 57 deletions.
99 changes: 42 additions & 57 deletions portalnet/src/overlay/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,9 @@ where
/// A map of active outgoing requests.
active_outgoing_requests: Arc<RwLock<HashMap<OverlayRequestId, ActiveOutgoingRequest>>>,
/// A query pool that manages find node queries.
find_node_query_pool: Arc<RwLock<QueryPool<NodeId, FindNodeQuery<NodeId>, TContentKey>>>,
find_node_query_pool: QueryPool<NodeId, FindNodeQuery<NodeId>, TContentKey>,
/// A query pool that manages find content queries.
find_content_query_pool: Arc<RwLock<QueryPool<NodeId, FindContentQuery<NodeId>, TContentKey>>>,
find_content_query_pool: QueryPool<NodeId, FindContentQuery<NodeId>, TContentKey>,
/// Timeout after which a peer in an ongoing query is marked unresponsive.
query_peer_timeout: Duration,
/// Number of peers to request data from in parallel for a single query.
Expand Down Expand Up @@ -220,8 +220,8 @@ where
command_rx,
command_tx: internal_command_tx,
active_outgoing_requests: Arc::new(RwLock::new(HashMap::new())),
find_node_query_pool: Arc::new(RwLock::new(QueryPool::new(query_timeout))),
find_content_query_pool: Arc::new(RwLock::new(QueryPool::new(query_timeout))),
find_node_query_pool: QueryPool::new(query_timeout),
find_content_query_pool: QueryPool::new(query_timeout),
query_peer_timeout,
query_parallelism,
query_num_results,
Expand Down Expand Up @@ -400,11 +400,11 @@ where
self.peers_to_ping.insert(node_id);
}
}
query_event = OverlayService::<TContentKey, TMetric, TValidator, TStore>::query_event_poll(self.find_node_query_pool.clone()) => {
query_event = OverlayService::<TContentKey, TMetric, TValidator, TStore>::query_event_poll(&mut self.find_node_query_pool) => {
self.handle_find_nodes_query_event(query_event);
}
// Handle query events for queries in the find content query pool.
query_event = OverlayService::<TContentKey, TMetric, TValidator, TStore>::query_event_poll(self.find_content_query_pool.clone()) => {
query_event = OverlayService::<TContentKey, TMetric, TValidator, TStore>::query_event_poll(&mut self.find_content_query_pool) => {
self.handle_find_content_query_event(query_event);
}
_ = OverlayService::<TContentKey, TMetric, TValidator, TStore>::bucket_maintenance_poll(self.protocol, &self.kbuckets) => {}
Expand Down Expand Up @@ -489,9 +489,9 @@ where
/// This happens when a query needs to send a request to a node, when a query has completed,
/// or when a query has timed out.
async fn query_event_poll<TQuery: Query<NodeId>>(
queries: Arc<RwLock<QueryPool<NodeId, TQuery, TContentKey>>>,
queries: &mut QueryPool<NodeId, TQuery, TContentKey>,
) -> QueryEvent<TQuery, TContentKey> {
future::poll_fn(move |_cx| match queries.write().poll() {
future::poll_fn(move |_cx| match queries.poll() {
QueryPoolState::Validating(query_info, query, sending_peer) => {
// This only happens during a FindContent query.
let content_key = match &query_info.query_type {
Expand All @@ -509,7 +509,8 @@ where
let query_result = query.pending_validation_result(sending_peer);
Poll::Ready(QueryEvent::Validating(content_key, query_result))
}
// Note that some query pools skip over Validating, straight to Finished (like the FindNode query pool)
// Note that some query pools skip over Validating, straight to Finished (like the
// FindNode query pool)
QueryPoolState::Finished(query_id, query_info, query) => {
Poll::Ready(QueryEvent::Finished(query_id, query_info, query))
}
Expand All @@ -532,8 +533,7 @@ where
}

QueryPoolState::Waiting(None) | QueryPoolState::Idle => Poll::Pending,
})
.await
}).await
}

/// Handles a `QueryEvent` from a poll on the find nodes query pool.
Expand Down Expand Up @@ -561,7 +561,7 @@ where
query.id = %query_id,
"Cannot query peer with unknown ENR",
);
if let Some((_, query)) = self.find_node_query_pool.write().get_mut(query_id) {
if let Some((_, query)) = self.find_node_query_pool.get_mut(query_id) {
query.on_failure(&node_id);
}
}
Expand Down Expand Up @@ -645,7 +645,7 @@ where
query.id = %query_id,
"Cannot query peer with unknown ENR"
);
if let Some((_, query)) = self.find_node_query_pool.write().get_mut(query_id) {
if let Some((_, query)) = self.find_node_query_pool.get_mut(query_id) {
query.on_failure(&node_id);
}
}
Expand Down Expand Up @@ -2120,7 +2120,7 @@ where
// Check whether this request was sent on behalf of a query.
// If so, advance the query with the returned data.
let local_node_id = self.local_enr().node_id();
if let Some((query_info, query)) = self.find_node_query_pool.write().get_mut(query_id) {
if let Some((query_info, query)) = self.find_node_query_pool.get_mut(query_id) {
for enr_ref in enrs.iter() {
if !query_info
.untrusted_enrs
Expand All @@ -2145,7 +2145,7 @@ where
enrs: Vec<Enr>,
) {
let local_node_id = self.local_enr().node_id();
if let Some((query_info, query)) = self.find_content_query_pool.write().get_mut(*query_id) {
if let Some((query_info, query)) = self.find_content_query_pool.get_mut(*query_id) {
// If an ENR is not present in the query's untrusted ENRs, then add the ENR.
// Ignore the local node's ENR.
let mut new_enrs: Vec<&Enr> = vec![];
Expand Down Expand Up @@ -2185,7 +2185,7 @@ where
source: Enr,
utp: u16,
) {
if let Some((query_info, query)) = self.find_content_query_pool.write().get_mut(*query_id) {
if let Some((query_info, query)) = self.find_content_query_pool.get_mut(*query_id) {
if let Some(trace) = &mut query_info.trace {
trace.node_responded_with_content(&source);
}
Expand All @@ -2204,7 +2204,7 @@ where
source: Enr,
content: Vec<u8>,
) {
let mut pool = self.find_content_query_pool.write();
let pool = &mut self.find_content_query_pool;
if let Some((query_info, query)) = pool.get_mut(*query_id) {
if let Some(trace) = &mut query_info.trace {
trace.node_responded_with_content(&source);
Expand Down Expand Up @@ -2502,7 +2502,6 @@ where
FindNodeQuery::with_config(query_config, query_info.key(), known_closest_peers);
Some(
self.find_node_query_pool
.write()
.add_query(query_info, find_nodes_query),
)
}
Expand Down Expand Up @@ -2566,11 +2565,7 @@ where
};

let query = FindContentQuery::with_config(query_config, target_key, closest_nodes);
Some(
self.find_content_query_pool
.write()
.add_query(query_info, query),
)
Some(self.find_content_query_pool.add_query(query_info, query))
}

/// Returns an ENR if one is known for the given NodeId.
Expand All @@ -2594,7 +2589,7 @@ where
}

// Check the existing find node queries for the ENR.
for (query_info, _) in self.find_node_query_pool.read().iter() {
for (query_info, _) in self.find_node_query_pool.iter() {
if let Some(enr) = query_info
.untrusted_enrs
.iter()
Expand All @@ -2605,7 +2600,7 @@ where
}

// Check the existing find content queries for the ENR.
for (query_info, _) in self.find_content_query_pool.read().iter() {
for (query_info, _) in self.find_content_query_pool.iter() {
if let Some(enr) = query_info
.untrusted_enrs
.iter()
Expand Down Expand Up @@ -2840,12 +2835,8 @@ mod tests {
command_tx,
command_rx,
active_outgoing_requests,
find_node_query_pool: Arc::new(RwLock::new(QueryPool::new(
overlay_config.query_timeout,
))),
find_content_query_pool: Arc::new(RwLock::new(QueryPool::new(
overlay_config.query_timeout,
))),
find_node_query_pool: QueryPool::new(overlay_config.query_timeout),
find_content_query_pool: QueryPool::new(overlay_config.query_timeout),
query_peer_timeout: overlay_config.query_peer_timeout,
query_parallelism: overlay_config.query_parallelism,
query_num_results: overlay_config.query_num_results,
Expand Down Expand Up @@ -3539,15 +3530,16 @@ mod tests {
let (_, target_enr) = generate_random_remote_enr();
let target_node_id = target_enr.node_id();

assert_eq!(service.find_node_query_pool.read().iter().count(), 0);
assert_eq!(service.find_node_query_pool.iter().count(), 0);

service.add_bootnodes(bootnodes, true);

// Initialize the query and call `poll` so that it starts
service.init_find_nodes_query(&target_node_id, None);
let _ = service.find_node_query_pool.write().poll();
let _ = service.find_node_query_pool.poll();

let pool = service.find_node_query_pool.read();
let expected_distances_per_peer = service.findnodes_query_distances_per_peer;
let pool = &mut service.find_node_query_pool;
let (query_info, query) = pool.iter().next().unwrap();

assert!(query_info.untrusted_enrs.contains(&bootnode1));
Expand All @@ -3559,10 +3551,7 @@ mod tests {
..
} => {
assert_eq!(target, target_node_id);
assert_eq!(
distances_to_request,
service.findnodes_query_distances_per_peer
);
assert_eq!(distances_to_request, expected_distances_per_peer);
}
_ => panic!("Unexpected query type"),
}
Expand Down Expand Up @@ -3591,7 +3580,7 @@ mod tests {
XorMetric,
MockValidator,
MemoryContentStore,
>::query_event_poll(service.find_node_query_pool.clone())
>::query_event_poll(&mut service.find_node_query_pool)
.await;
match event {
QueryEvent::Waiting(query_id, node_id, request) => {
Expand Down Expand Up @@ -3627,7 +3616,7 @@ mod tests {
XorMetric,
MockValidator,
MemoryContentStore,
>::query_event_poll(service.find_node_query_pool.clone())
>::query_event_poll(&mut service.find_node_query_pool)
.await;

// Check that the request is being sent to either node 1 or node 2. Keep track of which.
Expand All @@ -3644,7 +3633,7 @@ mod tests {
XorMetric,
MockValidator,
MemoryContentStore,
>::query_event_poll(service.find_node_query_pool.clone())
>::query_event_poll(&mut service.find_node_query_pool)
.await;

// Check that a request is being sent to the other node.
Expand All @@ -3668,7 +3657,7 @@ mod tests {
XorMetric,
MockValidator,
MemoryContentStore,
>::query_event_poll(service.find_node_query_pool.clone())
>::query_event_poll(&mut service.find_node_query_pool)
.await;

match event {
Expand Down Expand Up @@ -3711,7 +3700,7 @@ mod tests {
XorMetric,
MockValidator,
MemoryContentStore,
>::query_event_poll(service.find_node_query_pool.clone())
>::query_event_poll(&mut service.find_node_query_pool)
.await;

let (_, enr1) = generate_random_remote_enr();
Expand Down Expand Up @@ -3780,8 +3769,7 @@ mod tests {
let query_id = service.init_find_content_query(target_content_key.clone(), None, false);
let query_id = query_id.expect("Query ID for new find content query is `None`");

let pool = service.find_content_query_pool.clone();
let mut pool = pool.write();
let pool = &mut service.find_content_query_pool;
let (query_info, query) = pool
.get_mut(query_id)
.expect("Query pool does not contain query");
Expand Down Expand Up @@ -3850,8 +3838,7 @@ mod tests {

// update query in own span so mut ref is dropped after poll
{
let pool = service.find_content_query_pool.clone();
let mut pool = pool.write();
let pool = &mut service.find_content_query_pool;
let (_, query) = pool
.get_mut(query_id)
.expect("Query pool does not contain query");
Expand All @@ -3868,7 +3855,7 @@ mod tests {
vec![enr.clone()],
);

let mut pool = service.find_content_query_pool.write();
let pool = &mut service.find_content_query_pool;
let (query_info, query) = pool
.get_mut(query_id)
.expect("Query pool does not contain query");
Expand Down Expand Up @@ -3915,8 +3902,7 @@ mod tests {

// update query in own span so mut ref is dropped after poll
{
let pool = service.find_content_query_pool.clone();
let mut pool = pool.write();
let pool = &mut service.find_content_query_pool;
let (_, query) = pool
.get_mut(query_id)
.expect("Query pool does not contain query");
Expand All @@ -3930,7 +3916,7 @@ mod tests {
let content: Vec<u8> = vec![0, 1, 2, 3];
service.advance_find_content_query_with_content(&query_id, bootnode_enr, content.clone());

let mut pool = service.find_content_query_pool.write();
let pool = &mut service.find_content_query_pool;
let (_, query) = pool
.get_mut(query_id)
.expect("Query pool does not contain query");
Expand Down Expand Up @@ -3982,8 +3968,7 @@ mod tests {

// update query in own span so mut ref is dropped after poll
{
let pool = service.find_content_query_pool.clone();
let mut pool = pool.write();
let pool = &mut service.find_content_query_pool;
let (_, query) = pool
.get_mut(query_id)
.expect("Query pool does not contain query");
Expand All @@ -4000,7 +3985,7 @@ mod tests {
actual_connection_id,
);

let mut pool = service.find_content_query_pool.write();
let pool = &mut service.find_content_query_pool;
let (_, query) = pool
.get_mut(query_id)
.expect("Query pool does not contain query");
Expand Down Expand Up @@ -4054,7 +4039,7 @@ mod tests {

let query_event =
OverlayService::<_, XorMetric, MockValidator, MemoryContentStore>::query_event_poll(
service.find_content_query_pool.clone(),
&mut service.find_content_query_pool,
)
.await;

Expand Down Expand Up @@ -4093,7 +4078,7 @@ mod tests {

let query_event =
OverlayService::<_, XorMetric, MockValidator, MemoryContentStore>::query_event_poll(
service.find_content_query_pool.clone(),
&mut service.find_content_query_pool,
)
.await;

Expand All @@ -4108,7 +4093,7 @@ mod tests {
let polled = timeout(
Duration::from_millis(10),
OverlayService::<_, XorMetric, MockValidator, MemoryContentStore>::query_event_poll(
service.find_content_query_pool.clone(),
&mut service.find_content_query_pool,
),
)
.await;
Expand Down

0 comments on commit bf3d383

Please sign in to comment.