diff --git a/Cargo.lock b/Cargo.lock index a63e49d931..ef3fa7ca76 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2437,6 +2437,7 @@ dependencies = [ "iroh-quinn", "iroh-test", "num_cpus", + "once_cell", "parking_lot", "portable-atomic", "postcard", diff --git a/iroh/Cargo.toml b/iroh/Cargo.toml index aac6f9a645..5462a5f2ff 100644 --- a/iroh/Cargo.toml +++ b/iroh/Cargo.toml @@ -35,6 +35,7 @@ num_cpus = { version = "1.15.0" } portable-atomic = "1" iroh-docs = { version = "0.18.0", path = "../iroh-docs" } iroh-gossip = { version = "0.18.0", path = "../iroh-gossip" } +once_cell = "1.18.0" parking_lot = "0.12.1" postcard = { version = "1", default-features = false, features = ["alloc", "use-std", "experimental-derive"] } quic-rpc = { version = "0.10.0", default-features = false, features = ["flume-transport", "quinn-transport"] } @@ -101,3 +102,7 @@ required-features = ["examples"] [[example]] name = "client" required-features = ["examples"] + +[[example]] +name = "custom-protocol" +required-features = ["examples"] diff --git a/iroh/examples/custom-protocol.rs b/iroh/examples/custom-protocol.rs new file mode 100644 index 0000000000..c973b22063 --- /dev/null +++ b/iroh/examples/custom-protocol.rs @@ -0,0 +1,130 @@ +use std::{fmt, sync::Arc}; + +use anyhow::Result; +use clap::Parser; +use futures_lite::future::Boxed as BoxedFuture; +use iroh::{ + blobs::store::Store, + net::{ + endpoint::{get_remote_node_id, Connecting}, + NodeId, + }, + node::{Node, Protocol}, +}; +use tracing_subscriber::{prelude::*, EnvFilter}; + +#[derive(Debug, Parser)] +pub struct Cli { + #[clap(subcommand)] + command: Command, +} + +#[derive(Debug, Parser)] +pub enum Command { + Accept, + Connect { node: NodeId }, +} + +#[tokio::main] +async fn main() -> Result<()> { + setup_logging(); + let args = Cli::parse(); + // create a new node + let node = iroh::node::Node::memory() + .accept(EXAMPLE_ALPN, |node| ExampleProto::build(node)) + .spawn() + .await?; + + // print the ticket if this is the accepting side + match args.command { + Command::Accept => { + let node_id = node.node_id(); + println!("node id: {node_id}"); + // wait until ctrl-c + tokio::signal::ctrl_c().await?; + } + Command::Connect { node: node_id } => { + let proto = ExampleProto::get_from_node(&node, EXAMPLE_ALPN).expect("it is registered"); + proto.connect(node_id).await?; + } + } + + node.shutdown().await?; + + Ok(()) +} + +const EXAMPLE_ALPN: &[u8] = b"example-proto/0"; + +#[derive(Debug)] +struct ExampleProto { + node: Node, +} + +impl Protocol for ExampleProto { + fn accept(self: Arc, conn: Connecting) -> BoxedFuture> { + Box::pin(async move { + let conn = conn.await?; + let remote_node_id = get_remote_node_id(&conn)?; + println!("accepted connection from {remote_node_id}"); + let mut send_stream = conn.open_uni().await?; + // not that this is something that you wanted to do, but let's create a new blob for each + // incoming connection. this could be any mechanism, but we want to demonstrate how to use a + // custom protocol together with built-in iroh functionality + let content = format!("this blob is created for my beloved peer {remote_node_id} ♥"); + let hash = self + .node + .blobs() + .add_bytes(content.as_bytes().to_vec()) + .await?; + // send the hash over our custom proto + send_stream.write_all(hash.hash.as_bytes()).await?; + send_stream.finish().await?; + println!("closing connection from {remote_node_id}"); + Ok(()) + }) + } +} + +impl ExampleProto { + fn build(node: Node) -> Arc { + Arc::new(Self { node }) + } + + fn get_from_node(node: &Node, alpn: &'static [u8]) -> Option> { + node.get_protocol::>(alpn) + } + + async fn connect(&self, remote_node_id: NodeId) -> Result<()> { + println!("our node id: {}", self.node.node_id()); + println!("connecting to {remote_node_id}"); + let conn = self + .node + .endpoint() + .connect_by_node_id(&remote_node_id, EXAMPLE_ALPN) + .await?; + let mut recv_stream = conn.accept_uni().await?; + let hash_bytes = recv_stream.read_to_end(32).await?; + let hash = iroh::blobs::Hash::from_bytes(hash_bytes.try_into().unwrap()); + println!("received hash: {hash}"); + self.node + .blobs() + .download(hash, remote_node_id.into()) + .await? + .await?; + println!("blob downloaded"); + let content = self.node.blobs().read_to_bytes(hash).await?; + let message = String::from_utf8(content.to_vec())?; + println!("blob content: {message}"); + Ok(()) + } +} + +// set the RUST_LOG env var to one of {debug,info,warn} to see logging info +pub fn setup_logging() { + tracing_subscriber::registry() + .with(tracing_subscriber::fmt::layer().with_writer(std::io::stderr)) + .with(EnvFilter::from_default_env()) + .try_init() + .ok(); +} diff --git a/iroh/src/node.rs b/iroh/src/node.rs index 3b9173c706..36cf4705a9 100644 --- a/iroh/src/node.rs +++ b/iroh/src/node.rs @@ -16,6 +16,7 @@ use iroh_blobs::store::Store as BaoStore; use iroh_docs::engine::Engine; use iroh_net::util::AbortingJoinHandle; use iroh_net::{endpoint::LocalEndpointsStream, key::SecretKey, Endpoint}; +use once_cell::sync::OnceCell; use quic_rpc::transport::flume::FlumeConnection; use quic_rpc::RpcClient; use tokio::task::JoinHandle; @@ -23,14 +24,16 @@ use tokio_util::sync::CancellationToken; use tokio_util::task::LocalPoolHandle; use tracing::debug; -use crate::client::RpcService; +use crate::{client::RpcService, node::protocol::ProtocolMap}; mod builder; +mod protocol; mod rpc; mod rpc_status; pub use self::builder::{Builder, DiscoveryConfig, GcPolicy, StorageConfig}; pub use self::rpc_status::RpcStatus; +pub use protocol::Protocol; /// A server which implements the iroh node. /// @@ -45,8 +48,9 @@ pub use self::rpc_status::RpcStatus; #[derive(Debug, Clone)] pub struct Node { inner: Arc>, - task: Arc>, + task: Arc>>, client: crate::client::MemIroh, + protocols: ProtocolMap, } #[derive(derive_more::Debug)] @@ -150,6 +154,11 @@ impl Node { self.inner.endpoint.my_relay() } + /// Returns the protocol handler for a alpn. + pub fn get_protocol(&self, alpn: &[u8]) -> Option> { + self.protocols.get(alpn) + } + /// Aborts the node. /// /// This does not gracefully terminate currently: all connections are closed and @@ -161,7 +170,8 @@ impl Node { pub async fn shutdown(self) -> Result<()> { self.inner.cancel_token.cancel(); - if let Ok(task) = Arc::try_unwrap(self.task) { + if let Ok(mut task) = Arc::try_unwrap(self.task) { + let task = task.take().expect("cannot be empty"); task.await?; } Ok(()) diff --git a/iroh/src/node/builder.rs b/iroh/src/node/builder.rs index db935479f2..d23732a08c 100644 --- a/iroh/src/node/builder.rs +++ b/iroh/src/node/builder.rs @@ -33,6 +33,7 @@ use tracing::{debug, error, error_span, info, trace, warn, Instrument}; use crate::{ client::RPC_ALPN, + node::{protocol::ProtocolMap, Protocol}, rpc_protocol::RpcService, util::{fs::load_secret_key, path::IrohPaths}, }; @@ -54,6 +55,11 @@ const DEFAULT_GC_INTERVAL: Duration = Duration::from_secs(60 * 5); const MAX_CONNECTIONS: u32 = 1024; const MAX_STREAMS: u64 = 10; +type ProtocolBuilders = Vec<( + &'static [u8], + Box) -> Arc + Send + 'static>, +)>; + /// Builder for the [`Node`]. /// /// You must supply a blob store and a document store. @@ -84,6 +90,7 @@ where dns_resolver: Option, node_discovery: DiscoveryConfig, docs_store: iroh_docs::store::fs::Store, + protocols: ProtocolBuilders, #[cfg(any(test, feature = "test-utils"))] insecure_skip_relay_cert_verify: bool, /// Callback to register when a gc loop is done @@ -133,6 +140,7 @@ impl Default for Builder { rpc_endpoint: Default::default(), gc_policy: GcPolicy::Disabled, docs_store: iroh_docs::store::Store::memory(), + protocols: Default::default(), node_discovery: Default::default(), #[cfg(any(test, feature = "test-utils"))] insecure_skip_relay_cert_verify: false, @@ -160,6 +168,7 @@ impl Builder { gc_policy: GcPolicy::Disabled, docs_store, node_discovery: Default::default(), + protocols: Default::default(), #[cfg(any(test, feature = "test-utils"))] insecure_skip_relay_cert_verify: false, gc_done_callback: None, @@ -223,6 +232,7 @@ where gc_policy: self.gc_policy, docs_store, node_discovery: self.node_discovery, + protocols: Default::default(), #[cfg(any(test, feature = "test-utils"))] insecure_skip_relay_cert_verify: false, gc_done_callback: self.gc_done_callback, @@ -244,6 +254,7 @@ where gc_policy: self.gc_policy, docs_store: self.docs_store, node_discovery: self.node_discovery, + protocols: Default::default(), #[cfg(any(test, feature = "test-utils"))] insecure_skip_relay_cert_verify: self.insecure_skip_relay_cert_verify, gc_done_callback: self.gc_done_callback, @@ -270,6 +281,7 @@ where gc_policy: self.gc_policy, docs_store: self.docs_store, node_discovery: self.node_discovery, + protocols: Default::default(), #[cfg(any(test, feature = "test-utils"))] insecure_skip_relay_cert_verify: self.insecure_skip_relay_cert_verify, gc_done_callback: self.gc_done_callback, @@ -343,6 +355,16 @@ where self } + /// Accept a custom protocol. + pub fn accept( + mut self, + alpn: &'static [u8], + protocol: impl FnOnce(Node) -> Arc + Send + 'static, + ) -> Self { + self.protocols.push((alpn, Box::new(protocol))); + self + } + /// Register a callback for when GC is done. #[cfg(any(test, feature = "test-utils"))] pub fn register_gc_done_cb(mut self, cb: Box) -> Self { @@ -401,10 +423,16 @@ where } }; + let alpns = PROTOCOLS + .iter() + .chain(self.protocols.iter().map(|(alpn, _)| alpn)) + .map(|p| p.to_vec()) + .collect(); + let endpoint = Endpoint::builder() .secret_key(self.secret_key.clone()) .proxy_from_env() - .alpns(PROTOCOLS.iter().map(|p| p.to_vec()).collect()) + .alpns(alpns) .keylog(self.keylog) .transport_config(transport_config) .concurrent_connections(MAX_CONNECTIONS) @@ -481,6 +509,8 @@ where let (internal_rpc, controller) = quic_rpc::transport::flume::connection(1); let client = crate::client::Iroh::new(quic_rpc::RpcClient::new(controller.clone())); + let protocols = ProtocolMap::default(); + let inner = Arc::new(NodeInner { db: self.blobs_store, endpoint: endpoint.clone(), @@ -492,7 +522,21 @@ where sync, downloader, }); + + let node = Node { + inner: inner.clone(), + task: Default::default(), + client, + protocols: protocols.clone(), + }; + + for (alpn, p) in self.protocols { + let protocol = p(node.clone()); + protocols.insert(alpn, protocol); + } + let task = { + let protocols = protocols.clone(); let gossip = gossip.clone(); let handler = rpc::Handler { inner: inner.clone(), @@ -503,6 +547,7 @@ where async move { Self::run( ep, + protocols, handler, self.rpc_endpoint, internal_rpc, @@ -514,11 +559,7 @@ where ) }; - let node = Node { - inner, - task: Arc::new(task), - client, - }; + node.task.set(task).expect("was empty"); // spawn a task that updates the gossip endpoints. // TODO: track task @@ -545,6 +586,7 @@ where #[allow(clippy::too_many_arguments)] async fn run( server: Endpoint, + protocols: ProtocolMap, handler: rpc::Handler, rpc: E, internal_rpc: impl ServiceEndpoint, @@ -615,8 +657,9 @@ where let gossip = gossip.clone(); let inner = handler.inner.clone(); let sync = handler.inner.sync.clone(); + let protocols = protocols.clone(); tokio::task::spawn(async move { - if let Err(err) = handle_connection(connecting, alpn, inner, gossip, sync).await { + if let Err(err) = handle_connection(connecting, alpn, inner, gossip, sync, protocols).await { warn!("Handling incoming connection ended with error: {err}"); } }); @@ -738,6 +781,7 @@ async fn handle_connection( node: Arc>, gossip: Gossip, sync: DocsEngine, + protocols: ProtocolMap, ) -> Result<()> { match alpn.as_bytes() { GOSSIP_ALPN => gossip.handle_connection(connecting.await?).await?, @@ -752,7 +796,14 @@ async fn handle_connection( ) .await } - _ => bail!("ignoring connection: unsupported ALPN protocol"), + alpn => { + let protocol = protocols.get_any(alpn); + if let Some(protocol) = protocol { + protocol.accept(connecting).await?; + } else { + bail!("ignoring connection: unsupported ALPN protocol"); + } + } } Ok(()) } diff --git a/iroh/src/node/protocol.rs b/iroh/src/node/protocol.rs new file mode 100644 index 0000000000..139ebbda8a --- /dev/null +++ b/iroh/src/node/protocol.rs @@ -0,0 +1,58 @@ +use std::{ + any::Any, + collections::HashMap, + fmt, + sync::{Arc, RwLock}, +}; + +use anyhow::Result; +use futures_lite::future::Boxed as BoxedFuture; +use iroh_net::endpoint::Connecting; + +/// Trait for iroh protocol handlers. +pub trait Protocol: Send + Sync + IntoArcAny + fmt::Debug + 'static { + /// Handle an incoming connection. + /// + /// This runs on a freshly spawned tokio task so this can be long-running. + fn accept(self: Arc, conn: Connecting) -> BoxedFuture>; +} + +/// Helper trait to facilite casting from `Arc` to `Arc`. +/// +/// This trait has a blanket implementation so there is no need to implement this yourself. +pub trait IntoArcAny { + fn into_arc_any(self: Arc) -> Arc; +} + +impl IntoArcAny for T { + fn into_arc_any(self: Arc) -> Arc { + self + } +} + +/// Map of registered protocol handlers. +#[allow(clippy::type_complexity)] +#[derive(Debug, Clone, Default)] +pub struct ProtocolMap(Arc>>>); + +impl ProtocolMap { + /// Returns the registered protocol handler for an ALPN as a concrete type. + pub fn get(&self, alpn: &[u8]) -> Option> { + let protocols = self.0.read().unwrap(); + let protocol: Arc = protocols.get(alpn)?.clone(); + let protocol_any: Arc = protocol.into_arc_any(); + let protocol_ref = Arc::downcast(protocol_any).ok()?; + Some(protocol_ref) + } + + /// Returns the registered protocol handler for an ALPN as a `dyn Protocol`. + pub fn get_any(&self, alpn: &[u8]) -> Option> { + let protocols = self.0.read().unwrap(); + let protocol: Arc = protocols.get(alpn)?.clone(); + Some(protocol) + } + + pub(super) fn insert(&self, alpn: &'static [u8], protocol: Arc) { + self.0.write().unwrap().insert(alpn, protocol); + } +}