From 07844031c3e568e34c64a825803c9cd3f91a2035 Mon Sep 17 00:00:00 2001 From: Franz Heinzmann Date: Mon, 5 Aug 2024 18:38:23 +0200 Subject: [PATCH] fix(iroh-blobs): do not hit the network when downloading blobs which are complete (#2586) ## Description Two changes to the downloader: * Never try to download from ourselves. If the only provider node added is our own node, fail with error "no providers". * The actual download request flow is turned into a generator (while keeping API compatibility for the existing `get_to_db` public function). A new `get_to_db_in_steps` function either runs to completion if the requested data is fully available locally, or yields a `NeedsConn` struct at the point where it needs a network connection to proceed. The `NeedsConn` has an `async proceed(self, conn: Connection)`, which must be called with a connection for the actual download to start. This two-step process allows the downloader to check if we should dial nodes at all, or are already done without doing anything, while emitting the exact same flow of events (because we run the same loop) to the client. To achieve this, `get_to_db` now uses a genawaiter generator internally. This means that the big loop that is the iroh-blobs protocol request flow does not have to be changed at all, only that instead of a closure we yield and resume, which makes this much easier to integrate into an external state machine like the downloader. The changes needed for this for the downloader are a bit verbose because the downloader itself is generic over a `Getter`, with impls for the actual impl and a test impl that does not use networking; therefore the new `NeedsConn` state has to be modeled with an additional associated type and trait here. This PR adds three tests: * Downloading a missing blob from the local node fails without trying to connect to ourselves * Downloading an existing blob succeeds without trying to download * Downloading an existing collection succeeds without trying to download Closes #2575 Replaced #2576 ## Notes and open questions ## Breaking changes None, only an API addition to the public API of iroh_blobs: `iroh_blobs::get::check_local_with_progress_if_complete` --------- Co-authored-by: dignifiedquire --- iroh-blobs/src/downloader.rs | 191 ++++++++++++++++------- iroh-blobs/src/downloader/get.rs | 90 ++++++----- iroh-blobs/src/downloader/invariants.rs | 2 +- iroh-blobs/src/downloader/progress.rs | 3 + iroh-blobs/src/downloader/test.rs | 23 ++- iroh-blobs/src/downloader/test/dialer.rs | 7 + iroh-blobs/src/downloader/test/getter.rs | 26 ++- iroh-blobs/src/get/db.rs | 128 +++++++++++++-- iroh-net/src/dialer.rs | 5 + iroh/src/client/blobs.rs | 184 ++++++++++++++++++++++ iroh/src/node/rpc.rs | 98 +++++++----- 11 files changed, 594 insertions(+), 163 deletions(-) diff --git a/iroh-blobs/src/downloader.rs b/iroh-blobs/src/downloader.rs index dd26a8bc6d..21644e8d93 100644 --- a/iroh-blobs/src/downloader.rs +++ b/iroh-blobs/src/downloader.rs @@ -27,8 +27,12 @@ //! requests to a single node is also limited. use std::{ - collections::{hash_map::Entry, HashMap, HashSet}, + collections::{ + hash_map::{self, Entry}, + HashMap, HashSet, + }, fmt, + future::Future, num::NonZeroUsize, sync::{ atomic::{AtomicU64, Ordering}, @@ -46,7 +50,7 @@ use tokio::{ sync::{mpsc, oneshot}, task::JoinSet, }; -use tokio_util::{sync::CancellationToken, time::delay_queue}; +use tokio_util::{either::Either, sync::CancellationToken, time::delay_queue}; use tracing::{debug, error_span, trace, warn, Instrument}; use crate::{ @@ -75,13 +79,15 @@ pub struct IntentId(pub u64); /// Trait modeling a dialer. This allows for IO-less testing. pub trait Dialer: Stream)> + Unpin { /// Type of connections returned by the Dialer. - type Connection: Clone; + type Connection: Clone + 'static; /// Dial a node. fn queue_dial(&mut self, node_id: NodeId); /// Get the number of dialing nodes. fn pending_count(&self) -> usize; /// Check if a node is being dialed. fn is_pending(&self, node: NodeId) -> bool; + /// Get the node id of our node. + fn node_id(&self) -> NodeId; } /// Signals what should be done with the request when it fails. @@ -97,20 +103,39 @@ pub enum FailureAction { RetryLater(anyhow::Error), } -/// Future of a get request. -type GetFut = BoxedLocal; +/// Future of a get request, for the checking stage. +type GetStartFut = BoxedLocal, FailureAction>>; +/// Future of a get request, for the downloading stage. +type GetProceedFut = BoxedLocal; /// Trait modelling performing a single request over a connection. This allows for IO-less testing. pub trait Getter { /// Type of connections the Getter requires to perform a download. - type Connection; - /// Return a future that performs the download using the given connection. + type Connection: 'static; + /// Type of the intermediary state returned from [`Self::get`] if a connection is needed. + type NeedsConn: NeedsConn; + /// Returns a future that checks the local store if the request is already complete, returning + /// a struct implementing [`NeedsConn`] if we need a network connection to proceed. fn get( &mut self, kind: DownloadKind, - conn: Self::Connection, progress_sender: BroadcastProgressSender, - ) -> GetFut; + ) -> GetStartFut; +} + +/// Trait modelling the intermediary state when a connection is needed to proceed. +pub trait NeedsConn: std::fmt::Debug + 'static { + /// Proceeds the download with the given connection. + fn proceed(self, conn: C) -> GetProceedFut; +} + +/// Output returned from [`Getter::get`]. +#[derive(Debug)] +pub enum GetOutput { + /// The request is already complete in the local store. + Complete(Stats), + /// The request needs a connection to continue. + NeedsConn(N), } /// Concurrency limits for the [`Downloader`]. @@ -280,7 +305,7 @@ pub struct DownloadHandle { receiver: oneshot::Receiver, } -impl std::future::Future for DownloadHandle { +impl Future for DownloadHandle { type Output = ExternalDownloadResult; fn poll( @@ -424,10 +449,12 @@ struct IntentHandlers { } /// Information about a request. -#[derive(Debug, Default)] -struct RequestInfo { +#[derive(Debug)] +struct RequestInfo { /// Registered intents with progress senders and result callbacks. intents: HashMap, + progress_sender: BroadcastProgressSender, + get_state: Option, } /// Information about a request in progress. @@ -529,7 +556,7 @@ struct Service { /// Queue of pending downloads. queue: Queue, /// Information about pending and active requests. - requests: HashMap, + requests: HashMap>, /// State of running downloads. active_requests: HashMap, /// Tasks for currently running downloads. @@ -666,48 +693,85 @@ impl, D: Dialer> Service { on_progress: progress, }; - // early exit if no providers. - if nodes.is_empty() && self.providers.get_candidates(&kind.hash()).next().is_none() { - self.finalize_download( - kind, - [(intent_id, intent_handlers)].into(), - Err(DownloadError::NoProviders), - ); - return; - } - // add the nodes to the provider map - let updated = self - .providers - .add_hash_with_nodes(kind.hash(), nodes.iter().map(|n| n.node_id)); + // (skip the node id of our own node - we should never attempt to download from ourselves) + let node_ids = nodes + .iter() + .map(|n| n.node_id) + .filter(|node_id| *node_id != self.dialer.node_id()); + let updated = self.providers.add_hash_with_nodes(kind.hash(), node_ids); // queue the transfer (if not running) or attach to transfer progress (if already running) - if self.active_requests.contains_key(&kind) { - // the transfer is already running, so attach the progress sender - if let Some(on_progress) = &intent_handlers.on_progress { - // this is async because it sends the current state over the progress channel - if let Err(err) = self - .progress_tracker - .subscribe(kind, on_progress.clone()) - .await - { - debug!(?err, %kind, "failed to subscribe progress sender to transfer"); + match self.requests.entry(kind) { + hash_map::Entry::Occupied(mut entry) => { + if let Some(on_progress) = &intent_handlers.on_progress { + // this is async because it sends the current state over the progress channel + if let Err(err) = self + .progress_tracker + .subscribe(kind, on_progress.clone()) + .await + { + debug!(?err, %kind, "failed to subscribe progress sender to transfer"); + } } + entry.get_mut().intents.insert(intent_id, intent_handlers); } - } else { - // the transfer is not running. - if updated && self.queue.is_parked(&kind) { - // the transfer is on hold for pending retries, and we added new nodes, so move back to queue. - self.queue.unpark(&kind); - } else if !self.queue.contains(&kind) { - // the transfer is not yet queued: add to queue. + hash_map::Entry::Vacant(entry) => { + tracing::warn!("is new, queue"); + let progress_sender = self.progress_tracker.track( + kind, + intent_handlers + .on_progress + .clone() + .into_iter() + .collect::>(), + ); + + let get_state = match self.getter.get(kind, progress_sender.clone()).await { + Err(_err) => { + self.finalize_download( + kind, + [(intent_id, intent_handlers)].into(), + // TODO: add better error variant? this is only triggered if the local + // store failed with local IO. + Err(DownloadError::DownloadFailed), + ); + return; + } + Ok(GetOutput::Complete(stats)) => { + self.finalize_download( + kind, + [(intent_id, intent_handlers)].into(), + Ok(stats), + ); + return; + } + Ok(GetOutput::NeedsConn(state)) => { + // early exit if no providers. + if self.providers.get_candidates(&kind.hash()).next().is_none() { + self.finalize_download( + kind, + [(intent_id, intent_handlers)].into(), + Err(DownloadError::NoProviders), + ); + return; + } + state + } + }; + entry.insert(RequestInfo { + intents: [(intent_id, intent_handlers)].into_iter().collect(), + progress_sender, + get_state: Some(get_state), + }); self.queue.insert(kind); } } - // store the request info - let request_info = self.requests.entry(kind).or_default(); - request_info.intents.insert(intent_id, intent_handlers); + if updated && self.queue.is_parked(&kind) { + // the transfer is on hold for pending retries, and we added new nodes, so move back to queue. + self.queue.unpark(&kind); + } } /// Cancels a download intent. @@ -860,7 +924,6 @@ impl, D: Dialer> Service { ) { self.progress_tracker.remove(&kind); self.remove_hash_if_not_queued(&kind.hash()); - let result = result.map_err(|_| DownloadError::DownloadFailed); for (_id, handlers) in intents.into_iter() { handlers.on_finish.send(result.clone()).ok(); } @@ -1082,14 +1145,9 @@ impl, D: Dialer> Service { /// Panics if hash is not in self.requests or node is not in self.nodes. fn start_download(&mut self, kind: DownloadKind, node: NodeId) { let node_info = self.connected_nodes.get_mut(&node).expect("node exists"); - let request_info = self.requests.get(&kind).expect("hash exists"); - - // create a progress sender and subscribe all intents to the progress sender - let subscribers = request_info - .intents - .values() - .flat_map(|state| state.on_progress.clone()); - let progress_sender = self.progress_tracker.track(kind, subscribers); + let request_info = self.requests.get_mut(&kind).expect("request exists"); + let progress = request_info.progress_sender.clone(); + // .expect("queued state exists"); // create the active request state let cancellation = CancellationToken::new(); @@ -1098,7 +1156,15 @@ impl, D: Dialer> Service { node, }; let conn = node_info.conn.clone(); - let get_fut = self.getter.get(kind, conn, progress_sender); + + // If this is the first provider node we try, we have an initial state + // from starting the generator in Self::handle_queue_new_download. + // If this not the first provider node we try, we have to recreate the generator, because + // we can only resume it once. + let get_state = match request_info.get_state.take() { + Some(state) => Either::Left(async move { Ok(GetOutput::NeedsConn(state)) }), + None => Either::Right(self.getter.get(kind, progress)), + }; let fut = async move { // NOTE: it's an open question if we should do timeouts at this point. Considerations from @Frando: // > at this stage we do not know the size of the download, so the timeout would have @@ -1106,9 +1172,16 @@ impl, D: Dialer> Service { // > this means that a super slow node would block a download from succeeding for a long // > time, while faster nodes could be readily available. // As a conclusion, timeouts should be added only after downloads are known to be bounded + let fut = async move { + match get_state.await? { + GetOutput::Complete(stats) => Ok(stats), + GetOutput::NeedsConn(state) => state.proceed(conn).await, + } + }; + tokio::pin!(fut); let res = tokio::select! { _ = cancellation.cancelled() => Err(FailureAction::AllIntentsDropped), - res = get_fut => res + res = &mut fut => res }; trace!("transfer finished"); @@ -1433,4 +1506,8 @@ impl Dialer for iroh_net::dialer::Dialer { fn is_pending(&self, node: NodeId) -> bool { self.is_pending(node) } + + fn node_id(&self) -> NodeId { + self.endpoint().node_id() + } } diff --git a/iroh-blobs/src/downloader/get.rs b/iroh-blobs/src/downloader/get.rs index e48370d42c..b43cbaba92 100644 --- a/iroh-blobs/src/downloader/get.rs +++ b/iroh-blobs/src/downloader/get.rs @@ -3,18 +3,13 @@ //! [`Connection`]: iroh_net::endpoint::Connection use crate::{ - get::{db::get_to_db, error::GetError}, + get::{db::get_to_db_in_steps, error::GetError}, store::Store, }; use futures_lite::FutureExt; -#[cfg(feature = "metrics")] -use iroh_metrics::{inc, inc_by}; use iroh_net::endpoint; -#[cfg(feature = "metrics")] -use crate::metrics::Metrics; - -use super::{progress::BroadcastProgressSender, DownloadKind, FailureAction, GetFut, Getter}; +use super::{progress::BroadcastProgressSender, DownloadKind, FailureAction, GetStartFut, Getter}; impl From for FailureAction { fn from(e: GetError) -> Self { @@ -39,46 +34,63 @@ pub(crate) struct IoGetter { impl Getter for IoGetter { type Connection = endpoint::Connection; + type NeedsConn = crate::get::db::GetStateNeedsConn; fn get( &mut self, kind: DownloadKind, - conn: Self::Connection, progress_sender: BroadcastProgressSender, - ) -> GetFut { + ) -> GetStartFut { let store = self.store.clone(); - let fut = async move { - let get_conn = || async move { Ok(conn) }; - let res = get_to_db(&store, get_conn, &kind.hash_and_format(), progress_sender).await; - match res { - Ok(stats) => { - #[cfg(feature = "metrics")] - { - let crate::get::Stats { - bytes_written, - bytes_read: _, - elapsed, - } = stats; - - inc!(Metrics, downloads_success); - inc_by!(Metrics, download_bytes_total, bytes_written); - inc_by!(Metrics, download_time_total, elapsed.as_millis() as u64); - } - Ok(stats) + async move { + match get_to_db_in_steps(store, kind.hash_and_format(), progress_sender).await { + Err(err) => Err(err.into()), + Ok(crate::get::db::GetState::Complete(stats)) => { + Ok(super::GetOutput::Complete(stats)) } - Err(e) => { - // record metrics according to the error - #[cfg(feature = "metrics")] - { - match &e { - GetError::NotFound(_) => inc!(Metrics, downloads_notfound), - _ => inc!(Metrics, downloads_error), - } - } - Err(e.into()) + Ok(crate::get::db::GetState::NeedsConn(needs_conn)) => { + Ok(super::GetOutput::NeedsConn(needs_conn)) } } - }; - fut.boxed_local() + } + .boxed_local() + } +} + +impl super::NeedsConn for crate::get::db::GetStateNeedsConn { + fn proceed(self, conn: endpoint::Connection) -> super::GetProceedFut { + async move { + let res = self.proceed(conn).await; + #[cfg(feature = "metrics")] + track_metrics(&res); + match res { + Ok(stats) => Ok(stats), + Err(err) => Err(err.into()), + } + } + .boxed_local() + } +} + +#[cfg(feature = "metrics")] +fn track_metrics(res: &Result) { + use crate::metrics::Metrics; + use iroh_metrics::{inc, inc_by}; + match res { + Ok(stats) => { + let crate::get::Stats { + bytes_written, + bytes_read: _, + elapsed, + } = stats; + + inc!(Metrics, downloads_success); + inc_by!(Metrics, download_bytes_total, *bytes_written); + inc_by!(Metrics, download_time_total, elapsed.as_millis() as u64); + } + Err(e) => match &e { + GetError::NotFound(_) => inc!(Metrics, downloads_notfound), + _ => inc!(Metrics, downloads_error), + }, } } diff --git a/iroh-blobs/src/downloader/invariants.rs b/iroh-blobs/src/downloader/invariants.rs index e4a2656368..0409e3d922 100644 --- a/iroh-blobs/src/downloader/invariants.rs +++ b/iroh-blobs/src/downloader/invariants.rs @@ -77,8 +77,8 @@ impl, D: Dialer> Service { // check that the count of futures we are polling for downloads is consistent with the // number of requests assert_eq!( - self.in_progress_downloads.len(), self.active_requests.len(), + self.in_progress_downloads.len(), "active_requests and in_progress_downloads are out of sync" ); // check that the count of requests per peer matches the number of requests that have that diff --git a/iroh-blobs/src/downloader/progress.rs b/iroh-blobs/src/downloader/progress.rs index eac80985d5..60ded0e7a5 100644 --- a/iroh-blobs/src/downloader/progress.rs +++ b/iroh-blobs/src/downloader/progress.rs @@ -103,6 +103,7 @@ struct Inner { impl Inner { fn subscribe(&mut self, subscriber: ProgressSubscriber) -> DownloadProgress { + tracing::warn!(state=?self.state, "subscribe! emit initial"); let msg = DownloadProgress::InitialState(self.state.clone()); self.subscribers.push(subscriber); msg @@ -136,7 +137,9 @@ impl ProgressSender for BroadcastProgressSender { // making sure that the lock is not held across an await point. let futs = { let mut inner = self.shared.lock(); + tracing::warn!(?msg, state_pre=?inner.state, "send to {}", inner.subscribers.len()); inner.on_progress(msg.clone()); + tracing::warn!(state_post=?inner.state, "send"); let futs = inner .subscribers .iter_mut() diff --git a/iroh-blobs/src/downloader/test.rs b/iroh-blobs/src/downloader/test.rs index 2e734eaf3b..b2bd4c751a 100644 --- a/iroh-blobs/src/downloader/test.rs +++ b/iroh-blobs/src/downloader/test.rs @@ -9,7 +9,10 @@ use futures_util::future::FutureExt; use iroh_net::key::SecretKey; use crate::{ - get::{db::BlobId, progress::TransferState}, + get::{ + db::BlobId, + progress::{BlobProgress, TransferState}, + }, util::{ local_pool::LocalPool, progress::{AsyncChannelProgressSender, IdGenerator}, @@ -286,16 +289,26 @@ async fn concurrent_progress() { let req = DownloadRequest::new(kind_1, vec![peer]).progress_sender(prog_b_tx); let handle_b = downloader.queue(req).await; - start_tx.send(()).unwrap(); - let mut state_a = TransferState::new(hash); let mut state_b = TransferState::new(hash); let mut state_c = TransferState::new(hash); + let prog0_b = prog_b_rx.recv().await.unwrap(); + assert!(matches!( + prog0_b, + DownloadProgress::InitialState(state) if state.root.hash == hash && state.root.progress == BlobProgress::Pending, + )); + + start_tx.send(()).unwrap(); + let prog1_a = prog_a_rx.recv().await.unwrap(); let prog1_b = prog_b_rx.recv().await.unwrap(); - assert!(matches!(prog1_a, DownloadProgress::Found { hash, size: 100, ..} if hash == hash)); - assert!(matches!(prog1_b, DownloadProgress::Found { hash, size: 100, ..} if hash == hash)); + assert!( + matches!(prog1_a, DownloadProgress::Found { hash: found_hash, size: 100, ..} if found_hash == hash) + ); + assert!( + matches!(prog1_b, DownloadProgress::Found { hash: found_hash, size: 100, ..} if found_hash == hash) + ); state_a.on_progress(prog1_a); state_b.on_progress(prog1_b); diff --git a/iroh-blobs/src/downloader/test/dialer.rs b/iroh-blobs/src/downloader/test/dialer.rs index d099552a11..fc5a939959 100644 --- a/iroh-blobs/src/downloader/test/dialer.rs +++ b/iroh-blobs/src/downloader/test/dialer.rs @@ -21,6 +21,8 @@ struct TestingDialerInner { dial_duration: Duration, /// Fn deciding if a dial is successful. dial_outcome: Box bool + Send + Sync + 'static>, + /// Our own node id + node_id: NodeId, } impl Default for TestingDialerInner { @@ -31,6 +33,7 @@ impl Default for TestingDialerInner { dial_history: Vec::default(), dial_duration: Duration::from_millis(10), dial_outcome: Box::new(|_| true), + node_id: NodeId::from_bytes(&[0u8; 32]).unwrap(), } } } @@ -55,6 +58,10 @@ impl Dialer for TestingDialer { fn is_pending(&self, node: NodeId) -> bool { self.0.read().dialing.contains(&node) } + + fn node_id(&self) -> NodeId { + self.0.read().node_id + } } impl Stream for TestingDialer { diff --git a/iroh-blobs/src/downloader/test/getter.rs b/iroh-blobs/src/downloader/test/getter.rs index 397f1134f1..c3686a71c4 100644 --- a/iroh-blobs/src/downloader/test/getter.rs +++ b/iroh-blobs/src/downloader/test/getter.rs @@ -3,9 +3,12 @@ use futures_lite::{future::Boxed as BoxFuture, FutureExt}; use parking_lot::RwLock; +use crate::downloader; + use super::*; -#[derive(Default, Clone)] +#[derive(Default, Clone, derive_more::Debug)] +#[debug("TestingGetter")] pub(super) struct TestingGetter(Arc>); pub(super) type RequestHandlerFn = Arc< @@ -34,14 +37,29 @@ impl Getter for TestingGetter { // since for testing we don't need a real connection, just keep track of what peer is the // request being sent to type Connection = NodeId; + type NeedsConn = GetStateNeedsConn; fn get( &mut self, kind: DownloadKind, - peer: NodeId, progress_sender: BroadcastProgressSender, - ) -> GetFut { - let mut inner = self.0.write(); + ) -> GetStartFut { + std::future::ready(Ok(downloader::GetOutput::NeedsConn(GetStateNeedsConn( + self.clone(), + kind, + progress_sender, + )))) + .boxed_local() + } +} + +#[derive(Debug)] +pub(super) struct GetStateNeedsConn(TestingGetter, DownloadKind, BroadcastProgressSender); + +impl downloader::NeedsConn for GetStateNeedsConn { + fn proceed(self, peer: NodeId) -> super::GetProceedFut { + let GetStateNeedsConn(getter, kind, progress_sender) = self; + let mut inner = getter.0.write(); inner.request_history.push((kind, peer)); let request_duration = inner.request_duration; let handler = inner.request_handler.clone(); diff --git a/iroh-blobs/src/get/db.rs b/iroh-blobs/src/get/db.rs index 08ef2f82c7..afcdea6972 100644 --- a/iroh-blobs/src/get/db.rs +++ b/iroh-blobs/src/get/db.rs @@ -3,12 +3,18 @@ use std::future::Future; use std::io; use std::num::NonZeroU64; +use std::pin::Pin; use futures_lite::StreamExt; +use genawaiter::{ + rc::{Co, Gen}, + GeneratorState, +}; use iroh_base::hash::Hash; use iroh_base::rpc::RpcError; use iroh_net::endpoint::Connection; use serde::{Deserialize, Serialize}; +use tokio::sync::oneshot; use crate::hashseq::parse_hash_seq; use crate::protocol::RangeSpec; @@ -34,6 +40,9 @@ use bao_tree::{ChunkNum, ChunkRanges}; use iroh_io::AsyncSliceReader; use tracing::trace; +type GetGenerator = Gen>>>>; +type GetFuture = Pin> + 'static>>; + /// Get a blob or collection into a store. /// /// This considers data that is already in the store, and will only request @@ -50,12 +59,105 @@ pub async fn get_to_db< db: &D, get_conn: C, hash_and_format: &HashAndFormat, - sender: impl ProgressSender + IdGenerator, + progress_sender: impl ProgressSender + IdGenerator, +) -> Result { + match get_to_db_in_steps(db.clone(), *hash_and_format, progress_sender).await? { + GetState::Complete(res) => Ok(res), + GetState::NeedsConn(state) => { + let conn = get_conn().await.map_err(GetError::Io)?; + state.proceed(conn).await + } + } +} + +/// Get a blob or collection into a store, yielding if a connection is needed. +/// +/// This checks a get request against a local store, and returns [`GetState`], +/// which is either `Complete` in case the requested data is fully available in the local store, or +/// `NeedsConn`, once a connection is needed to proceed downloading the missing data. +/// +/// In the latter case, call [`GetStateNeedsConn::proceed`] with a connection to a provider to +/// proceed with the download. +/// +/// Progress reporting works in the same way as documented in [`get_to_db`]. +pub async fn get_to_db_in_steps< + D: BaoStore, + P: ProgressSender + IdGenerator, +>( + db: D, + hash_and_format: HashAndFormat, + progress_sender: P, +) -> Result { + let mut gen: GetGenerator = genawaiter::rc::Gen::new(move |co| { + let fut = async move { producer(co, &db, &hash_and_format, progress_sender).await }; + let fut: GetFuture = Box::pin(fut); + fut + }); + match gen.async_resume().await { + GeneratorState::Yielded(Yield::NeedConn(reply)) => { + Ok(GetState::NeedsConn(GetStateNeedsConn(gen, reply))) + } + GeneratorState::Complete(res) => res.map(GetState::Complete), + } +} + +/// Intermediary state returned from [`get_to_db_in_steps`] for a download request that needs a +/// connection to proceed. +#[derive(derive_more::Debug)] +#[debug("GetStateNeedsConn")] +pub struct GetStateNeedsConn(GetGenerator, oneshot::Sender); + +impl GetStateNeedsConn { + /// Proceed with the download by providing a connection to a provider. + pub async fn proceed(mut self, conn: Connection) -> Result { + self.1.send(conn).expect("receiver is not dropped"); + match self.0.async_resume().await { + GeneratorState::Yielded(y) => match y { + Yield::NeedConn(_) => panic!("NeedsConn may only be yielded once"), + }, + GeneratorState::Complete(res) => res, + } + } +} + +/// Output of [`get_to_db_in_steps`]. +#[derive(Debug)] +pub enum GetState { + /// The requested data is completely available in the local store, no network requests are + /// needed. + Complete(Stats), + /// The requested data is not fully available in the local store, we need a connection to + /// proceed. + /// + /// Once a connection is available, call [`GetStateNeedsConn::proceed`] to continue. + NeedsConn(GetStateNeedsConn), +} + +struct GetCo(Co); + +impl GetCo { + async fn get_conn(&self) -> Connection { + let (tx, rx) = oneshot::channel(); + self.0.yield_(Yield::NeedConn(tx)).await; + rx.await.expect("sender may not be dropped") + } +} + +enum Yield { + NeedConn(oneshot::Sender), +} + +async fn producer( + co: Co, + db: &D, + hash_and_format: &HashAndFormat, + progress: impl ProgressSender + IdGenerator, ) -> Result { let HashAndFormat { hash, format } = hash_and_format; + let co = GetCo(co); match format { - BlobFormat::Raw => get_blob(db, get_conn, hash, sender).await, - BlobFormat::HashSeq => get_hash_seq(db, get_conn, hash, sender).await, + BlobFormat::Raw => get_blob(db, co, hash, progress).await, + BlobFormat::HashSeq => get_hash_seq(db, co, hash, progress).await, } } @@ -63,9 +165,9 @@ pub async fn get_to_db< /// /// We need to create our own files and handle the case where an outboard /// is not needed. -async fn get_blob F, F: Future>>( +async fn get_blob( db: &D, - get_conn: C, + co: GetCo, hash: &Hash, progress: impl ProgressSender + IdGenerator, ) -> Result { @@ -100,7 +202,7 @@ async fn get_blob F, F: Future F, F: Future { // full request - let conn = get_conn().await.map_err(GetError::Io)?; + let conn = co.get_conn().await; let request = get::fsm::start(conn, GetRequest::single(*hash)); // create a new bidi stream let connected = request.next().await?; @@ -299,13 +401,9 @@ async fn blob_infos(db: &D, hash_seq: &[Hash]) -> io::Result F, - F: Future>, ->( +async fn get_hash_seq( db: &D, - get_conn: C, + co: GetCo, root_hash: &Hash, sender: impl ProgressSender + IdGenerator, ) -> Result { @@ -364,7 +462,7 @@ async fn get_hash_seq< .collect::>(); log!("requesting chunks {:?}", missing_iter); let request = GetRequest::new(*root_hash, RangeSpecSeq::from_ranges(missing_iter)); - let conn = get_conn().await.map_err(GetError::Io)?; + let conn = co.get_conn().await; let request = get::fsm::start(conn, request); // create a new bidi stream let connected = request.next().await?; @@ -410,7 +508,7 @@ async fn get_hash_seq< _ => { tracing::debug!("don't have collection - doing full download"); // don't have the collection, so probably got nothing - let conn = get_conn().await.map_err(GetError::Io)?; + let conn = co.get_conn().await; let request = get::fsm::start(conn, GetRequest::all(*root_hash)); // create a new bidi stream let connected = request.next().await?; diff --git a/iroh-net/src/dialer.rs b/iroh-net/src/dialer.rs index 7a7685d97b..8c37b08c08 100644 --- a/iroh-net/src/dialer.rs +++ b/iroh-net/src/dialer.rs @@ -99,6 +99,11 @@ impl Dialer { pub fn pending_count(&self) -> usize { self.pending_dials.len() } + + /// Returns a reference to the endpoint used in this dialer. + pub fn endpoint(&self) -> &Endpoint { + &self.endpoint + } } impl Stream for Dialer { diff --git a/iroh/src/client/blobs.rs b/iroh/src/client/blobs.rs index 3151c3fb1f..04e544e8b1 100644 --- a/iroh/src/client/blobs.rs +++ b/iroh/src/client/blobs.rs @@ -944,7 +944,10 @@ mod tests { use super::*; use anyhow::Context as _; + use iroh_blobs::hashseq::HashSeq; + use iroh_net::NodeId; use rand::RngCore; + use testresult::TestResult; use tokio::io::AsyncWriteExt; #[tokio::test] @@ -1248,4 +1251,185 @@ mod tests { Ok(()) } + + /// Download a existing blob from oneself + #[tokio::test] + async fn test_blob_get_self_existing() -> TestResult<()> { + let _guard = iroh_test::logging::setup(); + + let node = crate::node::Node::memory().spawn().await?; + let node_id = node.node_id(); + let client = node.client(); + + let AddOutcome { hash, size, .. } = client.blobs().add_bytes("foo").await?; + + // Direct + let res = client + .blobs() + .download_with_opts( + hash, + DownloadOptions { + format: BlobFormat::Raw, + nodes: vec![node_id.into()], + tag: SetTagOption::Auto, + mode: DownloadMode::Direct, + }, + ) + .await? + .await?; + + assert_eq!(res.local_size, size); + assert_eq!(res.downloaded_size, 0); + + // Queued + let res = client + .blobs() + .download_with_opts( + hash, + DownloadOptions { + format: BlobFormat::Raw, + nodes: vec![node_id.into()], + tag: SetTagOption::Auto, + mode: DownloadMode::Queued, + }, + ) + .await? + .await?; + + assert_eq!(res.local_size, size); + assert_eq!(res.downloaded_size, 0); + + Ok(()) + } + + /// Download a missing blob from oneself + #[tokio::test] + async fn test_blob_get_self_missing() -> TestResult<()> { + let _guard = iroh_test::logging::setup(); + + let node = crate::node::Node::memory().spawn().await?; + let node_id = node.node_id(); + let client = node.client(); + + let hash = Hash::from_bytes([0u8; 32]); + + // Direct + let res = client + .blobs() + .download_with_opts( + hash, + DownloadOptions { + format: BlobFormat::Raw, + nodes: vec![node_id.into()], + tag: SetTagOption::Auto, + mode: DownloadMode::Direct, + }, + ) + .await? + .await; + assert!(res.is_err()); + assert_eq!( + res.err().unwrap().to_string().as_str(), + "No nodes to download from provided" + ); + + // Queued + let res = client + .blobs() + .download_with_opts( + hash, + DownloadOptions { + format: BlobFormat::Raw, + nodes: vec![node_id.into()], + tag: SetTagOption::Auto, + mode: DownloadMode::Queued, + }, + ) + .await? + .await; + assert!(res.is_err()); + assert_eq!( + res.err().unwrap().to_string().as_str(), + "No provider nodes found" + ); + + Ok(()) + } + + /// Download a existing collection. Check that things succeed and no download is performed. + #[tokio::test] + async fn test_blob_get_existing_collection() -> TestResult<()> { + let _guard = iroh_test::logging::setup(); + + let node = crate::node::Node::memory().spawn().await?; + // We use a nonexisting node id because we just want to check that this succeeds without + // hitting the network. + let node_id = NodeId::from_bytes(&[0u8; 32])?; + let client = node.client(); + + let mut collection = Collection::default(); + let mut tags = Vec::new(); + let mut size = 0; + for value in ["iroh", "is", "cool"] { + let import_outcome = client.blobs().add_bytes(value).await.context("add bytes")?; + collection.push(value.to_string(), import_outcome.hash); + tags.push(import_outcome.tag); + size += import_outcome.size; + } + + let (hash, _tag) = client + .blobs() + .create_collection(collection, SetTagOption::Auto, tags) + .await?; + + // load the hashseq and collection header manually to calculate our expected size + let hashseq_bytes = client.blobs().read_to_bytes(hash).await?; + size += hashseq_bytes.len() as u64; + let hashseq = HashSeq::try_from(hashseq_bytes)?; + let collection_header_bytes = client + .blobs() + .read_to_bytes(hashseq.into_iter().next().expect("header to exist")) + .await?; + size += collection_header_bytes.len() as u64; + + // Direct + let res = client + .blobs() + .download_with_opts( + hash, + DownloadOptions { + format: BlobFormat::HashSeq, + nodes: vec![node_id.into()], + tag: SetTagOption::Auto, + mode: DownloadMode::Direct, + }, + ) + .await? + .await + .context("direct (download)")?; + + assert_eq!(res.local_size, size); + assert_eq!(res.downloaded_size, 0); + + // Queued + let res = client + .blobs() + .download_with_opts( + hash, + DownloadOptions { + format: BlobFormat::HashSeq, + nodes: vec![node_id.into()], + tag: SetTagOption::Auto, + mode: DownloadMode::Queued, + }, + ) + .await? + .await + .context("queued")?; + + assert_eq!(res.local_size, size); + assert_eq!(res.downloaded_size, 0); + + Ok(()) + } } diff --git a/iroh/src/node/rpc.rs b/iroh/src/node/rpc.rs index 467e91d402..e51e233ce8 100644 --- a/iroh/src/node/rpc.rs +++ b/iroh/src/node/rpc.rs @@ -3,12 +3,11 @@ use std::io; use std::sync::{Arc, Mutex}; use std::time::Duration; -use anyhow::{anyhow, ensure, Result}; +use anyhow::{anyhow, Result}; use futures_buffered::BufferedStreamExt; use futures_lite::{Stream, StreamExt}; use genawaiter::sync::{Co, Gen}; use iroh_base::rpc::{RpcError, RpcResult}; -use iroh_blobs::downloader::{DownloadRequest, Downloader}; use iroh_blobs::export::ExportProgress; use iroh_blobs::format::collection::Collection; use iroh_blobs::get::db::DownloadProgress; @@ -18,6 +17,10 @@ use iroh_blobs::util::local_pool::LocalPoolHandle; use iroh_blobs::util::progress::{AsyncChannelProgressSender, ProgressSender}; use iroh_blobs::util::SetTagOption; use iroh_blobs::BlobFormat; +use iroh_blobs::{ + downloader::{DownloadRequest, Downloader}, + get::db::GetState, +}; use iroh_blobs::{ provider::AddProgress, store::{Store as BaoStore, ValidateProgress}, @@ -1191,6 +1194,7 @@ async fn download_queued( Ok(stats) } +#[tracing::instrument("download_direct", skip_all, fields(hash=%hash_and_format.hash.fmt_short()))] async fn download_direct_from_nodes( db: &D, endpoint: Endpoint, @@ -1201,51 +1205,61 @@ async fn download_direct_from_nodes( where D: BaoStore, { - ensure!(!nodes.is_empty(), "No nodes to download from provided."); let mut last_err = None; - for node in nodes { - let node_id = node.node_id; - match download_direct( - db, - endpoint.clone(), - hash_and_format, - node, - progress.clone(), - ) - .await + let mut remaining_nodes = nodes.len(); + let mut nodes_iter = nodes.into_iter(); + 'outer: loop { + match iroh_blobs::get::db::get_to_db_in_steps(db.clone(), hash_and_format, progress.clone()) + .await? { - Ok(stats) => return Ok(stats), - Err(err) => { - debug!(?err, node = &node_id.fmt_short(), "Download failed"); - last_err = Some(err) + GetState::Complete(stats) => return Ok(stats), + GetState::NeedsConn(needs_conn) => { + let (conn, node_id) = 'inner: loop { + match nodes_iter.next() { + None => break 'outer, + Some(node) => { + remaining_nodes -= 1; + let node_id = node.node_id; + if node_id == endpoint.node_id() { + debug!( + ?remaining_nodes, + "skip node {} (it is the node id of ourselves)", + node_id.fmt_short() + ); + continue 'inner; + } + match endpoint.connect(node, iroh_blobs::protocol::ALPN).await { + Ok(conn) => break 'inner (conn, node_id), + Err(err) => { + debug!( + ?remaining_nodes, + "failed to connect to {}: {err}", + node_id.fmt_short() + ); + continue 'inner; + } + } + } + } + }; + match needs_conn.proceed(conn).await { + Ok(stats) => return Ok(stats), + Err(err) => { + warn!( + ?remaining_nodes, + "failed to download from {}: {err}", + node_id.fmt_short() + ); + last_err = Some(err); + } + } } } } - Err(last_err.unwrap()) -} - -async fn download_direct( - db: &D, - endpoint: Endpoint, - hash_and_format: HashAndFormat, - node: NodeAddr, - progress: AsyncChannelProgressSender, -) -> Result -where - D: BaoStore, -{ - let get_conn = { - let progress = progress.clone(); - move || async move { - let conn = endpoint.connect(node, iroh_blobs::protocol::ALPN).await?; - progress.send(DownloadProgress::Connected).await?; - Ok(conn) - } - }; - - let res = iroh_blobs::get::db::get_to_db(db, get_conn, &hash_and_format, progress).await; - - res.map_err(Into::into) + match last_err { + Some(err) => Err(err.into()), + None => Err(anyhow!("No nodes to download from provided")), + } } fn docs_disabled() -> RpcError {