diff --git a/iroh-blobs/src/store/traits.rs b/iroh-blobs/src/store/traits.rs index e0ec3e6b39..2a91d1c0f3 100644 --- a/iroh-blobs/src/store/traits.rs +++ b/iroh-blobs/src/store/traits.rs @@ -295,7 +295,7 @@ pub trait ReadableStore: Map { } /// The mutable part of a Bao store. -pub trait Store: ReadableStore + MapMut { +pub trait Store: ReadableStore + MapMut + std::fmt::Debug { /// This trait method imports a file from a local path. /// /// `data` is the path to the file. diff --git a/iroh-docs/src/engine.rs b/iroh-docs/src/engine.rs index b5345b0bea..c0867b644d 100644 --- a/iroh-docs/src/engine.rs +++ b/iroh-docs/src/engine.rs @@ -197,7 +197,7 @@ impl Engine { /// Handle an incoming iroh-docs connection. pub async fn handle_connection( &self, - conn: iroh_net::endpoint::Connecting, + conn: iroh_net::endpoint::Connection, ) -> anyhow::Result<()> { self.to_live_actor .send(ToLiveActor::HandleConnection { conn }) diff --git a/iroh-docs/src/engine/live.rs b/iroh-docs/src/engine/live.rs index 366379f4a3..86c7cedaba 100644 --- a/iroh-docs/src/engine/live.rs +++ b/iroh-docs/src/engine/live.rs @@ -76,7 +76,7 @@ pub enum ToLiveActor { reply: sync::oneshot::Sender>, }, HandleConnection { - conn: iroh_net::endpoint::Connecting, + conn: iroh_net::endpoint::Connection, }, AcceptSyncRequest { namespace: NamespaceId, @@ -749,7 +749,7 @@ impl LiveActor { } #[instrument("accept", skip_all)] - pub async fn handle_connection(&mut self, conn: iroh_net::endpoint::Connecting) { + pub async fn handle_connection(&mut self, conn: iroh_net::endpoint::Connection) { let to_actor_tx = self.sync_actor_tx.clone(); let accept_request_cb = move |namespace, peer| { let to_actor_tx = to_actor_tx.clone(); diff --git a/iroh-docs/src/net.rs b/iroh-docs/src/net.rs index a3f90032e1..cc29d3ec59 100644 --- a/iroh-docs/src/net.rs +++ b/iroh-docs/src/net.rs @@ -106,7 +106,7 @@ pub enum AcceptOutcome { /// Handle an iroh-docs connection and sync all shared documents in the replica store. pub async fn handle_connection( sync: SyncHandle, - connecting: iroh_net::endpoint::Connecting, + connection: iroh_net::endpoint::Connection, accept_cb: F, ) -> Result where @@ -114,7 +114,6 @@ where Fut: Future, { let t_start = Instant::now(); - let connection = connecting.await.map_err(AcceptError::connect)?; let peer = get_remote_node_id(&connection).map_err(AcceptError::connect)?; let (mut send_stream, mut recv_stream) = connection .accept_bi() diff --git a/iroh/Cargo.toml b/iroh/Cargo.toml index 5462a5f2ff..a8b92488f7 100644 --- a/iroh/Cargo.toml +++ b/iroh/Cargo.toml @@ -32,10 +32,10 @@ iroh-io = { version = "0.6.0", features = ["stats"] } iroh-metrics = { version = "0.18.0", path = "../iroh-metrics", optional = true } iroh-net = { version = "0.18.0", path = "../iroh-net" } num_cpus = { version = "1.15.0" } +once_cell = "1.17.0" portable-atomic = "1" iroh-docs = { version = "0.18.0", path = "../iroh-docs" } iroh-gossip = { version = "0.18.0", path = "../iroh-gossip" } -once_cell = "1.18.0" parking_lot = "0.12.1" postcard = { version = "1", default-features = false, features = ["alloc", "use-std", "experimental-derive"] } quic-rpc = { version = "0.10.0", default-features = false, features = ["flume-transport", "quinn-transport"] } diff --git a/iroh/examples/custom-protocol.rs b/iroh/examples/custom-protocol.rs index c973b22063..d8a5a54ac0 100644 --- a/iroh/examples/custom-protocol.rs +++ b/iroh/examples/custom-protocol.rs @@ -31,7 +31,9 @@ async fn main() -> Result<()> { let args = Cli::parse(); // create a new node let node = iroh::node::Node::memory() - .accept(EXAMPLE_ALPN, |node| ExampleProto::build(node)) + .accept(EXAMPLE_ALPN, |node| { + Box::pin(async move { Ok(ExampleProto::build(node)) }) + }) .spawn() .await?; diff --git a/iroh/src/client/authors.rs b/iroh/src/client/authors.rs index e6bddbb494..7cdd44ce72 100644 --- a/iroh/src/client/authors.rs +++ b/iroh/src/client/authors.rs @@ -40,7 +40,7 @@ where /// /// The default author can be set with [`Self::set_default`]. pub async fn default(&self) -> Result { - let res = self.rpc.rpc(AuthorGetDefaultRequest).await?; + let res = self.rpc.rpc(AuthorGetDefaultRequest).await??; Ok(res.author_id) } diff --git a/iroh/src/node.rs b/iroh/src/node.rs index 36cf4705a9..0915915989 100644 --- a/iroh/src/node.rs +++ b/iroh/src/node.rs @@ -6,7 +6,7 @@ use std::fmt::Debug; use std::net::SocketAddr; use std::path::Path; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; use anyhow::{anyhow, Result}; use futures_lite::StreamExt; @@ -14,12 +14,11 @@ use iroh_base::key::PublicKey; use iroh_blobs::downloader::Downloader; use iroh_blobs::store::Store as BaoStore; use iroh_docs::engine::Engine; -use iroh_net::util::AbortingJoinHandle; use iroh_net::{endpoint::LocalEndpointsStream, key::SecretKey, Endpoint}; use once_cell::sync::OnceCell; use quic_rpc::transport::flume::FlumeConnection; use quic_rpc::RpcClient; -use tokio::task::JoinHandle; +use tokio::task::{JoinHandle, JoinSet}; use tokio_util::sync::CancellationToken; use tokio_util::task::LocalPoolHandle; use tracing::debug; @@ -31,7 +30,7 @@ mod protocol; mod rpc; mod rpc_status; -pub use self::builder::{Builder, DiscoveryConfig, GcPolicy, StorageConfig}; +pub use self::builder::{Builder, DiscoveryConfig, DocsStorage, GcPolicy, StorageConfig}; pub use self::rpc_status::RpcStatus; pub use protocol::Protocol; @@ -60,12 +59,10 @@ struct NodeInner { secret_key: SecretKey, cancel_token: CancellationToken, controller: FlumeConnection, - #[allow(dead_code)] - gc_task: Option>, #[debug("rt")] rt: LocalPoolHandle, - pub(crate) sync: DocsEngine, downloader: Downloader, + tasks: Mutex>>, } /// In memory node. @@ -156,7 +153,11 @@ impl Node { /// Returns the protocol handler for a alpn. pub fn get_protocol(&self, alpn: &[u8]) -> Option> { - self.protocols.get(alpn) + self.protocols.get::

