Skip to content

Commit

Permalink
improve code structure
Browse files Browse the repository at this point in the history
  • Loading branch information
Frando committed Jun 12, 2024
1 parent ee043e5 commit db35136
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 48 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions iroh/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ num_cpus = { version = "1.15.0" }
portable-atomic = "1"
iroh-docs = { version = "0.18.0", path = "../iroh-docs" }
iroh-gossip = { version = "0.18.0", path = "../iroh-gossip" }
once_cell = "1.18.0"
parking_lot = "0.12.1"
postcard = { version = "1", default-features = false, features = ["alloc", "use-std", "experimental-derive"] }
quic-rpc = { version = "0.10.0", default-features = false, features = ["flume-transport", "quinn-transport"] }
Expand Down
4 changes: 0 additions & 4 deletions iroh/examples/custom-protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,6 @@ struct ExampleProto<S> {
}

impl<S: Store + fmt::Debug> Protocol for ExampleProto<S> {
fn as_arc_any(self: Arc<Self>) -> Arc<dyn Any + Send + Sync> {
self
}

fn handle_connection(self: Arc<Self>, conn: Connecting) -> BoxedFuture<Result<()>> {
Box::pin(async move { self.handle_connection(conn.await?).await })
}
Expand Down
16 changes: 7 additions & 9 deletions iroh/src/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
//! To shut down the node, call [`Node::shutdown`].
use std::fmt::Debug;
use std::net::SocketAddr;
use std::path::Path;
use std::sync::Arc;
use std::{any::Any, path::Path};

use anyhow::{anyhow, Result};
use futures_lite::StreamExt;
Expand All @@ -16,14 +16,15 @@ use iroh_blobs::store::Store as BaoStore;
use iroh_docs::engine::Engine;
use iroh_net::util::AbortingJoinHandle;
use iroh_net::{endpoint::LocalEndpointsStream, key::SecretKey, Endpoint};
use once_cell::sync::OnceCell;
use quic_rpc::transport::flume::FlumeConnection;
use quic_rpc::RpcClient;
use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken;
use tokio_util::task::LocalPoolHandle;
use tracing::debug;

use crate::{client::RpcService, node::builder::ProtocolMap};
use crate::{client::RpcService, node::protocol::ProtocolMap};

mod builder;
mod protocol;
Expand All @@ -47,7 +48,7 @@ pub use protocol::Protocol;
#[derive(Debug, Clone)]
pub struct Node<D> {
inner: Arc<NodeInner<D>>,
task: Arc<JoinHandle<()>>,
task: Arc<OnceCell<JoinHandle<()>>>,
client: crate::client::MemIroh,
protocols: ProtocolMap,
}
Expand Down Expand Up @@ -155,11 +156,7 @@ impl<D: BaoStore> Node<D> {

/// Returns the protocol handler for a alpn.
pub fn get_protocol<P: Protocol>(&self, alpn: &[u8]) -> Option<Arc<P>> {
let protocols = self.protocols.read().unwrap();
let protocol: Arc<dyn Protocol> = protocols.get(alpn)?.clone();
let protocol_any: Arc<dyn Any + Send + Sync> = protocol.as_arc_any();
let protocol_ref = Arc::downcast(protocol_any).ok()?;
Some(protocol_ref)
self.protocols.get(alpn)
}

/// Aborts the node.
Expand All @@ -173,7 +170,8 @@ impl<D: BaoStore> Node<D> {
pub async fn shutdown(self) -> Result<()> {
self.inner.cancel_token.cancel();

if let Ok(task) = Arc::try_unwrap(self.task) {
if let Ok(mut task) = Arc::try_unwrap(self.task) {
let task = task.take().expect("cannot be empty");
task.await?;
}
Ok(())
Expand Down
48 changes: 20 additions & 28 deletions iroh/src/node/builder.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use std::{
collections::{BTreeSet, HashMap},
collections::BTreeSet,
net::{Ipv4Addr, SocketAddrV4},
path::{Path, PathBuf},
sync::{Arc, RwLock},
sync::Arc,
time::Duration,
};

Expand All @@ -28,13 +28,12 @@ use quic_rpc::{
RpcServer, ServiceEndpoint,
};
use serde::{Deserialize, Serialize};
use tokio::sync::oneshot;
use tokio_util::{sync::CancellationToken, task::LocalPoolHandle};
use tracing::{debug, error, error_span, info, trace, warn, Instrument};

use crate::{
client::RPC_ALPN,
node::Protocol,
node::{protocol::ProtocolMap, Protocol},
rpc_protocol::RpcService,
util::{fs::load_secret_key, path::IrohPaths},
};
Expand All @@ -56,7 +55,6 @@ const DEFAULT_GC_INTERVAL: Duration = Duration::from_secs(60 * 5);
const MAX_CONNECTIONS: u32 = 1024;
const MAX_STREAMS: u64 = 10;

pub(super) type ProtocolMap = Arc<RwLock<HashMap<&'static [u8], Arc<dyn Protocol>>>>;
type ProtocolBuilders<D> = Vec<(
&'static [u8],
Box<dyn FnOnce(Node<D>) -> Arc<dyn Protocol> + Send + 'static>,
Expand Down Expand Up @@ -511,7 +509,7 @@ where
let (internal_rpc, controller) = quic_rpc::transport::flume::connection(1);
let client = crate::client::Iroh::new(quic_rpc::RpcClient::new(controller.clone()));

let protocols = Arc::new(RwLock::new(HashMap::new()));
let protocols = ProtocolMap::default();

let inner = Arc::new(NodeInner {
db: self.blobs_store,
Expand All @@ -524,9 +522,21 @@ where
sync,
downloader,
});
let (ready_tx, ready_rx) = oneshot::channel();

let node = Node {
inner: inner.clone(),
task: Default::default(),
client,
protocols: protocols.clone(),
};

for (alpn, p) in self.protocols {
let protocol = p(node.clone());
protocols.insert(alpn, protocol);
}

let task = {
let protocols = Arc::clone(&protocols);
let protocols = protocols.clone();
let gossip = gossip.clone();
let handler = rpc::Handler {
inner: inner.clone(),
Expand All @@ -535,8 +545,6 @@ where
let ep = endpoint.clone();
tokio::task::spawn(
async move {
// Wait until the protocol builders have run.
ready_rx.await.expect("cannot fail");
Self::run(
ep,
protocols,
Expand All @@ -551,20 +559,7 @@ where
)
};

let node = Node {
inner,
task: Arc::new(task),
client,
protocols,
};

for (alpn, p) in self.protocols {
let protocol = p(node.clone());
node.protocols.write().unwrap().insert(alpn, protocol);
}

// Notify the run task that the protocols are now built.
ready_tx.send(()).expect("cannot fail");
node.task.set(task).expect("was empty");

// spawn a task that updates the gossip endpoints.
// TODO: track task
Expand Down Expand Up @@ -802,10 +797,7 @@ async fn handle_connection<D: BaoStore>(
.await
}
alpn => {
let protocol = {
let protocols = protocols.read().unwrap();
protocols.get(alpn).cloned()
};
let protocol = protocols.get_any(alpn);
if let Some(protocol) = protocol {
protocol.handle_connection(connecting).await?;
} else {
Expand Down
54 changes: 47 additions & 7 deletions iroh/src/node/protocol.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,58 @@
use std::{any::Any, fmt, sync::Arc};
use std::{
any::Any,
collections::HashMap,
fmt,
sync::{Arc, RwLock},
};

use anyhow::Result;
use futures_lite::future::Boxed as BoxedFuture;
use iroh_net::endpoint::Connecting;

/// Trait for iroh protocol handlers.
pub trait Protocol: Send + Sync + Any + fmt::Debug + 'static {
/// Return `self` as `dyn Any`.
///
/// Implementations can simply return `self` here.
fn as_arc_any(self: Arc<Self>) -> Arc<dyn Any + Send + Sync>;

pub trait Protocol: Send + Sync + IntoArcAny + fmt::Debug + 'static {
/// Handle an incoming connection.
///
/// This runs on a freshly spawned tokio task so this can be long-running.
fn handle_connection(self: Arc<Self>, conn: Connecting) -> BoxedFuture<Result<()>>;
}

/// Helper trait to facilite casting from `Arc<dyn T>` to `Arc<dyn Any>`.
///
/// This trait has a blanket implementation so there is no need to implement this yourself.
pub trait IntoArcAny {
fn into_arc_any(self: Arc<Self>) -> Arc<dyn Any + Send + Sync>;
}

impl<T: Send + Sync + 'static> IntoArcAny for T {
fn into_arc_any(self: Arc<Self>) -> Arc<dyn Any + Send + Sync> {
self
}
}

/// Map of registered protocol handlers.
#[allow(clippy::type_complexity)]
#[derive(Debug, Clone, Default)]
pub struct ProtocolMap(Arc<RwLock<HashMap<&'static [u8], Arc<dyn Protocol>>>>);

impl ProtocolMap {
/// Returns the registered protocol handler for an ALPN as a concrete type.
pub fn get<P: Protocol>(&self, alpn: &[u8]) -> Option<Arc<P>> {
let protocols = self.0.read().unwrap();
let protocol: Arc<dyn Protocol> = protocols.get(alpn)?.clone();
let protocol_any: Arc<dyn Any + Send + Sync> = protocol.into_arc_any();
let protocol_ref = Arc::downcast(protocol_any).ok()?;
Some(protocol_ref)
}

/// Returns the registered protocol handler for an ALPN as a `dyn Protocol`.
pub fn get_any(&self, alpn: &[u8]) -> Option<Arc<dyn Protocol>> {
let protocols = self.0.read().unwrap();
let protocol: Arc<dyn Protocol> = protocols.get(alpn)?.clone();
Some(protocol)
}

pub(super) fn insert(&self, alpn: &'static [u8], protocol: Arc<dyn Protocol>) {
self.0.write().unwrap().insert(alpn, protocol);
}
}

0 comments on commit db35136

Please sign in to comment.