diff --git a/Cargo.lock b/Cargo.lock index a63e49d9312..ef3fa7ca764 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2437,6 +2437,7 @@ dependencies = [ "iroh-quinn", "iroh-test", "num_cpus", + "once_cell", "parking_lot", "portable-atomic", "postcard", diff --git a/iroh/Cargo.toml b/iroh/Cargo.toml index 5130f336c26..a8b92488f72 100644 --- a/iroh/Cargo.toml +++ b/iroh/Cargo.toml @@ -32,6 +32,7 @@ iroh-io = { version = "0.6.0", features = ["stats"] } iroh-metrics = { version = "0.18.0", path = "../iroh-metrics", optional = true } iroh-net = { version = "0.18.0", path = "../iroh-net" } num_cpus = { version = "1.15.0" } +once_cell = "1.17.0" portable-atomic = "1" iroh-docs = { version = "0.18.0", path = "../iroh-docs" } iroh-gossip = { version = "0.18.0", path = "../iroh-gossip" } diff --git a/iroh/src/node.rs b/iroh/src/node.rs index fc7b9a7522a..669f7ff992b 100644 --- a/iroh/src/node.rs +++ b/iroh/src/node.rs @@ -14,8 +14,8 @@ use iroh_base::key::PublicKey; use iroh_blobs::downloader::Downloader; use iroh_blobs::store::Store as BaoStore; use iroh_docs::engine::Engine; - use iroh_net::{endpoint::LocalEndpointsStream, key::SecretKey, Endpoint}; +use once_cell::sync::OnceCell; use quic_rpc::transport::flume::FlumeConnection; use quic_rpc::RpcClient; use tokio::task::{JoinHandle, JoinSet}; @@ -47,7 +47,7 @@ pub use protocol::Protocol; #[derive(Debug, Clone)] pub struct Node { inner: Arc>, - task: Arc>, + task: Arc>>, client: crate::client::MemIroh, protocols: ProtocolMap, tasks: Arc>>>, @@ -174,8 +174,8 @@ impl Node { pub async fn shutdown(self) -> Result<()> { self.inner.cancel_token.cancel(); - if let Ok(task) = Arc::try_unwrap(self.task) { - task.await?; + if let Ok(mut task) = Arc::try_unwrap(self.task) { + task.take().expect("cannot be empty").await?; } if let Some(mut tasks) = self.tasks.lock().unwrap().take() { tasks.abort_all(); diff --git a/iroh/src/node/builder.rs b/iroh/src/node/builder.rs index 10538fb4ee0..effa32e19bc 100644 --- a/iroh/src/node/builder.rs +++ b/iroh/src/node/builder.rs @@ -28,7 +28,7 @@ use quic_rpc::{ RpcServer, ServiceEndpoint, }; use serde::{Deserialize, Serialize}; -use tokio::{sync::oneshot, task::JoinSet}; +use tokio::task::JoinSet; use tokio_util::{sync::CancellationToken, task::LocalPoolHandle}; use tracing::{debug, error, error_span, info, trace, warn, Instrument}; @@ -553,24 +553,9 @@ where rt: lp.clone(), downloader, }); - let (ready_tx, ready_rx) = oneshot::channel(); - let task = { - let protocols = protocols.clone(); - let me = endpoint.node_id().fmt_short(); - let inner = inner.clone(); - tokio::task::spawn( - async move { - // Wait until the protocol builders have run. - ready_rx.await.expect("cannot fail"); - Self::run(inner, protocols, self.rpc_endpoint, internal_rpc).await - } - .instrument(error_span!("node", %me)), - ) - }; - let node = Node { - inner, - task: Arc::new(task), + inner: inner.clone(), + task: Default::default(), client, protocols: protocols.clone(), tasks: Default::default(), @@ -581,6 +566,17 @@ where protocols.insert(alpn, protocol); } + let task = { + let protocols = protocols.clone(); + let me = endpoint.node_id().fmt_short(); + let inner = inner.clone(); + tokio::task::spawn( + async move { Self::run(inner, protocols, self.rpc_endpoint, internal_rpc).await } + .instrument(error_span!("node", %me)), + ) + }; + node.task.set(task).expect("was empty"); + let sync = protocols .get::(DOCS_ALPN) .context("docs engine not registered")?; @@ -593,9 +589,6 @@ where tasks.spawn_local(Self::gc_loop(db, sync, gc_period, gc_done_callback)); } - // Notify the run task that the protocols are now built. - ready_tx.send(()).expect("cannot fail"); - // spawn a task that updates the gossip endpoints. let mut stream = endpoint.local_endpoints(); let gossip = protocols.get::(GOSSIP_ALPN);