(alpn) + } + + fn downloader(&self) -> &Downloader { + &self.inner.downloader } /// Aborts the node. @@ -171,8 +172,10 @@ impl Node { self.inner.cancel_token.cancel(); if let Ok(mut task) = Arc::try_unwrap(self.task) { - let task = task.take().expect("cannot be empty"); - task.await?; + task.take().expect("cannot be empty").await?; + } + if let Some(mut tasks) = self.inner.tasks.lock().unwrap().take() { + tasks.abort_all(); } Ok(()) } diff --git a/iroh/src/node/builder.rs b/iroh/src/node/builder.rs index d23732a08c..aab60ad6d2 100644 --- a/iroh/src/node/builder.rs +++ b/iroh/src/node/builder.rs @@ -7,7 +7,7 @@ use std::{ }; use anyhow::{bail, Context, Result}; -use futures_lite::StreamExt; +use futures_lite::{future::Boxed, StreamExt}; use iroh_base::key::SecretKey; use iroh_blobs::{ downloader::Downloader, @@ -28,12 +28,16 @@ use quic_rpc::{ RpcServer, ServiceEndpoint, }; use serde::{Deserialize, Serialize}; +use tokio::task::JoinSet; use tokio_util::{sync::CancellationToken, task::LocalPoolHandle}; use tracing::{debug, error, error_span, info, trace, warn, Instrument}; use crate::{ client::RPC_ALPN, - node::{protocol::ProtocolMap, Protocol}, + node::{ + protocol::{BlobsProtocol, ProtocolMap}, + Protocol, + }, rpc_protocol::RpcService, util::{fs::load_secret_key, path::IrohPaths}, }; @@ -57,9 +61,18 @@ const MAX_STREAMS: u64 = 10; type ProtocolBuilders = Vec<( &'static [u8], - Box) -> Arc + Send + 'static>, + Box) -> Boxed>> + Send + 'static>, )>; +/// Storage backend for documents. +#[derive(Debug, Clone)] +pub enum DocsStorage { + /// In-memory storage. + Memory, + /// File-based persistent storage. + Persistent(PathBuf), +} + /// Builder for the [`Node`]. /// /// You must supply a blob store and a document store. @@ -89,7 +102,7 @@ where gc_policy: GcPolicy, dns_resolver: Option, node_discovery: DiscoveryConfig, - docs_store: iroh_docs::store::fs::Store, + docs_store: Option, protocols: ProtocolBuilders, #[cfg(any(test, feature = "test-utils"))] insecure_skip_relay_cert_verify: bool, @@ -139,7 +152,7 @@ impl Default for Builder { dns_resolver: None, rpc_endpoint: Default::default(), gc_policy: GcPolicy::Disabled, - docs_store: iroh_docs::store::Store::memory(), + docs_store: Some(DocsStorage::Memory), protocols: Default::default(), node_discovery: Default::default(), #[cfg(any(test, feature = "test-utils"))] @@ -153,7 +166,7 @@ impl Builder { /// Creates a new builder for [`Node`] using the given databases. pub fn with_db_and_store( blobs_store: D, - docs_store: iroh_docs::store::Store, + docs_store: DocsStorage, storage: StorageConfig, ) -> Self { Self { @@ -166,7 +179,7 @@ impl Builder { dns_resolver: None, rpc_endpoint: Default::default(), gc_policy: GcPolicy::Disabled, - docs_store, + docs_store: Some(docs_store), node_discovery: Default::default(), protocols: Default::default(), #[cfg(any(test, feature = "test-utils"))] @@ -193,8 +206,7 @@ where let blobs_store = iroh_blobs::store::fs::Store::load(&blob_dir) .await .with_context(|| format!("Failed to load iroh database from {}", blob_dir.display()))?; - let docs_store = - iroh_docs::store::fs::Store::persistent(IrohPaths::DocsDatabase.with_root(root))?; + let docs_store = DocsStorage::Persistent(IrohPaths::DocsDatabase.with_root(root)); let v0 = blobs_store .import_flat_store(iroh_blobs::store::fs::FlatStorePaths { @@ -230,7 +242,7 @@ where relay_mode: self.relay_mode, dns_resolver: self.dns_resolver, gc_policy: self.gc_policy, - docs_store, + docs_store: Some(docs_store), node_discovery: self.node_discovery, protocols: Default::default(), #[cfg(any(test, feature = "test-utils"))] @@ -296,6 +308,12 @@ where self } + /// Disables documents support on this node completely. + pub fn disable_docs(mut self) -> Self { + self.docs_store = None; + self + } + /// Sets the relay servers to assist in establishing connectivity. /// /// Relay servers are used to discover other nodes by `PublicKey` and also help @@ -359,7 +377,7 @@ where pub fn accept( mut self, alpn: &'static [u8], - protocol: impl FnOnce(Node) -> Arc + Send + 'static, + protocol: impl FnOnce(Node) -> Boxed>> + Send + 'static, ) -> Self { self.protocols.push((alpn, Box::new(protocol))); self @@ -387,10 +405,68 @@ where /// This will create the underlying network server and spawn a tokio task accepting /// connections. The returned [`Node`] can be used to control the task as well as /// get information about it. - pub async fn spawn(self) -> Result> { + pub async fn spawn(mut self) -> Result> { + // Register the core iroh protocols. + // Register blobs. + let lp = LocalPoolHandle::new(num_cpus::get()); + let blobs_proto = BlobsProtocol::new(self.blobs_store.clone(), lp.clone()); + self = self.accept(iroh_blobs::protocol::ALPN, move |_node| { + Box::pin(async move { + let blobs: Arc = Arc::new(blobs_proto); + Ok(blobs) + }) + }); + + // Register gossip. + self = self.accept(GOSSIP_ALPN, |node| { + Box::pin(async move { + let addr = node.endpoint().my_addr().await?; + let gossip = + Gossip::from_endpoint(node.endpoint().clone(), Default::default(), &addr.info); + let gossip: Arc = Arc::new(gossip); + Ok(gossip) + }) + }); + + if let Some(docs_store) = &self.docs_store { + // register the docs protocol. + let docs_store = match docs_store { + DocsStorage::Memory => iroh_docs::store::fs::Store::memory(), + DocsStorage::Persistent(path) => iroh_docs::store::fs::Store::persistent(path)?, + }; + // load or create the default author for documents + let default_author_storage = match self.storage { + StorageConfig::Persistent(ref root) => { + let path = IrohPaths::DefaultAuthor.with_root(root); + DefaultAuthorStorage::Persistent(path) + } + StorageConfig::Mem => DefaultAuthorStorage::Mem, + }; + let blobs_store = self.blobs_store.clone(); + self = self.accept(DOCS_ALPN, |node| { + Box::pin(async move { + let gossip = node + .get_protocol::(GOSSIP_ALPN) + .context("gossip not found")?; + let sync = Engine::spawn( + node.endpoint().clone(), + (*gossip).clone(), + docs_store, + blobs_store, + node.downloader().clone(), + default_author_storage, + ) + .await?; + let sync = DocsEngine(sync); + let sync: Arc = Arc::new(sync); + Ok(sync) + }) + }); + } + // We clone the blob store to shut it down in case the node fails to spawn. let blobs_store = self.blobs_store.clone(); - match self.spawn_inner().await { + match self.spawn_inner(lp).await { Ok(node) => Ok(node), Err(err) => { debug!("failed to spawn node, shutting down"); @@ -400,9 +476,8 @@ where } } - async fn spawn_inner(mut self) -> Result> { + async fn spawn_inner(mut self, lp: LocalPoolHandle) -> Result> { trace!("spawning node"); - let lp = LocalPoolHandle::new(num_cpus::get()); let mut transport_config = quinn::TransportConfig::default(); transport_config @@ -465,47 +540,12 @@ where debug!("rpc listening on: {:?}", self.rpc_endpoint.local_addr()); - let addr = endpoint.my_addr().await?; - - // initialize the gossip protocol - let gossip = Gossip::from_endpoint(endpoint.clone(), Default::default(), &addr.info); + let blobs_store = self.blobs_store.clone(); + let mut tasks = JoinSet::new(); // initialize the downloader let downloader = Downloader::new(self.blobs_store.clone(), endpoint.clone(), lp.clone()); - // load or create the default author for documents - let default_author_storage = match self.storage { - StorageConfig::Persistent(ref root) => { - let path = IrohPaths::DefaultAuthor.with_root(root); - DefaultAuthorStorage::Persistent(path) - } - StorageConfig::Mem => DefaultAuthorStorage::Mem, - }; - - // spawn the docs engine - let sync = Engine::spawn( - endpoint.clone(), - gossip.clone(), - self.docs_store, - self.blobs_store.clone(), - downloader.clone(), - default_author_storage, - ) - .await?; - let sync_db = sync.sync.clone(); - let sync = DocsEngine(sync); - - let gc_task = if let GcPolicy::Interval(gc_period) = self.gc_policy { - tracing::info!("Starting GC task with interval {:?}", gc_period); - let db = self.blobs_store.clone(); - let gc_done_callback = self.gc_done_callback.take(); - - let task = - lp.spawn_pinned(move || Self::gc_loop(db, sync_db, gc_period, gc_done_callback)); - Some(task.into()) - } else { - None - }; let (internal_rpc, controller) = quic_rpc::transport::flume::connection(1); let client = crate::client::Iroh::new(quic_rpc::RpcClient::new(controller.clone())); @@ -517,9 +557,8 @@ where secret_key: self.secret_key, controller, cancel_token, - gc_task, + tasks: Default::default(), rt: lp.clone(), - sync, downloader, }); @@ -531,47 +570,45 @@ where }; for (alpn, p) in self.protocols { - let protocol = p(node.clone()); + let protocol = p(node.clone()).await?; protocols.insert(alpn, protocol); } let task = { let protocols = protocols.clone(); - let gossip = gossip.clone(); - let handler = rpc::Handler { - inner: inner.clone(), - }; let me = endpoint.node_id().fmt_short(); - let ep = endpoint.clone(); + let inner = inner.clone(); tokio::task::spawn( - async move { - Self::run( - ep, - protocols, - handler, - self.rpc_endpoint, - internal_rpc, - gossip, - ) - .await - } - .instrument(error_span!("node", %me)), + async move { Self::run(inner, protocols, self.rpc_endpoint, internal_rpc).await } + .instrument(error_span!("node", %me)), ) }; - node.task.set(task).expect("was empty"); + if let GcPolicy::Interval(gc_period) = self.gc_policy { + tracing::info!("Starting GC task with interval {:?}", gc_period); + let db = blobs_store.clone(); + let gc_done_callback = self.gc_done_callback.take(); + let sync = protocols.get::(DOCS_ALPN); + + tasks.spawn_local(Self::gc_loop(db, sync, gc_period, gc_done_callback)); + } + // spawn a task that updates the gossip endpoints. - // TODO: track task let mut stream = endpoint.local_endpoints(); - tokio::task::spawn(async move { - while let Some(eps) = stream.next().await { - if let Err(err) = gossip.update_endpoints(&eps) { - warn!("Failed to update gossip endpoints: {err:?}"); + let gossip = protocols.get::(GOSSIP_ALPN); + if let Some(gossip) = gossip { + tasks.spawn(async move { + while let Some(eps) = stream.next().await { + if let Err(err) = gossip.update_endpoints(&eps) { + warn!("Failed to update gossip endpoints: {err:?}"); + } } - } - warn!("failed to retrieve local endpoints"); - }); + warn!("failed to retrieve local endpoints"); + }); + } + + *(node.inner.tasks.lock().unwrap()) = Some(tasks); // Wait for a single endpoint update, to make sure // we found some endpoints @@ -585,13 +622,17 @@ where #[allow(clippy::too_many_arguments)] async fn run( - server: Endpoint, + inner: Arc>, protocols: ProtocolMap, - handler: rpc::Handler, rpc: E, internal_rpc: impl ServiceEndpoint, - gossip: Gossip, ) { + let server = inner.endpoint.clone(); + let docs = protocols.get::(DOCS_ALPN); + let handler = rpc::Handler { + inner: inner.clone(), + docs, + }; let rpc = RpcServer::new(rpc); let internal_rpc = RpcServer::new(internal_rpc); let (ipv4, ipv6) = server.local_addr(); @@ -603,13 +644,16 @@ where let cancel_token = handler.inner.cancel_token.clone(); - // forward our initial endpoints to the gossip protocol - // it may happen the the first endpoint update callback is missed because the gossip cell - // is only initialized once the endpoint is fully bound - if let Some(local_endpoints) = server.local_endpoints().next().await { - debug!(me = ?server.node_id(), "gossip initial update: {local_endpoints:?}"); - gossip.update_endpoints(&local_endpoints).ok(); + if let Some(gossip) = protocols.get::(GOSSIP_ALPN) { + // forward our initial endpoints to the gossip protocol + // it may happen the the first endpoint update callback is missed because the gossip cell + // is only initialized once the endpoint is fully bound + if let Some(local_endpoints) = server.local_endpoints().next().await { + debug!(me = ?server.node_id(), "gossip initial update: {local_endpoints:?}"); + gossip.update_endpoints(&local_endpoints).ok(); + } } + loop { tokio::select! { biased; @@ -617,9 +661,23 @@ where // clean shutdown of the blobs db to close the write transaction handler.inner.db.shutdown().await; - if let Err(err) = handler.inner.sync.shutdown().await { - warn!("sync shutdown error: {:?}", err); + // We cannot hold the RwLockReadGuard over an await point, + // so we have to manually loop, clone each protocol, and drop the read guard + // before awaiting shutdown. + let mut i = 0; + loop { + let protocol = { + let protocols = protocols.read(); + if let Some(protocol) = protocols.values().nth(i) { + protocol.clone() + } else { + break; + } + }; + protocol.shutdown().await; + i += 1; } + break }, // handle rpc requests. This will do nothing if rpc is not configured, since @@ -654,12 +712,11 @@ where continue; } }; - let gossip = gossip.clone(); - let inner = handler.inner.clone(); - let sync = handler.inner.sync.clone(); let protocols = protocols.clone(); - tokio::task::spawn(async move { - if let Err(err) = handle_connection(connecting, alpn, inner, gossip, sync, protocols).await { + let mut tasks_guard = inner.tasks.lock().unwrap(); + let tasks = tasks_guard.as_mut().expect("only empty after shutdown"); + tasks.spawn(async move { + if let Err(err) = handle_connection(connecting, alpn, protocols).await { warn!("Handling incoming connection ended with error: {err}"); } }); @@ -681,7 +738,7 @@ where async fn gc_loop( db: D, - ds: iroh_docs::actor::SyncHandle, + ds: Option>, gc_period: Duration, done_cb: Option>, ) { @@ -698,22 +755,24 @@ where tokio::time::sleep(gc_period).await; tracing::debug!("Starting GC"); live.clear(); - let doc_hashes = match ds.content_hashes().await { - Ok(hashes) => hashes, - Err(err) => { - tracing::warn!("Error getting doc hashes: {}", err); - continue 'outer; - } - }; - for hash in doc_hashes { - match hash { - Ok(hash) => { - live.insert(hash); - } + if let Some(ds) = &ds { + let doc_hashes = match ds.sync.content_hashes().await { + Ok(hashes) => hashes, Err(err) => { - tracing::error!("Error getting doc hash: {}", err); + tracing::warn!("Error getting doc hashes: {}", err); continue 'outer; } + }; + for hash in doc_hashes { + match hash { + Ok(hash) => { + live.insert(hash); + } + Err(err) => { + tracing::error!("Error getting doc hash: {}", err); + continue 'outer; + } + } } } @@ -773,37 +832,16 @@ impl Default for GcPolicy { } } -// TODO: Restructure this code to not take all these arguments. -#[allow(clippy::too_many_arguments)] -async fn handle_connection( +async fn handle_connection( connecting: iroh_net::endpoint::Connecting, alpn: String, - node: Arc>, - gossip: Gossip, - sync: DocsEngine, protocols: ProtocolMap, ) -> Result<()> { - match alpn.as_bytes() { - GOSSIP_ALPN => gossip.handle_connection(connecting.await?).await?, - DOCS_ALPN => sync.handle_connection(connecting).await?, - alpn if alpn == iroh_blobs::protocol::ALPN => { - let connection = connecting.await?; - iroh_blobs::provider::handle_connection( - connection, - node.db.clone(), - MockEventSender, - node.rt.clone(), - ) - .await - } - alpn => { - let protocol = protocols.get_any(alpn); - if let Some(protocol) = protocol { - protocol.accept(connecting).await?; - } else { - bail!("ignoring connection: unsupported ALPN protocol"); - } - } + let protocol = protocols.get_any(alpn.as_bytes()).clone(); + if let Some(protocol) = protocol { + protocol.accept(connecting).await?; + } else { + bail!("ignoring connection: unsupported ALPN protocol"); } Ok(()) } @@ -855,12 +893,3 @@ fn make_rpc_endpoint( Ok((rpc_endpoint, actual_rpc_port)) } - -#[derive(Debug, Clone)] -struct MockEventSender; - -impl iroh_blobs::provider::EventSender for MockEventSender { - fn send(&self, _event: iroh_blobs::provider::Event) -> futures_lite::future::Boxed<()> { - Box::pin(std::future::ready(())) - } -} diff --git a/iroh/src/node/protocol.rs b/iroh/src/node/protocol.rs index 139ebbda8a..3a099620fe 100644 --- a/iroh/src/node/protocol.rs +++ b/iroh/src/node/protocol.rs @@ -9,12 +9,19 @@ use anyhow::Result; use futures_lite::future::Boxed as BoxedFuture; use iroh_net::endpoint::Connecting; +use crate::node::DocsEngine; + /// Trait for iroh protocol handlers. pub trait Protocol: Send + Sync + IntoArcAny + fmt::Debug + 'static { /// Handle an incoming connection. /// /// This runs on a freshly spawned tokio task so this can be long-running. fn accept(self: Arc, conn: Connecting) -> BoxedFuture>; + + /// Called when the node shuts down. + fn shutdown(self: Arc) -> BoxedFuture<()> { + Box::pin(async move {}) + } } /// Helper trait to facilite casting from `Arc` to `Arc`. @@ -30,8 +37,6 @@ impl IntoArcAny for T { } } -/// Map of registered protocol handlers. -#[allow(clippy::type_complexity)] #[derive(Debug, Clone, Default)] pub struct ProtocolMap(Arc>>>); @@ -55,4 +60,58 @@ impl ProtocolMap { pub(super) fn insert(&self, alpn: &'static [u8], protocol: Arc) { self.0.write().unwrap().insert(alpn, protocol); } + + pub(super) fn read( + &self, + ) -> std::sync::RwLockReadGuard>> { + self.0.read().unwrap() + } +} + +#[derive(Debug)] +pub(crate) struct BlobsProtocol { + rt: tokio_util::task::LocalPoolHandle, + store: S, +} + +impl BlobsProtocol { + pub fn new(store: S, rt: tokio_util::task::LocalPoolHandle) -> Self { + Self { rt, store } + } +} + +impl Protocol for BlobsProtocol { + fn accept(self: Arc, conn: Connecting) -> BoxedFuture> { + Box::pin(async move { + iroh_blobs::provider::handle_connection( + conn.await?, + self.store.clone(), + MockEventSender, + self.rt.clone(), + ) + .await; + Ok(()) + }) + } +} + +#[derive(Debug, Clone)] +struct MockEventSender; + +impl iroh_blobs::provider::EventSender for MockEventSender { + fn send(&self, _event: iroh_blobs::provider::Event) -> futures_lite::future::Boxed<()> { + Box::pin(std::future::ready(())) + } +} + +impl Protocol for iroh_gossip::net::Gossip { + fn accept(self: Arc, conn: Connecting) -> BoxedFuture> { + Box::pin(async move { self.handle_connection(conn.await?).await }) + } +} + +impl Protocol for DocsEngine { + fn accept(self: Arc, conn: Connecting) -> BoxedFuture> { + Box::pin(async move { self.handle_connection(conn.await?).await }) + } } diff --git a/iroh/src/node/rpc.rs b/iroh/src/node/rpc.rs index 6382b50d6a..92dfade8fb 100644 --- a/iroh/src/node/rpc.rs +++ b/iroh/src/node/rpc.rs @@ -7,7 +7,7 @@ use anyhow::{anyhow, ensure, Result}; use futures_buffered::BufferedStreamExt; use futures_lite::{Stream, StreamExt}; use genawaiter::sync::{Co, Gen}; -use iroh_base::rpc::RpcResult; +use iroh_base::rpc::{RpcError, RpcResult}; use iroh_blobs::downloader::{DownloadRequest, Downloader}; use iroh_blobs::export::ExportProgress; use iroh_blobs::format::collection::Collection; @@ -32,21 +32,25 @@ use quic_rpc::{ use tokio_util::task::LocalPoolHandle; use tracing::{debug, info}; -use crate::client::blobs::{BlobInfo, DownloadMode, IncompleteBlobInfo, WrapOption}; -use crate::client::tags::TagInfo; -use crate::client::NodeStatus; use crate::rpc_protocol::{ BlobAddPathRequest, BlobAddPathResponse, BlobAddStreamRequest, BlobAddStreamResponse, BlobAddStreamUpdate, BlobConsistencyCheckRequest, BlobDeleteBlobRequest, BlobDownloadRequest, BlobDownloadResponse, BlobExportRequest, BlobExportResponse, BlobListIncompleteRequest, BlobListRequest, BlobReadAtRequest, BlobReadAtResponse, BlobValidateRequest, CreateCollectionRequest, CreateCollectionResponse, DeleteTagRequest, DocExportFileRequest, - DocExportFileResponse, DocImportFileRequest, DocImportFileResponse, DocSetHashRequest, - ListTagsRequest, NodeAddrRequest, NodeConnectionInfoRequest, NodeConnectionInfoResponse, - NodeConnectionsRequest, NodeConnectionsResponse, NodeIdRequest, NodeRelayRequest, - NodeShutdownRequest, NodeStatsRequest, NodeStatsResponse, NodeStatusRequest, NodeWatchRequest, - NodeWatchResponse, Request, RpcService, SetTagOption, + DocExportFileResponse, DocGetManyResponse, DocImportFileRequest, DocImportFileResponse, + DocListResponse, DocSetHashRequest, ListTagsRequest, NodeAddrRequest, + NodeConnectionInfoRequest, NodeConnectionInfoResponse, NodeConnectionsRequest, + NodeConnectionsResponse, NodeIdRequest, NodeRelayRequest, NodeShutdownRequest, + NodeStatsRequest, NodeStatsResponse, NodeStatusRequest, NodeWatchRequest, NodeWatchResponse, + Request, RpcService, SetTagOption, }; +use crate::{ + client::blobs::{BlobInfo, DownloadMode, IncompleteBlobInfo, WrapOption}, + node::DocsEngine, +}; +use crate::{client::tags::TagInfo, node::rpc::docs::ITER_CHANNEL_CAP}; +use crate::{client::NodeStatus, rpc_protocol::AuthorListResponse}; use super::NodeInner; @@ -61,6 +65,7 @@ const RPC_BLOB_GET_CHANNEL_CAP: usize = 2; #[derive(Debug, Clone)] pub(crate) struct Handler { pub(crate) inner: Arc>, + pub(crate) docs: Option>, } impl Handler { @@ -126,92 +131,164 @@ impl Handler { BlobAddStreamUpdate(_msg) => Err(RpcServerError::UnexpectedUpdateMessage), AuthorList(msg) => { chan.server_streaming(msg, handler, |handler, req| { - handler.inner.sync.author_list(req) + let (tx, rx) = flume::bounded(ITER_CHANNEL_CAP); + if let Some(docs) = handler.docs { + docs.author_list(req, tx); + } else { + tx.send(Err(anyhow!("docs are disabled"))) + .expect("has capacity"); + } + rx.into_stream().map(|r| { + r.map(|author_id| AuthorListResponse { author_id }) + .map_err(Into::into) + }) }) .await } AuthorCreate(msg) => { chan.rpc(msg, handler, |handler, req| async move { - handler.inner.sync.author_create(req).await + if let Some(docs) = handler.docs { + docs.author_create(req).await + } else { + Err(docs_disabled()) + } }) .await } AuthorImport(msg) => { chan.rpc(msg, handler, |handler, req| async move { - handler.inner.sync.author_import(req).await + if let Some(docs) = handler.docs { + docs.author_import(req).await + } else { + Err(docs_disabled()) + } }) .await } AuthorExport(msg) => { chan.rpc(msg, handler, |handler, req| async move { - handler.inner.sync.author_export(req).await + if let Some(docs) = handler.docs { + docs.author_export(req).await + } else { + Err(docs_disabled()) + } }) .await } AuthorDelete(msg) => { chan.rpc(msg, handler, |handler, req| async move { - handler.inner.sync.author_delete(req).await + if let Some(docs) = handler.docs { + docs.author_delete(req).await + } else { + Err(docs_disabled()) + } }) .await } AuthorGetDefault(msg) => { chan.rpc(msg, handler, |handler, req| async move { - handler.inner.sync.author_default(req) + if let Some(docs) = handler.docs { + Ok(docs.author_default(req)) + } else { + Err(docs_disabled()) + } }) .await } AuthorSetDefault(msg) => { chan.rpc(msg, handler, |handler, req| async move { - handler.inner.sync.author_set_default(req).await + if let Some(docs) = handler.docs { + docs.author_set_default(req).await + } else { + Err(docs_disabled()) + } }) .await } DocOpen(msg) => { chan.rpc(msg, handler, |handler, req| async move { - handler.inner.sync.doc_open(req).await + if let Some(docs) = handler.docs { + docs.doc_open(req).await + } else { + Err(docs_disabled()) + } }) .await } DocClose(msg) => { chan.rpc(msg, handler, |handler, req| async move { - handler.inner.sync.doc_close(req).await + if let Some(docs) = handler.docs { + docs.doc_close(req).await + } else { + Err(docs_disabled()) + } }) .await } DocStatus(msg) => { chan.rpc(msg, handler, |handler, req| async move { - handler.inner.sync.doc_status(req).await + if let Some(docs) = handler.docs { + docs.doc_status(req).await + } else { + Err(docs_disabled()) + } }) .await } DocList(msg) => { chan.server_streaming(msg, handler, |handler, req| { - handler.inner.sync.doc_list(req) + let (tx, rx) = flume::bounded(ITER_CHANNEL_CAP); + if let Some(docs) = handler.docs { + docs.doc_list(req, tx); + } else { + tx.send(Err(anyhow!("docs are disabled"))) + .expect("has capacity"); + } + rx.into_stream().map(|r| { + r.map(|(id, capability)| DocListResponse { id, capability }) + .map_err(Into::into) + }) }) .await } DocCreate(msg) => { chan.rpc(msg, handler, |handler, req| async move { - handler.inner.sync.doc_create(req).await + if let Some(docs) = handler.docs { + docs.doc_create(req).await + } else { + Err(docs_disabled()) + } }) .await } DocDrop(msg) => { chan.rpc(msg, handler, |handler, req| async move { - handler.inner.sync.doc_drop(req).await + if let Some(docs) = handler.docs { + docs.doc_drop(req).await + } else { + Err(docs_disabled()) + } }) .await } DocImport(msg) => { chan.rpc(msg, handler, |handler, req| async move { - handler.inner.sync.doc_import(req).await + if let Some(docs) = handler.docs { + docs.doc_import(req).await + } else { + Err(docs_disabled()) + } }) .await } DocSet(msg) => { let bao_store = handler.inner.db.clone(); chan.rpc(msg, handler, |handler, req| async move { - handler.inner.sync.doc_set(&bao_store, req).await + if let Some(docs) = handler.docs { + docs.doc_set(&bao_store, req).await + } else { + Err(docs_disabled()) + } }) .await } @@ -225,67 +302,117 @@ impl Handler { } DocDel(msg) => { chan.rpc(msg, handler, |handler, req| async move { - handler.inner.sync.doc_del(req).await + if let Some(docs) = handler.docs { + docs.doc_del(req).await + } else { + Err(docs_disabled()) + } }) .await } DocSetHash(msg) => { chan.rpc(msg, handler, |handler, req| async move { - handler.inner.sync.doc_set_hash(req).await + if let Some(docs) = handler.docs { + docs.doc_set_hash(req).await + } else { + Err(docs_disabled()) + } }) .await } DocGet(msg) => { chan.server_streaming(msg, handler, |handler, req| { - handler.inner.sync.doc_get_many(req) + let (tx, rx) = flume::bounded(ITER_CHANNEL_CAP); + if let Some(docs) = handler.docs { + docs.doc_get_many(req, tx); + } else { + tx.send(Err(anyhow!("docs are disabled"))) + .expect("has capacity"); + } + rx.into_stream().map(|r| { + r.map(|entry| DocGetManyResponse { entry }) + .map_err(Into::into) + }) }) .await } DocGetExact(msg) => { chan.rpc(msg, handler, |handler, req| async move { - handler.inner.sync.doc_get_exact(req).await + if let Some(docs) = handler.docs { + docs.doc_get_exact(req).await + } else { + Err(docs_disabled()) + } }) .await } DocStartSync(msg) => { chan.rpc(msg, handler, |handler, req| async move { - handler.inner.sync.doc_start_sync(req).await + if let Some(docs) = handler.docs { + docs.doc_start_sync(req).await + } else { + Err(docs_disabled()) + } }) .await } DocLeave(msg) => { chan.rpc(msg, handler, |handler, req| async move { - handler.inner.sync.doc_leave(req).await + if let Some(docs) = handler.docs { + docs.doc_leave(req).await + } else { + Err(docs_disabled()) + } }) .await } DocShare(msg) => { chan.rpc(msg, handler, |handler, req| async move { - handler.inner.sync.doc_share(req).await + if let Some(docs) = handler.docs { + docs.doc_share(req).await + } else { + Err(docs_disabled()) + } }) .await } DocSubscribe(msg) => { chan.try_server_streaming(msg, handler, |handler, req| async move { - handler.inner.sync.doc_subscribe(req).await + if let Some(docs) = handler.docs { + docs.doc_subscribe(req).await + } else { + Err(docs_disabled()) + } }) .await } DocSetDownloadPolicy(msg) => { chan.rpc(msg, handler, |handler, req| async move { - handler.inner.sync.doc_set_download_policy(req).await + if let Some(docs) = handler.docs { + docs.doc_set_download_policy(req).await + } else { + Err(docs_disabled()) + } }) .await } DocGetDownloadPolicy(msg) => { chan.rpc(msg, handler, |handler, req| async move { - handler.inner.sync.doc_get_download_policy(req).await + if let Some(docs) = handler.docs { + docs.doc_get_download_policy(req).await + } else { + Err(docs_disabled()) + } }) .await } DocGetSyncPeers(msg) => { chan.rpc(msg, handler, |handler, req| async move { - handler.inner.sync.doc_get_sync_peers(req).await + if let Some(docs) = handler.docs { + docs.doc_get_sync_peers(req).await + } else { + Err(docs_disabled()) + } }) .await } @@ -463,6 +590,7 @@ impl Handler { msg: DocImportFileRequest, progress: flume::Sender, ) -> anyhow::Result<()> { + let docs = self.docs.ok_or_else(|| anyhow!("docs are disabled"))?; use crate::client::docs::ImportProgress as DocImportProgress; use iroh_blobs::store::ImportMode; use std::collections::BTreeMap; @@ -515,16 +643,14 @@ impl Handler { let hash_and_format = temp_tag.inner(); let HashAndFormat { hash, .. } = *hash_and_format; - self.inner - .sync - .doc_set_hash(DocSetHashRequest { - doc_id, - author_id, - key: key.clone(), - hash, - size, - }) - .await?; + docs.doc_set_hash(DocSetHashRequest { + doc_id, + author_id, + key: key.clone(), + hash, + size, + }) + .await?; drop(temp_tag); progress.send(DocImportProgress::AllDone { key }).await?; Ok(()) @@ -549,6 +675,7 @@ impl Handler { msg: DocExportFileRequest, progress: flume::Sender, ) -> anyhow::Result<()> { + let _docs = self.docs.ok_or_else(|| anyhow!("docs are disabled"))?; let progress = FlumeProgressSender::new(progress); let DocExportFileRequest { entry, path, mode } = msg; let key = bytes::Bytes::from(entry.key().to_vec()); @@ -1118,3 +1245,7 @@ where res.map_err(Into::into) } + +fn docs_disabled() -> RpcError { + anyhow!("docs are disabled").into() +} diff --git a/iroh/src/node/rpc/docs.rs b/iroh/src/node/rpc/docs.rs index a0433a803e..00762945b4 100644 --- a/iroh/src/node/rpc/docs.rs +++ b/iroh/src/node/rpc/docs.rs @@ -3,7 +3,9 @@ use anyhow::anyhow; use futures_lite::Stream; use iroh_blobs::{store::Store as BaoStore, BlobFormat}; -use iroh_docs::{Author, DocTicket, NamespaceSecret}; +use iroh_docs::{ + Author, AuthorId, CapabilityKind, DocTicket, NamespaceId, NamespaceSecret, SignedEntry, +}; use tokio_stream::StreamExt; use crate::client::docs::ShareMode; @@ -11,21 +13,20 @@ use crate::node::DocsEngine; use crate::rpc_protocol::{ AuthorCreateRequest, AuthorCreateResponse, AuthorDeleteRequest, AuthorDeleteResponse, AuthorExportRequest, AuthorExportResponse, AuthorGetDefaultRequest, AuthorGetDefaultResponse, - AuthorImportRequest, AuthorImportResponse, AuthorListRequest, AuthorListResponse, - AuthorSetDefaultRequest, AuthorSetDefaultResponse, DocCloseRequest, DocCloseResponse, - DocCreateRequest, DocCreateResponse, DocDelRequest, DocDelResponse, DocDropRequest, - DocDropResponse, DocGetDownloadPolicyRequest, DocGetDownloadPolicyResponse, DocGetExactRequest, - DocGetExactResponse, DocGetManyRequest, DocGetManyResponse, DocGetSyncPeersRequest, - DocGetSyncPeersResponse, DocImportRequest, DocImportResponse, DocLeaveRequest, - DocLeaveResponse, DocListRequest, DocListResponse, DocOpenRequest, DocOpenResponse, - DocSetDownloadPolicyRequest, DocSetDownloadPolicyResponse, DocSetHashRequest, - DocSetHashResponse, DocSetRequest, DocSetResponse, DocShareRequest, DocShareResponse, - DocStartSyncRequest, DocStartSyncResponse, DocStatusRequest, DocStatusResponse, - DocSubscribeRequest, DocSubscribeResponse, RpcResult, + AuthorImportRequest, AuthorImportResponse, AuthorListRequest, AuthorSetDefaultRequest, + AuthorSetDefaultResponse, DocCloseRequest, DocCloseResponse, DocCreateRequest, + DocCreateResponse, DocDelRequest, DocDelResponse, DocDropRequest, DocDropResponse, + DocGetDownloadPolicyRequest, DocGetDownloadPolicyResponse, DocGetExactRequest, + DocGetExactResponse, DocGetManyRequest, DocGetSyncPeersRequest, DocGetSyncPeersResponse, + DocImportRequest, DocImportResponse, DocLeaveRequest, DocLeaveResponse, DocListRequest, + DocOpenRequest, DocOpenResponse, DocSetDownloadPolicyRequest, DocSetDownloadPolicyResponse, + DocSetHashRequest, DocSetHashResponse, DocSetRequest, DocSetResponse, DocShareRequest, + DocShareResponse, DocStartSyncRequest, DocStartSyncResponse, DocStatusRequest, + DocStatusResponse, DocSubscribeRequest, DocSubscribeResponse, RpcResult, }; /// Capacity for the flume channels to forward sync store iterators to async RPC streams. -const ITER_CHANNEL_CAP: usize = 64; +pub(super) const ITER_CHANNEL_CAP: usize = 64; #[allow(missing_docs)] impl DocsEngine { @@ -57,8 +58,8 @@ impl DocsEngine { pub fn author_list( &self, _req: AuthorListRequest, - ) -> impl Stream> { - let (tx, rx) = flume::bounded(ITER_CHANNEL_CAP); + tx: flume::Sender>, + ) { let sync = self.sync.clone(); // we need to spawn a task to send our request to the sync handle, because the method // itself must be sync. @@ -68,10 +69,6 @@ impl DocsEngine { tx2.send_async(Err(err)).await.ok(); } }); - rx.into_stream().map(|r| { - r.map(|author_id| AuthorListResponse { author_id }) - .map_err(Into::into) - }) } pub async fn author_import(&self, req: AuthorImportRequest) -> RpcResult { @@ -108,8 +105,12 @@ impl DocsEngine { Ok(DocDropResponse {}) } - pub fn doc_list(&self, _req: DocListRequest) -> impl Stream> { - let (tx, rx) = flume::bounded(ITER_CHANNEL_CAP); + pub fn doc_list( + &self, + _req: DocListRequest, + tx: flume::Sender>, + ) { + // let (tx, rx) = flume::bounded(ITER_CHANNEL_CAP); let sync = self.sync.clone(); // we need to spawn a task to send our request to the sync handle, because the method // itself must be sync. @@ -119,10 +120,6 @@ impl DocsEngine { tx2.send_async(Err(err)).await.ok(); } }); - rx.into_stream().map(|r| { - r.map(|(id, capability)| DocListResponse { id, capability }) - .map_err(Into::into) - }) } pub async fn doc_open(&self, req: DocOpenRequest) -> RpcResult { @@ -249,9 +246,9 @@ impl DocsEngine { pub fn doc_get_many( &self, req: DocGetManyRequest, - ) -> impl Stream> { + tx: flume::Sender>, + ) { let DocGetManyRequest { doc_id, query } = req; - let (tx, rx) = flume::bounded(ITER_CHANNEL_CAP); let sync = self.sync.clone(); // we need to spawn a task to send our request to the sync handle, because the method // itself must be sync. @@ -261,10 +258,6 @@ impl DocsEngine { tx2.send_async(Err(err)).await.ok(); } }); - rx.into_stream().map(|r| { - r.map(|entry| DocGetManyResponse { entry }) - .map_err(Into::into) - }) } pub async fn doc_get_exact(&self, req: DocGetExactRequest) -> RpcResult { diff --git a/iroh/src/rpc_protocol.rs b/iroh/src/rpc_protocol.rs index 8fe71e7d6a..8334590a11 100644 --- a/iroh/src/rpc_protocol.rs +++ b/iroh/src/rpc_protocol.rs @@ -439,7 +439,7 @@ pub struct AuthorCreateResponse { pub struct AuthorGetDefaultRequest; impl RpcMsg for AuthorGetDefaultRequest { - type Response = AuthorGetDefaultResponse; + type Response = RpcResult; } /// Response for [`AuthorGetDefaultRequest`] @@ -1153,7 +1153,7 @@ pub enum Response { AuthorList(RpcResult), AuthorCreate(RpcResult), - AuthorGetDefault(AuthorGetDefaultResponse), + AuthorGetDefault(RpcResult), AuthorSetDefault(RpcResult), AuthorImport(RpcResult), AuthorExport(RpcResult), diff --git a/iroh/tests/gc.rs b/iroh/tests/gc.rs index dcca0893b5..e032691df9 100644 --- a/iroh/tests/gc.rs +++ b/iroh/tests/gc.rs @@ -6,7 +6,7 @@ use std::{ use anyhow::Result; use bao_tree::{blake3, io::sync::Outboard, ChunkRanges}; use bytes::Bytes; -use iroh::node::{self, Node}; +use iroh::node::{self, DocsStorage, Node}; use rand::RngCore; use iroh_blobs::{ @@ -41,17 +41,19 @@ async fn wrap_in_node(bao_store: S, gc_period: Duration) -> (Node, flume:: where S: iroh_blobs::store::Store, { - let doc_store = iroh_docs::store::Store::memory(); let (gc_send, gc_recv) = flume::unbounded(); - let node = - node::Builder::with_db_and_store(bao_store, doc_store, iroh::node::StorageConfig::Mem) - .gc_policy(iroh::node::GcPolicy::Interval(gc_period)) - .register_gc_done_cb(Box::new(move || { - gc_send.send(()).ok(); - })) - .spawn() - .await - .unwrap(); + let node = node::Builder::with_db_and_store( + bao_store, + DocsStorage::Memory, + iroh::node::StorageConfig::Mem, + ) + .gc_policy(iroh::node::GcPolicy::Interval(gc_period)) + .register_gc_done_cb(Box::new(move || { + gc_send.send(()).ok(); + })) + .spawn() + .await + .unwrap(); (node, gc_recv) } diff --git a/iroh/tests/provide.rs b/iroh/tests/provide.rs index 13376273dd..7b9abf9648 100644 --- a/iroh/tests/provide.rs +++ b/iroh/tests/provide.rs @@ -8,7 +8,7 @@ use std::{ use anyhow::{Context, Result}; use bytes::Bytes; use futures_lite::FutureExt; -use iroh::node::Builder; +use iroh::node::{Builder, DocsStorage}; use iroh_base::node_addr::AddrInfoOptions; use iroh_net::{defaults::default_relay_map, key::SecretKey, NodeAddr, NodeId}; use quic_rpc::transport::misc::DummyServerEndpoint; @@ -40,8 +40,8 @@ async fn dial(secret_key: SecretKey, peer: NodeAddr) -> anyhow::Result(db: D) -> Builder { - let store = iroh_docs::store::Store::memory(); - iroh::node::Builder::with_db_and_store(db, store, iroh::node::StorageConfig::Mem).bind_port(0) + iroh::node::Builder::with_db_and_store(db, DocsStorage::Memory, iroh::node::StorageConfig::Mem) + .bind_port(0) } #[tokio::test]