Skip to content

Commit

Permalink
feat: custom protocols
Browse files Browse the repository at this point in the history
  • Loading branch information
Frando committed Jun 11, 2024
1 parent ea50b94 commit 79071f4
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 6 deletions.
16 changes: 14 additions & 2 deletions iroh/src/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
//! To shut down the node, call [`Node::shutdown`].
use std::fmt::Debug;
use std::net::SocketAddr;
use std::path::Path;
use std::sync::Arc;
use std::{any::Any, path::Path};

use anyhow::{anyhow, Result};
use futures_lite::StreamExt;
Expand All @@ -23,14 +23,16 @@ use tokio_util::sync::CancellationToken;
use tokio_util::task::LocalPoolHandle;
use tracing::debug;

use crate::client::RpcService;
use crate::{client::RpcService, node::builder::ProtocolMap};

mod builder;
mod protocol;
mod rpc;
mod rpc_status;

pub use self::builder::{Builder, DiscoveryConfig, GcPolicy, StorageConfig};
pub use self::rpc_status::RpcStatus;
pub use protocol::Protocol;

/// A server which implements the iroh node.
///
Expand All @@ -47,6 +49,7 @@ pub struct Node<D> {
inner: Arc<NodeInner<D>>,
task: Arc<JoinHandle<()>>,
client: crate::client::MemIroh,
protocols: ProtocolMap,
}

#[derive(derive_more::Debug)]
Expand Down Expand Up @@ -150,6 +153,15 @@ impl<D: BaoStore> Node<D> {
self.inner.endpoint.my_relay()
}

/// Returns the protocol handler for a alpn.
pub fn get_protocol<P: Protocol>(&self, alpn: &[u8]) -> Option<Arc<P>> {
let protocols = self.protocols.read().unwrap();
let protocol: Arc<dyn Protocol> = protocols.get(alpn)?.clone();
let protocol_any: Arc<dyn Any + Send + Sync> = protocol.as_arc_any();
let protocol_ref = Arc::downcast(protocol_any).ok()?;
Some(protocol_ref)
}

/// Aborts the node.
///
/// This does not gracefully terminate currently: all connections are closed and
Expand Down
60 changes: 56 additions & 4 deletions iroh/src/node/builder.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use std::{
collections::BTreeSet,
collections::{BTreeSet, HashMap},
net::{Ipv4Addr, SocketAddrV4},
path::{Path, PathBuf},
sync::Arc,
sync::{Arc, RwLock},
time::Duration,
};

Expand All @@ -28,11 +28,13 @@ use quic_rpc::{
RpcServer, ServiceEndpoint,
};
use serde::{Deserialize, Serialize};
use tokio::sync::oneshot;
use tokio_util::{sync::CancellationToken, task::LocalPoolHandle};
use tracing::{debug, error, error_span, info, trace, warn, Instrument};

use crate::{
client::RPC_ALPN,
node::Protocol,
rpc_protocol::RpcService,
util::{fs::load_secret_key, path::IrohPaths},
};
Expand All @@ -54,6 +56,9 @@ const DEFAULT_GC_INTERVAL: Duration = Duration::from_secs(60 * 5);
const MAX_CONNECTIONS: u32 = 1024;
const MAX_STREAMS: u64 = 10;

