diff --git a/Cargo.lock b/Cargo.lock index a63e49d931..ef3fa7ca76 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2437,6 +2437,7 @@ dependencies = [ "iroh-quinn", "iroh-test", "num_cpus", + "once_cell", "parking_lot", "portable-atomic", "postcard", diff --git a/iroh/Cargo.toml b/iroh/Cargo.toml index 5130f336c2..5462a5f2ff 100644 --- a/iroh/Cargo.toml +++ b/iroh/Cargo.toml @@ -35,6 +35,7 @@ num_cpus = { version = "1.15.0" } portable-atomic = "1" iroh-docs = { version = "0.18.0", path = "../iroh-docs" } iroh-gossip = { version = "0.18.0", path = "../iroh-gossip" } +once_cell = "1.18.0" parking_lot = "0.12.1" postcard = { version = "1", default-features = false, features = ["alloc", "use-std", "experimental-derive"] } quic-rpc = { version = "0.10.0", default-features = false, features = ["flume-transport", "quinn-transport"] } diff --git a/iroh/examples/custom-protocol.rs b/iroh/examples/custom-protocol.rs index 5507e3b098..be68c7e44a 100644 --- a/iroh/examples/custom-protocol.rs +++ b/iroh/examples/custom-protocol.rs @@ -62,10 +62,6 @@ struct ExampleProto { } impl Protocol for ExampleProto { - fn as_arc_any(self: Arc) -> Arc { - self - } - fn handle_connection(self: Arc, conn: Connecting) -> BoxedFuture> { Box::pin(async move { self.handle_connection(conn.await?).await }) } diff --git a/iroh/src/node.rs b/iroh/src/node.rs index 0b4f3d9a0c..36cf4705a9 100644 --- a/iroh/src/node.rs +++ b/iroh/src/node.rs @@ -5,8 +5,8 @@ //! To shut down the node, call [`Node::shutdown`]. use std::fmt::Debug; use std::net::SocketAddr; +use std::path::Path; use std::sync::Arc; -use std::{any::Any, path::Path}; use anyhow::{anyhow, Result}; use futures_lite::StreamExt; @@ -16,6 +16,7 @@ use iroh_blobs::store::Store as BaoStore; use iroh_docs::engine::Engine; use iroh_net::util::AbortingJoinHandle; use iroh_net::{endpoint::LocalEndpointsStream, key::SecretKey, Endpoint}; +use once_cell::sync::OnceCell; use quic_rpc::transport::flume::FlumeConnection; use quic_rpc::RpcClient; use tokio::task::JoinHandle; @@ -23,7 +24,7 @@ use tokio_util::sync::CancellationToken; use tokio_util::task::LocalPoolHandle; use tracing::debug; -use crate::{client::RpcService, node::builder::ProtocolMap}; +use crate::{client::RpcService, node::protocol::ProtocolMap}; mod builder; mod protocol; @@ -47,7 +48,7 @@ pub use protocol::Protocol; #[derive(Debug, Clone)] pub struct Node { inner: Arc>, - task: Arc>, + task: Arc>>, client: crate::client::MemIroh, protocols: ProtocolMap, } @@ -155,11 +156,7 @@ impl Node { /// Returns the protocol handler for a alpn. pub fn get_protocol(&self, alpn: &[u8]) -> Option> { - let protocols = self.protocols.read().unwrap(); - let protocol: Arc = protocols.get(alpn)?.clone(); - let protocol_any: Arc = protocol.as_arc_any(); - let protocol_ref = Arc::downcast(protocol_any).ok()?; - Some(protocol_ref) + self.protocols.get(alpn) } /// Aborts the node. @@ -173,7 +170,8 @@ impl Node { pub async fn shutdown(self) -> Result<()> { self.inner.cancel_token.cancel(); - if let Ok(task) = Arc::try_unwrap(self.task) { + if let Ok(mut task) = Arc::try_unwrap(self.task) { + let task = task.take().expect("cannot be empty"); task.await?; } Ok(()) diff --git a/iroh/src/node/builder.rs b/iroh/src/node/builder.rs index 941ff8915f..34724b64e1 100644 --- a/iroh/src/node/builder.rs +++ b/iroh/src/node/builder.rs @@ -1,8 +1,8 @@ use std::{ - collections::{BTreeSet, HashMap}, + collections::BTreeSet, net::{Ipv4Addr, SocketAddrV4}, path::{Path, PathBuf}, - sync::{Arc, RwLock}, + sync::Arc, time::Duration, }; @@ -28,13 +28,12 @@ use quic_rpc::{ RpcServer, ServiceEndpoint, }; use serde::{Deserialize, Serialize}; -use tokio::sync::oneshot; use tokio_util::{sync::CancellationToken, task::LocalPoolHandle}; use tracing::{debug, error, error_span, info, trace, warn, Instrument}; use crate::{ client::RPC_ALPN, - node::Protocol, + node::{protocol::ProtocolMap, Protocol}, rpc_protocol::RpcService, util::{fs::load_secret_key, path::IrohPaths}, }; @@ -56,7 +55,6 @@ const DEFAULT_GC_INTERVAL: Duration = Duration::from_secs(60 * 5); const MAX_CONNECTIONS: u32 = 1024; const MAX_STREAMS: u64 = 10; -pub(super) type ProtocolMap = Arc>>>; type ProtocolBuilders = Vec<( &'static [u8], Box) -> Arc + Send + 'static>, @@ -511,7 +509,7 @@ where let (internal_rpc, controller) = quic_rpc::transport::flume::connection(1); let client = crate::client::Iroh::new(quic_rpc::RpcClient::new(controller.clone())); - let protocols = Arc::new(RwLock::new(HashMap::new())); + let protocols = ProtocolMap::default(); let inner = Arc::new(NodeInner { db: self.blobs_store, @@ -524,9 +522,21 @@ where sync, downloader, }); - let (ready_tx, ready_rx) = oneshot::channel(); + + let node = Node { + inner: inner.clone(), + task: Default::default(), + client, + protocols: protocols.clone(), + }; + + for (alpn, p) in self.protocols { + let protocol = p(node.clone()); + protocols.insert(alpn, protocol); + } + let task = { - let protocols = Arc::clone(&protocols); + let protocols = protocols.clone(); let gossip = gossip.clone(); let handler = rpc::Handler { inner: inner.clone(), @@ -535,8 +545,6 @@ where let ep = endpoint.clone(); tokio::task::spawn( async move { - // Wait until the protocol builders have run. - ready_rx.await.expect("cannot fail"); Self::run( ep, protocols, @@ -551,20 +559,7 @@ where ) }; - let node = Node { - inner, - task: Arc::new(task), - client, - protocols, - }; - - for (alpn, p) in self.protocols { - let protocol = p(node.clone()); - node.protocols.write().unwrap().insert(alpn, protocol); - } - - // Notify the run task that the protocols are now built. - ready_tx.send(()).expect("cannot fail"); + node.task.set(task).expect("was empty"); // spawn a task that updates the gossip endpoints. // TODO: track task @@ -802,10 +797,7 @@ async fn handle_connection( .await } alpn => { - let protocol = { - let protocols = protocols.read().unwrap(); - protocols.get(alpn).cloned() - }; + let protocol = protocols.get_any(alpn); if let Some(protocol) = protocol { protocol.handle_connection(connecting).await?; } else { diff --git a/iroh/src/node/protocol.rs b/iroh/src/node/protocol.rs index dd9db9d84c..0bba70cc73 100644 --- a/iroh/src/node/protocol.rs +++ b/iroh/src/node/protocol.rs @@ -1,18 +1,58 @@ -use std::{any::Any, fmt, sync::Arc}; +use std::{ + any::Any, + collections::HashMap, + fmt, + sync::{Arc, RwLock}, +}; use anyhow::Result; use futures_lite::future::Boxed as BoxedFuture; use iroh_net::endpoint::Connecting; /// Trait for iroh protocol handlers. -pub trait Protocol: Send + Sync + Any + fmt::Debug + 'static { - /// Return `self` as `dyn Any`. - /// - /// Implementations can simply return `self` here. - fn as_arc_any(self: Arc) -> Arc; - +pub trait Protocol: Send + Sync + IntoArcAny + fmt::Debug + 'static { /// Handle an incoming connection. /// /// This runs on a freshly spawned tokio task so this can be long-running. fn handle_connection(self: Arc, conn: Connecting) -> BoxedFuture>; } + +/// Helper trait to facilite casting from `Arc` to `Arc`. +/// +/// This trait has a blanket implementation so there is no need to implement this yourself. +pub trait IntoArcAny { + fn into_arc_any(self: Arc) -> Arc; +} + +impl IntoArcAny for T { + fn into_arc_any(self: Arc) -> Arc { + self + } +} + +/// Map of registered protocol handlers. +#[allow(clippy::type_complexity)] +#[derive(Debug, Clone, Default)] +pub struct ProtocolMap(Arc>>>); + +impl ProtocolMap { + /// Returns the registered protocol handler for an ALPN as a concrete type. + pub fn get(&self, alpn: &[u8]) -> Option> { + let protocols = self.0.read().unwrap(); + let protocol: Arc = protocols.get(alpn)?.clone(); + let protocol_any: Arc = protocol.into_arc_any(); + let protocol_ref = Arc::downcast(protocol_any).ok()?; + Some(protocol_ref) + } + + /// Returns the registered protocol handler for an ALPN as a `dyn Protocol`. + pub fn get_any(&self, alpn: &[u8]) -> Option> { + let protocols = self.0.read().unwrap(); + let protocol: Arc = protocols.get(alpn)?.clone(); + Some(protocol) + } + + pub(super) fn insert(&self, alpn: &'static [u8], protocol: Arc) { + self.0.write().unwrap().insert(alpn, protocol); + } +}