diff --git a/iroh/src/node.rs b/iroh/src/node.rs index 3b9173c706..0b4f3d9a0c 100644 --- a/iroh/src/node.rs +++ b/iroh/src/node.rs @@ -5,8 +5,8 @@ //! To shut down the node, call [`Node::shutdown`]. use std::fmt::Debug; use std::net::SocketAddr; -use std::path::Path; use std::sync::Arc; +use std::{any::Any, path::Path}; use anyhow::{anyhow, Result}; use futures_lite::StreamExt; @@ -23,14 +23,16 @@ use tokio_util::sync::CancellationToken; use tokio_util::task::LocalPoolHandle; use tracing::debug; -use crate::client::RpcService; +use crate::{client::RpcService, node::builder::ProtocolMap}; mod builder; +mod protocol; mod rpc; mod rpc_status; pub use self::builder::{Builder, DiscoveryConfig, GcPolicy, StorageConfig}; pub use self::rpc_status::RpcStatus; +pub use protocol::Protocol; /// A server which implements the iroh node. /// @@ -47,6 +49,7 @@ pub struct Node { inner: Arc>, task: Arc>, client: crate::client::MemIroh, + protocols: ProtocolMap, } #[derive(derive_more::Debug)] @@ -150,6 +153,15 @@ impl Node { self.inner.endpoint.my_relay() } + /// Returns the protocol handler for a alpn. + pub fn get_protocol(&self, alpn: &[u8]) -> Option> { + let protocols = self.protocols.read().unwrap(); + let protocol: Arc = protocols.get(alpn)?.clone(); + let protocol_any: Arc = protocol.as_arc_any(); + let protocol_ref = Arc::downcast(protocol_any).ok()?; + Some(protocol_ref) + } + /// Aborts the node. /// /// This does not gracefully terminate currently: all connections are closed and diff --git a/iroh/src/node/builder.rs b/iroh/src/node/builder.rs index db935479f2..2e1f38ed25 100644 --- a/iroh/src/node/builder.rs +++ b/iroh/src/node/builder.rs @@ -1,8 +1,8 @@ use std::{ - collections::BTreeSet, + collections::{BTreeSet, HashMap}, net::{Ipv4Addr, SocketAddrV4}, path::{Path, PathBuf}, - sync::Arc, + sync::{Arc, RwLock}, time::Duration, }; @@ -28,11 +28,13 @@ use quic_rpc::{ RpcServer, ServiceEndpoint, }; use serde::{Deserialize, Serialize}; +use tokio::sync::oneshot; use tokio_util::{sync::CancellationToken, task::LocalPoolHandle}; use tracing::{debug, error, error_span, info, trace, warn, Instrument}; use crate::{ client::RPC_ALPN, + node::Protocol, rpc_protocol::RpcService, util::{fs::load_secret_key, path::IrohPaths}, }; @@ -54,6 +56,9 @@ const DEFAULT_GC_INTERVAL: Duration = Duration::from_secs(60 * 5); const MAX_CONNECTIONS: u32 = 1024; const MAX_STREAMS: u64 = 10; +pub(super) type ProtocolMap = Arc>>>; +type ProtocolBuilders = Vec<(&'static [u8], Box) -> Arc>)>; + /// Builder for the [`Node`]. /// /// You must supply a blob store and a document store. @@ -84,6 +89,7 @@ where dns_resolver: Option, node_discovery: DiscoveryConfig, docs_store: iroh_docs::store::fs::Store, + protocols: ProtocolBuilders, #[cfg(any(test, feature = "test-utils"))] insecure_skip_relay_cert_verify: bool, /// Callback to register when a gc loop is done @@ -133,6 +139,7 @@ impl Default for Builder { rpc_endpoint: Default::default(), gc_policy: GcPolicy::Disabled, docs_store: iroh_docs::store::Store::memory(), + protocols: Default::default(), node_discovery: Default::default(), #[cfg(any(test, feature = "test-utils"))] insecure_skip_relay_cert_verify: false, @@ -160,6 +167,7 @@ impl Builder { gc_policy: GcPolicy::Disabled, docs_store, node_discovery: Default::default(), + protocols: Default::default(), #[cfg(any(test, feature = "test-utils"))] insecure_skip_relay_cert_verify: false, gc_done_callback: None, @@ -223,6 +231,7 @@ where gc_policy: self.gc_policy, docs_store, node_discovery: self.node_discovery, + protocols: Default::default(), #[cfg(any(test, feature = "test-utils"))] insecure_skip_relay_cert_verify: false, gc_done_callback: self.gc_done_callback, @@ -244,6 +253,7 @@ where gc_policy: self.gc_policy, docs_store: self.docs_store, node_discovery: self.node_discovery, + protocols: Default::default(), #[cfg(any(test, feature = "test-utils"))] insecure_skip_relay_cert_verify: self.insecure_skip_relay_cert_verify, gc_done_callback: self.gc_done_callback, @@ -270,6 +280,7 @@ where gc_policy: self.gc_policy, docs_store: self.docs_store, node_discovery: self.node_discovery, + protocols: Default::default(), #[cfg(any(test, feature = "test-utils"))] insecure_skip_relay_cert_verify: self.insecure_skip_relay_cert_verify, gc_done_callback: self.gc_done_callback, @@ -343,6 +354,16 @@ where self } + /// Accept a custom protocol. + pub fn accept( + mut self, + alpn: &'static [u8], + protocol: impl FnOnce(Node) -> Arc + 'static, + ) -> Self { + self.protocols.push((alpn, Box::new(protocol))); + self + } + /// Register a callback for when GC is done. #[cfg(any(test, feature = "test-utils"))] pub fn register_gc_done_cb(mut self, cb: Box) -> Self { @@ -481,6 +502,8 @@ where let (internal_rpc, controller) = quic_rpc::transport::flume::connection(1); let client = crate::client::Iroh::new(quic_rpc::RpcClient::new(controller.clone())); + let protocols = Arc::new(RwLock::new(HashMap::new())); + let inner = Arc::new(NodeInner { db: self.blobs_store, endpoint: endpoint.clone(), @@ -492,7 +515,9 @@ where sync, downloader, }); + let (ready_tx, ready_rx) = oneshot::channel(); let task = { + let protocols = Arc::clone(&protocols); let gossip = gossip.clone(); let handler = rpc::Handler { inner: inner.clone(), @@ -501,8 +526,11 @@ where let ep = endpoint.clone(); tokio::task::spawn( async move { + // Wait until the protocol builders have run. + ready_rx.await.expect("cannot fail"); Self::run( ep, + protocols, handler, self.rpc_endpoint, internal_rpc, @@ -518,8 +546,17 @@ where inner, task: Arc::new(task), client, + protocols, }; + for (alpn, p) in self.protocols { + let protocol = p(node.clone()); + node.protocols.write().unwrap().insert(alpn, protocol); + } + + // Notify the run task that the protocols are now built. + ready_tx.send(()).expect("cannot fail"); + // spawn a task that updates the gossip endpoints. // TODO: track task let mut stream = endpoint.local_endpoints(); @@ -545,6 +582,7 @@ where #[allow(clippy::too_many_arguments)] async fn run( server: Endpoint, + protocols: ProtocolMap, handler: rpc::Handler, rpc: E, internal_rpc: impl ServiceEndpoint, @@ -615,8 +653,9 @@ where let gossip = gossip.clone(); let inner = handler.inner.clone(); let sync = handler.inner.sync.clone(); + let protocols = protocols.clone(); tokio::task::spawn(async move { - if let Err(err) = handle_connection(connecting, alpn, inner, gossip, sync).await { + if let Err(err) = handle_connection(connecting, alpn, inner, gossip, sync, protocols).await { warn!("Handling incoming connection ended with error: {err}"); } }); @@ -738,6 +777,7 @@ async fn handle_connection( node: Arc>, gossip: Gossip, sync: DocsEngine, + protocols: ProtocolMap, ) -> Result<()> { match alpn.as_bytes() { GOSSIP_ALPN => gossip.handle_connection(connecting.await?).await?, @@ -752,7 +792,19 @@ async fn handle_connection( ) .await } - _ => bail!("ignoring connection: unsupported ALPN protocol"), + alpn => { + let protocol = { + let protocols = protocols.read().unwrap(); + protocols.get(alpn).cloned() + }; + if let Some(protocol) = protocol { + drop(protocols); + let connection = connecting.await?; + protocol.accept(connection).await?; + } else { + bail!("ignoring connection: unsupported ALPN protocol"); + } + } } Ok(()) } diff --git a/iroh/src/node/protocol.rs b/iroh/src/node/protocol.rs new file mode 100644 index 0000000000..4dc7dbb29d --- /dev/null +++ b/iroh/src/node/protocol.rs @@ -0,0 +1,19 @@ +use std::{any::Any, fmt, future::Future, pin::Pin, sync::Arc}; + +use iroh_net::endpoint::Connection; + +/// Trait for iroh protocol handlers. +pub trait Protocol: Sync + Send + Any + fmt::Debug + 'static { + /// Return `self` as `dyn Any`. + /// + /// Implementations can simply return `self` here. + fn as_arc_any(self: Arc) -> Arc; + + /// Accept an incoming connection. + /// + /// This runs on a freshly spawned tokio task so this can be long-running. + fn accept( + &self, + conn: Connection, + ) -> Pin> + 'static + Send + Sync>>; +}