diff --git a/Cargo.lock b/Cargo.lock index 0e3eb408033..64caa4cf939 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1817,6 +1817,15 @@ dependencies = [ "allocator-api2", ] +[[package]] +name = "hashlink" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "692eaaf7f7607518dd3cef090f1474b61edc5301d8012f09579920df68b725ee" +dependencies = [ + "hashbrown 0.14.3", +] + [[package]] name = "hdrhistogram" version = "7.5.4" @@ -2406,9 +2415,11 @@ dependencies = [ "futures", "futures-buffered", "genawaiter", + "hashlink", "hex", "http-body 0.4.6", "iroh-base", + "iroh-bytes", "iroh-io", "iroh-metrics", "iroh-net", @@ -2508,6 +2519,7 @@ dependencies = [ "http 1.1.0", "iroh-metrics", "iroh-net", + "iroh-test", "lru", "parking_lot", "pkarr", diff --git a/iroh-base/src/hash.rs b/iroh-base/src/hash.rs index a6e4e82eea9..81ab8206f2b 100644 --- a/iroh-base/src/hash.rs +++ b/iroh-base/src/hash.rs @@ -7,7 +7,7 @@ use bao_tree::blake3; use postcard::experimental::max_size::MaxSize; use serde::{de, Deserialize, Deserializer, Serialize, Serializer}; -use crate::base32::{parse_array_hex_or_base32, HexOrBase32ParseError}; +use crate::base32::{self, parse_array_hex_or_base32, HexOrBase32ParseError}; /// Hash type used throughout. #[derive(PartialEq, Eq, Copy, Clone, Hash)] @@ -54,6 +54,12 @@ impl Hash { pub fn to_hex(&self) -> String { self.0.to_hex().to_string() } + + /// Convert to a base32 string limited to the first 10 bytes for a friendly string + /// representation of the hash. + pub fn fmt_short(&self) -> String { + base32::fmt_short(self.as_bytes()) + } } impl AsRef<[u8]> for Hash { @@ -173,7 +179,18 @@ impl MaxSize for Hash { /// A format identifier #[derive( - Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, Default, Debug, MaxSize, + Clone, + Copy, + PartialEq, + Eq, + PartialOrd, + Ord, + Serialize, + Deserialize, + Default, + Debug, + MaxSize, + Hash, )] pub enum BlobFormat { /// Raw blob @@ -205,7 +222,7 @@ impl BlobFormat { } /// A hash and format pair -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, MaxSize)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, MaxSize, Hash)] pub struct HashAndFormat { /// The hash pub hash: Hash, @@ -289,6 +306,11 @@ mod redb_support { } impl HashAndFormat { + /// Create a new hash and format pair. + pub fn new(hash: Hash, format: BlobFormat) -> Self { + Self { hash, format } + } + /// Create a new hash and format pair, using the default (raw) format. pub fn raw(hash: Hash) -> Self { Self { diff --git a/iroh-bytes/Cargo.toml b/iroh-bytes/Cargo.toml index bf9b4f95cbc..07a42c68e48 100644 --- a/iroh-bytes/Cargo.toml +++ b/iroh-bytes/Cargo.toml @@ -25,6 +25,7 @@ flume = "0.11" futures = "0.3.25" futures-buffered = "0.2.4" genawaiter = { version = "0.99.1", features = ["futures03"] } +hashlink = { version = "0.9.0", optional = true } hex = "0.4.3" iroh-base = { version = "0.14.0", features = ["redb"], path = "../iroh-base" } iroh-io = { version = "0.6.0", features = ["stats"] } @@ -51,6 +52,7 @@ tracing-futures = "0.2.5" [dev-dependencies] http-body = "0.4.5" +iroh-bytes = { path = ".", features = ["downloader"] } iroh-test = { path = "../iroh-test" } proptest = "1.0.0" serde_json = "1.0.107" @@ -63,8 +65,8 @@ tempfile = "3.10.0" [features] default = ["fs-store"] +downloader = ["iroh-net", "parking_lot", "tokio-util/time", "hashlink"] fs-store = ["reflink-copy", "redb", "redb_v1", "tempfile"] -downloader = ["iroh-net", "parking_lot", "tokio-util/time"] metrics = ["iroh-metrics"] [[example]] diff --git a/iroh-bytes/src/downloader.rs b/iroh-bytes/src/downloader.rs index 1963bb306be..ea9cf3138b2 100644 --- a/iroh-bytes/src/downloader.rs +++ b/iroh-bytes/src/downloader.rs @@ -2,12 +2,12 @@ //! //! The [`Downloader`] interacts with four main components to this end. //! - [`Dialer`]: Used to queue opening connections to nodes we need to perform downloads. -//! - [`ProviderMap`]: Where the downloader obtains information about nodes that could be +//! - `ProviderMap`: Where the downloader obtains information about nodes that could be //! used to perform a download. //! - [`Store`]: Where data is stored. //! //! Once a download request is received, the logic is as follows: -//! 1. The [`ProviderMap`] is queried for nodes. From these nodes some are selected +//! 1. The `ProviderMap` is queried for nodes. From these nodes some are selected //! prioritizing connected nodes with lower number of active requests. If no useful node is //! connected, or useful connected nodes have no capacity to perform the request, a connection //! attempt is started using the [`Dialer`]. @@ -27,18 +27,20 @@ //! requests to a single node is also limited. use std::{ - collections::{hash_map::Entry, HashMap, VecDeque}, + collections::{hash_map::Entry, HashMap, HashSet}, + fmt, num::NonZeroUsize, sync::{ atomic::{AtomicU64, Ordering}, Arc, }, + time::Duration, }; -use crate::{get::Stats, protocol::RangeSpecSeq, store::Store, Hash, HashAndFormat}; -use bao_tree::ChunkRanges; use futures::{future::LocalBoxFuture, FutureExt, StreamExt}; -use iroh_net::{MagicEndpoint, NodeId}; +use hashlink::LinkedHashSet; +use iroh_base::hash::{BlobFormat, Hash, HashAndFormat}; +use iroh_net::{MagicEndpoint, NodeAddr, NodeId}; use tokio::{ sync::{mpsc, oneshot}, task::JoinSet, @@ -46,22 +48,34 @@ use tokio::{ use tokio_util::{sync::CancellationToken, task::LocalPoolHandle, time::delay_queue}; use tracing::{debug, error_span, trace, warn, Instrument}; +use crate::{ + get::{db::DownloadProgress, Stats}, + store::Store, + util::{progress::ProgressSender, SetTagOption, TagSet}, + TempTag, +}; + mod get; mod invariants; +mod progress; mod test; -/// Delay added to a request when it's first received. -const INITIAL_REQUEST_DELAY: std::time::Duration = std::time::Duration::from_millis(500); -/// Number of retries initially assigned to a request. -const INITIAL_RETRY_COUNT: u8 = 4; +use self::progress::{BroadcastProgressSender, ProgressSubscriber, ProgressTracker}; + +// TODO: In which cases should we retry downloads? +// /// Number of retries for connecting to a node. +// const INITIAL_RETRY_COUNT: u8 = 4; +// /// Initial delay when reconnecting to a node. +// const INITIAL_RETRY_DELAY: Duration = Duration::from_millis(500); + /// Duration for which we keep nodes connected after they were last useful to us. -const IDLE_PEER_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10); +const IDLE_PEER_TIMEOUT: Duration = Duration::from_secs(10); /// Capacity of the channel used to communicate between the [`Downloader`] and the [`Service`]. const SERVICE_CHANNEL_CAPACITY: usize = 128; -/// Download identifier. -// Mainly for readability. -pub type Id = u64; +/// Identifier for a download intent. +#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, derive_more::Display)] +pub struct IntentId(pub u64); /// Trait modeling a dialer. This allows for IO-less testing. pub trait Dialer: @@ -80,7 +94,9 @@ pub trait Dialer: /// Signals what should be done with the request when it fails. #[derive(Debug)] pub enum FailureAction { - /// An error occurred that prevents the request from being retried at all. + /// The request was cancelled by us. + Cancelled, + /// An error ocurred that prevents the request from being retried at all. AbortRequest(anyhow::Error), /// An error occurred that suggests the node should not be used in general. DropPeer(anyhow::Error), @@ -89,14 +105,19 @@ pub enum FailureAction { } /// Future of a get request. -type GetFut = LocalBoxFuture<'static, Result>; +type GetFut = LocalBoxFuture<'static, InternalDownloadResult>; /// Trait modelling performing a single request over a connection. This allows for IO-less testing. pub trait Getter { /// Type of connections the Getter requires to perform a download. type Connection; /// Return a future that performs the download using the given connection. - fn get(&mut self, kind: DownloadKind, conn: Self::Connection) -> GetFut; + fn get( + &mut self, + kind: DownloadKind, + conn: Self::Connection, + progress_sender: BroadcastProgressSender, + ) -> GetFut; } /// Concurrency limits for the [`Downloader`]. @@ -108,6 +129,8 @@ pub struct ConcurrencyLimits { pub max_concurrent_requests_per_node: usize, /// Maximum number of open connections the service maintains. pub max_open_connections: usize, + /// Maximum number of nodes to dial concurrently for a single request. + pub max_concurrent_dials_per_hash: usize, } impl Default for ConcurrencyLimits { @@ -117,6 +140,7 @@ impl Default for ConcurrencyLimits { max_concurrent_requests: 50, max_concurrent_requests_per_node: 4, max_open_connections: 25, + max_concurrent_dials_per_hash: 5, } } } @@ -136,67 +160,136 @@ impl ConcurrencyLimits { fn at_connections_capacity(&self, active_connections: usize) -> bool { active_connections >= self.max_open_connections } + + /// Checks if the maximum number of concurrent dials per hash has been reached. + fn at_dials_per_hash_capacity(&self, concurrent_dials: usize) -> bool { + concurrent_dials >= self.max_concurrent_dials_per_hash + } } -/// Download requests the [`Downloader`] handles. -#[derive(Debug, PartialEq, Eq, Hash, Clone)] -pub enum DownloadKind { - /// Download a single blob entirely. - Blob { - /// Blob to be downloaded. - hash: Hash, - }, - /// Download a sequence of hashes entirely. - HashSeq { - /// Hash sequence to be downloaded. - hash: Hash, - }, +/// A download request. +#[derive(Debug, Clone)] +pub struct DownloadRequest { + kind: DownloadKind, + nodes: Vec, + tag: Option, + progress: Option, } -impl DownloadKind { - /// Get the requested hash. - const fn hash(&self) -> &Hash { - match self { - DownloadKind::Blob { hash } | DownloadKind::HashSeq { hash } => hash, +impl DownloadRequest { + /// Create a new download request. + /// + /// The blob will be auto-tagged after the download to prevent it from being garbage collected. + pub fn new( + resource: impl Into, + nodes: impl IntoIterator>, + ) -> Self { + Self { + kind: resource.into(), + nodes: nodes.into_iter().map(|n| n.into()).collect(), + tag: Some(SetTagOption::Auto), + progress: None, } } - /// Get the requested hash and format. - fn hash_and_format(&self) -> HashAndFormat { - match self { - DownloadKind::Blob { hash } => HashAndFormat::raw(*hash), - DownloadKind::HashSeq { hash } => HashAndFormat::hash_seq(*hash), - } + /// Create a new untagged download request. + /// + /// The blob will not be tagged, so only use this if the blob is already protected from garbage + /// collection through other means. + pub fn untagged( + resource: HashAndFormat, + nodes: impl IntoIterator>, + ) -> Self { + let mut r = Self::new(resource, nodes); + r.tag = None; + r } - /// Get the ranges this download is requesting. - // NOTE: necessary to extend downloads to support ranges of blobs ranges of collections. - #[allow(dead_code)] - fn ranges(&self) -> RangeSpecSeq { - match self { - DownloadKind::Blob { .. } => RangeSpecSeq::from_ranges([ChunkRanges::all()]), - DownloadKind::HashSeq { .. } => RangeSpecSeq::all(), - } + /// Set a tag to apply to the blob after download. + pub fn tag(mut self, tag: SetTagOption) -> Self { + self.tag = Some(tag); + self + } + + /// Pass a progress sender to receive progress updates. + pub fn progress_sender(mut self, sender: ProgressSubscriber) -> Self { + self.progress = Some(sender); + self } } -// For readability. In the future we might care about some data reporting on a successful download -// or kind of failure in the error case. -type DownloadResult = anyhow::Result<()>; +/// The kind of resource to download. +#[derive(Debug, Eq, PartialEq, Hash, Clone, Copy, derive_more::From, derive_more::Into)] +pub struct DownloadKind(HashAndFormat); + +impl DownloadKind { + /// Get the hash of this download + pub const fn hash(&self) -> Hash { + self.0.hash + } + + /// Get the format of this download + pub const fn format(&self) -> BlobFormat { + self.0.format + } + + /// Get the [`HashAndFormat`] pair of this download + pub const fn hash_and_format(&self) -> HashAndFormat { + self.0 + } + + /// Switch from [`BlobFormat::Raw`] to [`BlobFormat::HashSeq`] and vice-versa. + fn with_format_switched(&self) -> Self { + let format = match self.format() { + BlobFormat::Raw => BlobFormat::HashSeq, + BlobFormat::HashSeq => BlobFormat::Raw, + }; + Self(HashAndFormat::new(self.hash(), format)) + } +} + +impl fmt::Display for DownloadKind { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}:{:?}", self.0.hash.fmt_short(), self.0.format) + } +} + +/// The result of a download request, as returned to the application code. +type ExternalDownloadResult = Result; + +/// The result of a download request, as used in this module. +type InternalDownloadResult = Result; + +/// Error returned when a download could not be completed. +#[derive(Debug, Clone, thiserror::Error)] +pub enum DownloadError { + /// Failed to download from any provider + #[error("Failed to complete download")] + DownloadFailed, + /// The download was cancelled by us + #[error("Download cancelled by us")] + Cancelled, + /// No provider nodes found + #[error("No provider nodes found")] + NoProviders, + /// Failed to receive response from service. + #[error("Failed to receive response from download service")] + ActorClosed, +} /// Handle to interact with a download request. #[derive(Debug)] pub struct DownloadHandle { /// Id used to identify the request in the [`Downloader`]. - id: Id, + id: IntentId, /// Kind of download. kind: DownloadKind, /// Receiver to retrieve the return value of this download. - receiver: oneshot::Receiver, + receiver: oneshot::Receiver, } impl std::future::Future for DownloadHandle { - type Output = DownloadResult; + type Output = ExternalDownloadResult; fn poll( mut self: std::pin::Pin<&mut Self>, @@ -207,7 +300,7 @@ impl std::future::Future for DownloadHandle { // from the middle match self.receiver.poll_unpin(cx) { Ready(Ok(result)) => Ready(result), - Ready(Err(recv_err)) => Ready(Err(anyhow::anyhow!("oneshot error: {recv_err}"))), + Ready(Err(_recv_err)) => Ready(Err(DownloadError::ActorClosed)), Pending => Pending, } } @@ -234,9 +327,11 @@ impl Downloader { let create_future = move || { let concurrency_limits = ConcurrencyLimits::default(); - let getter = get::IoGetter { store }; + let getter = get::IoGetter { + store: store.clone(), + }; - let service = Service::new(getter, dialer, concurrency_limits, msg_rx); + let service = Service::new(store, getter, dialer, concurrency_limits, msg_rx); service.run().instrument(error_span!("downloader", %me)) }; @@ -248,20 +343,19 @@ impl Downloader { } /// Queue a download. - pub async fn queue(&mut self, kind: DownloadKind, nodes: Vec) -> DownloadHandle { - let id = self.next_id.fetch_add(1, Ordering::SeqCst); - + pub async fn queue(&self, request: DownloadRequest) -> DownloadHandle { + let kind = request.kind; + let intent_id = IntentId(self.next_id.fetch_add(1, Ordering::SeqCst)); let (sender, receiver) = oneshot::channel(); let handle = DownloadHandle { - id, - kind: kind.clone(), + id: intent_id, + kind, receiver, }; let msg = Message::Queue { - kind, - id, - sender, - nodes, + on_finish: sender, + request, + intent_id, }; // if this fails polling the handle will fail as well since the sender side of the oneshot // will be dropped @@ -274,13 +368,13 @@ impl Downloader { /// Cancel a download. // NOTE: receiving the handle ensures an intent can't be cancelled twice - pub async fn cancel(&mut self, handle: DownloadHandle) { + pub async fn cancel(&self, handle: DownloadHandle) { let DownloadHandle { id, kind, receiver: _, } = handle; - let msg = Message::Cancel { id, kind }; + let msg = Message::CancelIntent { id, kind }; if let Err(send_err) = self.msg_tx.send(msg).await { let msg = send_err.0; debug!(?msg, "cancel not sent"); @@ -288,8 +382,11 @@ impl Downloader { } /// Declare that certains nodes can be used to download a hash. - pub async fn nodes_have(&mut self, hash: Hash, nodes: Vec) { - let msg = Message::PeersHave { hash, nodes }; + /// + /// Note that this does not start a download, but only provides new nodes to already queued + /// downloads. Use [`Self::queue`] to queue a download. + pub async fn nodes_have(&mut self, hash: Hash, nodes: Vec) { + let msg = Message::NodesHave { hash, nodes }; if let Err(send_err) = self.msg_tx.send(msg).await { let msg = send_err.0; debug!(?msg, "nodes have not been sent") @@ -297,136 +394,82 @@ impl Downloader { } } -/// A node and its role with regard to a hash. -#[derive(Debug, Clone, Copy)] -pub struct NodeInfo { - node_id: NodeId, - role: Role, -} - -impl NodeInfo { - /// Create a new [`NodeInfo`] from its parts. - pub fn new(node_id: NodeId, role: Role) -> Self { - Self { node_id, role } - } -} - -impl From<(NodeId, Role)> for NodeInfo { - fn from((node_id, role): (NodeId, Role)) -> Self { - Self { node_id, role } - } -} - -/// The role of a node with regard to a download intent. -#[derive(Debug, Clone, Copy, Eq, PartialEq)] -pub enum Role { - /// We have information that this node has the requested blob. - Provider, - /// We do not have information if this node has the requested blob. - Candidate, -} - -impl PartialOrd for Role { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } -} -impl Ord for Role { - fn cmp(&self, other: &Self) -> std::cmp::Ordering { - match (self, other) { - (Role::Provider, Role::Provider) => std::cmp::Ordering::Equal, - (Role::Candidate, Role::Candidate) => std::cmp::Ordering::Equal, - (Role::Provider, Role::Candidate) => std::cmp::Ordering::Greater, - (Role::Candidate, Role::Provider) => std::cmp::Ordering::Less, - } - } -} - /// Messages the service can receive. #[derive(derive_more::Debug)] enum Message { /// Queue a download intent. Queue { - kind: DownloadKind, - id: Id, + request: DownloadRequest, #[debug(skip)] - sender: oneshot::Sender, - nodes: Vec, + on_finish: oneshot::Sender, + intent_id: IntentId, }, + /// Declare that nodes have a certain hash and can be used for downloading. + NodesHave { hash: Hash, nodes: Vec }, /// Cancel an intent. The associated request will be cancelled when the last intent is /// cancelled. - Cancel { id: Id, kind: DownloadKind }, - /// Declare that nodes have certains hash and can be used for downloading. This feeds the [`ProviderMap`]. - PeersHave { hash: Hash, nodes: Vec }, + CancelIntent { id: IntentId, kind: DownloadKind }, +} + +#[derive(derive_more::Debug)] +struct IntentHandlers { + #[debug("oneshot::Sender")] + on_finish: oneshot::Sender, + on_progress: Option, } -/// Information about a request being processed. +/// Information about a request. +#[derive(Debug, Default)] +struct RequestInfo { + /// Registered intents with progress senders and result callbacks. + intents: HashMap, + /// Tags requested for the blob to be created once the download finishes. + tags: TagSet, +} + +/// Information about a request in progress. #[derive(derive_more::Debug)] struct ActiveRequestInfo { - /// Ids of intents associated with this request. - #[debug("{:?}", intents.keys().collect::>())] - intents: HashMap>, - /// How many times can this request be retried. - remaining_retries: u8, /// Token used to cancel the future doing the request. #[debug(skip)] cancellation: CancellationToken, /// Peer doing this request attempt. node: NodeId, -} - -/// Information about a request that has not started. -#[derive(derive_more::Debug)] -struct PendingRequestInfo { - /// Ids of intents associated with this request. - #[debug("{:?}", intents.keys().collect::>())] - intents: HashMap>, - /// How many times can this request be retried. - remaining_retries: u8, - /// Key to manage the delay associated with this scheduled request. - #[debug(skip)] - delay_key: delay_queue::Key, - /// If this attempt was scheduled with a known potential node, this is stored here to - /// prevent another query to the [`ProviderMap`]. - next_node: Option, + /// Temporary tag to protect the partial blob from being garbage collected. + temp_tag: TempTag, } /// State of the connection to this node. #[derive(derive_more::Debug)] struct ConnectionInfo { /// Connection to this node. - /// - /// If this node was deemed unusable by a request, this will be set to `None`. As a - /// consequence, when evaluating nodes for a download, this node will not be considered. - /// Since nodes are kept for a longer time that they are strictly necessary, this acts as a - /// temporary ban. #[debug(skip)] - conn: Option, + conn: Conn, /// State of this node. - state: PeerState, + state: ConnectedState, } impl ConnectionInfo { /// Create a new idle node. fn new_idle(connection: Conn, drop_key: delay_queue::Key) -> Self { ConnectionInfo { - conn: Some(connection), - state: PeerState::Idle { drop_key }, + conn: connection, + state: ConnectedState::Idle { drop_key }, } } /// Count of active requests for the node. fn active_requests(&self) -> usize { match self.state { - PeerState::Busy { active_requests } => active_requests.get(), - PeerState::Idle { .. } => 0, + ConnectedState::Busy { active_requests } => active_requests.get(), + ConnectedState::Idle { .. } => 0, } } } /// State of a connected node. #[derive(derive_more::Debug)] -enum PeerState { +enum ConnectedState { /// Peer is handling at least one request. Busy { #[debug("{}", active_requests.get())] @@ -439,11 +482,15 @@ enum PeerState { }, } -/// Type that is returned from a download request. -type DownloadRes = (DownloadKind, Result<(), FailureAction>); +#[derive(Debug)] +enum NodeState<'a, Conn> { + Connected(&'a ConnectionInfo), + Dialing, + Disconnected, +} #[derive(Debug)] -struct Service { +struct Service { /// The getter performs individual requests. getter: G, /// Map to query for nodes that we believe have the data we are looking for. @@ -454,23 +501,26 @@ struct Service { concurrency_limits: ConcurrencyLimits, /// Channel to receive messages from the service's handle. msg_rx: mpsc::Receiver, - /// Peers available to use and their relevant information. + /// Active connections nodes: HashMap>, /// Queue to manage dropping nodes. goodbye_nodes_queue: delay_queue::DelayQueue, - /// Requests performed for download intents. Two download requests can produce the same - /// request. This map allows deduplication of efforts. - current_requests: HashMap, - /// Downloads underway. - in_progress_downloads: JoinSet, - /// Requests scheduled to be downloaded at a later time. - scheduled_requests: HashMap, - /// Queue of scheduled requests. - scheduled_request_queue: delay_queue::DelayQueue, + /// Queue of pending downloads. + queue: LinkedHashSet, + /// Information about pending and active requests + requests: HashMap, + /// State of running downloads + active_requests: HashMap, + /// Tasks for currently running transfers. + in_progress_downloads: JoinSet<(DownloadKind, InternalDownloadResult)>, + /// Progress tracker + progress_tracker: ProgressTracker, + /// The [`Store`] where tags are saved after a download completes. + db: DB, } - -impl, D: Dialer> Service { +impl, D: Dialer> Service { fn new( + db: DB, getter: G, dialer: D, concurrency_limits: ConcurrencyLimits, @@ -478,78 +528,80 @@ impl, D: Dialer> Service { ) -> Self { Service { getter, - providers: ProviderMap::default(), dialer, - concurrency_limits, msg_rx, - nodes: HashMap::default(), + concurrency_limits, + nodes: Default::default(), + providers: Default::default(), + requests: Default::default(), goodbye_nodes_queue: delay_queue::DelayQueue::default(), - current_requests: HashMap::default(), + active_requests: Default::default(), in_progress_downloads: Default::default(), - scheduled_requests: HashMap::default(), - scheduled_request_queue: delay_queue::DelayQueue::default(), + progress_tracker: ProgressTracker::new(), + queue: Default::default(), + db, } } /// Main loop for the service. async fn run(mut self) { loop { - // check if we have capacity to dequeue another scheduled request - let at_capacity = self - .concurrency_limits - .at_requests_capacity(self.in_progress_downloads.len()); - + trace!("wait for tick"); tokio::select! { Some((node, conn_result)) = self.dialer.next() => { - trace!("tick: connection ready"); + trace!(node=%node.fmt_short(), "tick: connection ready"); self.on_connection_ready(node, conn_result); } maybe_msg = self.msg_rx.recv() => { trace!(msg=?maybe_msg, "tick: message received"); match maybe_msg { - Some(msg) => self.handle_message(msg), + Some(msg) => self.handle_message(msg).await, None => return self.shutdown().await, } } - Some(res) = self.in_progress_downloads.join_next() => { + Some(res) = self.in_progress_downloads.join_next(), if !self.in_progress_downloads.is_empty() => { match res { Ok((kind, result)) => { - trace!("tick: download completed"); - self.on_download_completed(kind, result); + trace!(%kind, "tick: transfer completed"); + self.on_download_completed(kind, result).await; } - Err(e) => { - warn!("download issue: {:?}", e); + Err(err) => { + warn!(?err, "transfer task panicked"); } } } - Some(expired) = self.scheduled_request_queue.next(), if !at_capacity => { - trace!("tick: scheduled request ready"); - let kind = expired.into_inner(); - let request_info = self.scheduled_requests.remove(&kind).expect("is registered"); - self.on_scheduled_request_ready(kind, request_info); - } Some(expired) = self.goodbye_nodes_queue.next() => { let node = expired.into_inner(); self.nodes.remove(&node); - trace!(%node, "tick: goodbye node"); + trace!(node=%node.fmt_short(), "tick: goodbye node"); } } + + self.process_head(); + #[cfg(any(test, debug_assertions))] self.check_invariants(); } } /// Handle receiving a [`Message`]. - fn handle_message(&mut self, msg: Message) { + /// + // This is called in the actor loop, and only async because subscribing to an existing transfer + // sends the initial state. + async fn handle_message(&mut self, msg: Message) { match msg { Message::Queue { - kind, - id, - sender, - nodes, - } => self.handle_queue_new_download(kind, id, sender, nodes), - Message::Cancel { id, kind } => self.handle_cancel_download(id, kind), - Message::PeersHave { hash, nodes } => self.handle_nodes_have(hash, nodes), + request, + on_finish, + intent_id, + } => { + self.handle_queue_new_download(request, intent_id, on_finish) + .await + } + Message::CancelIntent { id, kind } => self.handle_cancel_download(id, kind).await, + Message::NodesHave { hash, nodes } => self + .providers + .add_nodes_if_hash_exists(hash, nodes.iter().cloned()), } } @@ -557,226 +609,104 @@ impl, D: Dialer> Service { /// /// If this intent maps to a request that already exists, it will be registered with it. If the /// request is new it will be scheduled. - fn handle_queue_new_download( + async fn handle_queue_new_download( &mut self, - kind: DownloadKind, - id: Id, - sender: oneshot::Sender, - nodes: Vec, + request: DownloadRequest, + intent_id: IntentId, + on_finish: oneshot::Sender, ) { - self.providers.add_nodes(*kind.hash(), &nodes); - if let Some(info) = self.current_requests.get_mut(&kind) { - // this intent maps to a download that already exists, simply register it - info.intents.insert(id, sender); - // increasing the retries by one accounts for multiple intents for the same request in - // a conservative way - info.remaining_retries += 1; - return trace!(?kind, ?info, "intent registered with active request"); + let DownloadRequest { + kind, + nodes, + tag, + progress, + } = request; + debug!(%kind, nodes=?nodes.iter().map(|n| n.node_id.fmt_short()).collect::>(), "queue intent"); + + // store the download intent + let intent_handlers = IntentHandlers { + on_finish, + on_progress: progress, + }; + + // early exit if no providers. + if nodes.is_empty() && self.providers.get_candidates(&kind.hash()).next().is_none() { + self.finalize_download( + kind, + [(intent_id, intent_handlers)].into(), + Err(DownloadError::NoProviders), + ); + return; } - let needs_node = self - .scheduled_requests - .get(&kind) - .map(|info| info.next_node.is_none()) - .unwrap_or(true); - - let next_node = needs_node - .then(|| self.get_best_candidate(kind.hash())) - .flatten(); - - // if we are here this request is not active, check if it needs to be scheduled - match self.scheduled_requests.get_mut(&kind) { - Some(info) => { - info.intents.insert(id, sender); - // pre-emptively get a node if we don't already have one - match (info.next_node, next_node) { - // We did not yet have next node, but have a node now. - (None, Some(next_node)) => { - info.next_node = Some(next_node); - } - (Some(_old_next_node), Some(_next_node)) => { - unreachable!("invariant: info.next_node must be none because checked above with needs_node") - } - _ => {} + // add the nodes to the provider map + self.providers + .add_hash_with_nodes(kind.hash(), nodes.iter().map(|n| n.node_id)); + + // queue the transfer (if not running) or attach to transfer progress (if already running) + if self.active_requests.contains_key(&kind) { + // the transfer is already running, so attach the progress sender + if let Some(on_progress) = &intent_handlers.on_progress { + // this is async because it sends the current state over the progress channel + if let Err(err) = self + .progress_tracker + .subscribe(kind, on_progress.clone()) + .await + { + debug!(?err, %kind, "failed to subscribe progress sender to transfer"); } - - // increasing the retries by one accounts for multiple intents for the same request in - // a conservative way - info.remaining_retries += 1; - trace!(?kind, ?info, "intent registered with scheduled request"); - } - None => { - let intents = HashMap::from([(id, sender)]); - self.schedule_request(kind, INITIAL_RETRY_COUNT, next_node, intents) } + } else { + // the transfer is not yet running, so add to queue. + // this is a noop if the transfer is already queued. + self.queue.insert(kind); + } + + // store the request info + let request_info = self.requests.entry(kind).or_default(); + request_info.intents.insert(intent_id, intent_handlers); + if let Some(tag) = &tag { + request_info.tags.insert(tag.clone()); } } - /// Gets the best candidate for a download. + /// Cancels a download intent. /// - /// Peers are selected prioritizing those with an open connection and with capacity for another - /// request, followed by nodes we are currently dialing with capacity for another request. - /// Lastly, nodes not connected and not dialing are considered. + /// This removes the intent from the list of intents for the `kind`. If the removed intent was + /// the last one for the `kind`, this means that the download is no longer needed. In this + /// case, the `kind` will be removed from the list of pending downloads - and, if the download was + /// already started, the download task will be cancelled. /// - /// If the selected candidate is not connected and we have capacity for another connection, a - /// dial is queued. - fn get_best_candidate(&mut self, hash: &Hash) -> Option { - /// Model the state of nodes found in the candidates - #[derive(PartialEq, Eq, Clone, Copy)] - enum ConnState { - Dialing, - Connected(usize), - NotConnected, - } - - impl Ord for ConnState { - fn cmp(&self, other: &Self) -> std::cmp::Ordering { - // define the order of preference between candidates as follows: - // - prefer connected nodes to dialing ones - // - prefer dialing nodes to not connected ones - // - prefer nodes with less active requests when connected - use std::cmp::Ordering::*; - match (self, other) { - (ConnState::Dialing, ConnState::Dialing) => Equal, - (ConnState::Dialing, ConnState::Connected(_)) => Less, - (ConnState::Dialing, ConnState::NotConnected) => Greater, - (ConnState::NotConnected, ConnState::Dialing) => Less, - (ConnState::NotConnected, ConnState::Connected(_)) => Less, - (ConnState::NotConnected, ConnState::NotConnected) => Equal, - (ConnState::Connected(_), ConnState::Dialing) => Greater, - (ConnState::Connected(_), ConnState::NotConnected) => Greater, - (ConnState::Connected(a), ConnState::Connected(b)) => match a.cmp(b) { - Less => Greater, // less preferable if greater number of requests - Equal => Equal, // no preference - Greater => Less, // more preferable if less number of requests - }, - } - } - } - - impl PartialOrd for ConnState { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } - } + /// The method is async because it will send a final abort event on the progress sender. + async fn handle_cancel_download(&mut self, intent_id: IntentId, kind: DownloadKind) { + let Entry::Occupied(mut occupied_entry) = self.requests.entry(kind) else { + warn!(%kind, %intent_id, "cancel download called for unknown download"); + return; + }; - // first collect suitable candidates - let mut candidates = self - .providers - .get_candidates(hash) - .filter_map(|(node_id, role)| { - let node = NodeInfo::new(*node_id, *role); - if let Some(info) = self.nodes.get(node_id) { - info.conn.as_ref()?; - let req_count = info.active_requests(); - // filter out nodes at capacity - let has_capacity = !self.concurrency_limits.node_at_request_capacity(req_count); - has_capacity.then_some((node, ConnState::Connected(req_count))) - } else if self.dialer.is_pending(node_id) { - Some((node, ConnState::Dialing)) - } else { - Some((node, ConnState::NotConnected)) - } - }) - .collect::>(); - - // Sort candidates by: - // * Role (Providers > Candidates) - // * ConnState (Connected > Dialing > NotConnected) - candidates.sort_unstable_by_key(|(NodeInfo { role, .. }, state)| (*role, *state)); - - // this is our best node, check if we need to dial it - let (node, state) = candidates.pop()?; - - if let ConnState::NotConnected = state { - if !self.at_connections_capacity() { - // node is not connected, not dialing and concurrency limits allow another connection - debug!(node = %node.node_id, "dialing node"); - self.dialer.queue_dial(node.node_id); - Some(node.node_id) - } else { - trace!(node = %node.node_id, "required node not dialed to maintain concurrency limits"); - None + let request_info = occupied_entry.get_mut(); + if let Some(handlers) = request_info.intents.remove(&intent_id) { + handlers.on_finish.send(Err(DownloadError::Cancelled)).ok(); + + if let Some(sender) = handlers.on_progress { + self.progress_tracker.unsubscribe(&kind, &sender); + sender + .send(DownloadProgress::Abort( + anyhow::Error::from(DownloadError::Cancelled).into(), + )) + .await + .ok(); } - } else { - Some(node.node_id) } - } - /// Cancels the download request. - /// - /// This removes the registered download intent and, depending on its state, it will either - /// remove it from the scheduled requests, or cancel the future. - fn handle_cancel_download(&mut self, id: Id, kind: DownloadKind) { - let hash = *kind.hash(); - let mut download_removed = false; - if let Entry::Occupied(mut occupied_entry) = self.current_requests.entry(kind.clone()) { - // remove the intent from the associated request - let intents = &mut occupied_entry.get_mut().intents; - intents.remove(&id); - // if this was the last intent associated with the request cancel it - if intents.is_empty() { - download_removed = true; + if request_info.intents.is_empty() { + occupied_entry.remove(); + if let Entry::Occupied(occupied_entry) = self.active_requests.entry(kind) { occupied_entry.remove().cancellation.cancel(); + } else { + self.queue.remove(&kind); } - } else if let Entry::Occupied(mut occupied_entry) = self.scheduled_requests.entry(kind) { - // remove the intent from the associated request - let intents = &mut occupied_entry.get_mut().intents; - intents.remove(&id); - // if this was the last intent associated with the request remove it from the schedule - // queue - if intents.is_empty() { - let delay_key = occupied_entry.remove().delay_key; - self.scheduled_request_queue.remove(&delay_key); - download_removed = true; - } - } - - if download_removed && !self.is_needed(hash) { - self.providers.remove(hash) - } - } - - /// Handle a [`Message::PeersHave`]. - fn handle_nodes_have(&mut self, hash: Hash, nodes: Vec) { - // check if this still needed - if self.is_needed(hash) { - self.providers.add_nodes(hash, &nodes); - } - } - - /// Checks if this hash is needed. - fn is_needed(&self, hash: Hash) -> bool { - let as_blob = DownloadKind::Blob { hash }; - let as_hash_seq = DownloadKind::HashSeq { hash }; - self.current_requests.contains_key(&as_blob) - || self.scheduled_requests.contains_key(&as_blob) - || self.current_requests.contains_key(&as_hash_seq) - || self.scheduled_requests.contains_key(&as_hash_seq) - } - - /// Check if this hash is currently being downloaded. - fn is_current_request(&self, hash: Hash) -> bool { - let as_blob = DownloadKind::Blob { hash }; - let as_hash_seq = DownloadKind::HashSeq { hash }; - self.current_requests.contains_key(&as_blob) - || self.current_requests.contains_key(&as_hash_seq) - } - - /// Remove a hash from the scheduled queue. - fn unschedule(&mut self, hash: Hash) -> Option<(DownloadKind, PendingRequestInfo)> { - let as_blob = DownloadKind::Blob { hash }; - let as_hash_seq = DownloadKind::HashSeq { hash }; - let info = match self.scheduled_requests.remove(&as_blob) { - Some(req) => Some(req), - None => self.scheduled_requests.remove(&as_hash_seq), - }; - if let Some(info) = info { - let kind = self.scheduled_request_queue.remove(&info.delay_key); - let kind = kind.into_inner(); - Some((kind, info)) - } else { - None + self.remove_kind_from_provider_map(&kind); } } @@ -784,11 +714,10 @@ impl, D: Dialer> Service { fn on_connection_ready(&mut self, node: NodeId, result: anyhow::Result) { match result { Ok(connection) => { - trace!(%node, "connected to node"); + trace!(node=%node.fmt_short(), "connected to node"); let drop_key = self.goodbye_nodes_queue.insert(node, IDLE_PEER_TIMEOUT); self.nodes .insert(node, ConnectionInfo::new_idle(connection, drop_key)); - self.on_node_ready(node); } Err(err) => { debug!(%node, %err, "connection to node failed") @@ -796,198 +725,291 @@ impl, D: Dialer> Service { } } - /// Called after the connection to a node is established, and after finishing a download. - /// - /// Starts the next provider hash download, if there is one. - fn on_node_ready(&mut self, node: NodeId) { - // Get the next provider hash for this node. - let Some(hash) = self.providers.get_next_provider_hash_for_node(&node) else { - return; - }; - - if self.is_current_request(hash) { - return; - } - - let Some(conn) = self.get_node_connection_for_download(&node) else { - return; - }; - - let Some((kind, info)) = self.unschedule(hash) else { - debug_assert!( - false, - "invalid state: expected {hash:?} to be scheduled, but it wasn't" - ); - return; - }; - - let PendingRequestInfo { - intents, - remaining_retries, - .. - } = info; - - self.start_download(kind, node, conn, remaining_retries, intents); - } - - fn on_download_completed(&mut self, kind: DownloadKind, result: Result<(), FailureAction>) { + async fn on_download_completed(&mut self, kind: DownloadKind, result: InternalDownloadResult) { // first remove the request - let info = self - .current_requests + let active_request_info = self + .active_requests .remove(&kind) .expect("request was active"); - // update the active requests for this node - let ActiveRequestInfo { - intents, - node, - mut remaining_retries, - .. - } = info; + // get general request info + let request_info = self.requests.remove(&kind).expect("request was active"); + let ActiveRequestInfo { node, temp_tag, .. } = active_request_info; + + // get node info let node_info = self .nodes .get_mut(&node) .expect("node exists in the mapping"); - node_info.state = match &node_info.state { - PeerState::Busy { active_requests } => { - match NonZeroUsize::new(active_requests.get() - 1) { - Some(active_requests) => PeerState::Busy { active_requests }, - None => { - // last request of the node was this one - let drop_key = self.goodbye_nodes_queue.insert(node, IDLE_PEER_TIMEOUT); - PeerState::Idle { drop_key } - } - } - } - PeerState::Idle { .. } => unreachable!("node was busy"), - }; - let hash = *kind.hash(); - - let node_ready = match result { + let (keep_node, _retry_node) = match &result { Ok(_) => { - debug!(%node, ?kind, "download completed"); - for sender in intents.into_values() { - let _ = sender.send(Ok(())); - } - true + debug!(%kind, node=%node.fmt_short(), "transfer finished"); + (true, false) + } + Err(FailureAction::Cancelled) => { + debug!(%kind, node=%node.fmt_short(), "download cancelled"); + (true, false) } Err(FailureAction::AbortRequest(reason)) => { - debug!(%node, ?kind, %reason, "aborting request"); - for sender in intents.into_values() { - let _ = sender.send(Err(anyhow::anyhow!("request aborted"))); - } - true + debug!(%kind, node=%node.fmt_short(), %reason, "aborting request"); + (true, false) } Err(FailureAction::DropPeer(reason)) => { - debug!(%node, ?kind, %reason, "node will be dropped"); - if let Some(_connection) = node_info.conn.take() { - // TODO(@divma): this will fail open streams, do we want this? - // connection.close(..) - } - false + debug!(%kind, node=%node.fmt_short(), %reason, "node will be dropped"); + (false, false) } Err(FailureAction::RetryLater(reason)) => { - // check if the download can be retried - if remaining_retries > 0 { - debug!(%node, ?kind, %reason, "download attempt failed"); - remaining_retries -= 1; - let next_node = self.get_best_candidate(kind.hash()); - self.schedule_request(kind, remaining_retries, next_node, intents); - } else { - warn!(%node, ?kind, %reason, "download failed"); - for sender in intents.into_values() { - let _ = sender.send(Err(anyhow::anyhow!("download ran out of attempts"))); - } - } - false + debug!(%kind, node=%node.fmt_short(), %reason, "download failed but retry later"); + // TODO: How do we want to actually do retries? + // Right now they are skipped (same as abort request) + (true, true) } }; - if !self.is_needed(hash) { - self.providers.remove(hash) + if keep_node { + // TODO: Handle retries somehow. + // if retry_node { ..} + self.providers.remove_hash_from_node(&kind.hash(), &node); + // update node busy/idle state + node_info.state = match &node_info.state { + ConnectedState::Busy { active_requests } => { + match NonZeroUsize::new(active_requests.get() - 1) { + Some(active_requests) => ConnectedState::Busy { active_requests }, + None => { + // last request of the node was this one, switch to idle + let drop_key = self.goodbye_nodes_queue.insert(node, IDLE_PEER_TIMEOUT); + ConnectedState::Idle { drop_key } + } + } + } + ConnectedState::Idle { .. } => unreachable!("node was busy"), + }; + } else { + // this drops the connection, thus disconnects + self.nodes.remove(&node); + self.providers.remove_node(&node); } - if node_ready { - self.on_node_ready(node); + + let finalize = result.is_ok() || !self.providers.has_candidates(&kind.hash()); + + if finalize { + let result = result.map_err(|_| DownloadError::DownloadFailed); + if result.is_ok() { + request_info.tags.apply(&self.db, kind.0).await.ok(); + } + drop(temp_tag); + self.finalize_download(kind, request_info.intents, result); + } else { + // reinsert the download at the front of the queue to try from the next node + self.requests.insert(kind, request_info); + self.queue.insert(kind); + self.queue.to_front(&kind); } } - /// A scheduled request is ready to be processed. + /// Finalize a download. /// - /// The node that was initially selected is used if possible. Otherwise we try to get a new - /// node - fn on_scheduled_request_ready(&mut self, kind: DownloadKind, info: PendingRequestInfo) { - let PendingRequestInfo { - intents, - mut remaining_retries, - next_node, - .. - } = info; - - // first try with the node that was initially assigned - if let Some((node_id, conn)) = next_node.and_then(|node_id| { - self.get_node_connection_for_download(&node_id) - .map(|conn| (node_id, conn)) - }) { - return self.start_download(kind, node_id, conn, remaining_retries, intents); + /// This triggers the intent return channels, and removes the download from the progress tracker + /// and provider map. + fn finalize_download( + &mut self, + kind: DownloadKind, + intents: HashMap, + result: ExternalDownloadResult, + ) { + self.progress_tracker.remove(&kind); + self.remove_kind_from_provider_map(&kind); + let result = result.map_err(|_| DownloadError::DownloadFailed); + for (_id, handlers) in intents.into_iter() { + handlers.on_finish.send(result.clone()).ok(); } + } - // we either didn't have a node or the node is busy or dialing. In any case try to get - // another node - let next_node = match self.get_best_candidate(kind.hash()) { - None => None, - Some(node_id) => { - // optimistically check if the node could do the request right away - match self.get_node_connection_for_download(&node_id) { - Some(conn) => { - return self.start_download(kind, node_id, conn, remaining_retries, intents) - } - None => Some(node_id), + /// Start the next downloads, or dial nodes, if limits permit and the queue is non-empty. + /// + /// This is called after all actions. If there is nothing to do, it will return cheaply. + /// Otherwise, we will check the next hash in the queue, and: + /// * start the transfer if we are connected to a provider and limits are ok + /// * or, connect to a provider, if there is one we are not dialing yet and limits are ok + /// * or, disconnect an idle node if it would allow us to connect to a provider, + /// * or, if our limits are reached, do nothing for now + /// + /// The download requests will only be popped from the queue once we either start the transfer + /// from a connected node [`NextStep::StartTransfer`], or if we abort the download on + /// [`NextStep::OutOfProviders`]. In all other cases, the request is kept at the top of the + /// queue, so the next call to [`Self::process_head`] will evaluate the situation again - and + /// so forth, until either [`NextStep::StartTransfer`] or [`NextStep::OutOfProviders`] is + /// reached. + fn process_head(&mut self) { + // start as many queued downloads as allowed by the request limits. + loop { + // if queue empty: break. + let Some(kind) = self.queue.front().cloned() else { + break; + }; + + let next_step = self.next_step(&kind); + trace!(%kind, ?next_step, "check queue head"); + + match next_step { + // We are waiting either for dialing to finish, or for a full node to finish a + // transfer, so nothing to do for us at the moment. + NextStep::Wait => break, + NextStep::StartTransfer(node) => { + let _ = self.queue.pop_front(); + debug!(%kind, node=%node.fmt_short(), "start transfer"); + self.start_download(kind, node); + } + NextStep::Dial(node) => { + debug!(%kind, node=%node.fmt_short(), "dial node"); + self.dialer.queue_dial(node); + } + NextStep::DialQueuedDisconnect(node, key) => { + let expired = self.goodbye_nodes_queue.remove(&key); + let expired_node = expired.into_inner(); + debug!(node=%expired_node.fmt_short(), "disconnect idle node to make room for next connection"); + let info = self.nodes.remove(&expired_node); + debug_assert!( + matches!( + info, + Some(ConnectionInfo { + state: ConnectedState::Idle { .. }, + .. + }) + ), + "node picked from goodbye queue to be idle" + ); + debug!(%kind, node=%node.fmt_short(), "dial node"); + self.dialer.queue_dial(node); + } + NextStep::OutOfProviders => { + debug!(%kind, "abort download: out of providers"); + let _ = self.queue.pop_front(); + let info = self.requests.remove(&kind).expect("queued downloads exist"); + self.finalize_download(kind, info.intents, Err(DownloadError::NoProviders)); } } + } + } + + /// Calculate the next step needed to proceed the download for `kind`. + /// + /// This is called once `kind` has reached the head of the queue, see [`Self::process_head`]. + /// It can be called repeatedly, and does nothing on itself, only calculate what *should* be + /// done next. + /// + /// See [`NextStep`] for details on the potential next steps returned from this method. + fn next_step(&self, kind: &DownloadKind) -> NextStep { + // If the total requests capacity is reached, we have to wait until an active request + // completes. + if self + .concurrency_limits + .at_requests_capacity(self.active_requests.len()) + { + return NextStep::Wait; }; - // we tried to get a node to perform this request but didn't get one, so now this attempt - // is failed - if remaining_retries > 0 { - remaining_retries -= 1; - self.schedule_request(kind, remaining_retries, next_node, intents); - } else { - // check if this hash is needed in some form, otherwise remove it from providers - let hash = *kind.hash(); - if !self.is_needed(hash) { - self.providers.remove(hash) + let mut candidates = self.providers.get_candidates(&kind.hash()).peekable(); + // If we have no provider candidates for this download, there's nothing else we can do. + if candidates.peek().is_none() { + return NextStep::OutOfProviders; + } + + // Track if there is provider node to which we are connected and which is not at its request capacity. + // If there are more than one, take the one with the least amount of running transfers. + let mut best_connected: Option<(NodeId, usize)> = None; + // Track if there is a disconnected provider node to which we can potentially connect. + let mut next_to_dial = None; + // Track the number of provider nodes that are currently being dialed. + let mut currently_dialing = 0; + // Track if we have at least one provider node which is currently at its request capacity. + // If this is the case, we will never return [`NextStep::OutOfProviders`] but [`NextStep::Wait`] + // instead, because we can still try that node once it has finished its work. + let mut has_exhausted_provider = false; + + for node in candidates { + match self.node_state(node) { + NodeState::Connected(info) => { + let active_requests = info.active_requests(); + if self + .concurrency_limits + .node_at_request_capacity(active_requests) + { + has_exhausted_provider = true; + } else { + best_connected = Some(match best_connected.take() { + Some(old) if old.1 <= active_requests => old, + _ => (*node, active_requests), + }); + } + } + NodeState::Dialing => { + currently_dialing += 1; + } + NodeState::Disconnected => { + if next_to_dial.is_none() { + next_to_dial = Some(node); + } + } } - // request can't be retried - for sender in intents.into_values() { - let _ = sender.send(Err(anyhow::anyhow!("download ran out of attempts"))); + } + + let has_dialing = currently_dialing > 0; + + if let Some((node, _active_requests)) = best_connected { + // If we have a connected provider node with free slots, use it! + NextStep::StartTransfer(node) + } else if let Some(node) = next_to_dial { + let at_dial_capacity = has_dialing + && self + .concurrency_limits + .at_dials_per_hash_capacity(currently_dialing); + let at_connections_capacity = self.at_connections_capacity(); + + if !at_connections_capacity && !at_dial_capacity { + NextStep::Dial(*node) + } else if at_connections_capacity + && !at_dial_capacity + && !self.goodbye_nodes_queue.is_empty() + { + let key = self.goodbye_nodes_queue.peek().expect("just checked"); + NextStep::DialQueuedDisconnect(*node, key) + } else { + NextStep::Wait } - debug!(?kind, "download ran out of attempts") + } else if has_exhausted_provider || has_dialing { + NextStep::Wait + } else { + NextStep::OutOfProviders } } /// Start downloading from the given node. - fn start_download( - &mut self, - kind: DownloadKind, - node: NodeId, - conn: D::Connection, - remaining_retries: u8, - intents: HashMap>, - ) { - debug!(%node, ?kind, "starting download"); + /// + /// Panics if hash is not in self.requests or node is not in self.nodes. + fn start_download(&mut self, kind: DownloadKind, node: NodeId) { + let node_info = self.nodes.get_mut(&node).expect("node exists"); + let request_info = self.requests.get(&kind).expect("hash exists"); + + // create a progress sender and subscribe all intents to the progress sender + let subscribers = request_info + .intents + .values() + .flat_map(|state| state.on_progress.clone()); + let progress_sender = self.progress_tracker.track(kind, subscribers); + + // create the active request state let cancellation = CancellationToken::new(); - let info = ActiveRequestInfo { - intents, - remaining_retries, - cancellation, + let temp_tag = self.db.temp_tag(kind.0); + let state = ActiveRequestInfo { + cancellation: cancellation.clone(), node, + temp_tag, }; - let cancellation = info.cancellation.clone(); - self.current_requests.insert(kind.clone(), info); - - let get = self.getter.get(kind.clone(), conn); + let conn = node_info.conn.clone(); + let get_fut = self.getter.get(kind, conn, progress_sender); let fut = async move { // NOTE: it's an open question if we should do timeouts at this point. Considerations from @Frando: // > at this stage we do not know the size of the download, so the timeout would have @@ -996,67 +1018,36 @@ impl, D: Dialer> Service { // > time, while faster nodes could be readily available. // As a conclusion, timeouts should be added only after downloads are known to be bounded let res = tokio::select! { - _ = cancellation.cancelled() => Err(FailureAction::AbortRequest(anyhow::anyhow!("cancelled"))), - res = get => res + _ = cancellation.cancelled() => Err(FailureAction::Cancelled), + res = get_fut => res }; + trace!("transfer finished"); - (kind, res.map(|_stats| ())) + (kind, res) + } + .instrument(error_span!("transfer", %kind, node=%node.fmt_short())); + node_info.state = match &node_info.state { + ConnectedState::Busy { active_requests } => ConnectedState::Busy { + active_requests: active_requests.saturating_add(1), + }, + ConnectedState::Idle { drop_key } => { + self.goodbye_nodes_queue.remove(drop_key); + ConnectedState::Busy { + active_requests: NonZeroUsize::new(1).expect("clearly non zero"), + } + } }; - + self.active_requests.insert(kind, state); self.in_progress_downloads.spawn_local(fut); } - /// Schedule a request for later processing. - fn schedule_request( - &mut self, - kind: DownloadKind, - remaining_retries: u8, - next_node: Option, - intents: HashMap>, - ) { - // this is simply INITIAL_REQUEST_DELAY * attempt_num where attempt_num (as an ordinal - // number) is maxed at INITIAL_RETRY_COUNT - let delay = INITIAL_REQUEST_DELAY - * (INITIAL_RETRY_COUNT.saturating_sub(remaining_retries) as u32 + 1); - - let delay_key = self.scheduled_request_queue.insert(kind.clone(), delay); - - let info = PendingRequestInfo { - intents, - remaining_retries, - delay_key, - next_node, - }; - debug!(?kind, ?info, "request scheduled"); - self.scheduled_requests.insert(kind, info); - } - - /// Gets the [`Dialer::Connection`] for a node if it's connected and has capacity for another - /// request. In this case, the count of active requests for the node is incremented. - fn get_node_connection_for_download(&mut self, node: &NodeId) -> Option { - let info = self.nodes.get_mut(node)?; - let connection = info.conn.as_ref()?; - // check if the node can be sent another request - match &mut info.state { - PeerState::Busy { active_requests } => { - if !self - .concurrency_limits - .node_at_request_capacity(active_requests.get()) - { - *active_requests = active_requests.saturating_add(1); - Some(connection.clone()) - } else { - None - } - } - PeerState::Idle { drop_key } => { - // node is no longer idle - self.goodbye_nodes_queue.remove(drop_key); - info.state = PeerState::Busy { - active_requests: NonZeroUsize::new(1).expect("clearly non zero"), - }; - Some(connection.clone()) - } + fn node_state<'a>(&'a self, node: &NodeId) -> NodeState<'a, D::Connection> { + if let Some(info) = self.nodes.get(node) { + NodeState::Connected(info) + } else if self.dialer.is_pending(node) { + NodeState::Dialing + } else { + NodeState::Disconnected } } @@ -1068,15 +1059,19 @@ impl, D: Dialer> Service { /// Get the total number of connected and dialing nodes. fn connections_count(&self) -> usize { - let connected_nodes = self - .nodes - .values() - .filter(|info| info.conn.is_some()) - .count(); + let connected_nodes = self.nodes.values().count(); let dialing_nodes = self.dialer.pending_count(); connected_nodes + dialing_nodes } + /// Remove a `kind` from the [`ProviderMap`], but only if [`Self::queue`] does not contain the + /// hash at all, even with the other [`BlobFormat`]. + fn remove_kind_from_provider_map(&mut self, kind: &DownloadKind) { + if !self.queue.contains(kind) && !self.queue.contains(&kind.with_format_switched()) { + self.providers.remove_hash(&kind.hash()); + } + } + #[allow(clippy::unused_async)] async fn shutdown(self) { debug!("shutting down"); @@ -1084,86 +1079,111 @@ impl, D: Dialer> Service { } } -/// Map of potential providers for a hash. -#[derive(Default, Debug)] -pub struct ProviderMap { - /// Candidates to download a hash. - candidates: HashMap>, - /// Ordered list of provider hashes per node. +/// The next step needed to continue a download. +/// +/// See [`Service::next_step`] for details. +#[derive(Debug)] +enum NextStep { + /// Provider connection is ready, initiate the transfer. + StartTransfer(NodeId), + /// Start to dial `NodeId`. /// - /// I.e. blobs we assume the node can provide. - provider_hashes_by_node: HashMap>, + /// This means: We have no non-exhausted connection to a provider node, but a free connection slot + /// and a provider node we are not yet connected to. + Dial(NodeId), + /// Start to dial `NodeId`, but first disconnect the idle node behind [`delay_queue::Key`] in + /// [`Service::goodbye_nodes_queue`] to free up a connection slot. + DialQueuedDisconnect(NodeId, delay_queue::Key), + /// All resource limits are exhausted, do nothing for now and wait until a slot frees up. + Wait, + /// We have tried all available providers. There is nothing else to do. + OutOfProviders, } -struct ProviderIter<'a> { - inner: Option>, -} - -impl<'a> Iterator for ProviderIter<'a> { - type Item = (&'a NodeId, &'a Role); - - fn next(&mut self) -> Option { - self.inner.as_mut().and_then(|iter| iter.next()) - } +/// Map of potential providers for a hash. +#[derive(Default, Debug)] +struct ProviderMap { + hash_node: HashMap>, + node_hash: HashMap>, } impl ProviderMap { /// Get candidates to download this hash. - fn get_candidates(&self, hash: &Hash) -> impl Iterator { - let inner = self.candidates.get(hash).map(|nodes| nodes.iter()); - ProviderIter { inner } + pub fn get_candidates(&self, hash: &Hash) -> impl Iterator { + self.hash_node + .get(hash) + .map(|nodes| nodes.iter()) + .into_iter() + .flatten() + } + + /// Whether we have any candidates to download this hash. + pub fn has_candidates(&self, hash: &Hash) -> bool { + self.hash_node + .get(hash) + .map(|nodes| !nodes.is_empty()) + .unwrap_or(false) } /// Register nodes for a hash. Should only be done for hashes we care to download. - fn add_nodes(&mut self, hash: Hash, nodes: &[NodeInfo]) { - let entry = self.candidates.entry(hash).or_default(); + fn add_hash_with_nodes(&mut self, hash: Hash, nodes: impl Iterator) { + let hash_entry = self.hash_node.entry(hash).or_default(); for node in nodes { - entry - .entry(node.node_id) - .and_modify(|role| *role = (*role).max(node.role)) - .or_insert(node.role); - if let Role::Provider = node.role { - self.provider_hashes_by_node - .entry(node.node_id) - .or_default() - .push_back(hash); - } + hash_entry.insert(node); + let node_entry = self.node_hash.entry(node).or_default(); + node_entry.insert(hash); } } - /// Get the next provider hash for a node. - /// - /// I.e. get the next hash that was added with [`Role::Provider`] for this node. - fn get_next_provider_hash_for_node(&mut self, node: &NodeId) -> Option { - let hash = self - .provider_hashes_by_node - .get(node) - .and_then(|hashes| hashes.front()) - .copied(); - if let Some(hash) = hash { - self.move_hash_to_back(node, hash); + /// Register nodes for a hash, but only if the hash is already in our queue. + fn add_nodes_if_hash_exists(&mut self, hash: Hash, nodes: impl Iterator) { + if let Some(hash_entry) = self.hash_node.get_mut(&hash) { + for node in nodes { + hash_entry.insert(node); + let node_entry = self.node_hash.entry(node).or_default(); + node_entry.insert(hash); + } } - hash } /// Signal the registry that this hash is no longer of interest. - fn remove(&mut self, hash: Hash) { - if let Some(nodes) = self.candidates.remove(&hash) { - for node in nodes.keys() { - if let Some(hashes) = self.provider_hashes_by_node.get_mut(node) { - hashes.retain(|h| *h != hash); + fn remove_hash(&mut self, hash: &Hash) { + if let Some(nodes) = self.hash_node.remove(hash) { + for node in nodes { + if let Some(hashes) = self.node_hash.get_mut(&node) { + hashes.remove(hash); + if hashes.is_empty() { + self.node_hash.remove(&node); + } } } } } - /// Move a hash to the back of the provider queue for a node. - fn move_hash_to_back(&mut self, node: &NodeId, hash: Hash) { - let hashes = self.provider_hashes_by_node.get_mut(node); - if let Some(hashes) = hashes { - debug_assert_eq!(hashes.front(), Some(&hash)); - if !hashes.is_empty() { - hashes.rotate_left(1); + fn remove_node(&mut self, node: &NodeId) { + if let Some(hashes) = self.node_hash.remove(node) { + for hash in hashes { + if let Some(nodes) = self.hash_node.get_mut(&hash) { + nodes.remove(node); + if nodes.is_empty() { + self.hash_node.remove(&hash); + } + } + } + } + } + + fn remove_hash_from_node(&mut self, hash: &Hash, node: &NodeId) { + if let Some(nodes) = self.hash_node.get_mut(hash) { + nodes.remove(node); + if nodes.is_empty() { + self.remove_hash(hash); + } + } + if let Some(hashes) = self.node_hash.get_mut(node) { + hashes.remove(hash); + if hashes.is_empty() { + self.remove_node(node); } } } diff --git a/iroh-bytes/src/downloader/get.rs b/iroh-bytes/src/downloader/get.rs index 334064bdeea..2fb39c2900f 100644 --- a/iroh-bytes/src/downloader/get.rs +++ b/iroh-bytes/src/downloader/get.rs @@ -3,7 +3,6 @@ use crate::{ get::{db::get_to_db, error::GetError}, store::Store, - util::progress::IgnoreProgressSender, }; use futures::FutureExt; #[cfg(feature = "metrics")] @@ -12,7 +11,7 @@ use iroh_metrics::{inc, inc_by}; #[cfg(feature = "metrics")] use crate::metrics::Metrics; -use super::{DownloadKind, FailureAction, GetFut, Getter}; +use super::{progress::BroadcastProgressSender, DownloadKind, FailureAction, GetFut, Getter}; impl From for FailureAction { fn from(e: GetError) -> Self { @@ -36,9 +35,13 @@ pub(crate) struct IoGetter { impl Getter for IoGetter { type Connection = quinn::Connection; - fn get(&mut self, kind: DownloadKind, conn: Self::Connection) -> GetFut { + fn get( + &mut self, + kind: DownloadKind, + conn: Self::Connection, + progress_sender: BroadcastProgressSender, + ) -> GetFut { let store = self.store.clone(); - let progress_sender = IgnoreProgressSender::default(); let fut = async move { let get_conn = || async move { Ok(conn) }; let res = get_to_db(&store, get_conn, &kind.hash_and_format(), progress_sender).await; diff --git a/iroh-bytes/src/downloader/invariants.rs b/iroh-bytes/src/downloader/invariants.rs index b24a6a125ec..f948179346b 100644 --- a/iroh-bytes/src/downloader/invariants.rs +++ b/iroh-bytes/src/downloader/invariants.rs @@ -5,12 +5,12 @@ use super::*; /// invariants for the service. -impl, D: Dialer> Service { +impl, D: Dialer, S: Store> Service { /// Checks the various invariants the service must maintain #[track_caller] pub(in crate::downloader) fn check_invariants(&self) { self.check_active_request_count(); - self.check_scheduled_requests_consistency(); + self.check_queued_requests_consistency(); self.check_idle_peer_consistency(); self.check_concurrency_limits(); self.check_provider_map_prunning(); @@ -21,8 +21,9 @@ impl, D: Dialer> Service { fn check_concurrency_limits(&self) { let ConcurrencyLimits { max_concurrent_requests, - max_concurrent_requests_per_node: max_concurrent_requests_per_peer, + max_concurrent_requests_per_node, max_open_connections, + max_concurrent_dials_per_hash, } = &self.concurrency_limits; // check the total number of active requests to ensure it stays within the limit @@ -32,16 +33,39 @@ impl, D: Dialer> Service { ); // check that the open and dialing peers don't exceed the connection capacity + tracing::trace!( + "limits: conns: {}/{} | reqs: {}/{}", + self.connections_count(), + max_open_connections, + self.in_progress_downloads.len(), + max_concurrent_requests + ); assert!( self.connections_count() <= *max_open_connections, "max_open_connections exceeded" ); // check the active requests per peer don't exceed the limit - for (peer, info) in self.nodes.iter() { + for (node, info) in self.nodes.iter() { assert!( - info.active_requests() <= *max_concurrent_requests_per_peer, - "max_concurrent_requests_per_peer exceeded for {peer}" + info.active_requests() <= *max_concurrent_requests_per_node, + "max_concurrent_requests_per_node exceeded for {node}" + ) + } + + // check that we do not dial more nodes than allowed for the next pending hashes + if let Some(kind) = self.queue.front() { + let hash = kind.hash(); + let nodes = self.providers.get_candidates(&hash); + let mut dialing = 0; + for node in nodes { + if self.dialer.is_pending(node) { + dialing += 1; + } + } + assert!( + dialing <= *max_concurrent_dials_per_hash, + "max_concurrent_dials_per_hash exceeded for {hash}" ) } } @@ -54,13 +78,13 @@ impl, D: Dialer> Service { // number of requests assert_eq!( self.in_progress_downloads.len(), - self.current_requests.len(), - "current_requests and in_progress_downloads are out of sync" + self.active_requests.len(), + "active_requests and in_progress_downloads are out of sync" ); // check that the count of requests per peer matches the number of requests that have that // peer as active let mut real_count: HashMap = HashMap::with_capacity(self.nodes.len()); - for req_info in self.current_requests.values() { + for req_info in self.active_requests.values() { // nothing like some classic word count *real_count.entry(req_info.node).or_default() += 1; } @@ -73,14 +97,22 @@ impl, D: Dialer> Service { } } - /// Checks that the scheduled requests match the queue that handles their delays. + /// Checks that the queued requests all appear in the provider map and request map. #[track_caller] - fn check_scheduled_requests_consistency(&self) { - assert_eq!( - self.scheduled_requests.len(), - self.scheduled_request_queue.len(), - "scheduled_request_queue and scheduled_requests are out of sync" - ); + fn check_queued_requests_consistency(&self) { + for entry in &self.queue { + assert!( + self.providers + .get_candidates(&entry.hash()) + .next() + .is_some(), + "all queued requests have providers" + ); + assert!( + self.requests.get(entry).is_some(), + "all queued requests have request info" + ); + } } /// Check that peers queued to be disconnected are consistent with peers considered idle. @@ -101,11 +133,16 @@ impl, D: Dialer> Service { /// Check that every hash in the provider map is needed. #[track_caller] fn check_provider_map_prunning(&self) { - for hash in self.providers.candidates.keys() { + for hash in self.providers.hash_node.keys() { + let as_raw = DownloadKind(HashAndFormat::raw(*hash)); + let as_hash_seq = DownloadKind(HashAndFormat::hash_seq(*hash)); assert!( - self.is_needed(*hash), - "provider map contains {hash:?} which should have been prunned" - ); + self.queue.contains(&as_raw) + || self.queue.contains(&as_hash_seq) + || self.active_requests.contains_key(&as_raw) + || self.active_requests.contains_key(&as_hash_seq), + "all hashes in the provider map are in the queue or active" + ) } } } diff --git a/iroh-bytes/src/downloader/progress.rs b/iroh-bytes/src/downloader/progress.rs new file mode 100644 index 00000000000..47ec74154ca --- /dev/null +++ b/iroh-bytes/src/downloader/progress.rs @@ -0,0 +1,195 @@ +use std::{ + collections::HashMap, + sync::{ + atomic::{AtomicU64, Ordering}, + Arc, + }, +}; + +use anyhow::anyhow; +use parking_lot::Mutex; + +use crate::{ + get::{db::DownloadProgress, progress::TransferState}, + util::progress::{FlumeProgressSender, IdGenerator, ProgressSendError, ProgressSender}, +}; + +use super::DownloadKind; + +/// The channel that can be used to subscribe to progress updates. +pub type ProgressSubscriber = FlumeProgressSender; + +/// Track the progress of downloads. +/// +/// This struct allows to create [`ProgressSender`] structs to be passed to +/// [`crate::get::db::get_to_db`]. Each progress sender can be subscribed to by any number of +/// [`ProgressSubscriber`] channel senders, which will receive each progress update (if they have +/// capacity). Additionally, the [`ProgressTracker`] maintains a [`TransferState`] for each +/// transfer, applying each progress update to update this state. When subscribing to an already +/// running transfer, the subscriber will receive a [`DownloadProgress::InitialState`] message +/// containing the state at the time of the subscription, and then receive all further progress +/// events directly. +#[derive(Debug, Default)] +pub struct ProgressTracker { + /// Map of shared state for each tracked download. + running: HashMap, + /// Shared [`IdGenerator`] for all progress senders created by the tracker. + id_gen: Arc, +} + +impl ProgressTracker { + pub fn new() -> Self { + Self::default() + } + + /// Track a new download with a list of initial subscribers. + /// + /// Note that this should only be called for *new* downloads. If a download for the `kind` is + /// already tracked in this [`ProgressTracker`], calling `track` will replace all existing + /// state and subscribers (equal to calling [`Self::remove`] first). + pub fn track( + &mut self, + kind: DownloadKind, + subscribers: impl IntoIterator, + ) -> BroadcastProgressSender { + let inner = Inner { + subscribers: subscribers.into_iter().collect(), + state: TransferState::new(kind.hash()), + }; + let shared = Arc::new(Mutex::new(inner)); + self.running.insert(kind, Arc::clone(&shared)); + let id_gen = Arc::clone(&self.id_gen); + BroadcastProgressSender { shared, id_gen } + } + + /// Subscribe to a tracked download. + /// + /// Will return an error if `kind` is not yet tracked. + pub async fn subscribe( + &mut self, + kind: DownloadKind, + sender: ProgressSubscriber, + ) -> anyhow::Result<()> { + let initial_msg = self + .running + .get_mut(&kind) + .ok_or_else(|| anyhow!("state for download {kind:?} not found"))? + .lock() + .subscribe(sender.clone()); + sender.send(initial_msg).await?; + Ok(()) + } + + /// Unsubscribe `sender` from `kind`. + pub fn unsubscribe(&mut self, kind: &DownloadKind, sender: &ProgressSubscriber) { + if let Some(shared) = self.running.get_mut(kind) { + shared.lock().unsubscribe(sender) + } + } + + /// Remove all state for a download. + pub fn remove(&mut self, kind: &DownloadKind) { + self.running.remove(kind); + } +} + +type Shared = Arc>; + +#[derive(Debug)] +struct Inner { + subscribers: Vec, + state: TransferState, +} + +impl Inner { + fn subscribe(&mut self, subscriber: ProgressSubscriber) -> DownloadProgress { + let msg = DownloadProgress::InitialState(self.state.clone()); + self.subscribers.push(subscriber); + msg + } + + fn unsubscribe(&mut self, sender: &ProgressSubscriber) { + self.subscribers.retain(|s| !s.same_channel(sender)); + } + + fn on_progress(&mut self, progress: DownloadProgress) { + self.state.on_progress(progress); + } +} + +#[derive(Debug, Clone)] +pub struct BroadcastProgressSender { + shared: Shared, + id_gen: Arc, +} + +impl IdGenerator for BroadcastProgressSender { + fn new_id(&self) -> u64 { + self.id_gen.fetch_add(1, Ordering::SeqCst) + } +} + +impl ProgressSender for BroadcastProgressSender { + type Msg = DownloadProgress; + + async fn send(&self, msg: Self::Msg) -> Result<(), ProgressSendError> { + // making sure that the lock is not held across an await point. + let futs = { + let mut inner = self.shared.lock(); + inner.on_progress(msg.clone()); + let futs = inner + .subscribers + .iter_mut() + .map(|sender| { + let sender = sender.clone(); + let msg = msg.clone(); + async move { + match sender.send(msg).await { + Ok(()) => None, + Err(ProgressSendError::ReceiverDropped) => Some(sender), + } + } + }) + .collect::>(); + drop(inner); + futs + }; + + let failed_senders = futures::future::join_all(futs).await; + // remove senders where the receiver is dropped + if failed_senders.iter().any(|s| s.is_some()) { + let mut inner = self.shared.lock(); + for sender in failed_senders.into_iter().flatten() { + inner.unsubscribe(&sender); + } + drop(inner); + } + Ok(()) + } + + fn try_send(&self, msg: Self::Msg) -> Result<(), ProgressSendError> { + let mut inner = self.shared.lock(); + inner.on_progress(msg.clone()); + // remove senders where the receiver is dropped + inner + .subscribers + .retain_mut(|sender| match sender.try_send(msg.clone()) { + Err(ProgressSendError::ReceiverDropped) => false, + Ok(()) => true, + }); + Ok(()) + } + + fn blocking_send(&self, msg: Self::Msg) -> Result<(), ProgressSendError> { + let mut inner = self.shared.lock(); + inner.on_progress(msg.clone()); + // remove senders where the receiver is dropped + inner + .subscribers + .retain_mut(|sender| match sender.blocking_send(msg.clone()) { + Err(ProgressSendError::ReceiverDropped) => false, + Ok(()) => true, + }); + Ok(()) + } +} diff --git a/iroh-bytes/src/downloader/test.rs b/iroh-bytes/src/downloader/test.rs index 6f7029ad2fa..fcdc91140d2 100644 --- a/iroh-bytes/src/downloader/test.rs +++ b/iroh-bytes/src/downloader/test.rs @@ -1,8 +1,14 @@ #![cfg(test)] +use futures::FutureExt; use std::time::Duration; use iroh_net::key::SecretKey; +use crate::{ + get::{db::BlobId, progress::TransferState}, + util::progress::{FlumeProgressSender, IdGenerator, ProgressSender}, +}; + use super::*; mod dialer; @@ -15,12 +21,13 @@ impl Downloader { concurrency_limits: ConcurrencyLimits, ) -> Self { let (msg_tx, msg_rx) = mpsc::channel(super::SERVICE_CHANNEL_CAPACITY); + let db = crate::store::mem::Store::default(); LocalPoolHandle::new(1).spawn_pinned(move || async move { // we want to see the logs of the service let _guard = iroh_test::logging::setup(); - let service = Service::new(getter, dialer, concurrency_limits, msg_rx); + let service = Service::new(db, getter, dialer, concurrency_limits, msg_rx); service.run().await }); @@ -34,21 +41,18 @@ impl Downloader { /// Tests that receiving a download request and performing it doesn't explode. #[tokio::test] async fn smoke_test() { + let _guard = iroh_test::logging::setup(); let dialer = dialer::TestingDialer::default(); let getter = getter::TestingGetter::default(); let concurrency_limits = ConcurrencyLimits::default(); - let mut downloader = - Downloader::spawn_for_test(dialer.clone(), getter.clone(), concurrency_limits); + let downloader = Downloader::spawn_for_test(dialer.clone(), getter.clone(), concurrency_limits); // send a request and make sure the peer is requested the corresponding download let peer = SecretKey::generate().public(); - let kind = DownloadKind::Blob { - hash: Hash::new([0u8; 32]), - }; - let handle = downloader - .queue(kind.clone(), vec![(peer, Role::Candidate).into()]) - .await; + let kind: DownloadKind = HashAndFormat::raw(Hash::new([0u8; 32])).into(); + let req = DownloadRequest::new(kind, vec![peer]); + let handle = downloader.queue(req).await; // wait for the download result to be reported handle.await.expect("should report success"); // verify that the peer was dialed @@ -60,24 +64,21 @@ async fn smoke_test() { /// Tests that multiple intents produce a single request. #[tokio::test] async fn deduplication() { + let _guard = iroh_test::logging::setup(); let dialer = dialer::TestingDialer::default(); let getter = getter::TestingGetter::default(); // make request take some time to ensure the intents are received before completion getter.set_request_duration(Duration::from_secs(1)); let concurrency_limits = ConcurrencyLimits::default(); - let mut downloader = - Downloader::spawn_for_test(dialer.clone(), getter.clone(), concurrency_limits); + let downloader = Downloader::spawn_for_test(dialer.clone(), getter.clone(), concurrency_limits); let peer = SecretKey::generate().public(); - let kind = DownloadKind::Blob { - hash: Hash::new([0u8; 32]), - }; + let kind: DownloadKind = HashAndFormat::raw(Hash::new([0u8; 32])).into(); let mut handles = Vec::with_capacity(10); for _ in 0..10 { - let h = downloader - .queue(kind.clone(), vec![(peer, Role::Candidate).into()]) - .await; + let req = DownloadRequest::new(kind, vec![peer]); + let h = downloader.queue(req).await; handles.push(h); } assert!( @@ -94,37 +95,27 @@ async fn deduplication() { /// Tests that the request is cancelled only when all intents are cancelled. #[tokio::test] async fn cancellation() { + let _guard = iroh_test::logging::setup(); let dialer = dialer::TestingDialer::default(); let getter = getter::TestingGetter::default(); // make request take some time to ensure cancellations are received on time getter.set_request_duration(Duration::from_millis(500)); let concurrency_limits = ConcurrencyLimits::default(); - let mut downloader = - Downloader::spawn_for_test(dialer.clone(), getter.clone(), concurrency_limits); + let downloader = Downloader::spawn_for_test(dialer.clone(), getter.clone(), concurrency_limits); let peer = SecretKey::generate().public(); - let kind_1 = DownloadKind::Blob { - hash: Hash::new([0u8; 32]), - }; - let handle_a = downloader - .queue(kind_1.clone(), vec![(peer, Role::Candidate).into()]) - .await; - let handle_b = downloader - .queue(kind_1.clone(), vec![(peer, Role::Candidate).into()]) - .await; + let kind_1: DownloadKind = HashAndFormat::raw(Hash::new([0u8; 32])).into(); + let req = DownloadRequest::new(kind_1, vec![peer]); + let handle_a = downloader.queue(req.clone()).await; + let handle_b = downloader.queue(req).await; downloader.cancel(handle_a).await; // create a request with two intents and cancel them both - let kind_2 = DownloadKind::Blob { - hash: Hash::new([1u8; 32]), - }; - let handle_c = downloader - .queue(kind_2.clone(), vec![(peer, Role::Candidate).into()]) - .await; - let handle_d = downloader - .queue(kind_2.clone(), vec![(peer, Role::Candidate).into()]) - .await; + let kind_2 = HashAndFormat::raw(Hash::new([1u8; 32])); + let req = DownloadRequest::new(kind_2, vec![peer]); + let handle_c = downloader.queue(req.clone()).await; + let handle_d = downloader.queue(req).await; downloader.cancel(handle_c).await; downloader.cancel(handle_d).await; @@ -138,7 +129,8 @@ async fn cancellation() { /// maximum number of concurrent requests is not exceed. /// NOTE: This is internally tested by [`Service::check_invariants`]. #[tokio::test] -async fn max_concurrent_requests() { +async fn max_concurrent_requests_total() { + let _guard = iroh_test::logging::setup(); let dialer = dialer::TestingDialer::default(); let getter = getter::TestingGetter::default(); // make request take some time to ensure concurreny limits are hit @@ -149,20 +141,16 @@ async fn max_concurrent_requests() { ..Default::default() }; - let mut downloader = - Downloader::spawn_for_test(dialer.clone(), getter.clone(), concurrency_limits); + let downloader = Downloader::spawn_for_test(dialer.clone(), getter.clone(), concurrency_limits); // send the downloads let peer = SecretKey::generate().public(); let mut handles = Vec::with_capacity(5); let mut expected_history = Vec::with_capacity(5); for i in 0..5 { - let kind = DownloadKind::Blob { - hash: Hash::new([i; 32]), - }; - let h = downloader - .queue(kind.clone(), vec![(peer, Role::Candidate).into()]) - .await; + let kind: DownloadKind = HashAndFormat::raw(Hash::new([i; 32])).into(); + let req = DownloadRequest::new(kind, vec![peer]); + let h = downloader.queue(req).await; expected_history.push((kind, peer)); handles.push(h); } @@ -184,6 +172,7 @@ async fn max_concurrent_requests() { /// NOTE: This is internally tested by [`Service::check_invariants`]. #[tokio::test] async fn max_concurrent_requests_per_peer() { + let _guard = iroh_test::logging::setup(); let dialer = dialer::TestingDialer::default(); let getter = getter::TestingGetter::default(); // make request take some time to ensure concurreny limits are hit @@ -195,60 +184,164 @@ async fn max_concurrent_requests_per_peer() { ..Default::default() }; - let mut downloader = - Downloader::spawn_for_test(dialer.clone(), getter.clone(), concurrency_limits); + let downloader = Downloader::spawn_for_test(dialer.clone(), getter.clone(), concurrency_limits); // send the downloads let peer = SecretKey::generate().public(); let mut handles = Vec::with_capacity(5); for i in 0..5 { - let kind = DownloadKind::Blob { - hash: Hash::new([i; 32]), - }; - let h = downloader - .queue(kind.clone(), vec![(peer, Role::Candidate).into()]) - .await; + let kind = HashAndFormat::raw(Hash::new([i; 32])); + let req = DownloadRequest::new(kind, vec![peer]); + let h = downloader.queue(req).await; handles.push(h); } futures::future::join_all(handles).await; } -/// Tests that providers are preferred over candidates. +/// Tests concurrent progress reporting for multiple intents. +/// +/// This first registers two intents for a download, and then proceeds until the `Found` event is +/// emitted, and verifies that both intents received the event. +/// It then registers a third intent mid-download, and makes sure it receives a correct ìnitial +/// state. The download then finishes, and we make sure that all events are emitted properly, and +/// the progress state of the handles converges. #[tokio::test] -async fn peer_role_provider() { +async fn concurrent_progress() { + let _guard = iroh_test::logging::setup(); let dialer = dialer::TestingDialer::default(); - dialer.set_dial_duration(Duration::from_millis(100)); let getter = getter::TestingGetter::default(); - let concurrency_limits = ConcurrencyLimits::default(); - let mut downloader = - Downloader::spawn_for_test(dialer.clone(), getter.clone(), concurrency_limits); + let (start_tx, start_rx) = oneshot::channel(); + let start_rx = start_rx.shared(); + + let (done_tx, done_rx) = oneshot::channel(); + let done_rx = done_rx.shared(); + + getter.set_handler(Arc::new(move |hash, _peer, progress, _duration| { + let start_rx = start_rx.clone(); + let done_rx = done_rx.clone(); + async move { + let hash = hash.hash(); + start_rx.await.unwrap(); + let id = progress.new_id(); + progress + .send(DownloadProgress::Found { + id, + child: BlobId::Root, + hash, + size: 100, + }) + .await + .unwrap(); + done_rx.await.unwrap(); + progress.send(DownloadProgress::Done { id }).await.unwrap(); + Ok(Stats::default()) + } + .boxed() + })); + let downloader = Downloader::spawn_for_test(dialer.clone(), getter.clone(), Default::default()); + + let peer = SecretKey::generate().public(); + let hash = Hash::new([0u8; 32]); + let kind_1 = HashAndFormat::raw(hash); + + let (prog_a_tx, prog_a_rx) = flume::bounded(64); + let prog_a_tx = FlumeProgressSender::new(prog_a_tx); + let req = DownloadRequest::new(kind_1, vec![peer]).progress_sender(prog_a_tx); + let handle_a = downloader.queue(req).await; + + let (prog_b_tx, prog_b_rx) = flume::bounded(64); + let prog_b_tx = FlumeProgressSender::new(prog_b_tx); + let req = DownloadRequest::new(kind_1, vec![peer]).progress_sender(prog_b_tx); + let handle_b = downloader.queue(req).await; + + start_tx.send(()).unwrap(); + + let mut state_a = TransferState::new(hash); + let mut state_b = TransferState::new(hash); + let mut state_c = TransferState::new(hash); + + let prog1_a = prog_a_rx.recv_async().await.unwrap(); + let prog1_b = prog_b_rx.recv_async().await.unwrap(); + assert!(matches!(prog1_a, DownloadProgress::Found { hash, size: 100, ..} if hash == hash)); + assert!(matches!(prog1_b, DownloadProgress::Found { hash, size: 100, ..} if hash == hash)); + + state_a.on_progress(prog1_a); + state_b.on_progress(prog1_b); + assert_eq!(state_a, state_b); + + let (prog_c_tx, prog_c_rx) = flume::bounded(64); + let prog_c_tx = FlumeProgressSender::new(prog_c_tx); + let req = DownloadRequest::new(kind_1, vec![peer]).progress_sender(prog_c_tx); + let handle_c = downloader.queue(req).await; + + let prog1_c = prog_c_rx.recv_async().await.unwrap(); + assert!(matches!(&prog1_c, DownloadProgress::InitialState(state) if state == &state_a)); + state_c.on_progress(prog1_c); + + done_tx.send(()).unwrap(); + + let (res_a, res_b, res_c) = futures::future::join3(handle_a, handle_b, handle_c).await; + res_a.unwrap(); + res_b.unwrap(); + res_c.unwrap(); - let peer_candidate1 = SecretKey::from_bytes(&[0u8; 32]).public(); - let peer_candidate2 = SecretKey::from_bytes(&[1u8; 32]).public(); - let peer_provider = SecretKey::from_bytes(&[2u8; 32]).public(); - let kind = DownloadKind::Blob { - hash: Hash::new([0u8; 32]), + let prog_a: Vec<_> = prog_a_rx.into_stream().collect().await; + let prog_b: Vec<_> = prog_b_rx.into_stream().collect().await; + let prog_c: Vec<_> = prog_c_rx.into_stream().collect().await; + + assert_eq!(prog_a.len(), 1); + assert_eq!(prog_b.len(), 1); + assert_eq!(prog_c.len(), 1); + + assert!(matches!(prog_a[0], DownloadProgress::Done { .. })); + assert!(matches!(prog_b[0], DownloadProgress::Done { .. })); + assert!(matches!(prog_c[0], DownloadProgress::Done { .. })); + + for p in prog_a { + state_a.on_progress(p); + } + for p in prog_b { + state_b.on_progress(p); + } + for p in prog_c { + state_c.on_progress(p); + } + assert_eq!(state_a, state_b); + assert_eq!(state_a, state_c); +} + +#[tokio::test] +async fn long_queue() { + let _guard = iroh_test::logging::setup(); + let dialer = dialer::TestingDialer::default(); + let getter = getter::TestingGetter::default(); + let concurrency_limits = ConcurrencyLimits { + max_open_connections: 2, + max_concurrent_requests_per_node: 2, + max_concurrent_requests: 4, // all requests can be performed at the same time + ..Default::default() }; - let handle = downloader - .queue( - kind.clone(), - vec![ - (peer_candidate1, Role::Candidate).into(), - (peer_provider, Role::Provider).into(), - (peer_candidate2, Role::Candidate).into(), - ], - ) - .await; - let now = std::time::Instant::now(); - assert!(handle.await.is_ok(), "download succeeded"); - // this is, I think, currently the best way to test that no delay was performed. It should be - // safe enough to assume that test runtime is not longer than the delay of 500ms. - assert!( - now.elapsed() < INITIAL_REQUEST_DELAY, - "no initial delay was added to fetching from a provider" - ); - getter.assert_history(&[(kind, peer_provider)]); - dialer.assert_history(&[peer_provider]); + + let downloader = Downloader::spawn_for_test(dialer.clone(), getter.clone(), concurrency_limits); + // send the downloads + let nodes = [ + SecretKey::generate().public(), + SecretKey::generate().public(), + SecretKey::generate().public(), + ]; + let mut handles = vec![]; + for i in 0..100usize { + let kind = HashAndFormat::raw(Hash::new(i.to_be_bytes())); + let peer = nodes[i % 3]; + let req = DownloadRequest::new(kind, vec![peer]); + let h = downloader.queue(req).await; + handles.push(h); + } + + let res = futures::future::join_all(handles).await; + for res in res { + res.expect("all downloads to succeed"); + } } diff --git a/iroh-bytes/src/downloader/test/dialer.rs b/iroh-bytes/src/downloader/test/dialer.rs index a112464b5df..05a1f036830 100644 --- a/iroh-bytes/src/downloader/test/dialer.rs +++ b/iroh-bytes/src/downloader/test/dialer.rs @@ -32,7 +32,7 @@ impl Default for TestingDialerInner { dialing: HashSet::default(), dial_futs: delay_queue::DelayQueue::default(), dial_history: Vec::default(), - dial_duration: Duration::ZERO, + dial_duration: Duration::from_millis(10), dial_outcome: Box::new(|_| true), } } @@ -72,6 +72,7 @@ impl futures::Stream for TestingDialer { let result = report_ok .then_some(node) .ok_or_else(|| anyhow::anyhow!("dialing test set to fail")); + inner.dialing.remove(&node); Poll::Ready(Some((node, result))) } _ => Poll::Pending, @@ -84,9 +85,4 @@ impl TestingDialer { pub(super) fn assert_history(&self, history: &[NodeId]) { assert_eq!(self.0.read().dial_history, history) } - - pub(super) fn set_dial_duration(&self, duration: Duration) { - let mut inner = self.0.write(); - inner.dial_duration = duration; - } } diff --git a/iroh-bytes/src/downloader/test/getter.rs b/iroh-bytes/src/downloader/test/getter.rs index b8a7f44a356..1581d84af61 100644 --- a/iroh-bytes/src/downloader/test/getter.rs +++ b/iroh-bytes/src/downloader/test/getter.rs @@ -1,5 +1,8 @@ //! Implementation of [`super::Getter`] used for testing. +use std::{sync::Arc, time::Duration}; + +use futures::future::BoxFuture; use parking_lot::RwLock; use super::*; @@ -7,12 +10,26 @@ use super::*; #[derive(Default, Clone)] pub(super) struct TestingGetter(Arc>); +pub(super) type RequestHandlerFn = Arc< + dyn Fn( + DownloadKind, + NodeId, + BroadcastProgressSender, + Duration, + ) -> BoxFuture<'static, InternalDownloadResult> + + Send + + Sync + + 'static, +>; + #[derive(Default)] struct TestingGetterInner { /// How long requests take. request_duration: Duration, /// History of requests performed by the [`Getter`] and if they were successful. request_history: Vec<(DownloadKind, NodeId)>, + /// Set a handler function which actually handles the requests. + request_handler: Option, } impl Getter for TestingGetter { @@ -20,19 +37,32 @@ impl Getter for TestingGetter { // request being sent to type Connection = NodeId; - fn get(&mut self, kind: DownloadKind, peer: NodeId) -> GetFut { + fn get( + &mut self, + kind: DownloadKind, + peer: NodeId, + progress_sender: BroadcastProgressSender, + ) -> GetFut { let mut inner = self.0.write(); inner.request_history.push((kind, peer)); let request_duration = inner.request_duration; + let handler = inner.request_handler.clone(); async move { - tokio::time::sleep(request_duration).await; - Ok(Stats::default()) + if let Some(f) = handler { + f(kind, peer, progress_sender, request_duration).await + } else { + tokio::time::sleep(request_duration).await; + Ok(Stats::default()) + } } .boxed_local() } } impl TestingGetter { + pub(super) fn set_handler(&self, handler: RequestHandlerFn) { + self.0.write().request_handler = Some(handler); + } pub(super) fn set_request_duration(&self, request_duration: Duration) { self.0.write().request_duration = request_duration; } diff --git a/iroh-bytes/src/get.rs b/iroh-bytes/src/get.rs index 47cd0179324..9975007d39b 100644 --- a/iroh-bytes/src/get.rs +++ b/iroh-bytes/src/get.rs @@ -30,6 +30,7 @@ use crate::IROH_BLOCK_SIZE; pub mod db; pub mod error; +pub mod progress; pub mod request; /// Stats about the transfer. diff --git a/iroh-bytes/src/get/db.rs b/iroh-bytes/src/get/db.rs index f8dbd39ca1a..b52aa833113 100644 --- a/iroh-bytes/src/get/db.rs +++ b/iroh-bytes/src/get/db.rs @@ -9,6 +9,7 @@ use crate::protocol::RangeSpec; use crate::store::BaoBlobSize; use crate::store::FallibleProgressBatchWriter; use std::io; +use std::num::NonZeroU64; use crate::hashseq::parse_hash_seq; use crate::store::BaoBatchWriter; @@ -18,6 +19,7 @@ use crate::{ self, error::GetError, fsm::{AtBlobHeader, AtEndBlob, ConnectedNext, EndBlobNext}, + progress::TransferState, Stats, }, protocol::{GetRequest, RangeSpecSeq}, @@ -74,7 +76,7 @@ async fn get_blob< tracing::info!("already got entire blob"); progress .send(DownloadProgress::FoundLocal { - child: 0, + child: BlobId::Root, hash: *hash, size: entry.size(), valid_ranges: RangeSpec::all(), @@ -90,7 +92,7 @@ async fn get_blob< .unwrap_or_else(ChunkRanges::all); progress .send(DownloadProgress::FoundLocal { - child: 0, + child: BlobId::Root, hash: *hash, size: entry.size(), valid_ranges: RangeSpec::new(&valid_ranges), @@ -186,7 +188,7 @@ async fn get_blob_inner( id, hash, size, - child: child_offset, + child: BlobId::from_offset(child_offset), }) .await?; let sender2 = sender.clone(); @@ -237,7 +239,7 @@ async fn get_blob_inner_partial( id, hash, size, - child: child_offset, + child: BlobId::from_offset(child_offset), }) .await?; let sender2 = sender.clone(); @@ -316,7 +318,7 @@ async fn get_hash_seq< // send info that we have the hashseq itself entirely sender .send(DownloadProgress::FoundLocal { - child: 0, + child: BlobId::Root, hash: *root_hash, size: entry.size(), valid_ranges: RangeSpec::all(), @@ -343,7 +345,7 @@ async fn get_hash_seq< if let Some(size) = info.size() { sender .send(DownloadProgress::FoundLocal { - child: (i as u64) + 1, + child: BlobId::from_offset((i as u64) + 1), hash: children[i], size, valid_ranges: RangeSpec::new(&info.valid_ranges()), @@ -521,12 +523,15 @@ impl BlobInfo { } /// Progress updates for the get operation. +// TODO: Move to super::progress #[derive(Debug, Clone, Serialize, Deserialize)] pub enum DownloadProgress { + /// Initial state if subscribing to a running or queued transfer. + InitialState(TransferState), /// Data was found locally. FoundLocal { /// child offset - child: u64, + child: BlobId, /// The hash of the entry. hash: Hash, /// The size of the entry in bytes. @@ -538,10 +543,13 @@ pub enum DownloadProgress { Connected, /// An item was found with hash `hash`, from now on referred to via `id`. Found { - /// A new unique id for this entry. + /// A new unique progress id for this entry. id: u64, - /// child offset - child: u64, + /// Identifier for this blob within this download. + /// + /// Will always be [`BlobId::Root`] unless a hashseq is downloaded, in which case this + /// allows to identify the children by their offset in the hashseq. + child: BlobId, /// The hash of the entry. hash: Hash, /// The size of the entry in bytes. @@ -575,3 +583,29 @@ pub enum DownloadProgress { /// This will be the last message in the stream. Abort(RpcError), } + +/// The id of a blob in a transfer +#[derive( + Debug, Copy, Clone, Ord, PartialOrd, Eq, PartialEq, std::hash::Hash, Serialize, Deserialize, +)] +pub enum BlobId { + /// The root blob (child id 0) + Root, + /// A child blob (child id > 0) + Child(NonZeroU64), +} + +impl BlobId { + fn from_offset(id: u64) -> Self { + NonZeroU64::new(id).map(Self::Child).unwrap_or(Self::Root) + } +} + +impl From for u64 { + fn from(value: BlobId) -> Self { + match value { + BlobId::Root => 0, + BlobId::Child(id) => id.into(), + } + } +} diff --git a/iroh-bytes/src/get/progress.rs b/iroh-bytes/src/get/progress.rs new file mode 100644 index 00000000000..a865a40c938 --- /dev/null +++ b/iroh-bytes/src/get/progress.rs @@ -0,0 +1,182 @@ +//! Types for get progress state management. + +use std::{collections::HashMap, num::NonZeroU64}; + +use serde::{Deserialize, Serialize}; +use tracing::warn; + +use crate::{protocol::RangeSpec, store::BaoBlobSize, Hash}; + +use super::db::{BlobId, DownloadProgress}; + +/// The identifier for progress events. +pub type ProgressId = u64; + +/// Accumulated progress state of a transfer. +#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)] +pub struct TransferState { + /// The root blob of this transfer (may be a hash seq), + pub root: BlobState, + /// Whether we are connected to a node + pub connected: bool, + /// Children if the root blob is a hash seq, empty for raw blobs + pub children: HashMap, + /// Child being transferred at the moment. + pub current: Option, + /// Progress ids for individual blobs. + pub progress_id_to_blob: HashMap, +} + +impl TransferState { + /// Create a new, empty transfer state. + pub fn new(root_hash: Hash) -> Self { + Self { + root: BlobState::new(root_hash), + connected: false, + children: Default::default(), + current: None, + progress_id_to_blob: Default::default(), + } + } +} + +/// State of a single blob in transfer +#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)] +pub struct BlobState { + /// The hash of this blob. + pub hash: Hash, + /// The size of this blob. Only known if the blob is partially present locally, or after having + /// received the size from the remote. + pub size: Option, + /// The current state of the blob transfer. + pub progress: BlobProgress, + /// Ranges already available locally at the time of starting the transfer. + pub local_ranges: Option, + /// Number of children (only applies to hashseqs, None for raw blobs). + pub child_count: Option, +} + +/// Progress state for a single blob +#[derive(Debug, Default, Clone, Serialize, Deserialize, Eq, PartialEq)] +pub enum BlobProgress { + /// Download is pending + #[default] + Pending, + /// Download is in progress + Progressing(u64), + /// Download has finished + Done, +} + +impl BlobState { + /// Create a new [`BlobState`]. + pub fn new(hash: Hash) -> Self { + Self { + hash, + size: None, + local_ranges: None, + child_count: None, + progress: BlobProgress::default(), + } + } +} + +impl TransferState { + /// Get state of the root blob of this transfer. + pub fn root(&self) -> &BlobState { + &self.root + } + + /// Get a blob state by its [`BlobId`] in this transfer. + pub fn get_blob(&self, blob_id: &BlobId) -> Option<&BlobState> { + match blob_id { + BlobId::Root => Some(&self.root), + BlobId::Child(id) => self.children.get(id), + } + } + + /// Get the blob state currently being transferred. + pub fn get_current(&self) -> Option<&BlobState> { + self.current.as_ref().and_then(|id| self.get_blob(id)) + } + + fn get_or_insert_blob(&mut self, blob_id: BlobId, hash: Hash) -> &mut BlobState { + match blob_id { + BlobId::Root => &mut self.root, + BlobId::Child(id) => self + .children + .entry(id) + .or_insert_with(|| BlobState::new(hash)), + } + } + fn get_blob_mut(&mut self, blob_id: &BlobId) -> Option<&mut BlobState> { + match blob_id { + BlobId::Root => Some(&mut self.root), + BlobId::Child(id) => self.children.get_mut(id), + } + } + + fn get_by_progress_id(&mut self, progress_id: ProgressId) -> Option<&mut BlobState> { + let blob_id = *self.progress_id_to_blob.get(&progress_id)?; + self.get_blob_mut(&blob_id) + } + + /// Update the state with a new [`DownloadProgress`] event for this transfer. + pub fn on_progress(&mut self, event: DownloadProgress) { + match event { + DownloadProgress::InitialState(s) => { + *self = s; + } + DownloadProgress::FoundLocal { + child, + hash, + size, + valid_ranges, + } => { + let blob = self.get_or_insert_blob(child, hash); + blob.size = Some(size); + blob.local_ranges = Some(valid_ranges); + } + DownloadProgress::Connected => self.connected = true, + DownloadProgress::Found { + id: progress_id, + child: blob_id, + hash, + size, + } => { + let blob = self.get_or_insert_blob(blob_id, hash); + if blob.size.is_none() { + blob.size = Some(BaoBlobSize::Verified(size)); + } + blob.progress = BlobProgress::Progressing(0); + self.progress_id_to_blob.insert(progress_id, blob_id); + self.current = Some(blob_id); + } + DownloadProgress::FoundHashSeq { hash, children } => { + if hash == self.root.hash { + self.root.child_count = Some(children); + } else { + // I think it is an invariant of the protocol that `FoundHashSeq` is only + // triggered for the root hash. + warn!("Received `FoundHashSeq` event for a hash which is not the download's root hash.") + } + } + DownloadProgress::Progress { id, offset } => { + if let Some(blob) = self.get_by_progress_id(id) { + blob.progress = BlobProgress::Progressing(offset); + } else { + warn!(%id, "Received `Progress` event for unknown progress id.") + } + } + DownloadProgress::Done { id } => { + if let Some(blob) = self.get_by_progress_id(id) { + blob.progress = BlobProgress::Done; + self.progress_id_to_blob.remove(&id); + } else { + warn!(%id, "Received `Done` event for unknown progress id.") + } + } + _ => {} + } + } +} diff --git a/iroh-bytes/src/store/fs/tables.rs b/iroh-bytes/src/store/fs/tables.rs index d458f2918b3..2eaf22c52a1 100644 --- a/iroh-bytes/src/store/fs/tables.rs +++ b/iroh-bytes/src/store/fs/tables.rs @@ -149,12 +149,15 @@ impl DeleteSet { BaoFilePart::Sizes => options.owned_sizes_path(hash), }; if let Err(cause) = std::fs::remove_file(&path) { - tracing::warn!( - "failed to delete {:?} {}: {}", - to_delete, - path.display(), - cause - ); + // Ignore NotFound errors, if the file is already gone that's fine. + if cause.kind() != std::io::ErrorKind::NotFound { + tracing::warn!( + "failed to delete {:?} {}: {}", + to_delete, + path.display(), + cause + ); + } } } self.0.clear(); diff --git a/iroh-bytes/src/store/traits.rs b/iroh-bytes/src/store/traits.rs index 6c5dd5aaca0..f8a88b07843 100644 --- a/iroh-bytes/src/store/traits.rs +++ b/iroh-bytes/src/store/traits.rs @@ -45,7 +45,7 @@ pub enum EntryStatus { } /// The size of a bao file -#[derive(Debug, Clone, Copy, Serialize, Deserialize)] +#[derive(Debug, Clone, Copy, Serialize, Deserialize, Eq, PartialEq)] pub enum BaoBlobSize { /// A remote side told us the size, but we have insufficient data to verify it. Unverified(u64), diff --git a/iroh-bytes/src/util.rs b/iroh-bytes/src/util.rs index 57812cb58df..b540b885628 100644 --- a/iroh-bytes/src/util.rs +++ b/iroh-bytes/src/util.rs @@ -6,7 +6,7 @@ use range_collections::range_set::RangeSetRange; use serde::{Deserialize, Serialize}; use std::{borrow::Borrow, fmt, sync::Arc, time::SystemTime}; -use crate::{BlobFormat, Hash, HashAndFormat, IROH_BLOCK_SIZE}; +use crate::{store::Store, BlobFormat, Hash, HashAndFormat, IROH_BLOCK_SIZE}; pub mod io; mod mem_or_file; @@ -121,6 +121,64 @@ impl Tag { } } +/// A set of merged [`SetTagOption`]s for a blob. +#[derive(Debug, Default)] +pub struct TagSet { + auto: bool, + named: Vec, +} + +impl TagSet { + /// Insert a new tag into the set. + pub fn insert(&mut self, tag: SetTagOption) { + match tag { + SetTagOption::Auto => self.auto = true, + SetTagOption::Named(tag) => { + if !self.named.iter().any(|t| t == &tag) { + self.named.push(tag) + } + } + } + } + + /// Convert the [`TagSet`] into a list of [`SetTagOption`]. + pub fn into_tags(self) -> impl Iterator { + self.auto + .then_some(SetTagOption::Auto) + .into_iter() + .chain(self.named.into_iter().map(SetTagOption::Named)) + } + + /// Apply the tags in the [`TagSet`] to the database. + pub async fn apply( + self, + db: &D, + hash_and_format: HashAndFormat, + ) -> std::io::Result<()> { + let tags = self.into_tags(); + for tag in tags { + match tag { + SetTagOption::Named(tag) => { + db.set_tag(tag, Some(hash_and_format)).await?; + } + SetTagOption::Auto => { + db.create_tag(hash_and_format).await?; + } + } + } + Ok(()) + } +} + +/// Option for commands that allow setting a tag +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] +pub enum SetTagOption { + /// A tag will be automatically generated + Auto, + /// The tag is explicitly named + Named(Tag), +} + /// A trait for things that can track liveness of blobs and collections. /// /// This trait works together with [TempTag] to keep track of the liveness of a diff --git a/iroh-bytes/src/util/progress.rs b/iroh-bytes/src/util/progress.rs index d80ec3533ed..308f4aab7af 100644 --- a/iroh-bytes/src/util/progress.rs +++ b/iroh-bytes/src/util/progress.rs @@ -471,13 +471,18 @@ impl Clone for FlumeProgressSender { } impl FlumeProgressSender { - /// Create a new progress sender from a tokio mpsc sender. + /// Create a new progress sender from a flume sender. pub fn new(sender: flume::Sender) -> Self { Self { sender, id: std::sync::Arc::new(std::sync::atomic::AtomicU64::new(0)), } } + + /// Returns true if `other` sends on the same `flume` channel as `self`. + pub fn same_channel(&self, other: &FlumeProgressSender) -> bool { + self.sender.same_channel(&other.sender) + } } impl IdGenerator for FlumeProgressSender { diff --git a/iroh-cli/src/commands/blob.rs b/iroh-cli/src/commands/blob.rs index 621d6c6b980..f5576170b83 100644 --- a/iroh-cli/src/commands/blob.rs +++ b/iroh-cli/src/commands/blob.rs @@ -14,7 +14,7 @@ use indicatif::{ ProgressStyle, }; use iroh::bytes::{ - get::{db::DownloadProgress, Stats}, + get::{db::DownloadProgress, progress::BlobProgress, Stats}, provider::AddProgress, store::{ConsistencyCheckProgress, ExportFormat, ExportMode, ReportLevel, ValidateProgress}, BlobFormat, Hash, HashAndFormat, Tag, @@ -24,7 +24,7 @@ use iroh::{ client::{BlobStatus, Iroh, ShareTicketOptions}, rpc_protocol::{ BlobDownloadRequest, BlobListCollectionsResponse, BlobListIncompleteResponse, - BlobListResponse, ProviderService, SetTagOption, WrapOption, + BlobListResponse, DownloadMode, ProviderService, SetTagOption, WrapOption, }, ticket::BlobTicket, }; @@ -81,6 +81,12 @@ pub enum BlobCommands { /// Tag to tag the data with. #[clap(long)] tag: Option, + /// If set, will queue the download in the download queue. + /// + /// Use this if you are doing many downloads in parallel and want to limit the number of + /// downloads running concurrently. + #[clap(long)] + queued: bool, }, /// Export a blob from the internal blob store to the local filesystem. Export { @@ -183,6 +189,7 @@ impl BlobCommands { out, stable, tag, + queued, } => { let (node_addr, hash, format) = match ticket { TicketOrHash::Ticket(ticket) => { @@ -241,13 +248,19 @@ impl BlobCommands { None => SetTagOption::Auto, }; + let mode = match queued { + true => DownloadMode::Queued, + false => DownloadMode::Direct, + }; + let mut stream = iroh .blobs .download(BlobDownloadRequest { hash, format, - peer: node_addr, + nodes: vec![node_addr], tag, + mode, }) .await?; @@ -277,6 +290,7 @@ impl BlobCommands { }; tracing::info!("exporting to {} -> {}", path.display(), absolute.display()); let stream = iroh.blobs.export(hash, absolute, format, mode).await?; + // TODO: report export progress stream.await?; } @@ -1009,6 +1023,36 @@ pub async fn show_download_progress( let mut seq = false; while let Some(x) = stream.next().await { match x? { + DownloadProgress::InitialState(state) => { + if state.connected { + op.set_message(format!("{} Requesting ...\n", style("[2/3]").bold().dim())); + } + if let Some(count) = state.root.child_count { + op.set_message(format!( + "{} Downloading {} blob(s)\n", + style("[3/3]").bold().dim(), + count + 1, + )); + op.set_length(count + 1); + op.reset(); + op.set_position(state.current.map(u64::from).unwrap_or(0)); + seq = true; + } + if let Some(blob) = state.get_current() { + if let Some(size) = blob.size { + ip.set_length(size.value()); + ip.reset(); + match blob.progress { + BlobProgress::Pending => {} + BlobProgress::Progressing(offset) => ip.set_position(offset), + BlobProgress::Done => ip.finish_and_clear(), + } + if !seq { + op.finish_and_clear(); + } + } + } + } DownloadProgress::FoundLocal { .. } => {} DownloadProgress::Connected => { op.set_message(format!("{} Requesting ...\n", style("[2/3]").bold().dim())); @@ -1025,7 +1069,7 @@ pub async fn show_download_progress( } DownloadProgress::Found { size, child, .. } => { if seq { - op.set_position(child); + op.set_position(child.into()); } else { op.finish_and_clear(); } diff --git a/iroh-dns-server/Cargo.toml b/iroh-dns-server/Cargo.toml index ad4bcd67b2f..589bb6370f5 100644 --- a/iroh-dns-server/Cargo.toml +++ b/iroh-dns-server/Cargo.toml @@ -52,3 +52,4 @@ z32 = "1.1.1" [dev-dependencies] hickory-resolver = "0.24.0" iroh-net = { version = "0.14.0", path = "../iroh-net" } +iroh-test = { path = "../iroh-test" } diff --git a/iroh-dns-server/src/dns.rs b/iroh-dns-server/src/dns.rs index 2faca9df6f2..42f9a5f58f3 100644 --- a/iroh-dns-server/src/dns.rs +++ b/iroh-dns-server/src/dns.rs @@ -33,6 +33,7 @@ use tokio::{ net::{TcpListener, UdpSocket}, sync::broadcast, }; +use tracing::{debug, info}; use crate::{metrics::Metrics, store::ZoneStore}; @@ -90,7 +91,7 @@ impl DnsServer { server.register_socket(socket); server.register_listener(TcpListener::bind(bind_addr).await?, TCP_TIMEOUT); - tracing::info!("DNS server listening on {}", bind_addr); + info!("DNS server listening on {}", bind_addr); Ok(Self { server, @@ -151,14 +152,9 @@ impl DnsHandler { /// Handle a DNS request pub async fn answer_request(&self, request: Request) -> Result { - tracing::info!(?request, "Got DNS request"); - let (tx, mut rx) = broadcast::channel(1); let response_handle = Handle(tx); - self.handle_request(&request, response_handle).await; - - tracing::debug!("Done handling request, trying to resolve response"); Ok(rx.recv().await?) } } @@ -176,6 +172,7 @@ impl RequestHandler for DnsHandler { hickory_server::server::Protocol::Https => inc!(Metrics, dns_requests_https), _ => {} } + debug!(protocol=%request.protocol(), query=%request.query(), "incoming DNS request"); let res = self.catalog.handle_request(request, response_handle).await; match &res.response_code() { diff --git a/iroh-dns-server/src/dns/node_authority.rs b/iroh-dns-server/src/dns/node_authority.rs index 67b498fc552..bc64165a6b5 100644 --- a/iroh-dns-server/src/dns/node_authority.rs +++ b/iroh-dns-server/src/dns/node_authority.rs @@ -1,6 +1,6 @@ use std::{fmt, sync::Arc}; -use anyhow::{bail, ensure, Result}; +use anyhow::{bail, ensure, Context, Result}; use async_trait::async_trait; use hickory_proto::{ op::ResponseCode, @@ -87,22 +87,22 @@ impl Authority for NodeAuthority { record_type: RecordType, lookup_options: LookupOptions, ) -> Result { + debug!(name=%name, "lookup in node authority"); match record_type { RecordType::SOA | RecordType::NS => { self.static_authority .lookup(name, record_type, lookup_options) .await } - _ => match split_and_parse_pkarr(name, &self.origins) { + _ => match parse_name_as_pkarr_with_origin(name, &self.origins) { Err(err) => { - trace!(%name, ?err, "name is not a pkarr zone"); - debug!("resolve static: name {name}"); + debug!(%name, failed_with=%err, "not a pkarr name, resolve in static authority"); self.static_authority .lookup(name, record_type, lookup_options) .await } Ok((name, pubkey, origin)) => { - debug!(%origin, "resolve pkarr: {name} {pubkey}"); + debug!(%origin, %pubkey, %name, "resolve in pkarr zones"); match self .zones .resolve(&pubkey, &name, record_type) @@ -110,6 +110,7 @@ impl Authority for NodeAuthority { .map_err(err_refused)? { Some(pkarr_set) => { + debug!(%origin, %pubkey, %name, "found {} records in pkarr zone", pkarr_set.records_without_rrsigs().count()); let new_origin = Name::parse(&pubkey.to_z32(), Some(&origin)) .map_err(err_refused)?; let record_set = @@ -131,7 +132,7 @@ impl Authority for NodeAuthority { request_info: RequestInfo<'_>, lookup_options: LookupOptions, ) -> Result { - debug!("searching NodeAuthority for: {}", request_info.query); + debug!("search in node authority for {}", request_info.query); let lookup_name = request_info.query.name(); let record_type: RecordType = request_info.query.query_type(); match record_type { @@ -154,7 +155,7 @@ impl Authority for NodeAuthority { } } -fn split_and_parse_pkarr( +fn parse_name_as_pkarr_with_origin( name: impl Into, allowed_origins: &[Name], ) -> Result<(Name, PublicKeyBytes, Name)> { @@ -166,18 +167,19 @@ fn split_and_parse_pkarr( continue; } if name.num_labels() < origin.num_labels() + 1 { - bail!("invalid name"); + bail!("not a valid pkarr name: missing pubkey"); } trace!("parse {origin}"); let labels = name.iter().rev(); let mut labels_without_origin = labels.skip(origin.num_labels() as usize); let pkey_label = labels_without_origin.next().expect("length checked above"); let pkey_str = std::str::from_utf8(pkey_label)?; - let pkey = PublicKeyBytes::from_z32(pkey_str)?; - let remaining_name = Name::from_labels(labels_without_origin)?; + let pkey = + PublicKeyBytes::from_z32(pkey_str).context("not a valid pkarr name: invalid pubkey")?; + let remaining_name = Name::from_labels(labels_without_origin.rev())?; return Ok((remaining_name, pkey, origin.clone())); } - bail!("name does not match any origin"); + bail!("name does not match any allowed origin"); } fn err_refused(e: impl fmt::Debug) -> LookupError { diff --git a/iroh-dns-server/src/lib.rs b/iroh-dns-server/src/lib.rs index 2374b86e027..d09b103336c 100644 --- a/iroh-dns-server/src/lib.rs +++ b/iroh-dns-server/src/lib.rs @@ -13,7 +13,7 @@ mod util; #[cfg(test)] mod tests { - use std::net::SocketAddr; + use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr}; use anyhow::Result; use hickory_resolver::{ @@ -28,13 +28,126 @@ mod tests { }, key::SecretKey, }; + use pkarr::SignedPacket; use url::Url; use crate::server::Server; + #[tokio::test] + async fn pkarr_publish_dns_resolve() -> Result<()> { + iroh_test::logging::setup_multithreaded(); + let (server, nameserver, http_url) = Server::spawn_for_tests().await?; + let pkarr_relay_url = { + let mut url = http_url.clone(); + url.set_path("/pkarr"); + url + }; + let signed_packet = { + use pkarr::dns; + let keypair = pkarr::Keypair::random(); + let mut packet = dns::Packet::new_reply(0); + // record at root + packet.answers.push(dns::ResourceRecord::new( + dns::Name::new("").unwrap(), + dns::CLASS::IN, + 30, + dns::rdata::RData::TXT("hi0".try_into()?), + )); + // record at level one + packet.answers.push(dns::ResourceRecord::new( + dns::Name::new("_hello").unwrap(), + dns::CLASS::IN, + 30, + dns::rdata::RData::TXT("hi1".try_into()?), + )); + // record at level two + packet.answers.push(dns::ResourceRecord::new( + dns::Name::new("_hello.world").unwrap(), + dns::CLASS::IN, + 30, + dns::rdata::RData::TXT("hi2".try_into()?), + )); + // multiple records for same name + packet.answers.push(dns::ResourceRecord::new( + dns::Name::new("multiple").unwrap(), + dns::CLASS::IN, + 30, + dns::rdata::RData::TXT("hi3".try_into()?), + )); + packet.answers.push(dns::ResourceRecord::new( + dns::Name::new("multiple").unwrap(), + dns::CLASS::IN, + 30, + dns::rdata::RData::TXT("hi4".try_into()?), + )); + // record of type A + packet.answers.push(dns::ResourceRecord::new( + dns::Name::new("").unwrap(), + dns::CLASS::IN, + 30, + dns::rdata::RData::A(Ipv4Addr::LOCALHOST.into()), + )); + // record of type AAAA + packet.answers.push(dns::ResourceRecord::new( + dns::Name::new("foo.bar.baz").unwrap(), + dns::CLASS::IN, + 30, + dns::rdata::RData::AAAA(Ipv6Addr::LOCALHOST.into()), + )); + SignedPacket::from_packet(&keypair, &packet)? + }; + let pkarr_client = pkarr::PkarrClient::builder().build(); + pkarr_client + .relay_put(&pkarr_relay_url, &signed_packet) + .await?; + + use hickory_proto::rr::Name; + let pubkey = signed_packet.public_key().to_z32(); + let resolver = test_resolver(nameserver); + + // resolve root record + let name = Name::from_utf8(format!("{pubkey}."))?; + let res = resolver.txt_lookup(name).await?; + let records = res.iter().map(|t| t.to_string()).collect::>(); + assert_eq!(records, vec!["hi0".to_string()]); + + // resolve level one record + let name = Name::from_utf8(format!("_hello.{pubkey}."))?; + let res = resolver.txt_lookup(name).await?; + let records = res.iter().map(|t| t.to_string()).collect::>(); + assert_eq!(records, vec!["hi1".to_string()]); + + // resolve level two record + let name = Name::from_utf8(format!("_hello.world.{pubkey}."))?; + let res = resolver.txt_lookup(name).await?; + let records = res.iter().map(|t| t.to_string()).collect::>(); + assert_eq!(records, vec!["hi2".to_string()]); + + // resolve multiple records for same name + let name = Name::from_utf8(format!("multiple.{pubkey}."))?; + let res = resolver.txt_lookup(name).await?; + let records = res.iter().map(|t| t.to_string()).collect::>(); + assert_eq!(records, vec!["hi3".to_string(), "hi4".to_string()]); + + // resolve A record + let name = Name::from_utf8(format!("{pubkey}."))?; + let res = resolver.ipv4_lookup(name).await?; + let records = res.iter().map(|t| t.0).collect::>(); + assert_eq!(records, vec![Ipv4Addr::LOCALHOST]); + + // resolve AAAA record + let name = Name::from_utf8(format!("foo.bar.baz.{pubkey}."))?; + let res = resolver.ipv6_lookup(name).await?; + let records = res.iter().map(|t| t.0).collect::>(); + assert_eq!(records, vec![Ipv6Addr::LOCALHOST]); + + server.shutdown().await?; + Ok(()) + } + #[tokio::test] async fn integration_smoke() -> Result<()> { - tracing_subscriber::fmt::init(); + iroh_test::logging::setup_multithreaded(); let (server, nameserver, http_url) = Server::spawn_for_tests().await?; let pkarr_relay = { diff --git a/iroh/examples/collection-fetch.rs b/iroh/examples/collection-fetch.rs index e21ac30139e..53fac68f52f 100644 --- a/iroh/examples/collection-fetch.rs +++ b/iroh/examples/collection-fetch.rs @@ -4,7 +4,7 @@ //! This is using an in memory database and a random node id. //! Run the `collection-provide` example, which will give you instructions on how to run this example. use anyhow::{bail, ensure, Context, Result}; -use iroh::rpc_protocol::BlobDownloadRequest; +use iroh::rpc_protocol::{BlobDownloadRequest, DownloadMode}; use iroh_bytes::BlobFormat; use std::env; use std::str::FromStr; @@ -60,15 +60,18 @@ async fn main() -> Result<()> { // When interacting with the iroh API, you will most likely be using blobs and collections. format: ticket.format(), - // The `peer` field is a `NodeAddr`, which combines all of the known address information we have for the remote node. + // The `nodes` field is a list of `NodeAddr`, where each combines all of the known address information we have for the remote node. // This includes the `node_id` (or `PublicKey` of the node), any direct UDP addresses we know about for that node, as well as the relay url of that node. The relay url is the url of the relay server that that node is connected to. // If the direct UDP addresses to that node do not work, than we can use the relay node to attempt to holepunch between your current node and the remote node. // If holepunching fails, iroh will use the relay node to proxy a connection to the remote node over HTTPS. // Thankfully, the ticket contains all of this information - peer: ticket.node_addr().clone(), + nodes: vec![ticket.node_addr().clone()], // You can create a special tag name (`SetTagOption::Named`), or create an automatic tag that is derived from the timestamp. tag: iroh::rpc_protocol::SetTagOption::Auto, + + // Whether to use the download queue, or do a direct download. + mode: DownloadMode::Direct, }; // `download` returns a stream of `DownloadProgress` events. You can iterate through these updates to get progress on the state of your download. diff --git a/iroh/examples/hello-world-fetch.rs b/iroh/examples/hello-world-fetch.rs index 7294c670934..fcc75cdbcaf 100644 --- a/iroh/examples/hello-world-fetch.rs +++ b/iroh/examples/hello-world-fetch.rs @@ -4,7 +4,7 @@ //! This is using an in memory database and a random node id. //! Run the `provide` example, which will give you instructions on how to run this example. use anyhow::{bail, ensure, Context, Result}; -use iroh::rpc_protocol::BlobDownloadRequest; +use iroh::rpc_protocol::{BlobDownloadRequest, DownloadMode}; use iroh_bytes::BlobFormat; use std::env; use std::str::FromStr; @@ -60,15 +60,18 @@ async fn main() -> Result<()> { // When interacting with the iroh API, you will most likely be using blobs and collections. format: ticket.format(), - // The `peer` field is a `NodeAddr`, which combines all of the known address information we have for the remote node. + // The `nodes` field is a list of `NodeAddr`, where each combines all of the known address information we have for the remote node. // This includes the `node_id` (or `PublicKey` of the node), any direct UDP addresses we know about for that node, as well as the relay url of that node. The relay url is the url of the relay server that that node is connected to. // If the direct UDP addresses to that node do not work, than we can use the relay node to attempt to holepunch between your current node and the remote node. // If holepunching fails, iroh will use the relay node to proxy a connection to the remote node over HTTPS. // Thankfully, the ticket contains all of this information - peer: ticket.node_addr().clone(), + nodes: vec![ticket.node_addr().clone()], // You can create a special tag name (`SetTagOption::Named`), or create an automatic tag that is derived from the timestamp. tag: iroh::rpc_protocol::SetTagOption::Auto, + + // Whether to use the download queue, or do a direct download. + mode: DownloadMode::Direct, }; // `download` returns a stream of `DownloadProgress` events. You can iterate through these updates to get progress on the state of your download. diff --git a/iroh/src/node.rs b/iroh/src/node.rs index b2e231dec81..314b4e0b121 100644 --- a/iroh/src/node.rs +++ b/iroh/src/node.rs @@ -16,6 +16,7 @@ use std::task::Poll; use anyhow::{anyhow, Result}; use futures::future::{BoxFuture, Shared}; use futures::{FutureExt, StreamExt}; +use iroh_bytes::downloader::Downloader; use iroh_bytes::store::Store as BaoStore; use iroh_bytes::BlobFormat; use iroh_bytes::Hash; @@ -108,6 +109,7 @@ struct NodeInner { #[debug("rt")] rt: LocalPoolHandle, pub(crate) sync: SyncEngine, + downloader: Downloader, } /// Events emitted by the [`Node`] informing about the current status. @@ -294,7 +296,8 @@ mod tests { use crate::{ client::BlobAddOutcome, rpc_protocol::{ - BlobAddPathRequest, BlobAddPathResponse, BlobDownloadRequest, SetTagOption, WrapOption, + BlobAddPathRequest, BlobAddPathResponse, BlobDownloadRequest, DownloadMode, + SetTagOption, WrapOption, }, }; @@ -442,7 +445,8 @@ mod tests { hash, tag: SetTagOption::Auto, format: BlobFormat::Raw, - peer: addr, + mode: DownloadMode::Direct, + nodes: vec![addr], }; node2.blobs.download(req).await?.await?; assert_eq!( diff --git a/iroh/src/node/builder.rs b/iroh/src/node/builder.rs index 5ef192e719f..564ca87f017 100644 --- a/iroh/src/node/builder.rs +++ b/iroh/src/node/builder.rs @@ -397,7 +397,7 @@ where gossip.clone(), self.docs_store, self.blobs_store.clone(), - downloader, + downloader.clone(), ); let sync_db = sync.sync.clone(); @@ -425,6 +425,7 @@ where gc_task, rt: lp.clone(), sync, + downloader, }); let task = { let gossip = gossip.clone(); diff --git a/iroh/src/node/rpc.rs b/iroh/src/node/rpc.rs index 99bf28ce5e0..870620ab02d 100644 --- a/iroh/src/node/rpc.rs +++ b/iroh/src/node/rpc.rs @@ -3,15 +3,17 @@ use std::io; use std::sync::{Arc, Mutex}; use std::time::Duration; -use anyhow::{anyhow, Result}; -use futures::{Future, FutureExt, Stream, StreamExt}; +use anyhow::{anyhow, ensure, Result}; +use futures::{FutureExt, Stream, StreamExt}; use genawaiter::sync::{Co, Gen}; use iroh_base::rpc::RpcResult; +use iroh_bytes::downloader::{DownloadRequest, Downloader}; use iroh_bytes::export::ExportProgress; use iroh_bytes::format::collection::Collection; use iroh_bytes::get::db::DownloadProgress; +use iroh_bytes::get::Stats; use iroh_bytes::store::{ConsistencyCheckProgress, ExportFormat, ImportProgress, MapEntry}; -use iroh_bytes::util::progress::{IdGenerator, ProgressSender}; +use iroh_bytes::util::progress::ProgressSender; use iroh_bytes::BlobFormat; use iroh_bytes::{ hashseq::parse_hash_seq, @@ -21,6 +23,7 @@ use iroh_bytes::{ HashAndFormat, }; use iroh_io::AsyncSliceReader; +use iroh_net::{MagicEndpoint, NodeAddr}; use quic_rpc::{ server::{RpcChannel, RpcServerError}, ServiceEndpoint, @@ -37,10 +40,11 @@ use crate::rpc_protocol::{ BlobReadAtRequest, BlobReadAtResponse, BlobValidateRequest, CreateCollectionRequest, CreateCollectionResponse, DeleteTagRequest, DocExportFileRequest, DocExportFileResponse, DocImportFileRequest, DocImportFileResponse, DocImportProgress, DocSetHashRequest, - ListTagsRequest, ListTagsResponse, NodeConnectionInfoRequest, NodeConnectionInfoResponse, - NodeConnectionsRequest, NodeConnectionsResponse, NodeShutdownRequest, NodeStatsRequest, - NodeStatsResponse, NodeStatusRequest, NodeStatusResponse, NodeWatchRequest, NodeWatchResponse, - ProviderRequest, ProviderService, SetTagOption, + DownloadMode, ListTagsRequest, ListTagsResponse, NodeConnectionInfoRequest, + NodeConnectionInfoResponse, NodeConnectionsRequest, NodeConnectionsResponse, + NodeShutdownRequest, NodeStatsRequest, NodeStatsResponse, NodeStatusRequest, + NodeStatusResponse, NodeWatchRequest, NodeWatchResponse, ProviderRequest, ProviderService, + SetTagOption, }; use super::{Event, NodeInner}; @@ -602,38 +606,17 @@ impl Handler { fn blob_download(self, msg: BlobDownloadRequest) -> impl Stream { let (sender, receiver) = flume::bounded(1024); - let progress = FlumeProgressSender::new(sender); - - let BlobDownloadRequest { - hash, - format, - peer, - tag, - } = msg; - let db = self.inner.db.clone(); - let hash_and_format = HashAndFormat { hash, format }; - let temp_pin = self.inner.db.temp_tag(hash_and_format); - let get_conn = { - let progress = progress.clone(); - let ep = self.inner.endpoint.clone(); - move || async move { - let conn = ep.connect(peer, iroh_bytes::protocol::ALPN).await?; - progress.send(DownloadProgress::Connected).await?; - Ok(conn) - } - }; - + let downloader = self.inner.downloader.clone(); + let endpoint = self.inner.endpoint.clone(); + let progress = FlumeProgressSender::new(sender); self.inner.rt.spawn_pinned(move || async move { - if let Err(err) = - download_blob(db, get_conn, hash_and_format, tag, progress.clone()).await - { + if let Err(err) = download(&db, endpoint, &downloader, msg, progress.clone()).await { progress .send(DownloadProgress::Abort(err.into())) .await .ok(); } - drop(temp_pin); }); receiver.into_stream().map(BlobDownloadResponse) @@ -1052,31 +1035,138 @@ impl Handler { } } -async fn download_blob( - db: D, - get_conn: C, +async fn download( + db: &D, + endpoint: MagicEndpoint, + downloader: &Downloader, + req: BlobDownloadRequest, + progress: FlumeProgressSender, +) -> Result<()> +where + D: iroh_bytes::store::Store, +{ + let BlobDownloadRequest { + hash, + format, + nodes, + tag, + mode, + } = req; + let hash_and_format = HashAndFormat { hash, format }; + let stats = match mode { + DownloadMode::Queued => { + download_queued( + endpoint, + downloader, + hash_and_format, + nodes, + tag, + progress.clone(), + ) + .await? + } + DownloadMode::Direct => { + download_direct_from_nodes(db, endpoint, hash_and_format, nodes, tag, progress.clone()) + .await? + } + }; + + progress.send(DownloadProgress::AllDone(stats)).await.ok(); + + Ok(()) +} + +async fn download_queued( + endpoint: MagicEndpoint, + downloader: &Downloader, + hash_and_format: HashAndFormat, + nodes: Vec, + tag: SetTagOption, + progress: FlumeProgressSender, +) -> Result { + let mut node_ids = Vec::with_capacity(nodes.len()); + for node in nodes { + node_ids.push(node.node_id); + endpoint.add_node_addr(node)?; + } + let req = DownloadRequest::new(hash_and_format, node_ids) + .progress_sender(progress) + .tag(tag); + let handle = downloader.queue(req).await; + let stats = handle.await?; + Ok(stats) +} + +async fn download_direct_from_nodes( + db: &D, + endpoint: MagicEndpoint, hash_and_format: HashAndFormat, + nodes: Vec, tag: SetTagOption, - progress: impl ProgressSender + IdGenerator, -) -> Result<()> + progress: FlumeProgressSender, +) -> Result where D: BaoStore, - C: FnOnce() -> F, - F: Future>, { - let stats = - iroh_bytes::get::db::get_to_db(&db, get_conn, &hash_and_format, progress.clone()).await?; + ensure!(!nodes.is_empty(), "No nodes to download from provided."); + let mut last_err = None; + for node in nodes { + let node_id = node.node_id; + match download_direct( + db, + endpoint.clone(), + hash_and_format, + node, + tag.clone(), + progress.clone(), + ) + .await + { + Ok(stats) => return Ok(stats), + Err(err) => { + debug!(?err, node = &node_id.fmt_short(), "Download failed"); + last_err = Some(err) + } + } + } + Err(last_err.unwrap()) +} - match tag { - SetTagOption::Named(tag) => { - db.set_tag(tag, Some(hash_and_format)).await?; +async fn download_direct( + db: &D, + endpoint: MagicEndpoint, + hash_and_format: HashAndFormat, + node: NodeAddr, + tag: SetTagOption, + progress: FlumeProgressSender, +) -> Result +where + D: BaoStore, +{ + let temp_pin = db.temp_tag(hash_and_format); + let get_conn = { + let progress = progress.clone(); + move || async move { + let conn = endpoint.connect(node, iroh_bytes::protocol::ALPN).await?; + progress.send(DownloadProgress::Connected).await?; + Ok(conn) } - SetTagOption::Auto => { - db.create_tag(hash_and_format).await?; + }; + + let res = iroh_bytes::get::db::get_to_db(db, get_conn, &hash_and_format, progress).await; + + if res.is_ok() { + match tag { + SetTagOption::Named(tag) => { + db.set_tag(tag, Some(hash_and_format)).await?; + } + SetTagOption::Auto => { + db.create_tag(hash_and_format).await?; + } } } - progress.send(DownloadProgress::AllDone(stats)).await.ok(); + drop(temp_pin); - Ok(()) + res.map_err(Into::into) } diff --git a/iroh/src/rpc_protocol.rs b/iroh/src/rpc_protocol.rs index 053150f51e3..ce71da856db 100644 --- a/iroh/src/rpc_protocol.rs +++ b/iroh/src/rpc_protocol.rs @@ -39,19 +39,11 @@ pub use iroh_bytes::{provider::AddProgress, store::ValidateProgress}; use crate::sync_engine::LiveEvent; pub use crate::ticket::DocTicket; +pub use iroh_bytes::util::SetTagOption; /// A 32-byte key or token pub type KeyBytes = [u8; 32]; -/// Option for commands that allow setting a tag -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] -pub enum SetTagOption { - /// A tag will be automatically generated - Auto, - /// The tag is explicitly named - Named(Tag), -} - /// A request to the node to provide the data at the given path /// /// Will produce a stream of [`AddProgress`] messages. @@ -104,10 +96,31 @@ pub struct BlobDownloadRequest { /// If the format is [`BlobFormat::HashSeq`], all children are downloaded and shared as /// well. pub format: BlobFormat, - /// This mandatory field specifies the peer to download the data from. - pub peer: NodeAddr, + /// This mandatory field specifies the nodes to download the data from. + /// + /// If set to more than a single node, they will all be tried. If `mode` is set to + /// [`DownloadMode::Direct`], they will be tried sequentially until a download succeeds. + /// If `mode` is set to [`DownloadMode::Queued`], the nodes may be dialed in parallel, + /// if the concurrency limits permit. + pub nodes: Vec, /// Optional tag to tag the data with. pub tag: SetTagOption, + /// Whether to directly start the download or add it to the downlod queue. + pub mode: DownloadMode, +} + +/// Set the mode for whether to directly start the download or add it to the download queue. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum DownloadMode { + /// Start the download right away. + /// + /// No concurrency limits or queuing will be applied. It is up to the user to manage download + /// concurrency. + Direct, + /// Queue the download. + /// + /// The download queue will be processed in-order, while respecting the downloader concurrency limits. + Queued, } impl Msg for BlobDownloadRequest { diff --git a/iroh/src/sync_engine/gossip.rs b/iroh/src/sync_engine/gossip.rs index 8daa8781f81..8912887a9ff 100644 --- a/iroh/src/sync_engine/gossip.rs +++ b/iroh/src/sync_engine/gossip.rs @@ -15,7 +15,7 @@ use tokio::{ use tracing::{debug, error, trace}; use super::live::{Op, ToLiveActor}; -use iroh_bytes::downloader::{Downloader, Role}; +use iroh_bytes::downloader::Downloader; #[derive(strum::Display, Debug)] pub enum ToGossipActor { @@ -179,7 +179,7 @@ impl GossipActor { // Inform the downloader that we now know that this peer has the content // for this hash. self.downloader - .nodes_have(hash, vec![(msg.delivered_from, Role::Provider).into()]) + .nodes_have(hash, vec![msg.delivered_from]) .await; } Op::SyncReport(report) => { diff --git a/iroh/src/sync_engine/live.rs b/iroh/src/sync_engine/live.rs index 510676c2b31..3d0cf74a11b 100644 --- a/iroh/src/sync_engine/live.rs +++ b/iroh/src/sync_engine/live.rs @@ -4,7 +4,8 @@ use std::{collections::HashMap, time::SystemTime}; use anyhow::{Context, Result}; use futures::FutureExt; -use iroh_bytes::downloader::{DownloadKind, Downloader, Role}; +use iroh_bytes::downloader::{DownloadRequest, Downloader}; +use iroh_bytes::HashAndFormat; use iroh_bytes::{store::EntryStatus, Hash}; use iroh_gossip::{net::Gossip, proto::TopicId}; use iroh_net::{key::PublicKey, MagicEndpoint, NodeAddr}; @@ -634,15 +635,13 @@ impl LiveActor { if matches!(entry_status, EntryStatus::NotFound | EntryStatus::Partial) && should_download { - let from = PublicKey::from_bytes(&from)?; - let role = match remote_content_status { - ContentStatus::Complete => Role::Provider, - _ => Role::Candidate, + let mut nodes = vec![]; + if let ContentStatus::Complete = remote_content_status { + let node_id = PublicKey::from_bytes(&from)?; + nodes.push(node_id); }; - let handle = self - .downloader - .queue(DownloadKind::Blob { hash }, vec![(from, role).into()]) - .await; + let req = DownloadRequest::untagged(HashAndFormat::raw(hash), nodes); + let handle = self.downloader.queue(req).await; self.pending_downloads.spawn(async move { // NOTE: this ignores the result for now, simply keeping the option