Skip to content

Commit

Permalink
improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
Frando committed Jun 12, 2024
1 parent a369048 commit 63fbcc0
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 21 deletions.
16 changes: 7 additions & 9 deletions iroh/src/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,7 @@ pub use protocol::Protocol;
#[derive(Debug, Clone)]
pub struct Node<D> {
inner: Arc<NodeInner<D>>,
task: Arc<OnceCell<JoinHandle<()>>>,
client: crate::client::MemIroh,
protocols: ProtocolMap,
}

#[derive(derive_more::Debug)]
Expand All @@ -62,7 +60,9 @@ struct NodeInner<D> {
#[debug("rt")]
rt: LocalPoolHandle,
downloader: Downloader,
tasks: Mutex<Option<JoinSet<()>>>,
task: OnceCell<JoinHandle<()>>,
protocols: ProtocolMap,
tasks: Mutex<JoinSet<()>>,
}

/// In memory node.
Expand Down Expand Up @@ -153,7 +153,7 @@ impl<D: BaoStore> Node<D> {

/// Returns the protocol handler for a alpn.
pub fn get_protocol<P: Protocol>(&self, alpn: &[u8]) -> Option<Arc<P>> {
self.protocols.get::<P>(alpn)
self.inner.protocols.get::<P>(alpn)
}

fn downloader(&self) -> &Downloader {
Expand All @@ -171,11 +171,9 @@ impl<D: BaoStore> Node<D> {
pub async fn shutdown(self) -> Result<()> {
self.inner.cancel_token.cancel();

if let Ok(mut task) = Arc::try_unwrap(self.task) {
task.take().expect("cannot be empty").await?;
}
if let Some(mut tasks) = self.inner.tasks.lock().unwrap().take() {
tasks.abort_all();
if let Ok(mut inner) = Arc::try_unwrap(self.inner) {
inner.task.take().expect("cannot be empty").await?;
inner.tasks.lock().unwrap().abort_all();
}
Ok(())
}
Expand Down
23 changes: 11 additions & 12 deletions iroh/src/node/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ use quic_rpc::{
RpcServer, ServiceEndpoint,
};
use serde::{Deserialize, Serialize};
use tokio::task::JoinSet;
use tokio_util::{sync::CancellationToken, task::LocalPoolHandle};
use tracing::{debug, error, error_span, info, trace, warn, Instrument};

Expand Down Expand Up @@ -541,7 +540,6 @@ where
debug!("rpc listening on: {:?}", self.rpc_endpoint.local_addr());

let blobs_store = self.blobs_store.clone();
let mut tasks = JoinSet::new();

// initialize the downloader
let downloader = Downloader::new(self.blobs_store.clone(), endpoint.clone(), lp.clone());
Expand All @@ -560,13 +558,13 @@ where
tasks: Default::default(),
rt: lp.clone(),
downloader,
task: Default::default(),
protocols: protocols.clone(),
});

let node = Node {
inner: inner.clone(),
task: Default::default(),
client,
protocols: protocols.clone(),
};

for (alpn, p) in self.protocols {
Expand All @@ -583,22 +581,27 @@ where
.instrument(error_span!("node", %me)),
)
};
node.task.set(task).expect("was empty");
node.inner.task.set(task).expect("was empty");

if let GcPolicy::Interval(gc_period) = self.gc_policy {
tracing::info!("Starting GC task with interval {:?}", gc_period);
let db = blobs_store.clone();
let gc_done_callback = self.gc_done_callback.take();
let sync = protocols.get::<DocsEngine>(DOCS_ALPN);

tasks.spawn_local(Self::gc_loop(db, sync, gc_period, gc_done_callback));
node.inner.tasks.lock().unwrap().spawn_local(Self::gc_loop(
db,
sync,
gc_period,
gc_done_callback,
));
}

// spawn a task that updates the gossip endpoints.
let mut stream = endpoint.local_endpoints();
let gossip = protocols.get::<Gossip>(GOSSIP_ALPN);
if let Some(gossip) = gossip {
tasks.spawn(async move {
node.inner.tasks.lock().unwrap().spawn(async move {
while let Some(eps) = stream.next().await {
if let Err(err) = gossip.update_endpoints(&eps) {
warn!("Failed to update gossip endpoints: {err:?}");
Expand All @@ -608,8 +611,6 @@ where
});
}

*(node.inner.tasks.lock().unwrap()) = Some(tasks);

// Wait for a single endpoint update, to make sure
// we found some endpoints
tokio::time::timeout(ENDPOINT_WAIT, endpoint.local_endpoints().next())
Expand Down Expand Up @@ -713,9 +714,7 @@ where
}
};
let protocols = protocols.clone();
let mut tasks_guard = inner.tasks.lock().unwrap();
let tasks = tasks_guard.as_mut().expect("only empty after shutdown");
tasks.spawn(async move {
inner.tasks.lock().unwrap().spawn(async move {
if let Err(err) = handle_connection(connecting, alpn, protocols).await {
warn!("Handling incoming connection ended with error: {err}");
}
Expand Down
1 change: 1 addition & 0 deletions iroh/src/node/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ impl<T: Send + Sync + 'static> IntoArcAny for T {
}

#[derive(Debug, Clone, Default)]
#[allow(clippy::type_complexity)]
pub struct ProtocolMap(Arc<RwLock<HashMap<&'static [u8], Arc<dyn Protocol>>>>);

impl ProtocolMap {
Expand Down

0 comments on commit 63fbcc0

Please sign in to comment.