diff --git a/iroh/src/node.rs b/iroh/src/node.rs index 74b7449ece5..d7950abed8f 100644 --- a/iroh/src/node.rs +++ b/iroh/src/node.rs @@ -3,27 +3,27 @@ //! A node is a server that serves various protocols. //! //! To shut down the node, call [`Node::shutdown`]. -use std::fmt::Debug; -use std::net::SocketAddr; use std::path::Path; use std::sync::Arc; +use std::{collections::BTreeSet, net::SocketAddr}; +use std::{fmt::Debug, time::Duration}; use anyhow::{anyhow, Result}; use futures_lite::StreamExt; use iroh_base::key::PublicKey; -use iroh_blobs::downloader::Downloader; -use iroh_blobs::store::Store as BaoStore; +use iroh_blobs::store::{GcMarkEvent, GcSweepEvent, Store as BaoStore}; +use iroh_blobs::{downloader::Downloader, protocol::Closed}; +use iroh_docs::actor::SyncHandle; use iroh_docs::engine::Engine; use iroh_gossip::net::Gossip; use iroh_net::key::SecretKey; use iroh_net::Endpoint; use iroh_net::{endpoint::DirectAddrsStream, util::SharedAbortingJoinHandle}; -use quic_rpc::transport::flume::FlumeConnection; -use quic_rpc::RpcClient; +use quic_rpc::{RpcServer, ServiceEndpoint}; +use tokio::task::JoinSet; use tokio_util::sync::CancellationToken; use tokio_util::task::LocalPoolHandle; -use tracing::debug; -use iroh_docs::actor::SyncHandle; +use tracing::{debug, error, info, warn}; use crate::{client::RpcService, node::protocol::ProtocolMap}; @@ -49,7 +49,6 @@ pub use protocol::ProtocolHandler; #[derive(Debug, Clone)] pub struct Node { inner: Arc>, - client: crate::client::MemIroh, task: SharedAbortingJoinHandle<()>, protocols: Arc, } @@ -57,12 +56,12 @@ pub struct Node { #[derive(derive_more::Debug)] struct NodeInner { db: D, - sync: DocsEngine, + docs: DocsEngine, endpoint: Endpoint, gossip: Gossip, secret_key: SecretKey, cancel_token: CancellationToken, - controller: FlumeConnection, + client: crate::client::MemIroh, #[debug("rt")] rt: LocalPoolHandle, downloader: Downloader, @@ -134,14 +133,9 @@ impl Node { self.inner.secret_key.public() } - /// Returns a handle that can be used to do RPC calls to the node internally. - pub fn controller(&self) -> crate::client::MemRpcClient { - RpcClient::new(self.inner.controller.clone()) - } - /// Return a client to control this node over an in-memory channel. pub fn client(&self) -> &crate::client::MemIroh { - &self.client + &self.inner.client } /// Returns a referenc to the used `LocalPoolHandle`. @@ -185,6 +179,7 @@ impl Node { /// Expose sync pub fn sync_handle(&self) -> &SyncHandle { &self.inner.sync.sync + } /// Returns a protocol handler for an ALPN. /// @@ -199,11 +194,11 @@ impl std::ops::Deref for Node { type Target = crate::client::MemIroh; fn deref(&self) -> &Self::Target { - &self.client + &self.inner.client } } -impl NodeInner { +impl NodeInner { async fn local_endpoint_addresses(&self) -> Result> { let endpoints = self .endpoint @@ -213,6 +208,243 @@ impl NodeInner { .ok_or(anyhow!("no endpoints found"))?; Ok(endpoints.into_iter().map(|x| x.addr).collect()) } + + async fn run( + self: Arc, + external_rpc: impl ServiceEndpoint, + internal_rpc: impl ServiceEndpoint, + protocols: Arc, + gc_policy: GcPolicy, + gc_done_callback: Option>, + ) { + let (ipv4, ipv6) = self.endpoint.bound_sockets(); + debug!( + "listening at: {}{}", + ipv4, + ipv6.map(|addr| format!(" and {addr}")).unwrap_or_default() + ); + debug!("rpc listening at: {:?}", external_rpc.local_addr()); + + let mut join_set = JoinSet::new(); + + // Setup the RPC servers. + let external_rpc = RpcServer::new(external_rpc); + let internal_rpc = RpcServer::new(internal_rpc); + + // TODO(frando): I think this is not needed as we do the same in a task just below. + // forward the initial endpoints to the gossip protocol. + // it may happen the the first endpoint update callback is missed because the gossip cell + // is only initialized once the endpoint is fully bound + if let Some(direct_addresses) = self.endpoint.direct_addresses().next().await { + debug!(me = ?self.endpoint.node_id(), "gossip initial update: {direct_addresses:?}"); + self.gossip.update_direct_addresses(&direct_addresses).ok(); + } + + // Spawn a task for the garbage collection. + if let GcPolicy::Interval(gc_period) = gc_policy { + let inner = self.clone(); + let handle = self + .rt + .spawn_pinned(move || inner.run_gc_loop(gc_period, gc_done_callback)); + // We cannot spawn tasks that run on the local pool directly into the join set, + // so instead we create a new task that supervises the local task. + join_set.spawn({ + async move { + if let Err(err) = handle.await { + return Err(anyhow::Error::from(err)); + } + Ok(()) + } + }); + } + + // Spawn a task that updates the gossip endpoints. + let inner = self.clone(); + join_set.spawn(async move { + let mut stream = inner.endpoint.direct_addresses(); + while let Some(eps) = stream.next().await { + if let Err(err) = inner.gossip.update_direct_addresses(&eps) { + warn!("Failed to update direct addresses for gossip: {err:?}"); + } + } + warn!("failed to retrieve local endpoints"); + Ok(()) + }); + + loop { + tokio::select! { + biased; + _ = self.cancel_token.cancelled() => { + break; + }, + // handle rpc requests. This will do nothing if rpc is not configured, since + // accept is just a pending future. + request = external_rpc.accept() => { + match request { + Ok((msg, chan)) => { + rpc::Handler::spawn_rpc_request(self.clone(), &mut join_set, msg, chan); + } + Err(e) => { + info!("rpc request error: {:?}", e); + } + } + }, + // handle internal rpc requests. + request = internal_rpc.accept() => { + match request { + Ok((msg, chan)) => { + rpc::Handler::spawn_rpc_request(self.clone(), &mut join_set, msg, chan); + } + Err(e) => { + info!("internal rpc request error: {:?}", e); + } + } + }, + // handle incoming p2p connections. + Some(connecting) = self.endpoint.accept() => { + let protocols = protocols.clone(); + join_set.spawn(async move { + handle_connection(connecting, protocols).await; + Ok(()) + }); + }, + // handle task terminations and quit on panics. + res = join_set.join_next(), if !join_set.is_empty() => { + if let Some(Err(err)) = res { + error!("Task failed: {err:?}"); + break; + } + }, + else => break, + } + } + + self.shutdown(protocols).await; + + // Abort remaining tasks. + join_set.shutdown().await; + } + + async fn shutdown(&self, protocols: Arc) { + // Shutdown the different parts of the node concurrently. + let error_code = Closed::ProviderTerminating; + // We ignore all errors during shutdown. + let _ = tokio::join!( + // Close the endpoint. + // Closing the Endpoint is the equivalent of calling Connection::close on all + // connections: Operations will immediately fail with ConnectionError::LocallyClosed. + // All streams are interrupted, this is not graceful. + self.endpoint + .clone() + .close(error_code.into(), error_code.reason()), + // Shutdown sync engine. + self.docs.shutdown(), + // Shutdown blobs store engine. + self.db.shutdown(), + // Shutdown protocol handlers. + protocols.shutdown(), + ); + } + + async fn run_gc_loop( + self: Arc, + gc_period: Duration, + done_cb: Option>, + ) { + tracing::info!("Starting GC task with interval {:?}", gc_period); + let db = &self.db; + let docs = &self.docs; + let mut live = BTreeSet::new(); + 'outer: loop { + if let Err(cause) = db.gc_start().await { + tracing::debug!( + "unable to notify the db of GC start: {cause}. Shutting down GC loop." + ); + break; + } + // do delay before the two phases of GC + tokio::time::sleep(gc_period).await; + tracing::debug!("Starting GC"); + live.clear(); + + let doc_hashes = match docs.sync.content_hashes().await { + Ok(hashes) => hashes, + Err(err) => { + tracing::warn!("Error getting doc hashes: {}", err); + continue 'outer; + } + }; + for hash in doc_hashes { + match hash { + Ok(hash) => { + live.insert(hash); + } + Err(err) => { + tracing::error!("Error getting doc hash: {}", err); + continue 'outer; + } + } + } + + tracing::debug!("Starting GC mark phase"); + let mut stream = db.gc_mark(&mut live); + while let Some(item) = stream.next().await { + match item { + GcMarkEvent::CustomDebug(text) => { + tracing::debug!("{}", text); + } + GcMarkEvent::CustomWarning(text, _) => { + tracing::warn!("{}", text); + } + GcMarkEvent::Error(err) => { + tracing::error!("Fatal error during GC mark {}", err); + continue 'outer; + } + } + } + drop(stream); + + tracing::debug!("Starting GC sweep phase"); + let mut stream = db.gc_sweep(&live); + while let Some(item) = stream.next().await { + match item { + GcSweepEvent::CustomDebug(text) => { + tracing::debug!("{}", text); + } + GcSweepEvent::CustomWarning(text, _) => { + tracing::warn!("{}", text); + } + GcSweepEvent::Error(err) => { + tracing::error!("Fatal error during GC mark {}", err); + continue 'outer; + } + } + } + if let Some(ref cb) = done_cb { + cb(); + } + } + } +} + +async fn handle_connection( + mut connecting: iroh_net::endpoint::Connecting, + protocols: Arc, +) { + let alpn = match connecting.alpn().await { + Ok(alpn) => alpn, + Err(err) => { + warn!("Ignoring connection: invalid handshake: {:?}", err); + return; + } + }; + let Some(handler) = protocols.get(&alpn) else { + warn!("Ignoring connection: unsupported ALPN protocol"); + return; + }; + if let Err(err) = handler.accept(connecting).await { + warn!("Handling incoming connection ended with error: {err}"); + } } /// Wrapper around [`Engine`] so that we can implement our RPC methods directly. @@ -238,7 +470,7 @@ mod tests { use crate::{ client::blobs::{AddOutcome, WrapOption}, - rpc_protocol::{BlobAddPathRequest, BlobAddPathResponse, SetTagOption}, + rpc_protocol::SetTagOption, }; use super::*; @@ -299,18 +531,17 @@ mod tests { let _got_hash = tokio::time::timeout(Duration::from_secs(1), async move { let mut stream = node - .controller() - .server_streaming(BlobAddPathRequest { - path: Path::new(env!("CARGO_MANIFEST_DIR")).join("README.md"), - in_place: false, - tag: SetTagOption::Auto, - wrap: WrapOption::NoWrap, - }) + .blobs() + .add_from_path( + Path::new(env!("CARGO_MANIFEST_DIR")).join("README.md"), + false, + SetTagOption::Auto, + WrapOption::NoWrap, + ) .await?; - while let Some(item) = stream.next().await { - let BlobAddPathResponse(progress) = item?; - match progress { + while let Some(progress) = stream.next().await { + match progress? { AddProgress::AllDone { hash, .. } => { return Ok(hash); } diff --git a/iroh/src/node/builder.rs b/iroh/src/node/builder.rs index 5a266127cf7..a3719b4503a 100644 --- a/iroh/src/node/builder.rs +++ b/iroh/src/node/builder.rs @@ -1,5 +1,4 @@ use std::{ - collections::BTreeSet, net::{Ipv4Addr, SocketAddrV4}, path::{Path, PathBuf}, sync::Arc, @@ -11,8 +10,7 @@ use futures_lite::StreamExt; use iroh_base::key::SecretKey; use iroh_blobs::{ downloader::Downloader, - protocol::Closed, - store::{GcMarkEvent, GcSweepEvent, Map, Store as BaoStore}, + store::{Map, Store as BaoStore}, }; use iroh_docs::engine::{DefaultAuthorStorage, Engine}; use iroh_docs::net::DOCS_ALPN; @@ -27,12 +25,11 @@ use quic_rpc::{ transport::{ flume::FlumeServerEndpoint, misc::DummyServerEndpoint, quinn::QuinnServerEndpoint, }, - RpcServer, ServiceEndpoint, + ServiceEndpoint, }; use serde::{Deserialize, Serialize}; -use tokio::task::JoinSet; use tokio_util::{sync::CancellationToken, task::LocalPoolHandle}; -use tracing::{debug, error, error_span, info, trace, warn, Instrument}; +use tracing::{debug, error_span, trace, Instrument}; use crate::{ client::RPC_ALPN, @@ -44,7 +41,7 @@ use crate::{ util::{fs::load_secret_key, path::IrohPaths}, }; -use super::{rpc, rpc_status::RpcStatus, DocsEngine, Node, NodeInner}; +use super::{rpc_status::RpcStatus, DocsEngine, Node, NodeInner}; /// Default bind address for the node. /// 11204 is "iroh" in leetspeak @@ -105,6 +102,18 @@ pub enum StorageConfig { Persistent(PathBuf), } +impl StorageConfig { + fn default_author_storage(&self) -> DefaultAuthorStorage { + match self { + StorageConfig::Persistent(ref root) => { + let path = IrohPaths::DefaultAuthor.with_root(root); + DefaultAuthorStorage::Persistent(path) + } + StorageConfig::Mem => DefaultAuthorStorage::Mem, + } + } +} + /// Configuration for node discovery. #[derive(Debug, Default)] pub enum DiscoveryConfig { @@ -397,59 +406,60 @@ where trace!("building node"); let lp = LocalPoolHandle::new(num_cpus::get()); - let mut transport_config = quinn::TransportConfig::default(); - transport_config - .max_concurrent_bidi_streams(MAX_STREAMS.try_into()?) - .max_concurrent_uni_streams(0u32.into()); - - let discovery: Option> = match self.node_discovery { - DiscoveryConfig::None => None, - DiscoveryConfig::Custom(discovery) => Some(discovery), - DiscoveryConfig::Default => { - let discovery = ConcurrentDiscovery::from_services(vec![ - // Enable DNS discovery by default - Box::new(DnsDiscovery::n0_dns()), - // Enable pkarr publishing by default - Box::new(PkarrPublisher::n0_dns(self.secret_key.clone())), - ]); - Some(Box::new(discovery)) - } - }; + let endpoint = { + let mut transport_config = quinn::TransportConfig::default(); + transport_config + .max_concurrent_bidi_streams(MAX_STREAMS.try_into()?) + .max_concurrent_uni_streams(0u32.into()); + + let discovery: Option> = match self.node_discovery { + DiscoveryConfig::None => None, + DiscoveryConfig::Custom(discovery) => Some(discovery), + DiscoveryConfig::Default => { + let discovery = ConcurrentDiscovery::from_services(vec![ + // Enable DNS discovery by default + Box::new(DnsDiscovery::n0_dns()), + // Enable pkarr publishing by default + Box::new(PkarrPublisher::n0_dns(self.secret_key.clone())), + ]); + Some(Box::new(discovery)) + } + }; - let endpoint = Endpoint::builder() - .secret_key(self.secret_key.clone()) - .proxy_from_env() - .keylog(self.keylog) - .transport_config(transport_config) - .concurrent_connections(MAX_CONNECTIONS) - .relay_mode(self.relay_mode); - let endpoint = match discovery { - Some(discovery) => endpoint.discovery(discovery), - None => endpoint, - }; - let endpoint = match self.dns_resolver { - Some(resolver) => endpoint.dns_resolver(resolver), - None => endpoint, - }; + let endpoint = Endpoint::builder() + .secret_key(self.secret_key.clone()) + .proxy_from_env() + .keylog(self.keylog) + .transport_config(transport_config) + .concurrent_connections(MAX_CONNECTIONS) + .relay_mode(self.relay_mode); + let endpoint = match discovery { + Some(discovery) => endpoint.discovery(discovery), + None => endpoint, + }; + let endpoint = match self.dns_resolver { + Some(resolver) => endpoint.dns_resolver(resolver), + None => endpoint, + }; - #[cfg(any(test, feature = "test-utils"))] - let endpoint = - endpoint.insecure_skip_relay_cert_verify(self.insecure_skip_relay_cert_verify); + #[cfg(any(test, feature = "test-utils"))] + let endpoint = + endpoint.insecure_skip_relay_cert_verify(self.insecure_skip_relay_cert_verify); - let endpoint = match self.storage { - StorageConfig::Persistent(ref root) => { - let peers_data_path = IrohPaths::PeerData.with_root(root); - endpoint.peers_data_path(peers_data_path) - } - StorageConfig::Mem => endpoint, + let endpoint = match self.storage { + StorageConfig::Persistent(ref root) => { + let peers_data_path = IrohPaths::PeerData.with_root(root); + endpoint.peers_data_path(peers_data_path) + } + StorageConfig::Mem => endpoint, + }; + let bind_port = self.bind_port.unwrap_or(DEFAULT_BIND_PORT); + endpoint.bind(bind_port).await? }; - let bind_port = self.bind_port.unwrap_or(DEFAULT_BIND_PORT); - let endpoint = endpoint.bind(bind_port).await?; - trace!("created quinn endpoint"); - - let cancel_token = CancellationToken::new(); + trace!("created endpoint"); let addr = endpoint.node_addr().await?; + trace!("endpoint address: {addr:?}"); // initialize the gossip protocol let gossip = Gossip::from_endpoint(endpoint.clone(), Default::default(), &addr.info); @@ -458,235 +468,47 @@ where let downloader = Downloader::new(self.blobs_store.clone(), endpoint.clone(), lp.clone()); // load or create the default author for documents - let default_author_storage = match self.storage { - StorageConfig::Persistent(ref root) => { - let path = IrohPaths::DefaultAuthor.with_root(root); - DefaultAuthorStorage::Persistent(path) - } - StorageConfig::Mem => DefaultAuthorStorage::Mem, - }; - // spawn the docs engine - let sync = Engine::spawn( - endpoint.clone(), - gossip.clone(), - self.docs_store, - self.blobs_store.clone(), - downloader.clone(), - default_author_storage, - ) - .await?; - let sync = DocsEngine(sync); + let docs = DocsEngine( + Engine::spawn( + endpoint.clone(), + gossip.clone(), + self.docs_store, + self.blobs_store.clone(), + downloader.clone(), + self.storage.default_author_storage(), + ) + .await?, + ); // Initialize the internal RPC connection. let (internal_rpc, controller) = quic_rpc::transport::flume::connection(1); let client = crate::client::Iroh::new(quic_rpc::RpcClient::new(controller.clone())); - debug!("rpc listening on: {:?}", self.rpc_endpoint.local_addr()); let inner = Arc::new(NodeInner { - db: self.blobs_store.clone(), - sync, - endpoint: endpoint.clone(), + db: self.blobs_store, + docs, + endpoint, secret_key: self.secret_key, - controller, - cancel_token, + client, + cancel_token: CancellationToken::new(), rt: lp, downloader, gossip, }); - let node = ProtocolBuilder { + let protocol_builder = ProtocolBuilder { inner, - client, protocols: Default::default(), internal_rpc, + external_rpc: self.rpc_endpoint, gc_policy: self.gc_policy, gc_done_callback: self.gc_done_callback, - rpc_endpoint: self.rpc_endpoint, }; - let node = node.register_iroh_protocols(); + let protocol_builder = protocol_builder.register_iroh_protocols(); - Ok(node) - } - - async fn run( - inner: Arc>, - rpc: E, - internal_rpc: impl ServiceEndpoint, - protocols: Arc, - mut join_set: JoinSet>, - ) { - let endpoint = inner.endpoint.clone(); - - let handler = rpc::Handler { - inner: inner.clone(), - }; - let rpc = RpcServer::new(rpc); - let internal_rpc = RpcServer::new(internal_rpc); - let (ipv4, ipv6) = endpoint.bound_sockets(); - debug!( - "listening at: {}{}", - ipv4, - ipv6.map(|addr| format!(" and {addr}")).unwrap_or_default() - ); - - let cancel_token = handler.inner.cancel_token.clone(); - - // forward the initial endpoints to the gossip protocol. - // it may happen the the first endpoint update callback is missed because the gossip cell - // is only initialized once the endpoint is fully bound - if let Some(direct_addresses) = endpoint.direct_addresses().next().await { - debug!(me = ?endpoint.node_id(), "gossip initial update: {direct_addresses:?}"); - inner.gossip.update_direct_addresses(&direct_addresses).ok(); - } - - loop { - tokio::select! { - biased; - _ = cancel_token.cancelled() => { - break; - }, - // handle rpc requests. This will do nothing if rpc is not configured, since - // accept is just a pending future. - request = rpc.accept() => { - match request { - Ok((msg, chan)) => { - handler.handle_rpc_request(msg, chan); - } - Err(e) => { - info!("rpc request error: {:?}", e); - } - } - }, - // handle internal rpc requests. - request = internal_rpc.accept() => { - match request { - Ok((msg, chan)) => { - handler.handle_rpc_request(msg, chan); - } - Err(e) => { - info!("internal rpc request error: {:?}", e); - } - } - }, - // handle incoming p2p connections. - Some(connecting) = endpoint.accept() => { - let protocols = protocols.clone(); - join_set.spawn(async move { - handle_connection(connecting, protocols).await; - Ok(()) - }); - }, - // handle task terminations and quit on panics. - res = join_set.join_next(), if !join_set.is_empty() => { - if let Some(Err(err)) = res { - error!("Task failed: {err:?}"); - break; - } - }, - else => break, - } - } - - // Shutdown the different parts of the node concurrently. - let error_code = Closed::ProviderTerminating; - // We ignore all errors during shutdown. - let _ = tokio::join!( - // Close the endpoint. - // Closing the Endpoint is the equivalent of calling Connection::close on all - // connections: Operations will immediately fail with ConnectionError::LocallyClosed. - // All streams are interrupted, this is not graceful. - endpoint.close(error_code.into(), error_code.reason()), - // Shutdown sync engine. - inner.sync.shutdown(), - // Shutdown blobs store engine. - inner.db.shutdown(), - // Shutdown protocol handlers. - protocols.shutdown(), - ); - - // Abort remaining tasks. - join_set.shutdown().await; - } - - async fn gc_loop( - db: D, - ds: DocsEngine, - gc_period: Duration, - done_cb: Option>, - ) { - let mut live = BTreeSet::new(); - tracing::debug!("GC loop starting {:?}", gc_period); - 'outer: loop { - if let Err(cause) = db.gc_start().await { - tracing::debug!( - "unable to notify the db of GC start: {cause}. Shutting down GC loop." - ); - break; - } - // do delay before the two phases of GC - tokio::time::sleep(gc_period).await; - tracing::debug!("Starting GC"); - live.clear(); - - let doc_hashes = match ds.sync.content_hashes().await { - Ok(hashes) => hashes, - Err(err) => { - tracing::warn!("Error getting doc hashes: {}", err); - continue 'outer; - } - }; - for hash in doc_hashes { - match hash { - Ok(hash) => { - live.insert(hash); - } - Err(err) => { - tracing::error!("Error getting doc hash: {}", err); - continue 'outer; - } - } - } - - tracing::debug!("Starting GC mark phase"); - let mut stream = db.gc_mark(&mut live); - while let Some(item) = stream.next().await { - match item { - GcMarkEvent::CustomDebug(text) => { - tracing::debug!("{}", text); - } - GcMarkEvent::CustomWarning(text, _) => { - tracing::warn!("{}", text); - } - GcMarkEvent::Error(err) => { - tracing::error!("Fatal error during GC mark {}", err); - continue 'outer; - } - } - } - drop(stream); - - tracing::debug!("Starting GC sweep phase"); - let mut stream = db.gc_sweep(&live); - while let Some(item) = stream.next().await { - match item { - GcSweepEvent::CustomDebug(text) => { - tracing::debug!("{}", text); - } - GcSweepEvent::CustomWarning(text, _) => { - tracing::warn!("{}", text); - } - GcSweepEvent::Error(err) => { - tracing::error!("Fatal error during GC mark {}", err); - continue 'outer; - } - } - } - if let Some(ref cb) = done_cb { - cb(); - } - } + Ok(protocol_builder) } } @@ -701,9 +523,8 @@ where #[derive(derive_more::Debug)] pub struct ProtocolBuilder { inner: Arc>, - client: crate::client::MemIroh, internal_rpc: FlumeServerEndpoint, - rpc_endpoint: E, + external_rpc: E, protocols: ProtocolMap, #[debug("callback")] gc_done_callback: Option>, @@ -769,7 +590,7 @@ impl> ProtocolBuilde /// Note that RPC calls performed with the client will not complete until the node is /// spawned. pub fn client(&self) -> &crate::client::MemIroh { - &self.client + &self.inner.client } /// Returns the [`Endpoint`] of the node. @@ -817,7 +638,7 @@ impl> ProtocolBuilde self = self.accept(GOSSIP_ALPN, Arc::new(gossip)); // Register docs. - let docs = self.inner.sync.clone(); + let docs = self.inner.docs.clone(); self = self.accept(DOCS_ALPN, Arc::new(docs)); self @@ -827,97 +648,63 @@ impl> ProtocolBuilde pub async fn spawn(self) -> Result> { let Self { inner, - client, internal_rpc, - rpc_endpoint, + external_rpc, protocols, gc_done_callback, gc_policy, } = self; let protocols = Arc::new(protocols); - let protocols_clone = protocols.clone(); - - // Create the actual spawn future in an async block so that we can shutdown the protocols in case of - // error. - let node_fut = async move { - let mut join_set = JoinSet::new(); - - // Spawn a task for the garbage collection. - if let GcPolicy::Interval(gc_period) = gc_policy { - tracing::info!("Starting GC task with interval {:?}", gc_period); - let lp = inner.rt.clone(); - let docs = inner.sync.clone(); - let blobs_store = inner.db.clone(); - let handle = lp.spawn_pinned(move || { - Builder::::gc_loop(blobs_store, docs, gc_period, gc_done_callback) - }); - // We cannot spawn tasks that run on the local pool directly into the join set, - // so instead we create a new task that supervises the local task. - join_set.spawn(async move { - if let Err(err) = handle.await { - return Err(anyhow::Error::from(err)); - } - Ok(()) - }); - } - - // Spawn a task that updates the gossip endpoints. - let mut stream = inner.endpoint.direct_addresses(); - let gossip = inner.gossip.clone(); - join_set.spawn(async move { - while let Some(eps) = stream.next().await { - if let Err(err) = gossip.update_direct_addresses(&eps) { - warn!("Failed to update direct addresses for gossip: {err:?}"); - } - } - warn!("failed to retrieve local endpoints"); - Ok(()) - }); - - // Update the endpoint with our alpns. - let alpns = protocols - .alpns() - .map(|alpn| alpn.to_vec()) - .collect::>(); - inner.endpoint.set_alpns(alpns)?; - - // Spawn the main task and store it in the node for structured termination in shutdown. - let task = tokio::task::spawn( - Builder::run( - inner.clone(), - rpc_endpoint, - internal_rpc, - protocols.clone(), - join_set, - ) - .instrument(error_span!("node", me=%inner.endpoint.node_id().fmt_short())), - ); - - let node = Node { - inner, - client, - protocols, - task: task.into(), - }; - - // Wait for a single endpoint update, to make sure - // we found some endpoints - tokio::time::timeout(ENDPOINT_WAIT, node.endpoint().direct_addresses().next()) - .await - .context("waiting for endpoint")? - .context("no endpoints")?; + let node_id = inner.endpoint.node_id(); + + // Update the endpoint with our alpns. + let alpns = protocols + .alpns() + .map(|alpn| alpn.to_vec()) + .collect::>(); + if let Err(err) = inner.endpoint.set_alpns(alpns) { + inner.shutdown(protocols).await; + return Err(err); + } - Ok(node) + // Spawn the main task and store it in the node for structured termination in shutdown. + let fut = inner + .clone() + .run( + external_rpc, + internal_rpc, + protocols.clone(), + gc_policy, + gc_done_callback, + ) + .instrument(error_span!("node", me=%node_id.fmt_short())); + let task = tokio::task::spawn(fut); + + let node = Node { + inner, + protocols, + task: task.into(), }; - match node_fut.await { - Ok(node) => Ok(node), - Err(err) => { - // Shutdown the protocols in case of error. - protocols_clone.shutdown().await; - Err(err) + // Wait for a single direct address update, to make sure + // we found at least one direct address. + let wait_for_endpoints = { + let node = node.clone(); + async move { + tokio::time::timeout(ENDPOINT_WAIT, node.endpoint().direct_addresses().next()) + .await + .context("waiting for endpoint")? + .context("no endpoints")?; + Ok(()) } + }; + + if let Err(err) = wait_for_endpoints.await { + node.shutdown().await.ok(); + return Err(err); } + + Ok(node) } } @@ -936,26 +723,6 @@ impl Default for GcPolicy { } } -async fn handle_connection( - mut connecting: iroh_net::endpoint::Connecting, - protocols: Arc, -) { - let alpn = match connecting.alpn().await { - Ok(alpn) => alpn, - Err(err) => { - warn!("Ignoring connection: invalid handshake: {:?}", err); - return; - } - }; - let Some(handler) = protocols.get(&alpn) else { - warn!("Ignoring connection: unsupported ALPN protocol"); - return; - }; - if let Err(err) = handler.accept(connecting).await { - warn!("Handling incoming connection ended with error: {err}"); - } -} - const DEFAULT_RPC_PORT: u16 = 0x1337; const MAX_RPC_CONNECTIONS: u32 = 16; const MAX_RPC_STREAMS: u32 = 1024; diff --git a/iroh/src/node/rpc.rs b/iroh/src/node/rpc.rs index 89d9d5fd9f0..697b6d63cd5 100644 --- a/iroh/src/node/rpc.rs +++ b/iroh/src/node/rpc.rs @@ -29,8 +29,9 @@ use quic_rpc::{ server::{RpcChannel, RpcServerError}, ServiceEndpoint, }; +use tokio::task::JoinSet; use tokio_util::task::LocalPoolHandle; -use tracing::{debug, info}; +use tracing::{debug, info, warn}; use crate::client::blobs::{BlobInfo, DownloadMode, IncompleteBlobInfo, WrapOption}; use crate::client::tags::TagInfo; @@ -65,234 +66,235 @@ pub(crate) struct Handler { pub(crate) inner: Arc>, } +impl Handler { + pub fn new(inner: Arc>) -> Self { + Self { inner } + } +} + impl Handler { - pub(crate) fn handle_rpc_request>( - &self, + pub(crate) fn spawn_rpc_request>( + inner: Arc>, + join_set: &mut JoinSet>, msg: Request, chan: RpcChannel, ) { - let handler = self.clone(); - tokio::task::spawn(async move { - use Request::*; - debug!("handling rpc request: {msg}"); - match msg { - NodeWatch(msg) => chan.server_streaming(msg, handler, Self::node_watch).await, - NodeStatus(msg) => chan.rpc(msg, handler, Self::node_status).await, - NodeId(msg) => chan.rpc(msg, handler, Self::node_id).await, - NodeAddr(msg) => chan.rpc(msg, handler, Self::node_addr).await, - NodeRelay(msg) => chan.rpc(msg, handler, Self::node_relay).await, - NodeShutdown(msg) => chan.rpc(msg, handler, Self::node_shutdown).await, - NodeStats(msg) => chan.rpc(msg, handler, Self::node_stats).await, - NodeConnections(msg) => { - chan.server_streaming(msg, handler, Self::node_connections) - .await - } - NodeConnectionInfo(msg) => chan.rpc(msg, handler, Self::node_connection_info).await, - BlobList(msg) => chan.server_streaming(msg, handler, Self::blob_list).await, - BlobListIncomplete(msg) => { - chan.server_streaming(msg, handler, Self::blob_list_incomplete) - .await - } - CreateCollection(msg) => chan.rpc(msg, handler, Self::create_collection).await, - ListTags(msg) => { - chan.server_streaming(msg, handler, Self::blob_list_tags) - .await - } - DeleteTag(msg) => chan.rpc(msg, handler, Self::blob_delete_tag).await, - BlobDeleteBlob(msg) => chan.rpc(msg, handler, Self::blob_delete_blob).await, - BlobAddPath(msg) => { - chan.server_streaming(msg, handler, Self::blob_add_from_path) - .await - } - BlobDownload(msg) => { - chan.server_streaming(msg, handler, Self::blob_download) - .await - } - BlobExport(msg) => chan.server_streaming(msg, handler, Self::blob_export).await, - BlobValidate(msg) => { - chan.server_streaming(msg, handler, Self::blob_validate) - .await - } - BlobFsck(msg) => { - chan.server_streaming(msg, handler, Self::blob_consistency_check) - .await - } - BlobReadAt(msg) => { - chan.server_streaming(msg, handler, Self::blob_read_at) - .await - } - BlobAddStream(msg) => { - chan.bidi_streaming(msg, handler, Self::blob_add_stream) - .await - } - BlobAddStreamUpdate(_msg) => Err(RpcServerError::UnexpectedUpdateMessage), - AuthorList(msg) => { - chan.server_streaming(msg, handler, |handler, req| { - handler.inner.sync.author_list(req) - }) - .await - } - AuthorCreate(msg) => { - chan.rpc(msg, handler, |handler, req| async move { - handler.inner.sync.author_create(req).await - }) - .await - } - AuthorImport(msg) => { - chan.rpc(msg, handler, |handler, req| async move { - handler.inner.sync.author_import(req).await - }) - .await - } - AuthorExport(msg) => { - chan.rpc(msg, handler, |handler, req| async move { - handler.inner.sync.author_export(req).await - }) - .await - } - AuthorDelete(msg) => { - chan.rpc(msg, handler, |handler, req| async move { - handler.inner.sync.author_delete(req).await - }) - .await - } - AuthorGetDefault(msg) => { - chan.rpc(msg, handler, |handler, req| async move { - handler.inner.sync.author_default(req) - }) - .await - } - AuthorSetDefault(msg) => { - chan.rpc(msg, handler, |handler, req| async move { - handler.inner.sync.author_set_default(req).await - }) - .await - } - DocOpen(msg) => { - chan.rpc(msg, handler, |handler, req| async move { - handler.inner.sync.doc_open(req).await - }) - .await - } - DocClose(msg) => { - chan.rpc(msg, handler, |handler, req| async move { - handler.inner.sync.doc_close(req).await - }) - .await - } - DocStatus(msg) => { - chan.rpc(msg, handler, |handler, req| async move { - handler.inner.sync.doc_status(req).await - }) - .await - } - DocList(msg) => { - chan.server_streaming(msg, handler, |handler, req| { - handler.inner.sync.doc_list(req) - }) - .await - } - DocCreate(msg) => { - chan.rpc(msg, handler, |handler, req| async move { - handler.inner.sync.doc_create(req).await - }) - .await - } - DocDrop(msg) => { - chan.rpc(msg, handler, |handler, req| async move { - handler.inner.sync.doc_drop(req).await - }) - .await - } - DocImport(msg) => { - chan.rpc(msg, handler, |handler, req| async move { - handler.inner.sync.doc_import(req).await - }) - .await - } - DocSet(msg) => { - let bao_store = handler.inner.db.clone(); - chan.rpc(msg, handler, |handler, req| async move { - handler.inner.sync.doc_set(&bao_store, req).await - }) - .await - } - DocImportFile(msg) => { - chan.server_streaming(msg, handler, Self::doc_import_file) - .await - } - DocExportFile(msg) => { - chan.server_streaming(msg, handler, Self::doc_export_file) - .await - } - DocDel(msg) => { - chan.rpc(msg, handler, |handler, req| async move { - handler.inner.sync.doc_del(req).await - }) - .await - } - DocSetHash(msg) => { - chan.rpc(msg, handler, |handler, req| async move { - handler.inner.sync.doc_set_hash(req).await - }) - .await - } - DocGet(msg) => { - chan.server_streaming(msg, handler, |handler, req| { - handler.inner.sync.doc_get_many(req) - }) - .await - } - DocGetExact(msg) => { - chan.rpc(msg, handler, |handler, req| async move { - handler.inner.sync.doc_get_exact(req).await - }) - .await - } - DocStartSync(msg) => { - chan.rpc(msg, handler, |handler, req| async move { - handler.inner.sync.doc_start_sync(req).await - }) + let handler = Self::new(inner); + join_set.spawn(async move { + if let Err(err) = handler.handle_rpc_request(msg, chan).await { + warn!("rpc request handler error: {err:?}"); + } + Ok(()) + }); + } + + pub(crate) async fn handle_rpc_request>( + self, + msg: Request, + chan: RpcChannel, + ) -> Result<(), RpcServerError> { + use Request::*; + debug!("handling rpc request: {msg}"); + match msg { + NodeWatch(msg) => chan.server_streaming(msg, self, Self::node_watch).await, + NodeStatus(msg) => chan.rpc(msg, self, Self::node_status).await, + NodeId(msg) => chan.rpc(msg, self, Self::node_id).await, + NodeAddr(msg) => chan.rpc(msg, self, Self::node_addr).await, + NodeRelay(msg) => chan.rpc(msg, self, Self::node_relay).await, + NodeShutdown(msg) => chan.rpc(msg, self, Self::node_shutdown).await, + NodeStats(msg) => chan.rpc(msg, self, Self::node_stats).await, + NodeConnections(msg) => { + chan.server_streaming(msg, self, Self::node_connections) .await - } - DocLeave(msg) => { - chan.rpc(msg, handler, |handler, req| async move { - handler.inner.sync.doc_leave(req).await - }) + } + NodeConnectionInfo(msg) => chan.rpc(msg, self, Self::node_connection_info).await, + BlobList(msg) => chan.server_streaming(msg, self, Self::blob_list).await, + BlobListIncomplete(msg) => { + chan.server_streaming(msg, self, Self::blob_list_incomplete) .await - } - DocShare(msg) => { - chan.rpc(msg, handler, |handler, req| async move { - handler.inner.sync.doc_share(req).await - }) + } + CreateCollection(msg) => chan.rpc(msg, self, Self::create_collection).await, + ListTags(msg) => chan.server_streaming(msg, self, Self::blob_list_tags).await, + DeleteTag(msg) => chan.rpc(msg, self, Self::blob_delete_tag).await, + BlobDeleteBlob(msg) => chan.rpc(msg, self, Self::blob_delete_blob).await, + BlobAddPath(msg) => { + chan.server_streaming(msg, self, Self::blob_add_from_path) .await - } - DocSubscribe(msg) => { - chan.try_server_streaming(msg, handler, |handler, req| async move { - handler.inner.sync.doc_subscribe(req).await - }) + } + BlobDownload(msg) => chan.server_streaming(msg, self, Self::blob_download).await, + BlobExport(msg) => chan.server_streaming(msg, self, Self::blob_export).await, + BlobValidate(msg) => chan.server_streaming(msg, self, Self::blob_validate).await, + BlobFsck(msg) => { + chan.server_streaming(msg, self, Self::blob_consistency_check) .await - } - DocSetDownloadPolicy(msg) => { - chan.rpc(msg, handler, |handler, req| async move { - handler.inner.sync.doc_set_download_policy(req).await - }) + } + BlobReadAt(msg) => chan.server_streaming(msg, self, Self::blob_read_at).await, + BlobAddStream(msg) => chan.bidi_streaming(msg, self, Self::blob_add_stream).await, + BlobAddStreamUpdate(_msg) => Err(RpcServerError::UnexpectedUpdateMessage), + AuthorList(msg) => { + chan.server_streaming(msg, self, |handler, req| { + handler.inner.docs.author_list(req) + }) + .await + } + AuthorCreate(msg) => { + chan.rpc(msg, self, |handler, req| async move { + handler.inner.docs.author_create(req).await + }) + .await + } + AuthorImport(msg) => { + chan.rpc(msg, self, |handler, req| async move { + handler.inner.docs.author_import(req).await + }) + .await + } + AuthorExport(msg) => { + chan.rpc(msg, self, |handler, req| async move { + handler.inner.docs.author_export(req).await + }) + .await + } + AuthorDelete(msg) => { + chan.rpc(msg, self, |handler, req| async move { + handler.inner.docs.author_delete(req).await + }) + .await + } + AuthorGetDefault(msg) => { + chan.rpc(msg, self, |handler, req| async move { + handler.inner.docs.author_default(req) + }) + .await + } + AuthorSetDefault(msg) => { + chan.rpc(msg, self, |handler, req| async move { + handler.inner.docs.author_set_default(req).await + }) + .await + } + DocOpen(msg) => { + chan.rpc(msg, self, |handler, req| async move { + handler.inner.docs.doc_open(req).await + }) + .await + } + DocClose(msg) => { + chan.rpc(msg, self, |handler, req| async move { + handler.inner.docs.doc_close(req).await + }) + .await + } + DocStatus(msg) => { + chan.rpc(msg, self, |handler, req| async move { + handler.inner.docs.doc_status(req).await + }) + .await + } + DocList(msg) => { + chan.server_streaming(msg, self, |handler, req| handler.inner.docs.doc_list(req)) .await - } - DocGetDownloadPolicy(msg) => { - chan.rpc(msg, handler, |handler, req| async move { - handler.inner.sync.doc_get_download_policy(req).await - }) + } + DocCreate(msg) => { + chan.rpc(msg, self, |handler, req| async move { + handler.inner.docs.doc_create(req).await + }) + .await + } + DocDrop(msg) => { + chan.rpc(msg, self, |handler, req| async move { + handler.inner.docs.doc_drop(req).await + }) + .await + } + DocImport(msg) => { + chan.rpc(msg, self, |handler, req| async move { + handler.inner.docs.doc_import(req).await + }) + .await + } + DocSet(msg) => { + let bao_store = self.inner.db.clone(); + chan.rpc(msg, self, |handler, req| async move { + handler.inner.docs.doc_set(&bao_store, req).await + }) + .await + } + DocImportFile(msg) => { + chan.server_streaming(msg, self, Self::doc_import_file) .await - } - DocGetSyncPeers(msg) => { - chan.rpc(msg, handler, |handler, req| async move { - handler.inner.sync.doc_get_sync_peers(req).await - }) + } + DocExportFile(msg) => { + chan.server_streaming(msg, self, Self::doc_export_file) .await - } } - }); + DocDel(msg) => { + chan.rpc(msg, self, |handler, req| async move { + handler.inner.docs.doc_del(req).await + }) + .await + } + DocSetHash(msg) => { + chan.rpc(msg, self, |handler, req| async move { + handler.inner.docs.doc_set_hash(req).await + }) + .await + } + DocGet(msg) => { + chan.server_streaming(msg, self, |handler, req| { + handler.inner.docs.doc_get_many(req) + }) + .await + } + DocGetExact(msg) => { + chan.rpc(msg, self, |handler, req| async move { + handler.inner.docs.doc_get_exact(req).await + }) + .await + } + DocStartSync(msg) => { + chan.rpc(msg, self, |handler, req| async move { + handler.inner.docs.doc_start_sync(req).await + }) + .await + } + DocLeave(msg) => { + chan.rpc(msg, self, |handler, req| async move { + handler.inner.docs.doc_leave(req).await + }) + .await + } + DocShare(msg) => { + chan.rpc(msg, self, |handler, req| async move { + handler.inner.docs.doc_share(req).await + }) + .await + } + DocSubscribe(msg) => { + chan.try_server_streaming(msg, self, |handler, req| async move { + handler.inner.docs.doc_subscribe(req).await + }) + .await + } + DocSetDownloadPolicy(msg) => { + chan.rpc(msg, self, |handler, req| async move { + handler.inner.docs.doc_set_download_policy(req).await + }) + .await + } + DocGetDownloadPolicy(msg) => { + chan.rpc(msg, self, |handler, req| async move { + handler.inner.docs.doc_get_download_policy(req).await + }) + .await + } + DocGetSyncPeers(msg) => { + chan.rpc(msg, self, |handler, req| async move { + handler.inner.docs.doc_get_sync_peers(req).await + }) + .await + } + } } fn rt(&self) -> LocalPoolHandle { @@ -518,7 +520,7 @@ impl Handler { let hash_and_format = temp_tag.inner(); let HashAndFormat { hash, .. } = *hash_and_format; self.inner - .sync + .docs .doc_set_hash(DocSetHashRequest { doc_id, author_id,