pub(super) type ProtocolMap = Arc<RwLock<HashMap<&'static [u8], Arc<dyn Protocol>>>>;
type ProtocolBuilders<D> = Vec<(&'static [u8], Box<dyn FnOnce(Node<D>) -> Arc<dyn Protocol>>)>;

/// Builder for the [`Node`].
///
/// You must supply a blob store and a document store.
Expand Down Expand Up @@ -84,6 +89,7 @@ where
dns_resolver: Option<DnsResolver>,
node_discovery: DiscoveryConfig,
docs_store: iroh_docs::store::fs::Store,
protocols: ProtocolBuilders<D>,
#[cfg(any(test, feature = "test-utils"))]
insecure_skip_relay_cert_verify: bool,
/// Callback to register when a gc loop is done
Expand Down Expand Up @@ -133,6 +139,7 @@ impl Default for Builder<iroh_blobs::store::mem::Store> {
rpc_endpoint: Default::default(),
gc_policy: GcPolicy::Disabled,
docs_store: iroh_docs::store::Store::memory(),
protocols: Default::default(),
node_discovery: Default::default(),
#[cfg(any(test, feature = "test-utils"))]
insecure_skip_relay_cert_verify: false,
Expand Down Expand Up @@ -160,6 +167,7 @@ impl<D: Map> Builder<D> {
gc_policy: GcPolicy::Disabled,
docs_store,
node_discovery: Default::default(),
protocols: Default::default(),
#[cfg(any(test, feature = "test-utils"))]
insecure_skip_relay_cert_verify: false,
gc_done_callback: None,
Expand Down Expand Up @@ -223,6 +231,7 @@ where
gc_policy: self.gc_policy,
docs_store,
node_discovery: self.node_discovery,
protocols: Default::default(),
#[cfg(any(test, feature = "test-utils"))]
insecure_skip_relay_cert_verify: false,
gc_done_callback: self.gc_done_callback,
Expand All @@ -244,6 +253,7 @@ where
gc_policy: self.gc_policy,
docs_store: self.docs_store,
node_discovery: self.node_discovery,
protocols: Default::default(),
#[cfg(any(test, feature = "test-utils"))]
insecure_skip_relay_cert_verify: self.insecure_skip_relay_cert_verify,
gc_done_callback: self.gc_done_callback,
Expand All @@ -270,6 +280,7 @@ where
gc_policy: self.gc_policy,
docs_store: self.docs_store,
node_discovery: self.node_discovery,
protocols: Default::default(),
#[cfg(any(test, feature = "test-utils"))]
insecure_skip_relay_cert_verify: self.insecure_skip_relay_cert_verify,
gc_done_callback: self.gc_done_callback,
Expand Down Expand Up @@ -343,6 +354,16 @@ where
self
}

/// Accept a custom protocol.
pub fn accept(
mut self,
alpn: &'static [u8],
protocol: impl FnOnce(Node<D>) -> Arc<dyn Protocol> + 'static,
) -> Self {
self.protocols.push((alpn, Box::new(protocol)));
self
}

/// Register a callback for when GC is done.
#[cfg(any(test, feature = "test-utils"))]
pub fn register_gc_done_cb(mut self, cb: Box<dyn Fn() + Send>) -> Self {
Expand Down Expand Up @@ -481,6 +502,8 @@ where
let (internal_rpc, controller) = quic_rpc::transport::flume::connection(1);
let client = crate::client::Iroh::new(quic_rpc::RpcClient::new(controller.clone()));

let protocols = Arc::new(RwLock::new(HashMap::new()));

let inner = Arc::new(NodeInner {
db: self.blobs_store,
endpoint: endpoint.clone(),
Expand All @@ -492,7 +515,9 @@ where
sync,
downloader,
});
let (ready_tx, ready_rx) = oneshot::channel();
let task = {
let protocols = Arc::clone(&protocols);
let gossip = gossip.clone();
let handler = rpc::Handler {
inner: inner.clone(),
Expand All @@ -501,8 +526,11 @@ where
let ep = endpoint.clone();
tokio::task::spawn(
async move {
// Wait until the protocol builders have run.
ready_rx.await.expect("cannot fail");
Self::run(
ep,
protocols,
handler,
self.rpc_endpoint,
internal_rpc,
Expand All @@ -518,8 +546,17 @@ where
inner,
task: Arc::new(task),
client,
protocols,
};

for (alpn, p) in self.protocols {
let protocol = p(node.clone());
node.protocols.write().unwrap().insert(alpn, protocol);
}

// Notify the run task that the protocols are now built.
ready_tx.send(()).expect("cannot fail");

// spawn a task that updates the gossip endpoints.
// TODO: track task
let mut stream = endpoint.local_endpoints();
Expand All @@ -545,6 +582,7 @@ where
#[allow(clippy::too_many_arguments)]
async fn run(
server: Endpoint,
protocols: ProtocolMap,
handler: rpc::Handler<D>,
rpc: E,
internal_rpc: impl ServiceEndpoint<RpcService>,
Expand Down Expand Up @@ -615,8 +653,9 @@ where
let gossip = gossip.clone();
let inner = handler.inner.clone();
let sync = handler.inner.sync.clone();
let protocols = protocols.clone();
tokio::task::spawn(async move {
if let Err(err) = handle_connection(connecting, alpn, inner, gossip, sync).await {
if let Err(err) = handle_connection(connecting, alpn, inner, gossip, sync, protocols).await {
warn!("Handling incoming connection ended with error: {err}");
}
});
Expand Down Expand Up @@ -738,6 +777,7 @@ async fn handle_connection<D: BaoStore>(
node: Arc<NodeInner<D>>,
gossip: Gossip,
sync: DocsEngine,
protocols: ProtocolMap,
) -> Result<()> {
match alpn.as_bytes() {
GOSSIP_ALPN => gossip.handle_connection(connecting.await?).await?,
Expand All @@ -752,7 +792,19 @@ async fn handle_connection<D: BaoStore>(
)
.await
}
_ => bail!("ignoring connection: unsupported ALPN protocol"),
alpn => {
let protocol = {
let protocols = protocols.read().unwrap();
protocols.get(alpn).cloned()
};
if let Some(protocol) = protocol {
drop(protocols);
let connection = connecting.await?;
protocol.accept(connection).await?;
} else {
bail!("ignoring connection: unsupported ALPN protocol");
}
}
}
Ok(())
}
Expand Down
19 changes: 19 additions & 0 deletions iroh/src/node/protocol.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
use std::{any::Any, fmt, future::Future, pin::Pin, sync::Arc};

use iroh_net::endpoint::Connection;

/// Trait for iroh protocol handlers.
pub trait Protocol: Sync + Send + Any + fmt::Debug + 'static {
/// Return `self` as `dyn Any`.
///
/// Implementations can simply return `self` here.
fn as_arc_any(self: Arc<Self>) -> Arc<dyn Any + Send + Sync>;

/// Accept an incoming connection.
///
/// This runs on a freshly spawned tokio task so this can be long-running.
fn accept(
&self,
conn: Connection,
) -> Pin<Box<dyn Future<Output = anyhow::Result<()>> + 'static + Send + Sync>>;
}

0 comments on commit 79071f4

Please sign in to comment.