diff --git a/iroh-blobs/src/store/fs.rs b/iroh-blobs/src/store/fs.rs index 5febe54457..e9e113a603 100644 --- a/iroh-blobs/src/store/fs.rs +++ b/iroh-blobs/src/store/fs.rs @@ -1486,6 +1486,8 @@ impl Actor { let mut msgs = PeekableFlumeReceiver::new(self.state.msgs.clone()); while let Some(msg) = msgs.recv() { if let ActorMessage::Shutdown { tx } = msg { + // Make sure the database is dropped before we send the reply. + drop(self); if let Some(tx) = tx { tx.send(()).ok(); } diff --git a/iroh-blobs/src/store/traits.rs b/iroh-blobs/src/store/traits.rs index e0ec3e6b39..2a91d1c0f3 100644 --- a/iroh-blobs/src/store/traits.rs +++ b/iroh-blobs/src/store/traits.rs @@ -295,7 +295,7 @@ pub trait ReadableStore: Map { } /// The mutable part of a Bao store. -pub trait Store: ReadableStore + MapMut { +pub trait Store: ReadableStore + MapMut + std::fmt::Debug { /// This trait method imports a file from a local path. /// /// `data` is the path to the file. diff --git a/iroh-net/src/endpoint.rs b/iroh-net/src/endpoint.rs index e739c2606e..b741f47178 100644 --- a/iroh-net/src/endpoint.rs +++ b/iroh-net/src/endpoint.rs @@ -125,15 +125,12 @@ impl Builder { } }; let secret_key = self.secret_key.unwrap_or_else(SecretKey::generate); - let mut server_config = make_server_config( - &secret_key, - self.alpn_protocols, - self.transport_config, - self.keylog, - )?; - if let Some(c) = self.concurrent_connections { - server_config.concurrent_connections(c); - } + let static_config = StaticConfig { + transport_config: Arc::new(self.transport_config.unwrap_or_default()), + keylog: self.keylog, + concurrent_connections: self.concurrent_connections, + secret_key: secret_key.clone(), + }; let dns_resolver = self .dns_resolver .unwrap_or_else(|| default_resolver().clone()); @@ -149,7 +146,7 @@ impl Builder { #[cfg(any(test, feature = "test-utils"))] insecure_skip_relay_cert_verify: self.insecure_skip_relay_cert_verify, }; - Endpoint::bind(Some(server_config), msock_opts, self.keylog).await + Endpoint::bind(static_config, msock_opts, self.alpn_protocols).await } // # The very common methods everyone basically needs. @@ -296,17 +293,41 @@ impl Builder { } } +/// Configuration for a [`quinn::Endpoint`] that cannot be changed at runtime. +#[derive(Debug)] +struct StaticConfig { + secret_key: SecretKey, + transport_config: Arc, + keylog: bool, + concurrent_connections: Option, +} + +impl StaticConfig { + /// Create a [`quinn::ServerConfig`] with the specified ALPN protocols. + fn create_server_config(&self, alpn_protocols: Vec>) -> Result { + let mut server_config = make_server_config( + &self.secret_key, + alpn_protocols, + self.transport_config.clone(), + self.keylog, + )?; + if let Some(c) = self.concurrent_connections { + server_config.concurrent_connections(c); + } + Ok(server_config) + } +} + /// Creates a [`quinn::ServerConfig`] with the given secret key and limits. pub fn make_server_config( secret_key: &SecretKey, alpn_protocols: Vec>, - transport_config: Option, + transport_config: Arc, keylog: bool, ) -> Result { let tls_server_config = tls::make_server_config(secret_key, alpn_protocols, keylog)?; let mut server_config = quinn::ServerConfig::with_crypto(Arc::new(tls_server_config)); - server_config.transport_config(Arc::new(transport_config.unwrap_or_default())); - + server_config.transport_config(transport_config); Ok(server_config) } @@ -334,12 +355,11 @@ pub fn make_server_config( /// [QUIC]: https://quicwg.org #[derive(Clone, Debug)] pub struct Endpoint { - secret_key: Arc, msock: Handle, endpoint: quinn::Endpoint, rtt_actor: Arc, - keylog: bool, cancel_token: CancellationToken, + static_config: Arc, } impl Endpoint { @@ -359,16 +379,17 @@ impl Endpoint { /// This is for internal use, the public interface is the [`Builder`] obtained from /// [Self::builder]. See the methods on the builder for documentation of the parameters. async fn bind( - server_config: Option, + static_config: StaticConfig, msock_opts: magicsock::Options, - keylog: bool, + initial_alpns: Vec>, ) -> Result { - let secret_key = msock_opts.secret_key.clone(); - let span = info_span!("magic_ep", me = %secret_key.public().fmt_short()); + let span = info_span!("magic_ep", me = %static_config.secret_key.public().fmt_short()); let _guard = span.enter(); let msock = magicsock::MagicSock::spawn(msock_opts).await?; trace!("created magicsock"); + let server_config = static_config.create_server_config(initial_alpns)?; + let mut endpoint_config = quinn::EndpointConfig::default(); // Setting this to false means that quinn will ignore packets that have the QUIC fixed bit // set to 0. The fixed bit is the 3rd bit of the first byte of a packet. @@ -379,22 +400,31 @@ impl Endpoint { let endpoint = quinn::Endpoint::new_with_abstract_socket( endpoint_config, - server_config, + Some(server_config), msock.clone(), Arc::new(quinn::TokioRuntime), )?; trace!("created quinn endpoint"); Ok(Self { - secret_key: Arc::new(secret_key), msock, endpoint, rtt_actor: Arc::new(rtt_actor::RttHandle::new()), - keylog, cancel_token: CancellationToken::new(), + static_config: Arc::new(static_config), }) } + /// Set the list of accepted ALPN protocols. + /// + /// This will only affect new incoming connections. + /// Note that this *overrides* the current list of ALPNs. + pub fn set_alpns(&self, alpns: Vec>) -> Result<()> { + let server_config = self.static_config.create_server_config(alpns)?; + self.endpoint.set_server_config(Some(server_config)); + Ok(()) + } + // # Methods for establishing connectivity. /// Connects to a remote [`Endpoint`]. @@ -480,10 +510,10 @@ impl Endpoint { let client_config = { let alpn_protocols = vec![alpn.to_vec()]; let tls_client_config = tls::make_client_config( - &self.secret_key, + &self.static_config.secret_key, Some(*node_id), alpn_protocols, - self.keylog, + self.static_config.keylog, )?; let mut client_config = quinn::ClientConfig::new(Arc::new(tls_client_config)); let mut transport_config = quinn::TransportConfig::default(); @@ -579,7 +609,7 @@ impl Endpoint { /// Returns the secret_key of this endpoint. pub fn secret_key(&self) -> &SecretKey { - &self.secret_key + &self.static_config.secret_key } /// Returns the node id of this endpoint. @@ -587,7 +617,7 @@ impl Endpoint { /// This ID is the unique addressing information of this node and other peers must know /// it to be able to connect to this node. pub fn node_id(&self) -> NodeId { - self.secret_key.public() + self.static_config.secret_key.public() } /// Returns the current [`NodeAddr`] for this endpoint. diff --git a/iroh/Cargo.toml b/iroh/Cargo.toml index 8b19019e91..15404847da 100644 --- a/iroh/Cargo.toml +++ b/iroh/Cargo.toml @@ -102,3 +102,7 @@ required-features = ["examples"] [[example]] name = "client" required-features = ["examples"] + +[[example]] +name = "custom-protocol" +required-features = ["examples"] diff --git a/iroh/examples/custom-protocol.rs b/iroh/examples/custom-protocol.rs new file mode 100644 index 0000000000..4a12687725 --- /dev/null +++ b/iroh/examples/custom-protocol.rs @@ -0,0 +1,127 @@ +use std::sync::Arc; + +use anyhow::Result; +use clap::Parser; +use futures_lite::future::Boxed as BoxedFuture; +use iroh::{ + client::MemIroh, + net::{ + endpoint::{get_remote_node_id, Connecting}, + Endpoint, NodeId, + }, + node::ProtocolHandler, +}; +use tracing_subscriber::{prelude::*, EnvFilter}; + +#[derive(Debug, Parser)] +pub struct Cli { + #[clap(subcommand)] + command: Command, +} + +#[derive(Debug, Parser)] +pub enum Command { + Accept, + Connect { node: NodeId }, +} + +#[tokio::main] +async fn main() -> Result<()> { + setup_logging(); + let args = Cli::parse(); + // create a new node + let builder = iroh::node::Node::memory().build().await?; + let proto = ExampleProto::new(builder.client().clone(), builder.endpoint().clone()); + let node = builder + .accept(EXAMPLE_ALPN, Arc::new(proto.clone())) + .spawn() + .await?; + + // print the ticket if this is the accepting side + match args.command { + Command::Accept => { + let node_id = node.node_id(); + println!("node id: {node_id}"); + // wait until ctrl-c + tokio::signal::ctrl_c().await?; + } + Command::Connect { node: node_id } => { + proto.connect(node_id).await?; + } + } + + node.shutdown().await?; + + Ok(()) +} + +const EXAMPLE_ALPN: &[u8] = b"example-proto/0"; + +#[derive(Debug, Clone)] +struct ExampleProto { + client: MemIroh, + endpoint: Endpoint, +} + +impl ProtocolHandler for ExampleProto { + fn accept(self: Arc, connecting: Connecting) -> BoxedFuture> { + Box::pin(async move { + let connection = connecting.await?; + let peer = get_remote_node_id(&connection)?; + println!("accepted connection from {peer}"); + let mut send_stream = connection.open_uni().await?; + // Let's create a new blob for each incoming connection. + // This functions as an example of using existing iroh functionality within a protocol + // (you likely don't want to create a new blob for each connection for real) + let content = format!("this blob is created for my beloved peer {peer} ♥"); + let hash = self + .client + .blobs() + .add_bytes(content.as_bytes().to_vec()) + .await?; + // Send the hash over our custom protocol. + send_stream.write_all(hash.hash.as_bytes()).await?; + send_stream.finish().await?; + println!("closing connection from {peer}"); + Ok(()) + }) + } +} + +impl ExampleProto { + pub fn new(client: MemIroh, endpoint: Endpoint) -> Self { + Self { client, endpoint } + } + + pub async fn connect(&self, remote_node_id: NodeId) -> Result<()> { + println!("our node id: {}", self.endpoint.node_id()); + println!("connecting to {remote_node_id}"); + let conn = self + .endpoint + .connect_by_node_id(&remote_node_id, EXAMPLE_ALPN) + .await?; + let mut recv_stream = conn.accept_uni().await?; + let hash_bytes = recv_stream.read_to_end(32).await?; + let hash = iroh::blobs::Hash::from_bytes(hash_bytes.try_into().unwrap()); + println!("received hash: {hash}"); + self.client + .blobs() + .download(hash, remote_node_id.into()) + .await? + .await?; + println!("blob downloaded"); + let content = self.client.blobs().read_to_bytes(hash).await?; + let message = String::from_utf8(content.to_vec())?; + println!("blob content: {message}"); + Ok(()) + } +} + +/// Set the RUST_LOG env var to one of {debug,info,warn} to see logging. +fn setup_logging() { + tracing_subscriber::registry() + .with(tracing_subscriber::fmt::layer().with_writer(std::io::stderr)) + .with(EnvFilter::from_default_env()) + .try_init() + .ok(); +} diff --git a/iroh/src/node.rs b/iroh/src/node.rs index ac1bee9548..ae9a5ddb69 100644 --- a/iroh/src/node.rs +++ b/iroh/src/node.rs @@ -14,25 +14,26 @@ use iroh_base::key::PublicKey; use iroh_blobs::downloader::Downloader; use iroh_blobs::store::Store as BaoStore; use iroh_docs::engine::Engine; -use iroh_net::endpoint::DirectAddrsStream; +use iroh_gossip::net::Gossip; use iroh_net::key::SecretKey; -use iroh_net::util::AbortingJoinHandle; use iroh_net::Endpoint; +use iroh_net::{endpoint::DirectAddrsStream, util::SharedAbortingJoinHandle}; use quic_rpc::transport::flume::FlumeConnection; use quic_rpc::RpcClient; -use tokio::task::JoinHandle; use tokio_util::sync::CancellationToken; use tokio_util::task::LocalPoolHandle; use tracing::debug; -use crate::client::RpcService; +use crate::{client::RpcService, node::protocol::ProtocolMap}; mod builder; +mod protocol; mod rpc; mod rpc_status; pub use self::builder::{Builder, DiscoveryConfig, GcPolicy, StorageConfig}; pub use self::rpc_status::RpcStatus; +pub use protocol::ProtocolHandler; /// A server which implements the iroh node. /// @@ -47,22 +48,22 @@ pub use self::rpc_status::RpcStatus; #[derive(Debug, Clone)] pub struct Node { inner: Arc>, - task: Arc>, client: crate::client::MemIroh, + task: SharedAbortingJoinHandle<()>, + protocols: Arc, } #[derive(derive_more::Debug)] struct NodeInner { db: D, + sync: DocsEngine, endpoint: Endpoint, + gossip: Gossip, secret_key: SecretKey, cancel_token: CancellationToken, controller: FlumeConnection, - #[allow(dead_code)] - gc_task: Option>, #[debug("rt")] rt: LocalPoolHandle, - pub(crate) sync: DocsEngine, downloader: Downloader, } @@ -152,20 +153,21 @@ impl Node { self.inner.endpoint.home_relay() } - /// Aborts the node. + /// Shutdown the node. /// /// This does not gracefully terminate currently: all connections are closed and - /// anything in-transit is lost. The task will stop running. - /// If this is the last copy of the `Node`, this will finish once the task is - /// fully shutdown. + /// anything in-transit is lost. The shutdown behaviour will become more graceful + /// in the future. /// - /// The shutdown behaviour will become more graceful in the future. + /// Returns a future that completes once all tasks terminated and all resources are closed. + /// The future resolves to an error if the main task panicked. pub async fn shutdown(self) -> Result<()> { + // Trigger shutdown of the main run task by activating the cancel token. self.inner.cancel_token.cancel(); - if let Ok(task) = Arc::try_unwrap(self.task) { - task.await?; - } + // Wait for the main task to terminate. + self.task.await.map_err(|err| anyhow!(err))?; + Ok(()) } @@ -173,6 +175,14 @@ impl Node { pub fn cancel_token(&self) -> CancellationToken { self.inner.cancel_token.clone() } + + /// Returns a protocol handler for an ALPN. + /// + /// This downcasts to the concrete type and returns `None` if the handler registered for `alpn` + /// does not match the passed type. + pub fn get_protocol(&self, alpn: &[u8]) -> Option> { + self.protocols.get_typed(alpn) + } } impl std::ops::Deref for Node { diff --git a/iroh/src/node/builder.rs b/iroh/src/node/builder.rs index 69a9a451b4..5a266127cf 100644 --- a/iroh/src/node/builder.rs +++ b/iroh/src/node/builder.rs @@ -6,7 +6,7 @@ use std::{ time::Duration, }; -use anyhow::{bail, Context, Result}; +use anyhow::{Context, Result}; use futures_lite::StreamExt; use iroh_base::key::SecretKey; use iroh_blobs::{ @@ -24,23 +24,28 @@ use iroh_net::{ Endpoint, }; use quic_rpc::{ - transport::{misc::DummyServerEndpoint, quinn::QuinnServerEndpoint}, + transport::{ + flume::FlumeServerEndpoint, misc::DummyServerEndpoint, quinn::QuinnServerEndpoint, + }, RpcServer, ServiceEndpoint, }; use serde::{Deserialize, Serialize}; +use tokio::task::JoinSet; use tokio_util::{sync::CancellationToken, task::LocalPoolHandle}; use tracing::{debug, error, error_span, info, trace, warn, Instrument}; use crate::{ client::RPC_ALPN, + node::{ + protocol::{BlobsProtocol, ProtocolMap}, + ProtocolHandler, + }, rpc_protocol::RpcService, util::{fs::load_secret_key, path::IrohPaths}, }; use super::{rpc, rpc_status::RpcStatus, DocsEngine, Node, NodeInner}; -pub const PROTOCOLS: [&[u8]; 3] = [iroh_blobs::protocol::ALPN, GOSSIP_ALPN, DOCS_ALPN]; - /// Default bind address for the node. /// 11204 is "iroh" in leetspeak pub const DEFAULT_BIND_PORT: u16 = 11204; @@ -83,7 +88,7 @@ where gc_policy: GcPolicy, dns_resolver: Option, node_discovery: DiscoveryConfig, - docs_store: iroh_docs::store::fs::Store, + docs_store: iroh_docs::store::Store, #[cfg(any(test, feature = "test-utils"))] insecure_skip_relay_cert_verify: bool, /// Callback to register when a gc loop is done @@ -183,7 +188,9 @@ where tokio::fs::create_dir_all(&blob_dir).await?; let blobs_store = iroh_blobs::store::fs::Store::load(&blob_dir) .await - .with_context(|| format!("Failed to load iroh database from {}", blob_dir.display()))?; + .with_context(|| { + format!("Failed to load blobs database from {}", blob_dir.display()) + })?; let docs_store = iroh_docs::store::fs::Store::persistent(IrohPaths::DocsDatabase.with_root(root))?; @@ -366,20 +373,28 @@ where /// connections. The returned [`Node`] can be used to control the task as well as /// get information about it. pub async fn spawn(self) -> Result> { - // We clone the blob store to shut it down in case the node fails to spawn. + let unspawned_node = self.build().await?; + unspawned_node.spawn().await + } + + /// Build a node without spawning it. + /// + /// Returns an `ProtocolBuilder`, on which custom protocols can be registered with + /// [`ProtocolBuilder::accept`]. To spawn the node, call [`ProtocolBuilder::spawn`]. + pub async fn build(self) -> Result> { + // Clone the blob store to shutdown in case of error. let blobs_store = self.blobs_store.clone(); - match self.spawn_inner().await { + match self.build_inner().await { Ok(node) => Ok(node), Err(err) => { - debug!("failed to spawn node, shutting down"); blobs_store.shutdown().await; Err(err) } } } - async fn spawn_inner(mut self) -> Result> { - trace!("spawning node"); + async fn build_inner(self) -> Result> { + trace!("building node"); let lp = LocalPoolHandle::new(num_cpus::get()); let mut transport_config = quinn::TransportConfig::default(); @@ -404,7 +419,6 @@ where let endpoint = Endpoint::builder() .secret_key(self.secret_key.clone()) .proxy_from_env() - .alpns(PROTOCOLS.iter().map(|p| p.to_vec()).collect()) .keylog(self.keylog) .transport_config(transport_config) .concurrent_connections(MAX_CONNECTIONS) @@ -435,8 +449,6 @@ where let cancel_token = CancellationToken::new(); - debug!("rpc listening on: {:?}", self.rpc_endpoint.local_addr()); - let addr = endpoint.node_addr().await?; // initialize the gossip protocol @@ -464,95 +476,55 @@ where default_author_storage, ) .await?; - let sync_db = sync.sync.clone(); let sync = DocsEngine(sync); - let gc_task = if let GcPolicy::Interval(gc_period) = self.gc_policy { - tracing::info!("Starting GC task with interval {:?}", gc_period); - let db = self.blobs_store.clone(); - let gc_done_callback = self.gc_done_callback.take(); - - let task = - lp.spawn_pinned(move || Self::gc_loop(db, sync_db, gc_period, gc_done_callback)); - Some(task.into()) - } else { - None - }; + // Initialize the internal RPC connection. let (internal_rpc, controller) = quic_rpc::transport::flume::connection(1); let client = crate::client::Iroh::new(quic_rpc::RpcClient::new(controller.clone())); + debug!("rpc listening on: {:?}", self.rpc_endpoint.local_addr()); let inner = Arc::new(NodeInner { - db: self.blobs_store, + db: self.blobs_store.clone(), + sync, endpoint: endpoint.clone(), secret_key: self.secret_key, controller, cancel_token, - gc_task, - rt: lp.clone(), - sync, + rt: lp, downloader, + gossip, }); - let task = { - let gossip = gossip.clone(); - let handler = rpc::Handler { - inner: inner.clone(), - }; - let me = endpoint.node_id().fmt_short(); - let ep = endpoint.clone(); - tokio::task::spawn( - async move { - Self::run( - ep, - handler, - self.rpc_endpoint, - internal_rpc, - gossip, - ) - .await - } - .instrument(error_span!("node", %me)), - ) - }; - let node = Node { + let node = ProtocolBuilder { inner, - task: Arc::new(task), client, + protocols: Default::default(), + internal_rpc, + gc_policy: self.gc_policy, + gc_done_callback: self.gc_done_callback, + rpc_endpoint: self.rpc_endpoint, }; - // spawn a task that updates the gossip endpoints. - // TODO: track task - let mut stream = endpoint.direct_addresses(); - tokio::task::spawn(async move { - while let Some(eps) = stream.next().await { - if let Err(err) = gossip.update_direct_addresses(&eps) { - warn!("Failed to update gossip endpoints: {err:?}"); - } - } - warn!("failed to retrieve local endpoints"); - }); - - // Wait for a single endpoint update, to make sure - // we found some endpoints - tokio::time::timeout(ENDPOINT_WAIT, endpoint.direct_addresses().next()) - .await - .context("waiting for endpoint")? - .context("no endpoints")?; + let node = node.register_iroh_protocols(); Ok(node) } - #[allow(clippy::too_many_arguments)] async fn run( - server: Endpoint, - handler: rpc::Handler, + inner: Arc>, rpc: E, internal_rpc: impl ServiceEndpoint, - gossip: Gossip, + protocols: Arc, + mut join_set: JoinSet>, ) { + let endpoint = inner.endpoint.clone(); + + let handler = rpc::Handler { + inner: inner.clone(), + }; let rpc = RpcServer::new(rpc); let internal_rpc = RpcServer::new(internal_rpc); - let (ipv4, ipv6) = server.bound_sockets(); + let (ipv4, ipv6) = endpoint.bound_sockets(); debug!( "listening at: {}{}", ipv4, @@ -561,24 +533,19 @@ where let cancel_token = handler.inner.cancel_token.clone(); - // forward our initial endpoints to the gossip protocol + // forward the initial endpoints to the gossip protocol. // it may happen the the first endpoint update callback is missed because the gossip cell // is only initialized once the endpoint is fully bound - if let Some(local_endpoints) = server.direct_addresses().next().await { - debug!(me = ?server.node_id(), "gossip initial update: {local_endpoints:?}"); - gossip.update_direct_addresses(&local_endpoints).ok(); + if let Some(direct_addresses) = endpoint.direct_addresses().next().await { + debug!(me = ?endpoint.node_id(), "gossip initial update: {direct_addresses:?}"); + inner.gossip.update_direct_addresses(&direct_addresses).ok(); } + loop { tokio::select! { biased; _ = cancel_token.cancelled() => { - // clean shutdown of the blobs db to close the write transaction - handler.inner.db.shutdown().await; - - if let Err(err) = handler.inner.sync.shutdown().await { - warn!("sync shutdown error: {:?}", err); - } - break + break; }, // handle rpc requests. This will do nothing if rpc is not configured, since // accept is just a pending future. @@ -603,42 +570,49 @@ where } } }, - // handle incoming p2p connections - Some(mut connecting) = server.accept() => { - let alpn = match connecting.alpn().await { - Ok(alpn) => alpn, - Err(err) => { - error!("invalid handshake: {:?}", err); - continue; - } - }; - let gossip = gossip.clone(); - let inner = handler.inner.clone(); - let sync = handler.inner.sync.clone(); - tokio::task::spawn(async move { - if let Err(err) = handle_connection(connecting, alpn, inner, gossip, sync).await { - warn!("Handling incoming connection ended with error: {err}"); - } + // handle incoming p2p connections. + Some(connecting) = endpoint.accept() => { + let protocols = protocols.clone(); + join_set.spawn(async move { + handle_connection(connecting, protocols).await; + Ok(()) }); }, + // handle task terminations and quit on panics. + res = join_set.join_next(), if !join_set.is_empty() => { + if let Some(Err(err)) = res { + error!("Task failed: {err:?}"); + break; + } + }, else => break, } } - // Closing the Endpoint is the equivalent of calling Connection::close on all - // connections: Operations will immediately fail with - // ConnectionError::LocallyClosed. All streams are interrupted, this is not - // graceful. + // Shutdown the different parts of the node concurrently. let error_code = Closed::ProviderTerminating; - server - .close(error_code.into(), error_code.reason()) - .await - .ok(); + // We ignore all errors during shutdown. + let _ = tokio::join!( + // Close the endpoint. + // Closing the Endpoint is the equivalent of calling Connection::close on all + // connections: Operations will immediately fail with ConnectionError::LocallyClosed. + // All streams are interrupted, this is not graceful. + endpoint.close(error_code.into(), error_code.reason()), + // Shutdown sync engine. + inner.sync.shutdown(), + // Shutdown blobs store engine. + inner.db.shutdown(), + // Shutdown protocol handlers. + protocols.shutdown(), + ); + + // Abort remaining tasks. + join_set.shutdown().await; } async fn gc_loop( db: D, - ds: iroh_docs::actor::SyncHandle, + ds: DocsEngine, gc_period: Duration, done_cb: Option>, ) { @@ -655,7 +629,8 @@ where tokio::time::sleep(gc_period).await; tracing::debug!("Starting GC"); live.clear(); - let doc_hashes = match ds.content_hashes().await { + + let doc_hashes = match ds.sync.content_hashes().await { Ok(hashes) => hashes, Err(err) => { tracing::warn!("Error getting doc hashes: {}", err); @@ -715,6 +690,237 @@ where } } +/// A node that is initialized but not yet spawned. +/// +/// This is returned from [`Builder::build`] and may be used to register custom protocols with +/// [`Self::accept`]. It provides access to the services which are already started, the node's +/// endpoint and a client to the node. +/// +/// Note that RPC calls performed with client returned from [`Self::client`] will not complete +/// until the node is spawned. +#[derive(derive_more::Debug)] +pub struct ProtocolBuilder { + inner: Arc>, + client: crate::client::MemIroh, + internal_rpc: FlumeServerEndpoint, + rpc_endpoint: E, + protocols: ProtocolMap, + #[debug("callback")] + gc_done_callback: Option>, + gc_policy: GcPolicy, +} + +impl> ProtocolBuilder { + /// Register a protocol handler for incoming connections. + /// + /// Use this to register custom protocols onto the iroh node. Whenever a new connection for + /// `alpn` comes in, it is passed to this protocol handler. + /// + /// See the [`ProtocolHandler`] trait for details. + /// + /// Example usage: + /// + /// ```rust + /// # use std::sync::Arc; + /// # use anyhow::Result; + /// # use futures_lite::future::Boxed as BoxedFuture; + /// # use iroh::{node::{Node, ProtocolHandler}, net::endpoint::Connecting, client::MemIroh}; + /// # + /// # #[tokio::main] + /// # async fn main() -> Result<()> { + /// + /// const MY_ALPN: &[u8] = b"my-protocol/1"; + /// + /// #[derive(Debug)] + /// struct MyProtocol { + /// client: MemIroh + /// } + /// + /// impl ProtocolHandler for MyProtocol { + /// fn accept(self: Arc, conn: Connecting) -> BoxedFuture> { + /// todo!(); + /// } + /// } + /// + /// let unspawned_node = Node::memory() + /// .build() + /// .await?; + /// + /// let client = unspawned_node.client().clone(); + /// let handler = MyProtocol { client }; + /// + /// let node = unspawned_node + /// .accept(MY_ALPN, Arc::new(handler)) + /// .spawn() + /// .await?; + /// # node.shutdown().await?; + /// # Ok(()) + /// # } + /// ``` + /// + /// + pub fn accept(mut self, alpn: &'static [u8], handler: Arc) -> Self { + self.protocols.insert(alpn, handler); + self + } + + /// Return a client to control this node over an in-memory channel. + /// + /// Note that RPC calls performed with the client will not complete until the node is + /// spawned. + pub fn client(&self) -> &crate::client::MemIroh { + &self.client + } + + /// Returns the [`Endpoint`] of the node. + pub fn endpoint(&self) -> &Endpoint { + &self.inner.endpoint + } + + /// Returns the [`crate::blobs::store::Store`] used by the node. + pub fn blobs_db(&self) -> &D { + &self.inner.db + } + + /// Returns a reference to the used [`LocalPoolHandle`]. + pub fn local_pool_handle(&self) -> &LocalPoolHandle { + &self.inner.rt + } + + /// Returns a reference to the [`Downloader`] used by the node. + pub fn downloader(&self) -> &Downloader { + &self.inner.downloader + } + + /// Returns a reference to the [`Gossip`] handle used by the node. + pub fn gossip(&self) -> &Gossip { + &self.inner.gossip + } + + /// Returns a protocol handler for an ALPN. + /// + /// This downcasts to the concrete type and returns `None` if the handler registered for `alpn` + /// does not match the passed type. + pub fn get_protocol(&self, alpn: &[u8]) -> Option> { + self.protocols.get_typed(alpn) + } + + /// Register the core iroh protocols (blobs, gossip, docs). + fn register_iroh_protocols(mut self) -> Self { + // Register blobs. + let blobs_proto = + BlobsProtocol::new(self.blobs_db().clone(), self.local_pool_handle().clone()); + self = self.accept(iroh_blobs::protocol::ALPN, Arc::new(blobs_proto)); + + // Register gossip. + let gossip = self.gossip().clone(); + self = self.accept(GOSSIP_ALPN, Arc::new(gossip)); + + // Register docs. + let docs = self.inner.sync.clone(); + self = self.accept(DOCS_ALPN, Arc::new(docs)); + + self + } + + /// Spawn the node and start accepting connections. + pub async fn spawn(self) -> Result> { + let Self { + inner, + client, + internal_rpc, + rpc_endpoint, + protocols, + gc_done_callback, + gc_policy, + } = self; + let protocols = Arc::new(protocols); + let protocols_clone = protocols.clone(); + + // Create the actual spawn future in an async block so that we can shutdown the protocols in case of + // error. + let node_fut = async move { + let mut join_set = JoinSet::new(); + + // Spawn a task for the garbage collection. + if let GcPolicy::Interval(gc_period) = gc_policy { + tracing::info!("Starting GC task with interval {:?}", gc_period); + let lp = inner.rt.clone(); + let docs = inner.sync.clone(); + let blobs_store = inner.db.clone(); + let handle = lp.spawn_pinned(move || { + Builder::::gc_loop(blobs_store, docs, gc_period, gc_done_callback) + }); + // We cannot spawn tasks that run on the local pool directly into the join set, + // so instead we create a new task that supervises the local task. + join_set.spawn(async move { + if let Err(err) = handle.await { + return Err(anyhow::Error::from(err)); + } + Ok(()) + }); + } + + // Spawn a task that updates the gossip endpoints. + let mut stream = inner.endpoint.direct_addresses(); + let gossip = inner.gossip.clone(); + join_set.spawn(async move { + while let Some(eps) = stream.next().await { + if let Err(err) = gossip.update_direct_addresses(&eps) { + warn!("Failed to update direct addresses for gossip: {err:?}"); + } + } + warn!("failed to retrieve local endpoints"); + Ok(()) + }); + + // Update the endpoint with our alpns. + let alpns = protocols + .alpns() + .map(|alpn| alpn.to_vec()) + .collect::>(); + inner.endpoint.set_alpns(alpns)?; + + // Spawn the main task and store it in the node for structured termination in shutdown. + let task = tokio::task::spawn( + Builder::run( + inner.clone(), + rpc_endpoint, + internal_rpc, + protocols.clone(), + join_set, + ) + .instrument(error_span!("node", me=%inner.endpoint.node_id().fmt_short())), + ); + + let node = Node { + inner, + client, + protocols, + task: task.into(), + }; + + // Wait for a single endpoint update, to make sure + // we found some endpoints + tokio::time::timeout(ENDPOINT_WAIT, node.endpoint().direct_addresses().next()) + .await + .context("waiting for endpoint")? + .context("no endpoints")?; + + Ok(node) + }; + + match node_fut.await { + Ok(node) => Ok(node), + Err(err) => { + // Shutdown the protocols in case of error. + protocols_clone.shutdown().await; + Err(err) + } + } + } +} + /// Policy for garbage collection. #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] pub enum GcPolicy { @@ -730,31 +936,24 @@ impl Default for GcPolicy { } } -// TODO: Restructure this code to not take all these arguments. -#[allow(clippy::too_many_arguments)] -async fn handle_connection( - connecting: iroh_net::endpoint::Connecting, - alpn: Vec, - node: Arc>, - gossip: Gossip, - sync: DocsEngine, -) -> Result<()> { - match alpn.as_ref() { - GOSSIP_ALPN => gossip.handle_connection(connecting.await?).await?, - DOCS_ALPN => sync.handle_connection(connecting).await?, - alpn if alpn == iroh_blobs::protocol::ALPN => { - let connection = connecting.await?; - iroh_blobs::provider::handle_connection( - connection, - node.db.clone(), - MockEventSender, - node.rt.clone(), - ) - .await +async fn handle_connection( + mut connecting: iroh_net::endpoint::Connecting, + protocols: Arc, +) { + let alpn = match connecting.alpn().await { + Ok(alpn) => alpn, + Err(err) => { + warn!("Ignoring connection: invalid handshake: {:?}", err); + return; } - _ => bail!("ignoring connection: unsupported ALPN protocol"), + }; + let Some(handler) = protocols.get(&alpn) else { + warn!("Ignoring connection: unsupported ALPN protocol"); + return; + }; + if let Err(err) = handler.accept(connecting).await { + warn!("Handling incoming connection ended with error: {err}"); } - Ok(()) } const DEFAULT_RPC_PORT: u16 = 0x1337; @@ -774,7 +973,7 @@ fn make_rpc_endpoint( let mut server_config = iroh_net::endpoint::make_server_config( secret_key, vec![RPC_ALPN.to_vec()], - Some(transport_config), + Arc::new(transport_config), false, )?; server_config.concurrent_connections(MAX_RPC_CONNECTIONS); @@ -804,12 +1003,3 @@ fn make_rpc_endpoint( Ok((rpc_endpoint, actual_rpc_port)) } - -#[derive(Debug, Clone)] -struct MockEventSender; - -impl iroh_blobs::provider::EventSender for MockEventSender { - fn send(&self, _event: iroh_blobs::provider::Event) -> futures_lite::future::Boxed<()> { - Box::pin(std::future::ready(())) - } -} diff --git a/iroh/src/node/protocol.rs b/iroh/src/node/protocol.rs new file mode 100644 index 0000000000..25106e7c38 --- /dev/null +++ b/iroh/src/node/protocol.rs @@ -0,0 +1,127 @@ +use std::{any::Any, collections::BTreeMap, fmt, sync::Arc}; + +use anyhow::Result; +use futures_lite::future::Boxed as BoxedFuture; +use futures_util::future::join_all; +use iroh_net::endpoint::Connecting; + +use crate::node::DocsEngine; + +/// Handler for incoming connections. +/// +/// An iroh node can accept connections for arbitrary ALPN protocols. By default, the iroh node +/// only accepts connections for the ALPNs of the core iroh protocols (blobs, gossip, docs). +/// +/// With this trait, you can handle incoming connections for custom protocols. +/// +/// Implement this trait on a struct that should handle incoming connections. +/// The protocol handler must then be registered on the node for an ALPN protocol with +/// [`crate::node::builder::ProtocolBuilder::accept`]. +pub trait ProtocolHandler: Send + Sync + IntoArcAny + fmt::Debug + 'static { + /// Handle an incoming connection. + /// + /// This runs on a freshly spawned tokio task so this can be long-running. + fn accept(self: Arc, conn: Connecting) -> BoxedFuture>; + + /// Called when the node shuts down. + fn shutdown(self: Arc) -> BoxedFuture<()> { + Box::pin(async move {}) + } +} + +/// Helper trait to facilite casting from `Arc` to `Arc`. +/// +/// This trait has a blanket implementation so there is no need to implement this yourself. +pub trait IntoArcAny { + fn into_arc_any(self: Arc) -> Arc; +} + +impl IntoArcAny for T { + fn into_arc_any(self: Arc) -> Arc { + self + } +} + +#[derive(Debug, Clone, Default)] +pub(super) struct ProtocolMap(BTreeMap<&'static [u8], Arc>); + +impl ProtocolMap { + /// Returns the registered protocol handler for an ALPN as a concrete type. + pub fn get_typed(&self, alpn: &[u8]) -> Option> { + let protocol: Arc = self.0.get(alpn)?.clone(); + let protocol_any: Arc = protocol.into_arc_any(); + let protocol_ref = Arc::downcast(protocol_any).ok()?; + Some(protocol_ref) + } + + /// Returns the registered protocol handler for an ALPN as a [`Arc`]. + pub fn get(&self, alpn: &[u8]) -> Option> { + self.0.get(alpn).cloned() + } + + /// Insert a protocol handler. + pub fn insert(&mut self, alpn: &'static [u8], handler: Arc) { + self.0.insert(alpn, handler); + } + + /// Returns an iterator of all registered ALPN protocol identifiers. + pub fn alpns(&self) -> impl Iterator { + self.0.keys() + } + + /// Shutdown all protocol handlers. + /// + /// Calls and awaits [`ProtocolHandler::shutdown`] for all registered handlers concurrently. + pub async fn shutdown(&self) { + let handlers = self.0.values().cloned().map(ProtocolHandler::shutdown); + join_all(handlers).await; + } +} + +#[derive(Debug)] +pub(crate) struct BlobsProtocol { + rt: tokio_util::task::LocalPoolHandle, + store: S, +} + +impl BlobsProtocol { + pub fn new(store: S, rt: tokio_util::task::LocalPoolHandle) -> Self { + Self { rt, store } + } +} + +impl ProtocolHandler for BlobsProtocol { + fn accept(self: Arc, conn: Connecting) -> BoxedFuture> { + Box::pin(async move { + iroh_blobs::provider::handle_connection( + conn.await?, + self.store.clone(), + MockEventSender, + self.rt.clone(), + ) + .await; + Ok(()) + }) + } +} + +#[derive(Debug, Clone)] +struct MockEventSender; + +impl iroh_blobs::provider::EventSender for MockEventSender { + fn send(&self, _event: iroh_blobs::provider::Event) -> futures_lite::future::Boxed<()> { + Box::pin(std::future::ready(())) + } +} + +impl ProtocolHandler for iroh_gossip::net::Gossip { + fn accept(self: Arc, conn: Connecting) -> BoxedFuture> { + Box::pin(async move { self.handle_connection(conn.await?).await }) + } +} + +impl ProtocolHandler for DocsEngine { + fn accept(self: Arc, conn: Connecting) -> BoxedFuture> { + Box::pin(async move { self.handle_connection(conn).await }) + } +}