diff --git a/iroh-blobs/src/downloader.rs b/iroh-blobs/src/downloader.rs index dd26a8bc6d..21644e8d93 100644 --- a/iroh-blobs/src/downloader.rs +++ b/iroh-blobs/src/downloader.rs @@ -27,8 +27,12 @@ //! requests to a single node is also limited. use std::{ - collections::{hash_map::Entry, HashMap, HashSet}, + collections::{ + hash_map::{self, Entry}, + HashMap, HashSet, + }, fmt, + future::Future, num::NonZeroUsize, sync::{ atomic::{AtomicU64, Ordering}, @@ -46,7 +50,7 @@ use tokio::{ sync::{mpsc, oneshot}, task::JoinSet, }; -use tokio_util::{sync::CancellationToken, time::delay_queue}; +use tokio_util::{either::Either, sync::CancellationToken, time::delay_queue}; use tracing::{debug, error_span, trace, warn, Instrument}; use crate::{ @@ -75,13 +79,15 @@ pub struct IntentId(pub u64); /// Trait modeling a dialer. This allows for IO-less testing. pub trait Dialer: Stream)> + Unpin { /// Type of connections returned by the Dialer. - type Connection: Clone; + type Connection: Clone + 'static; /// Dial a node. fn queue_dial(&mut self, node_id: NodeId); /// Get the number of dialing nodes. fn pending_count(&self) -> usize; /// Check if a node is being dialed. fn is_pending(&self, node: NodeId) -> bool; + /// Get the node id of our node. + fn node_id(&self) -> NodeId; } /// Signals what should be done with the request when it fails. @@ -97,20 +103,39 @@ pub enum FailureAction { RetryLater(anyhow::Error), } -/// Future of a get request. -type GetFut = BoxedLocal; +/// Future of a get request, for the checking stage. +type GetStartFut = BoxedLocal, FailureAction>>; +/// Future of a get request, for the downloading stage. +type GetProceedFut = BoxedLocal; /// Trait modelling performing a single request over a connection. This allows for IO-less testing. pub trait Getter { /// Type of connections the Getter requires to perform a download. - type Connection; - /// Return a future that performs the download using the given connection. + type Connection: 'static; + /// Type of the intermediary state returned from [`Self::get`] if a connection is needed. + type NeedsConn: NeedsConn; + /// Returns a future that checks the local store if the request is already complete, returning + /// a struct implementing [`NeedsConn`] if we need a network connection to proceed. fn get( &mut self, kind: DownloadKind, - conn: Self::Connection, progress_sender: BroadcastProgressSender, - ) -> GetFut; + ) -> GetStartFut; +} + +/// Trait modelling the intermediary state when a connection is needed to proceed. +pub trait NeedsConn: std::fmt::Debug + 'static { + /// Proceeds the download with the given connection. + fn proceed(self, conn: C) -> GetProceedFut; +} + +/// Output returned from [`Getter::get`]. +#[derive(Debug)] +pub enum GetOutput { + /// The request is already complete in the local store. + Complete(Stats), + /// The request needs a connection to continue. + NeedsConn(N), } /// Concurrency limits for the [`Downloader`]. @@ -280,7 +305,7 @@ pub struct DownloadHandle { receiver: oneshot::Receiver, } -impl std::future::Future for DownloadHandle { +impl Future for DownloadHandle { type Output = ExternalDownloadResult; fn poll( @@ -424,10 +449,12 @@ struct IntentHandlers { } /// Information about a request. -#[derive(Debug, Default)] -struct RequestInfo { +#[derive(Debug)] +struct RequestInfo { /// Registered intents with progress senders and result callbacks. intents: HashMap, + progress_sender: BroadcastProgressSender, + get_state: Option, } /// Information about a request in progress. @@ -529,7 +556,7 @@ struct Service { /// Queue of pending downloads. queue: Queue, /// Information about pending and active requests. - requests: HashMap, + requests: HashMap>, /// State of running downloads. active_requests: HashMap, /// Tasks for currently running downloads. @@ -666,48 +693,85 @@ impl, D: Dialer> Service { on_progress: progress, }; - // early exit if no providers. - if nodes.is_empty() && self.providers.get_candidates(&kind.hash()).next().is_none() { - self.finalize_download( - kind, - [(intent_id, intent_handlers)].into(), - Err(DownloadError::NoProviders), - ); - return; - } - // add the nodes to the provider map - let updated = self - .providers - .add_hash_with_nodes(kind.hash(), nodes.iter().map(|n| n.node_id)); + // (skip the node id of our own node - we should never attempt to download from ourselves) + let node_ids = nodes + .iter() + .map(|n| n.node_id) + .filter(|node_id| *node_id != self.dialer.node_id()); + let updated = self.providers.add_hash_with_nodes(kind.hash(), node_ids); // queue the transfer (if not running) or attach to transfer progress (if already running) - if self.active_requests.contains_key(&kind) { - // the transfer is already running, so attach the progress sender - if let Some(on_progress) = &intent_handlers.on_progress { - // this is async because it sends the current state over the progress channel - if let Err(err) = self - .progress_tracker - .subscribe(kind, on_progress.clone()) - .await - { - debug!(?err, %kind, "failed to subscribe progress sender to transfer"); + match self.requests.entry(kind) { + hash_map::Entry::Occupied(mut entry) => { + if let Some(on_progress) = &intent_handlers.on_progress { + // this is async because it sends the current state over the progress channel + if let Err(err) = self + .progress_tracker + .subscribe(kind, on_progress.clone()) + .await + { + debug!(?err, %kind, "failed to subscribe progress sender to transfer"); + } } + entry.get_mut().intents.insert(intent_id, intent_handlers); } - } else { - // the transfer is not running. - if updated && self.queue.is_parked(&kind) { - // the transfer is on hold for pending retries, and we added new nodes, so move back to queue. - self.queue.unpark(&kind); - } else if !self.queue.contains(&kind) { - // the transfer is not yet queued: add to queue. + hash_map::Entry::Vacant(entry) => { + tracing::warn!("is new, queue"); + let progress_sender = self.progress_tracker.track( + kind, + intent_handlers + .on_progress + .clone() + .into_iter() + .collect::>(), + ); + + let get_state = match self.getter.get(kind, progress_sender.clone()).await { + Err(_err) => { + self.finalize_download( + kind, + [(intent_id, intent_handlers)].into(), + // TODO: add better error variant? this is only triggered if the local + // store failed with local IO. + Err(DownloadError::DownloadFailed), + ); + return; + } + Ok(GetOutput::Complete(stats)) => { + self.finalize_download( + kind, + [(intent_id, intent_handlers)].into(), + Ok(stats), + ); + return; + } + Ok(GetOutput::NeedsConn(state)) => { + // early exit if no providers. + if self.providers.get_candidates(&kind.hash()).next().is_none() { + self.finalize_download( + kind, + [(intent_id, intent_handlers)].into(), + Err(DownloadError::NoProviders), + ); + return; + } + state + } + }; + entry.insert(RequestInfo { + intents: [(intent_id, intent_handlers)].into_iter().collect(), + progress_sender, + get_state: Some(get_state), + }); self.queue.insert(kind); } } - // store the request info - let request_info = self.requests.entry(kind).or_default(); - request_info.intents.insert(intent_id, intent_handlers); + if updated && self.queue.is_parked(&kind) { + // the transfer is on hold for pending retries, and we added new nodes, so move back to queue. + self.queue.unpark(&kind); + } } /// Cancels a download intent. @@ -860,7 +924,6 @@ impl, D: Dialer> Service { ) { self.progress_tracker.remove(&kind); self.remove_hash_if_not_queued(&kind.hash()); - let result = result.map_err(|_| DownloadError::DownloadFailed); for (_id, handlers) in intents.into_iter() { handlers.on_finish.send(result.clone()).ok(); } @@ -1082,14 +1145,9 @@ impl, D: Dialer> Service { /// Panics if hash is not in self.requests or node is not in self.nodes. fn start_download(&mut self, kind: DownloadKind, node: NodeId) { let node_info = self.connected_nodes.get_mut(&node).expect("node exists"); - let request_info = self.requests.get(&kind).expect("hash exists"); - - // create a progress sender and subscribe all intents to the progress sender - let subscribers = request_info - .intents - .values() - .flat_map(|state| state.on_progress.clone()); - let progress_sender = self.progress_tracker.track(kind, subscribers); + let request_info = self.requests.get_mut(&kind).expect("request exists"); + let progress = request_info.progress_sender.clone(); + // .expect("queued state exists"); // create the active request state let cancellation = CancellationToken::new(); @@ -1098,7 +1156,15 @@ impl, D: Dialer> Service { node, }; let conn = node_info.conn.clone(); - let get_fut = self.getter.get(kind, conn, progress_sender); + + // If this is the first provider node we try, we have an initial state + // from starting the generator in Self::handle_queue_new_download. + // If this not the first provider node we try, we have to recreate the generator, because + // we can only resume it once. + let get_state = match request_info.get_state.take() { + Some(state) => Either::Left(async move { Ok(GetOutput::NeedsConn(state)) }), + None => Either::Right(self.getter.get(kind, progress)), + }; let fut = async move { // NOTE: it's an open question if we should do timeouts at this point. Considerations from @Frando: // > at this stage we do not know the size of the download, so the timeout would have @@ -1106,9 +1172,16 @@ impl, D: Dialer> Service { // > this means that a super slow node would block a download from succeeding for a long // > time, while faster nodes could be readily available. // As a conclusion, timeouts should be added only after downloads are known to be bounded + let fut = async move { + match get_state.await? { + GetOutput::Complete(stats) => Ok(stats), + GetOutput::NeedsConn(state) => state.proceed(conn).await, + } + }; + tokio::pin!(fut); let res = tokio::select! { _ = cancellation.cancelled() => Err(FailureAction::AllIntentsDropped), - res = get_fut => res + res = &mut fut => res }; trace!("transfer finished"); @@ -1433,4 +1506,8 @@ impl Dialer for iroh_net::dialer::Dialer { fn is_pending(&self, node: NodeId) -> bool { self.is_pending(node) } + + fn node_id(&self) -> NodeId { + self.endpoint().node_id() + } } diff --git a/iroh-blobs/src/downloader/get.rs b/iroh-blobs/src/downloader/get.rs index e48370d42c..b43cbaba92 100644 --- a/iroh-blobs/src/downloader/get.rs +++ b/iroh-blobs/src/downloader/get.rs @@ -3,18 +3,13 @@ //! [`Connection`]: iroh_net::endpoint::Connection use crate::{ - get::{db::get_to_db, error::GetError}, + get::{db::get_to_db_in_steps, error::GetError}, store::Store, }; use futures_lite::FutureExt; -#[cfg(feature = "metrics")] -use iroh_metrics::{inc, inc_by}; use iroh_net::endpoint; -#[cfg(feature = "metrics")] -use crate::metrics::Metrics; - -use super::{progress::BroadcastProgressSender, DownloadKind, FailureAction, GetFut, Getter}; +use super::{progress::BroadcastProgressSender, DownloadKind, FailureAction, GetStartFut, Getter}; impl From for FailureAction { fn from(e: GetError) -> Self { @@ -39,46 +34,63 @@ pub(crate) struct IoGetter { impl Getter for IoGetter { type Connection = endpoint::Connection; + type NeedsConn = crate::get::db::GetStateNeedsConn; fn get( &mut self, kind: DownloadKind, - conn: Self::Connection, progress_sender: BroadcastProgressSender, - ) -> GetFut { + ) -> GetStartFut { let store = self.store.clone(); - let fut = async move { - let get_conn = || async move { Ok(conn) }; - let res = get_to_db(&store, get_conn, &kind.hash_and_format(), progress_sender).await; - match res { - Ok(stats) => { - #[cfg(feature = "metrics")] - { - let crate::get::Stats { - bytes_written, - bytes_read: _, - elapsed, - } = stats; - - inc!(Metrics, downloads_success); - inc_by!(Metrics, download_bytes_total, bytes_written); - inc_by!(Metrics, download_time_total, elapsed.as_millis() as u64); - } - Ok(stats) + async move { + match get_to_db_in_steps(store, kind.hash_and_format(), progress_sender).await { + Err(err) => Err(err.into()), + Ok(crate::get::db::GetState::Complete(stats)) => { + Ok(super::GetOutput::Complete(stats)) } - Err(e) => { - // record metrics according to the error - #[cfg(feature = "metrics")] - { - match &e { - GetError::NotFound(_) => inc!(Metrics, downloads_notfound), - _ => inc!(Metrics, downloads_error), - } - } - Err(e.into()) + Ok(crate::get::db::GetState::NeedsConn(needs_conn)) => { + Ok(super::GetOutput::NeedsConn(needs_conn)) } } - }; - fut.boxed_local() + } + .boxed_local() + } +} + +impl super::NeedsConn for crate::get::db::GetStateNeedsConn { + fn proceed(self, conn: endpoint::Connection) -> super::GetProceedFut { + async move { + let res = self.proceed(conn).await; + #[cfg(feature = "metrics")] + track_metrics(&res); + match res { + Ok(stats) => Ok(stats), + Err(err) => Err(err.into()), + } + } + .boxed_local() + } +} + +#[cfg(feature = "metrics")] +fn track_metrics(res: &Result) { + use crate::metrics::Metrics; + use iroh_metrics::{inc, inc_by}; + match res { + Ok(stats) => { + let crate::get::Stats { + bytes_written, + bytes_read: _, + elapsed, + } = stats; + + inc!(Metrics, downloads_success); + inc_by!(Metrics, download_bytes_total, *bytes_written); + inc_by!(Metrics, download_time_total, elapsed.as_millis() as u64); + } + Err(e) => match &e { + GetError::NotFound(_) => inc!(Metrics, downloads_notfound), + _ => inc!(Metrics, downloads_error), + }, } } diff --git a/iroh-blobs/src/downloader/invariants.rs b/iroh-blobs/src/downloader/invariants.rs index e4a2656368..0409e3d922 100644 --- a/iroh-blobs/src/downloader/invariants.rs +++ b/iroh-blobs/src/downloader/invariants.rs @@ -77,8 +77,8 @@ impl, D: Dialer> Service { // check that the count of futures we are polling for downloads is consistent with the // number of requests assert_eq!( - self.in_progress_downloads.len(), self.active_requests.len(), + self.in_progress_downloads.len(), "active_requests and in_progress_downloads are out of sync" ); // check that the count of requests per peer matches the number of requests that have that diff --git a/iroh-blobs/src/downloader/progress.rs b/iroh-blobs/src/downloader/progress.rs index eac80985d5..60ded0e7a5 100644 --- a/iroh-blobs/src/downloader/progress.rs +++ b/iroh-blobs/src/downloader/progress.rs @@ -103,6 +103,7 @@ struct Inner { impl Inner { fn subscribe(&mut self, subscriber: ProgressSubscriber) -> DownloadProgress { + tracing::warn!(state=?self.state, "subscribe! emit initial"); let msg = DownloadProgress::InitialState(self.state.clone()); self.subscribers.push(subscriber); msg @@ -136,7 +137,9 @@ impl ProgressSender for BroadcastProgressSender { // making sure that the lock is not held across an await point. let futs = { let mut inner = self.shared.lock(); + tracing::warn!(?msg, state_pre=?inner.state, "send to {}", inner.subscribers.len()); inner.on_progress(msg.clone()); + tracing::warn!(state_post=?inner.state, "send"); let futs = inner .subscribers .iter_mut() diff --git a/iroh-blobs/src/downloader/test.rs b/iroh-blobs/src/downloader/test.rs index 2e734eaf3b..b2bd4c751a 100644 --- a/iroh-blobs/src/downloader/test.rs +++ b/iroh-blobs/src/downloader/test.rs @@ -9,7 +9,10 @@ use futures_util::future::FutureExt; use iroh_net::key::SecretKey; use crate::{ - get::{db::BlobId, progress::TransferState}, + get::{ + db::BlobId, + progress::{BlobProgress, TransferState}, + }, util::{ local_pool::LocalPool, progress::{AsyncChannelProgressSender, IdGenerator}, @@ -286,16 +289,26 @@ async fn concurrent_progress() { let req = DownloadRequest::new(kind_1, vec![peer]).progress_sender(prog_b_tx); let handle_b = downloader.queue(req).await; - start_tx.send(()).unwrap(); - let mut state_a = TransferState::new(hash); let mut state_b = TransferState::new(hash); let mut state_c = TransferState::new(hash); + let prog0_b = prog_b_rx.recv().await.unwrap(); + assert!(matches!( + prog0_b, + DownloadProgress::InitialState(state) if state.root.hash == hash && state.root.progress == BlobProgress::Pending, + )); + + start_tx.send(()).unwrap(); + let prog1_a = prog_a_rx.recv().await.unwrap(); let prog1_b = prog_b_rx.recv().await.unwrap(); - assert!(matches!(prog1_a, DownloadProgress::Found { hash, size: 100, ..} if hash == hash)); - assert!(matches!(prog1_b, DownloadProgress::Found { hash, size: 100, ..} if hash == hash)); + assert!( + matches!(prog1_a, DownloadProgress::Found { hash: found_hash, size: 100, ..} if found_hash == hash) + ); + assert!( + matches!(prog1_b, DownloadProgress::Found { hash: found_hash, size: 100, ..} if found_hash == hash) + ); state_a.on_progress(prog1_a); state_b.on_progress(prog1_b); diff --git a/iroh-blobs/src/downloader/test/dialer.rs b/iroh-blobs/src/downloader/test/dialer.rs index d099552a11..fc5a939959 100644 --- a/iroh-blobs/src/downloader/test/dialer.rs +++ b/iroh-blobs/src/downloader/test/dialer.rs @@ -21,6 +21,8 @@ struct TestingDialerInner { dial_duration: Duration, /// Fn deciding if a dial is successful. dial_outcome: Box bool + Send + Sync + 'static>, + /// Our own node id + node_id: NodeId, } impl Default for TestingDialerInner { @@ -31,6 +33,7 @@ impl Default for TestingDialerInner { dial_history: Vec::default(), dial_duration: Duration::from_millis(10), dial_outcome: Box::new(|_| true), + node_id: NodeId::from_bytes(&[0u8; 32]).unwrap(), } } } @@ -55,6 +58,10 @@ impl Dialer for TestingDialer { fn is_pending(&self, node: NodeId) -> bool { self.0.read().dialing.contains(&node) } + + fn node_id(&self) -> NodeId { + self.0.read().node_id + } } impl Stream for TestingDialer { diff --git a/iroh-blobs/src/downloader/test/getter.rs b/iroh-blobs/src/downloader/test/getter.rs index 397f1134f1..c3686a71c4 100644 --- a/iroh-blobs/src/downloader/test/getter.rs +++ b/iroh-blobs/src/downloader/test/getter.rs @@ -3,9 +3,12 @@ use futures_lite::{future::Boxed as BoxFuture, FutureExt}; use parking_lot::RwLock; +use crate::downloader; + use super::*; -#[derive(Default, Clone)] +#[derive(Default, Clone, derive_more::Debug)] +#[debug("TestingGetter")] pub(super) struct TestingGetter(Arc>); pub(super) type RequestHandlerFn = Arc< @@ -34,14 +37,29 @@ impl Getter for TestingGetter { // since for testing we don't need a real connection, just keep track of what peer is the // request being sent to type Connection = NodeId; + type NeedsConn = GetStateNeedsConn; fn get( &mut self, kind: DownloadKind, - peer: NodeId, progress_sender: BroadcastProgressSender, - ) -> GetFut { - let mut inner = self.0.write(); + ) -> GetStartFut { + std::future::ready(Ok(downloader::GetOutput::NeedsConn(GetStateNeedsConn( + self.clone(), + kind, + progress_sender, + )))) + .boxed_local() + } +} + +#[derive(Debug)] +pub(super) struct GetStateNeedsConn(TestingGetter, DownloadKind, BroadcastProgressSender); + +impl downloader::NeedsConn for GetStateNeedsConn { + fn proceed(self, peer: NodeId) -> super::GetProceedFut { + let GetStateNeedsConn(getter, kind, progress_sender) = self; + let mut inner = getter.0.write(); inner.request_history.push((kind, peer)); let request_duration = inner.request_duration; let handler = inner.request_handler.clone(); diff --git a/iroh-blobs/src/get/db.rs b/iroh-blobs/src/get/db.rs index 08ef2f82c7..afcdea6972 100644 --- a/iroh-blobs/src/get/db.rs +++ b/iroh-blobs/src/get/db.rs @@ -3,12 +3,18 @@ use std::future::Future; use std::io; use std::num::NonZeroU64; +use std::pin::Pin; use futures_lite::StreamExt; +use genawaiter::{ + rc::{Co, Gen}, + GeneratorState, +}; use iroh_base::hash::Hash; use iroh_base::rpc::RpcError; use iroh_net::endpoint::Connection; use serde::{Deserialize, Serialize}; +use tokio::sync::oneshot; use crate::hashseq::parse_hash_seq; use crate::protocol::RangeSpec; @@ -34,6 +40,9 @@ use bao_tree::{ChunkNum, ChunkRanges}; use iroh_io::AsyncSliceReader; use tracing::trace; +type GetGenerator = Gen>>>>; +type GetFuture = Pin> + 'static>>; + /// Get a blob or collection into a store. /// /// This considers data that is already in the store, and will only request @@ -50,12 +59,105 @@ pub async fn get_to_db< db: &D, get_conn: C, hash_and_format: &HashAndFormat, - sender: impl ProgressSender + IdGenerator, + progress_sender: impl ProgressSender + IdGenerator, +) -> Result { + match get_to_db_in_steps(db.clone(), *hash_and_format, progress_sender).await? { + GetState::Complete(res) => Ok(res), + GetState::NeedsConn(state) => { + let conn = get_conn().await.map_err(GetError::Io)?; + state.proceed(conn).await + } + } +} + +/// Get a blob or collection into a store, yielding if a connection is needed. +/// +/// This checks a get request against a local store, and returns [`GetState`], +/// which is either `Complete` in case the requested data is fully available in the local store, or +/// `NeedsConn`, once a connection is needed to proceed downloading the missing data. +/// +/// In the latter case, call [`GetStateNeedsConn::proceed`] with a connection to a provider to +/// proceed with the download. +/// +/// Progress reporting works in the same way as documented in [`get_to_db`]. +pub async fn get_to_db_in_steps< + D: BaoStore, + P: ProgressSender + IdGenerator, +>( + db: D, + hash_and_format: HashAndFormat, + progress_sender: P, +) -> Result { + let mut gen: GetGenerator = genawaiter::rc::Gen::new(move |co| { + let fut = async move { producer(co, &db, &hash_and_format, progress_sender).await }; + let fut: GetFuture = Box::pin(fut); + fut + }); + match gen.async_resume().await { + GeneratorState::Yielded(Yield::NeedConn(reply)) => { + Ok(GetState::NeedsConn(GetStateNeedsConn(gen, reply))) + } + GeneratorState::Complete(res) => res.map(GetState::Complete), + } +} + +/// Intermediary state returned from [`get_to_db_in_steps`] for a download request that needs a +/// connection to proceed. +#[derive(derive_more::Debug)] +#[debug("GetStateNeedsConn")] +pub struct GetStateNeedsConn(GetGenerator, oneshot::Sender); + +impl GetStateNeedsConn { + /// Proceed with the download by providing a connection to a provider. + pub async fn proceed(mut self, conn: Connection) -> Result { + self.1.send(conn).expect("receiver is not dropped"); + match self.0.async_resume().await { + GeneratorState::Yielded(y) => match y { + Yield::NeedConn(_) => panic!("NeedsConn may only be yielded once"), + }, + GeneratorState::Complete(res) => res, + } + } +} + +/// Output of [`get_to_db_in_steps`]. +#[derive(Debug)] +pub enum GetState { + /// The requested data is completely available in the local store, no network requests are + /// needed. + Complete(Stats), + /// The requested data is not fully available in the local store, we need a connection to + /// proceed. + /// + /// Once a connection is available, call [`GetStateNeedsConn::proceed`] to continue. + NeedsConn(GetStateNeedsConn), +} + +struct GetCo(Co); + +impl GetCo { + async fn get_conn(&self) -> Connection { + let (tx, rx) = oneshot::channel(); + self.0.yield_(Yield::NeedConn(tx)).await; + rx.await.expect("sender may not be dropped") + } +} + +enum Yield { + NeedConn(oneshot::Sender), +} + +async fn producer( + co: Co, + db: &D, + hash_and_format: &HashAndFormat, + progress: impl ProgressSender + IdGenerator, ) -> Result { let HashAndFormat { hash, format } = hash_and_format; + let co = GetCo(co); match format { - BlobFormat::Raw => get_blob(db, get_conn, hash, sender).await, - BlobFormat::HashSeq => get_hash_seq(db, get_conn, hash, sender).await, + BlobFormat::Raw => get_blob(db, co, hash, progress).await, + BlobFormat::HashSeq => get_hash_seq(db, co, hash, progress).await, } } @@ -63,9 +165,9 @@ pub async fn get_to_db< /// /// We need to create our own files and handle the case where an outboard /// is not needed. -async fn get_blob F, F: Future>>( +async fn get_blob( db: &D, - get_conn: C, + co: GetCo, hash: &Hash, progress: impl ProgressSender + IdGenerator, ) -> Result { @@ -100,7 +202,7 @@ async fn get_blob F, F: Future F, F: Future { // full request - let conn = get_conn().await.map_err(GetError::Io)?; + let conn = co.get_conn().await; let request = get::fsm::start(conn, GetRequest::single(*hash)); // create a new bidi stream let connected = request.next().await?; @@ -299,13 +401,9 @@ async fn blob_infos(db: &D, hash_seq: &[Hash]) -> io::Result F, - F: Future>, ->( +async fn get_hash_seq( db: &D, - get_conn: C, + co: GetCo, root_hash: &Hash, sender: impl ProgressSender + IdGenerator, ) -> Result { @@ -364,7 +462,7 @@ async fn get_hash_seq< .collect::>(); log!("requesting chunks {:?}", missing_iter); let request = GetRequest::new(*root_hash, RangeSpecSeq::from_ranges(missing_iter)); - let conn = get_conn().await.map_err(GetError::Io)?; + let conn = co.get_conn().await; let request = get::fsm::start(conn, request); // create a new bidi stream let connected = request.next().await?; @@ -410,7 +508,7 @@ async fn get_hash_seq< _ => { tracing::debug!("don't have collection - doing full download"); // don't have the collection, so probably got nothing - let conn = get_conn().await.map_err(GetError::Io)?; + let conn = co.get_conn().await; let request = get::fsm::start(conn, GetRequest::all(*root_hash)); // create a new bidi stream let connected = request.next().await?; diff --git a/iroh-net/src/dialer.rs b/iroh-net/src/dialer.rs index 7a7685d97b..8c37b08c08 100644 --- a/iroh-net/src/dialer.rs +++ b/iroh-net/src/dialer.rs @@ -99,6 +99,11 @@ impl Dialer { pub fn pending_count(&self) -> usize { self.pending_dials.len() } + + /// Returns a reference to the endpoint used in this dialer. + pub fn endpoint(&self) -> &Endpoint { + &self.endpoint + } } impl Stream for Dialer { diff --git a/iroh/src/client/blobs.rs b/iroh/src/client/blobs.rs index 3151c3fb1f..04e544e8b1 100644 --- a/iroh/src/client/blobs.rs +++ b/iroh/src/client/blobs.rs @@ -944,7 +944,10 @@ mod tests { use super::*; use anyhow::Context as _; + use iroh_blobs::hashseq::HashSeq; + use iroh_net::NodeId; use rand::RngCore; + use testresult::TestResult; use tokio::io::AsyncWriteExt; #[tokio::test] @@ -1248,4 +1251,185 @@ mod tests { Ok(()) } + + /// Download a existing blob from oneself + #[tokio::test] + async fn test_blob_get_self_existing() -> TestResult<()> { + let _guard = iroh_test::logging::setup(); + + let node = crate::node::Node::memory().spawn().await?; + let node_id = node.node_id(); + let client = node.client(); + + let AddOutcome { hash, size, .. } = client.blobs().add_bytes("foo").await?; + + // Direct + let res = client + .blobs() + .download_with_opts( + hash, + DownloadOptions { + format: BlobFormat::Raw, + nodes: vec![node_id.into()], + tag: SetTagOption::Auto, + mode: DownloadMode::Direct, + }, + ) + .await? + .await?; + + assert_eq!(res.local_size, size); + assert_eq!(res.downloaded_size, 0); + + // Queued + let res = client + .blobs() + .download_with_opts( + hash, + DownloadOptions { + format: BlobFormat::Raw, + nodes: vec![node_id.into()], + tag: SetTagOption::Auto, + mode: DownloadMode::Queued, + }, + ) + .await? + .await?; + + assert_eq!(res.local_size, size); + assert_eq!(res.downloaded_size, 0); + + Ok(()) + } + + /// Download a missing blob from oneself + #[tokio::test] + async fn test_blob_get_self_missing() -> TestResult<()> { + let _guard = iroh_test::logging::setup(); + + let node = crate::node::Node::memory().spawn().await?; + let node_id = node.node_id(); + let client = node.client(); + + let hash = Hash::from_bytes([0u8; 32]); + + // Direct + let res = client + .blobs() + .download_with_opts( + hash, + DownloadOptions { + format: BlobFormat::Raw, + nodes: vec![node_id.into()], + tag: SetTagOption::Auto, + mode: DownloadMode::Direct, + }, + ) + .await? + .await; + assert!(res.is_err()); + assert_eq!( + res.err().unwrap().to_string().as_str(), + "No nodes to download from provided" + ); + + // Queued + let res = client + .blobs() + .download_with_opts( + hash, + DownloadOptions { + format: BlobFormat::Raw, + nodes: vec![node_id.into()], + tag: SetTagOption::Auto, + mode: DownloadMode::Queued, + }, + ) + .await? + .await; + assert!(res.is_err()); + assert_eq!( + res.err().unwrap().to_string().as_str(), + "No provider nodes found" + ); + + Ok(()) + } + + /// Download a existing collection. Check that things succeed and no download is performed. + #[tokio::test] + async fn test_blob_get_existing_collection() -> TestResult<()> { + let _guard = iroh_test::logging::setup(); + + let node = crate::node::Node::memory().spawn().await?; + // We use a nonexisting node id because we just want to check that this succeeds without + // hitting the network. + let node_id = NodeId::from_bytes(&[0u8; 32])?; + let client = node.client(); + + let mut collection = Collection::default(); + let mut tags = Vec::new(); + let mut size = 0; + for value in ["iroh", "is", "cool"] { + let import_outcome = client.blobs().add_bytes(value).await.context("add bytes")?; + collection.push(value.to_string(), import_outcome.hash); + tags.push(import_outcome.tag); + size += import_outcome.size; + } + + let (hash, _tag) = client + .blobs() + .create_collection(collection, SetTagOption::Auto, tags) + .await?; + + // load the hashseq and collection header manually to calculate our expected size + let hashseq_bytes = client.blobs().read_to_bytes(hash).await?; + size += hashseq_bytes.len() as u64; + let hashseq = HashSeq::try_from(hashseq_bytes)?; + let collection_header_bytes = client + .blobs() + .read_to_bytes(hashseq.into_iter().next().expect("header to exist")) + .await?; + size += collection_header_bytes.len() as u64; + + // Direct + let res = client + .blobs() + .download_with_opts( + hash, + DownloadOptions { + format: BlobFormat::HashSeq, + nodes: vec![node_id.into()], + tag: SetTagOption::Auto, + mode: DownloadMode::Direct, + }, + ) + .await? + .await + .context("direct (download)")?; + + assert_eq!(res.local_size, size); + assert_eq!(res.downloaded_size, 0); + + // Queued + let res = client + .blobs() + .download_with_opts( + hash, + DownloadOptions { + format: BlobFormat::HashSeq, + nodes: vec![node_id.into()], + tag: SetTagOption::Auto, + mode: DownloadMode::Queued, + }, + ) + .await? + .await + .context("queued")?; + + assert_eq!(res.local_size, size); + assert_eq!(res.downloaded_size, 0); + + Ok(()) + } } diff --git a/iroh/src/node/rpc.rs b/iroh/src/node/rpc.rs index 467e91d402..e51e233ce8 100644 --- a/iroh/src/node/rpc.rs +++ b/iroh/src/node/rpc.rs @@ -3,12 +3,11 @@ use std::io; use std::sync::{Arc, Mutex}; use std::time::Duration; -use anyhow::{anyhow, ensure, Result}; +use anyhow::{anyhow, Result}; use futures_buffered::BufferedStreamExt; use futures_lite::{Stream, StreamExt}; use genawaiter::sync::{Co, Gen}; use iroh_base::rpc::{RpcError, RpcResult}; -use iroh_blobs::downloader::{DownloadRequest, Downloader}; use iroh_blobs::export::ExportProgress; use iroh_blobs::format::collection::Collection; use iroh_blobs::get::db::DownloadProgress; @@ -18,6 +17,10 @@ use iroh_blobs::util::local_pool::LocalPoolHandle; use iroh_blobs::util::progress::{AsyncChannelProgressSender, ProgressSender}; use iroh_blobs::util::SetTagOption; use iroh_blobs::BlobFormat; +use iroh_blobs::{ + downloader::{DownloadRequest, Downloader}, + get::db::GetState, +}; use iroh_blobs::{ provider::AddProgress, store::{Store as BaoStore, ValidateProgress}, @@ -1191,6 +1194,7 @@ async fn download_queued( Ok(stats) } +#[tracing::instrument("download_direct", skip_all, fields(hash=%hash_and_format.hash.fmt_short()))] async fn download_direct_from_nodes( db: &D, endpoint: Endpoint, @@ -1201,51 +1205,61 @@ async fn download_direct_from_nodes( where D: BaoStore, { - ensure!(!nodes.is_empty(), "No nodes to download from provided."); let mut last_err = None; - for node in nodes { - let node_id = node.node_id; - match download_direct( - db, - endpoint.clone(), - hash_and_format, - node, - progress.clone(), - ) - .await + let mut remaining_nodes = nodes.len(); + let mut nodes_iter = nodes.into_iter(); + 'outer: loop { + match iroh_blobs::get::db::get_to_db_in_steps(db.clone(), hash_and_format, progress.clone()) + .await? { - Ok(stats) => return Ok(stats), - Err(err) => { - debug!(?err, node = &node_id.fmt_short(), "Download failed"); - last_err = Some(err) + GetState::Complete(stats) => return Ok(stats), + GetState::NeedsConn(needs_conn) => { + let (conn, node_id) = 'inner: loop { + match nodes_iter.next() { + None => break 'outer, + Some(node) => { + remaining_nodes -= 1; + let node_id = node.node_id; + if node_id == endpoint.node_id() { + debug!( + ?remaining_nodes, + "skip node {} (it is the node id of ourselves)", + node_id.fmt_short() + ); + continue 'inner; + } + match endpoint.connect(node, iroh_blobs::protocol::ALPN).await { + Ok(conn) => break 'inner (conn, node_id), + Err(err) => { + debug!( + ?remaining_nodes, + "failed to connect to {}: {err}", + node_id.fmt_short() + ); + continue 'inner; + } + } + } + } + }; + match needs_conn.proceed(conn).await { + Ok(stats) => return Ok(stats), + Err(err) => { + warn!( + ?remaining_nodes, + "failed to download from {}: {err}", + node_id.fmt_short() + ); + last_err = Some(err); + } + } } } } - Err(last_err.unwrap()) -} - -async fn download_direct( - db: &D, - endpoint: Endpoint, - hash_and_format: HashAndFormat, - node: NodeAddr, - progress: AsyncChannelProgressSender, -) -> Result -where - D: BaoStore, -{ - let get_conn = { - let progress = progress.clone(); - move || async move { - let conn = endpoint.connect(node, iroh_blobs::protocol::ALPN).await?; - progress.send(DownloadProgress::Connected).await?; - Ok(conn) - } - }; - - let res = iroh_blobs::get::db::get_to_db(db, get_conn, &hash_and_format, progress).await; - - res.map_err(Into::into) + match last_err { + Some(err) => Err(err.into()), + None => Err(anyhow!("No nodes to download from provided")), + } } fn docs_disabled() -> RpcError {