diff --git a/Cargo.lock b/Cargo.lock index 66a4963d3c..636df6ffd2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1817,6 +1817,15 @@ dependencies = [ "allocator-api2", ] +[[package]] +name = "hashlink" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "692eaaf7f7607518dd3cef090f1474b61edc5301d8012f09579920df68b725ee" +dependencies = [ + "hashbrown 0.14.3", +] + [[package]] name = "hdrhistogram" version = "7.5.4" @@ -2406,9 +2415,11 @@ dependencies = [ "futures", "futures-buffered", "genawaiter", + "hashlink", "hex", "http-body 0.4.6", "iroh-base", + "iroh-bytes", "iroh-io", "iroh-metrics", "iroh-net", diff --git a/iroh-base/src/hash.rs b/iroh-base/src/hash.rs index a6e4e82eea..81ab8206f2 100644 --- a/iroh-base/src/hash.rs +++ b/iroh-base/src/hash.rs @@ -7,7 +7,7 @@ use bao_tree::blake3; use postcard::experimental::max_size::MaxSize; use serde::{de, Deserialize, Deserializer, Serialize, Serializer}; -use crate::base32::{parse_array_hex_or_base32, HexOrBase32ParseError}; +use crate::base32::{self, parse_array_hex_or_base32, HexOrBase32ParseError}; /// Hash type used throughout. #[derive(PartialEq, Eq, Copy, Clone, Hash)] @@ -54,6 +54,12 @@ impl Hash { pub fn to_hex(&self) -> String { self.0.to_hex().to_string() } + + /// Convert to a base32 string limited to the first 10 bytes for a friendly string + /// representation of the hash. + pub fn fmt_short(&self) -> String { + base32::fmt_short(self.as_bytes()) + } } impl AsRef<[u8]> for Hash { @@ -173,7 +179,18 @@ impl MaxSize for Hash { /// A format identifier #[derive( - Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, Default, Debug, MaxSize, + Clone, + Copy, + PartialEq, + Eq, + PartialOrd, + Ord, + Serialize, + Deserialize, + Default, + Debug, + MaxSize, + Hash, )] pub enum BlobFormat { /// Raw blob @@ -205,7 +222,7 @@ impl BlobFormat { } /// A hash and format pair -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, MaxSize)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, MaxSize, Hash)] pub struct HashAndFormat { /// The hash pub hash: Hash, @@ -289,6 +306,11 @@ mod redb_support { } impl HashAndFormat { + /// Create a new hash and format pair. + pub fn new(hash: Hash, format: BlobFormat) -> Self { + Self { hash, format } + } + /// Create a new hash and format pair, using the default (raw) format. pub fn raw(hash: Hash) -> Self { Self { diff --git a/iroh-bytes/Cargo.toml b/iroh-bytes/Cargo.toml index 1895903715..de2048644c 100644 --- a/iroh-bytes/Cargo.toml +++ b/iroh-bytes/Cargo.toml @@ -25,6 +25,7 @@ flume = "0.11" futures = "0.3.25" futures-buffered = "0.2.4" genawaiter = { version = "0.99.1", features = ["futures03"] } +hashlink = { version = "0.9.0", optional = true } hex = "0.4.3" iroh-base = { version = "0.14.0", features = ["redb"], path = "../iroh-base" } iroh-io = { version = "0.6.0", features = ["stats"] } @@ -51,6 +52,7 @@ tracing-futures = "0.2.5" [dev-dependencies] http-body = "0.4.5" +iroh-bytes = { path = ".", features = ["downloader"] } iroh-test = { path = "../iroh-test" } proptest = "1.0.0" serde_json = "1.0.107" @@ -63,8 +65,8 @@ tempfile = "3.10.0" [features] default = ["fs-store"] +downloader = ["iroh-net", "parking_lot", "tokio-util/time", "hashlink"] fs-store = ["reflink-copy", "redb", "redb_v1", "tempfile"] -downloader = ["iroh-net", "parking_lot", "tokio-util/time"] metrics = ["iroh-metrics"] [[example]] diff --git a/iroh-bytes/src/downloader.rs b/iroh-bytes/src/downloader.rs index 1963bb306b..4a8e910655 100644 --- a/iroh-bytes/src/downloader.rs +++ b/iroh-bytes/src/downloader.rs @@ -2,12 +2,12 @@ //! //! The [`Downloader`] interacts with four main components to this end. //! - [`Dialer`]: Used to queue opening connections to nodes we need to perform downloads. -//! - [`ProviderMap`]: Where the downloader obtains information about nodes that could be +//! - `ProviderMap`: Where the downloader obtains information about nodes that could be //! used to perform a download. //! - [`Store`]: Where data is stored. //! //! Once a download request is received, the logic is as follows: -//! 1. The [`ProviderMap`] is queried for nodes. From these nodes some are selected +//! 1. The `ProviderMap` is queried for nodes. From these nodes some are selected //! prioritizing connected nodes with lower number of active requests. If no useful node is //! connected, or useful connected nodes have no capacity to perform the request, a connection //! attempt is started using the [`Dialer`]. @@ -27,18 +27,20 @@ //! requests to a single node is also limited. use std::{ - collections::{hash_map::Entry, HashMap, VecDeque}, + collections::{hash_map::Entry, HashMap, HashSet}, + fmt, num::NonZeroUsize, sync::{ atomic::{AtomicU64, Ordering}, Arc, }, + time::Duration, }; -use crate::{get::Stats, protocol::RangeSpecSeq, store::Store, Hash, HashAndFormat}; -use bao_tree::ChunkRanges; use futures::{future::LocalBoxFuture, FutureExt, StreamExt}; -use iroh_net::{MagicEndpoint, NodeId}; +use hashlink::LinkedHashSet; +use iroh_base::hash::{BlobFormat, Hash, HashAndFormat}; +use iroh_net::{MagicEndpoint, NodeAddr, NodeId}; use tokio::{ sync::{mpsc, oneshot}, task::JoinSet, @@ -46,22 +48,28 @@ use tokio::{ use tokio_util::{sync::CancellationToken, task::LocalPoolHandle, time::delay_queue}; use tracing::{debug, error_span, trace, warn, Instrument}; +use crate::{ + get::{db::DownloadProgress, Stats}, + store::Store, + util::{progress::ProgressSender, SetTagOption, TagSet}, + TempTag, +}; + mod get; mod invariants; +mod progress; mod test; -/// Delay added to a request when it's first received. -const INITIAL_REQUEST_DELAY: std::time::Duration = std::time::Duration::from_millis(500); -/// Number of retries initially assigned to a request. -const INITIAL_RETRY_COUNT: u8 = 4; +use self::progress::{BroadcastProgressSender, ProgressSubscriber, ProgressTracker}; + /// Duration for which we keep nodes connected after they were last useful to us. -const IDLE_PEER_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10); +const IDLE_PEER_TIMEOUT: Duration = Duration::from_secs(10); /// Capacity of the channel used to communicate between the [`Downloader`] and the [`Service`]. const SERVICE_CHANNEL_CAPACITY: usize = 128; -/// Download identifier. -// Mainly for readability. -pub type Id = u64; +/// Identifier for a download intent. +#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, derive_more::Display)] +pub struct IntentId(pub u64); /// Trait modeling a dialer. This allows for IO-less testing. pub trait Dialer: @@ -80,7 +88,9 @@ pub trait Dialer: /// Signals what should be done with the request when it fails. #[derive(Debug)] pub enum FailureAction { - /// An error occurred that prevents the request from being retried at all. + /// The request was cancelled by us. + AllIntentsDropped, + /// An error ocurred that prevents the request from being retried at all. AbortRequest(anyhow::Error), /// An error occurred that suggests the node should not be used in general. DropPeer(anyhow::Error), @@ -89,14 +99,19 @@ pub enum FailureAction { } /// Future of a get request. -type GetFut = LocalBoxFuture<'static, Result>; +type GetFut = LocalBoxFuture<'static, InternalDownloadResult>; /// Trait modelling performing a single request over a connection. This allows for IO-less testing. pub trait Getter { /// Type of connections the Getter requires to perform a download. type Connection; /// Return a future that performs the download using the given connection. - fn get(&mut self, kind: DownloadKind, conn: Self::Connection) -> GetFut; + fn get( + &mut self, + kind: DownloadKind, + conn: Self::Connection, + progress_sender: BroadcastProgressSender, + ) -> GetFut; } /// Concurrency limits for the [`Downloader`]. @@ -108,6 +123,8 @@ pub struct ConcurrencyLimits { pub max_concurrent_requests_per_node: usize, /// Maximum number of open connections the service maintains. pub max_open_connections: usize, + /// Maximum number of nodes to dial concurrently for a single request. + pub max_concurrent_dials_per_hash: usize, } impl Default for ConcurrencyLimits { @@ -117,6 +134,7 @@ impl Default for ConcurrencyLimits { max_concurrent_requests: 50, max_concurrent_requests_per_node: 4, max_open_connections: 25, + max_concurrent_dials_per_hash: 5, } } } @@ -136,67 +154,151 @@ impl ConcurrencyLimits { fn at_connections_capacity(&self, active_connections: usize) -> bool { active_connections >= self.max_open_connections } + + /// Checks if the maximum number of concurrent dials per hash has been reached. + /// + /// Note that this limit is not strictly enforced, and not checked in + /// [`Service::check_invariants`]. A certain hash can exceed this limit in a valid way if some + /// of its providers are dialed for another hash. However, once the limit is reached, + /// no new dials will be initiated for the hash. + fn at_dials_per_hash_capacity(&self, concurrent_dials: usize) -> bool { + concurrent_dials >= self.max_concurrent_dials_per_hash + } } -/// Download requests the [`Downloader`] handles. -#[derive(Debug, PartialEq, Eq, Hash, Clone)] -pub enum DownloadKind { - /// Download a single blob entirely. - Blob { - /// Blob to be downloaded. - hash: Hash, - }, - /// Download a sequence of hashes entirely. - HashSeq { - /// Hash sequence to be downloaded. - hash: Hash, - }, +/// Configuration for retry behavior of the [`Downloader`]. +#[derive(Debug)] +pub struct RetryConfig { + /// Maximum number of retry attempts for a node that failed to dial or failed with IO errors. + pub max_retries_per_node: u32, + /// The initial delay to wait before retrying a node. On subsequent failures, the retry delay + /// will be multiplied with the number of failed retries. + pub initial_retry_delay: Duration, } -impl DownloadKind { - /// Get the requested hash. - const fn hash(&self) -> &Hash { - match self { - DownloadKind::Blob { hash } | DownloadKind::HashSeq { hash } => hash, +impl Default for RetryConfig { + fn default() -> Self { + Self { + max_retries_per_node: 6, + initial_retry_delay: Duration::from_millis(500), } } +} - /// Get the requested hash and format. - fn hash_and_format(&self) -> HashAndFormat { - match self { - DownloadKind::Blob { hash } => HashAndFormat::raw(*hash), - DownloadKind::HashSeq { hash } => HashAndFormat::hash_seq(*hash), +/// A download request. +#[derive(Debug, Clone)] +pub struct DownloadRequest { + kind: DownloadKind, + nodes: Vec, + tag: Option, + progress: Option, +} + +impl DownloadRequest { + /// Create a new download request. + /// + /// The blob will be auto-tagged after the download to prevent it from being garbage collected. + pub fn new( + resource: impl Into, + nodes: impl IntoIterator>, + ) -> Self { + Self { + kind: resource.into(), + nodes: nodes.into_iter().map(|n| n.into()).collect(), + tag: Some(SetTagOption::Auto), + progress: None, } } - /// Get the ranges this download is requesting. - // NOTE: necessary to extend downloads to support ranges of blobs ranges of collections. - #[allow(dead_code)] - fn ranges(&self) -> RangeSpecSeq { - match self { - DownloadKind::Blob { .. } => RangeSpecSeq::from_ranges([ChunkRanges::all()]), - DownloadKind::HashSeq { .. } => RangeSpecSeq::all(), - } + /// Create a new untagged download request. + /// + /// The blob will not be tagged, so only use this if the blob is already protected from garbage + /// collection through other means. + pub fn untagged( + resource: HashAndFormat, + nodes: impl IntoIterator>, + ) -> Self { + let mut r = Self::new(resource, nodes); + r.tag = None; + r + } + + /// Set a tag to apply to the blob after download. + pub fn tag(mut self, tag: SetTagOption) -> Self { + self.tag = Some(tag); + self + } + + /// Pass a progress sender to receive progress updates. + pub fn progress_sender(mut self, sender: ProgressSubscriber) -> Self { + self.progress = Some(sender); + self } } -// For readability. In the future we might care about some data reporting on a successful download -// or kind of failure in the error case. -type DownloadResult = anyhow::Result<()>; +/// The kind of resource to download. +#[derive(Debug, Eq, PartialEq, Hash, Clone, Copy, derive_more::From, derive_more::Into)] +pub struct DownloadKind(HashAndFormat); + +impl DownloadKind { + /// Get the hash of this download + pub const fn hash(&self) -> Hash { + self.0.hash + } + + /// Get the format of this download + pub const fn format(&self) -> BlobFormat { + self.0.format + } + + /// Get the [`HashAndFormat`] pair of this download + pub const fn hash_and_format(&self) -> HashAndFormat { + self.0 + } +} + +impl fmt::Display for DownloadKind { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}:{:?}", self.0.hash.fmt_short(), self.0.format) + } +} + +/// The result of a download request, as returned to the application code. +type ExternalDownloadResult = Result; + +/// The result of a download request, as used in this module. +type InternalDownloadResult = Result; + +/// Error returned when a download could not be completed. +#[derive(Debug, Clone, thiserror::Error)] +pub enum DownloadError { + /// Failed to download from any provider + #[error("Failed to complete download")] + DownloadFailed, + /// The download was cancelled by us + #[error("Download cancelled by us")] + Cancelled, + /// No provider nodes found + #[error("No provider nodes found")] + NoProviders, + /// Failed to receive response from service. + #[error("Failed to receive response from download service")] + ActorClosed, +} /// Handle to interact with a download request. #[derive(Debug)] pub struct DownloadHandle { /// Id used to identify the request in the [`Downloader`]. - id: Id, + id: IntentId, /// Kind of download. kind: DownloadKind, /// Receiver to retrieve the return value of this download. - receiver: oneshot::Receiver, + receiver: oneshot::Receiver, } impl std::future::Future for DownloadHandle { - type Output = DownloadResult; + type Output = ExternalDownloadResult; fn poll( mut self: std::pin::Pin<&mut Self>, @@ -207,7 +309,7 @@ impl std::future::Future for DownloadHandle { // from the middle match self.receiver.poll_unpin(cx) { Ready(Ok(result)) => Ready(result), - Ready(Err(recv_err)) => Ready(Err(anyhow::anyhow!("oneshot error: {recv_err}"))), + Ready(Err(_recv_err)) => Ready(Err(DownloadError::ActorClosed)), Pending => Pending, } } @@ -223,8 +325,22 @@ pub struct Downloader { } impl Downloader { - /// Create a new Downloader. + /// Create a new Downloader with the default [`ConcurrencyLimits`] and [`RetryConfig`]. pub fn new(store: S, endpoint: MagicEndpoint, rt: LocalPoolHandle) -> Self + where + S: Store, + { + Self::with_config(store, endpoint, rt, Default::default(), Default::default()) + } + + /// Create a new Downloader with custom [`ConcurrencyLimits`] and [`RetryConfig`]. + pub fn with_config( + store: S, + endpoint: MagicEndpoint, + rt: LocalPoolHandle, + concurrency_limits: ConcurrencyLimits, + retry_config: RetryConfig, + ) -> Self where S: Store, { @@ -233,10 +349,18 @@ impl Downloader { let dialer = iroh_net::dialer::Dialer::new(endpoint); let create_future = move || { - let concurrency_limits = ConcurrencyLimits::default(); - let getter = get::IoGetter { store }; + let getter = get::IoGetter { + store: store.clone(), + }; - let service = Service::new(getter, dialer, concurrency_limits, msg_rx); + let service = Service::new( + store, + getter, + dialer, + concurrency_limits, + retry_config, + msg_rx, + ); service.run().instrument(error_span!("downloader", %me)) }; @@ -248,20 +372,19 @@ impl Downloader { } /// Queue a download. - pub async fn queue(&mut self, kind: DownloadKind, nodes: Vec) -> DownloadHandle { - let id = self.next_id.fetch_add(1, Ordering::SeqCst); - + pub async fn queue(&self, request: DownloadRequest) -> DownloadHandle { + let kind = request.kind; + let intent_id = IntentId(self.next_id.fetch_add(1, Ordering::SeqCst)); let (sender, receiver) = oneshot::channel(); let handle = DownloadHandle { - id, - kind: kind.clone(), + id: intent_id, + kind, receiver, }; let msg = Message::Queue { - kind, - id, - sender, - nodes, + on_finish: sender, + request, + intent_id, }; // if this fails polling the handle will fail as well since the sender side of the oneshot // will be dropped @@ -274,13 +397,13 @@ impl Downloader { /// Cancel a download. // NOTE: receiving the handle ensures an intent can't be cancelled twice - pub async fn cancel(&mut self, handle: DownloadHandle) { + pub async fn cancel(&self, handle: DownloadHandle) { let DownloadHandle { id, kind, receiver: _, } = handle; - let msg = Message::Cancel { id, kind }; + let msg = Message::CancelIntent { id, kind }; if let Err(send_err) = self.msg_tx.send(msg).await { let msg = send_err.0; debug!(?msg, "cancel not sent"); @@ -288,8 +411,11 @@ impl Downloader { } /// Declare that certains nodes can be used to download a hash. - pub async fn nodes_have(&mut self, hash: Hash, nodes: Vec) { - let msg = Message::PeersHave { hash, nodes }; + /// + /// Note that this does not start a download, but only provides new nodes to already queued + /// downloads. Use [`Self::queue`] to queue a download. + pub async fn nodes_have(&mut self, hash: Hash, nodes: Vec) { + let msg = Message::NodesHave { hash, nodes }; if let Err(send_err) = self.msg_tx.send(msg).await { let msg = send_err.0; debug!(?msg, "nodes have not been sent") @@ -297,136 +423,95 @@ impl Downloader { } } -/// A node and its role with regard to a hash. -#[derive(Debug, Clone, Copy)] -pub struct NodeInfo { - node_id: NodeId, - role: Role, -} - -impl NodeInfo { - /// Create a new [`NodeInfo`] from its parts. - pub fn new(node_id: NodeId, role: Role) -> Self { - Self { node_id, role } - } -} - -impl From<(NodeId, Role)> for NodeInfo { - fn from((node_id, role): (NodeId, Role)) -> Self { - Self { node_id, role } - } -} - -/// The role of a node with regard to a download intent. -#[derive(Debug, Clone, Copy, Eq, PartialEq)] -pub enum Role { - /// We have information that this node has the requested blob. - Provider, - /// We do not have information if this node has the requested blob. - Candidate, -} - -impl PartialOrd for Role { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } -} -impl Ord for Role { - fn cmp(&self, other: &Self) -> std::cmp::Ordering { - match (self, other) { - (Role::Provider, Role::Provider) => std::cmp::Ordering::Equal, - (Role::Candidate, Role::Candidate) => std::cmp::Ordering::Equal, - (Role::Provider, Role::Candidate) => std::cmp::Ordering::Greater, - (Role::Candidate, Role::Provider) => std::cmp::Ordering::Less, - } - } -} - /// Messages the service can receive. #[derive(derive_more::Debug)] enum Message { /// Queue a download intent. Queue { - kind: DownloadKind, - id: Id, + request: DownloadRequest, #[debug(skip)] - sender: oneshot::Sender, - nodes: Vec, + on_finish: oneshot::Sender, + intent_id: IntentId, }, + /// Declare that nodes have a certain hash and can be used for downloading. + NodesHave { hash: Hash, nodes: Vec }, /// Cancel an intent. The associated request will be cancelled when the last intent is /// cancelled. - Cancel { id: Id, kind: DownloadKind }, - /// Declare that nodes have certains hash and can be used for downloading. This feeds the [`ProviderMap`]. - PeersHave { hash: Hash, nodes: Vec }, + CancelIntent { id: IntentId, kind: DownloadKind }, +} + +#[derive(derive_more::Debug)] +struct IntentHandlers { + #[debug("oneshot::Sender")] + on_finish: oneshot::Sender, + on_progress: Option, } -/// Information about a request being processed. +/// Information about a request. +#[derive(Debug, Default)] +struct RequestInfo { + /// Registered intents with progress senders and result callbacks. + intents: HashMap, + /// Tags requested for the blob to be created once the download finishes. + tags: TagSet, +} + +/// Information about a request in progress. #[derive(derive_more::Debug)] struct ActiveRequestInfo { - /// Ids of intents associated with this request. - #[debug("{:?}", intents.keys().collect::>())] - intents: HashMap>, - /// How many times can this request be retried. - remaining_retries: u8, /// Token used to cancel the future doing the request. #[debug(skip)] cancellation: CancellationToken, /// Peer doing this request attempt. node: NodeId, + /// Temporary tag to protect the partial blob from being garbage collected. + temp_tag: TempTag, } -/// Information about a request that has not started. -#[derive(derive_more::Debug)] -struct PendingRequestInfo { - /// Ids of intents associated with this request. - #[debug("{:?}", intents.keys().collect::>())] - intents: HashMap>, - /// How many times can this request be retried. - remaining_retries: u8, - /// Key to manage the delay associated with this scheduled request. - #[debug(skip)] - delay_key: delay_queue::Key, - /// If this attempt was scheduled with a known potential node, this is stored here to - /// prevent another query to the [`ProviderMap`]. - next_node: Option, +#[derive(Debug, Default)] +struct RetryState { + /// How many times did we retry this node? + retry_count: u32, + /// Whether the node is currently queued for retry. + retry_is_queued: bool, } /// State of the connection to this node. #[derive(derive_more::Debug)] struct ConnectionInfo { /// Connection to this node. - /// - /// If this node was deemed unusable by a request, this will be set to `None`. As a - /// consequence, when evaluating nodes for a download, this node will not be considered. - /// Since nodes are kept for a longer time that they are strictly necessary, this acts as a - /// temporary ban. #[debug(skip)] - conn: Option, + conn: Conn, /// State of this node. - state: PeerState, + state: ConnectedState, } impl ConnectionInfo { /// Create a new idle node. fn new_idle(connection: Conn, drop_key: delay_queue::Key) -> Self { ConnectionInfo { - conn: Some(connection), - state: PeerState::Idle { drop_key }, + conn: connection, + state: ConnectedState::Idle { drop_key }, } } /// Count of active requests for the node. fn active_requests(&self) -> usize { match self.state { - PeerState::Busy { active_requests } => active_requests.get(), - PeerState::Idle { .. } => 0, + ConnectedState::Busy { active_requests } => active_requests.get(), + ConnectedState::Idle { .. } => 0, } } + + /// Returns `true` if the node is currently idle. + fn is_idle(&self) -> bool { + matches!(self.state, ConnectedState::Idle { .. }) + } } /// State of a connected node. #[derive(derive_more::Debug)] -enum PeerState { +enum ConnectedState { /// Peer is handling at least one request. Busy { #[debug("{}", active_requests.get())] @@ -439,11 +524,16 @@ enum PeerState { }, } -/// Type that is returned from a download request. -type DownloadRes = (DownloadKind, Result<(), FailureAction>); +#[derive(Debug)] +enum NodeState<'a, Conn> { + Connected(&'a ConnectionInfo), + Dialing, + WaitForRetry, + Disconnected, +} #[derive(Debug)] -struct Service { +struct Service { /// The getter performs individual requests. getter: G, /// Map to query for nodes that we believe have the data we are looking for. @@ -452,104 +542,129 @@ struct Service { dialer: D, /// Limits to concurrent tasks handled by the service. concurrency_limits: ConcurrencyLimits, + /// Configuration for retry behavior. + retry_config: RetryConfig, /// Channel to receive messages from the service's handle. msg_rx: mpsc::Receiver, - /// Peers available to use and their relevant information. - nodes: HashMap>, - /// Queue to manage dropping nodes. + /// Nodes to which we have an active or idle connection. + connected_nodes: HashMap>, + /// We track a retry state for nodes which failed to dial or in a transfer. + retry_node_state: HashMap, + /// Delay queue for retrying failed nodes. + retry_nodes_queue: delay_queue::DelayQueue, + /// Delay queue for dropping idle nodes. goodbye_nodes_queue: delay_queue::DelayQueue, - /// Requests performed for download intents. Two download requests can produce the same - /// request. This map allows deduplication of efforts. - current_requests: HashMap, - /// Downloads underway. - in_progress_downloads: JoinSet, - /// Requests scheduled to be downloaded at a later time. - scheduled_requests: HashMap, - /// Queue of scheduled requests. - scheduled_request_queue: delay_queue::DelayQueue, + /// Queue of pending downloads. + queue: Queue, + /// Information about pending and active requests. + requests: HashMap, + /// State of running downloads. + active_requests: HashMap, + /// Tasks for currently running downloads. + in_progress_downloads: JoinSet<(DownloadKind, InternalDownloadResult)>, + /// Progress tracker + progress_tracker: ProgressTracker, + /// The [`Store`] where tags are saved after a download completes. + db: DB, } - -impl, D: Dialer> Service { +impl, D: Dialer> Service { fn new( + db: DB, getter: G, dialer: D, concurrency_limits: ConcurrencyLimits, + retry_config: RetryConfig, msg_rx: mpsc::Receiver, ) -> Self { Service { getter, - providers: ProviderMap::default(), dialer, - concurrency_limits, msg_rx, - nodes: HashMap::default(), + concurrency_limits, + retry_config, + connected_nodes: Default::default(), + retry_node_state: Default::default(), + providers: Default::default(), + requests: Default::default(), + retry_nodes_queue: delay_queue::DelayQueue::default(), goodbye_nodes_queue: delay_queue::DelayQueue::default(), - current_requests: HashMap::default(), + active_requests: Default::default(), in_progress_downloads: Default::default(), - scheduled_requests: HashMap::default(), - scheduled_request_queue: delay_queue::DelayQueue::default(), + progress_tracker: ProgressTracker::new(), + queue: Default::default(), + db, } } /// Main loop for the service. async fn run(mut self) { loop { - // check if we have capacity to dequeue another scheduled request - let at_capacity = self - .concurrency_limits - .at_requests_capacity(self.in_progress_downloads.len()); - + trace!("wait for tick"); tokio::select! { Some((node, conn_result)) = self.dialer.next() => { - trace!("tick: connection ready"); + trace!(node=%node.fmt_short(), "tick: connection ready"); self.on_connection_ready(node, conn_result); } maybe_msg = self.msg_rx.recv() => { trace!(msg=?maybe_msg, "tick: message received"); match maybe_msg { - Some(msg) => self.handle_message(msg), + Some(msg) => self.handle_message(msg).await, None => return self.shutdown().await, } } - Some(res) = self.in_progress_downloads.join_next() => { + Some(res) = self.in_progress_downloads.join_next(), if !self.in_progress_downloads.is_empty() => { match res { Ok((kind, result)) => { - trace!("tick: download completed"); - self.on_download_completed(kind, result); + trace!(%kind, "tick: transfer completed"); + self.on_download_completed(kind, result).await; } - Err(e) => { - warn!("download issue: {:?}", e); + Err(err) => { + warn!(?err, "transfer task panicked"); } } } - Some(expired) = self.scheduled_request_queue.next(), if !at_capacity => { - trace!("tick: scheduled request ready"); - let kind = expired.into_inner(); - let request_info = self.scheduled_requests.remove(&kind).expect("is registered"); - self.on_scheduled_request_ready(kind, request_info); + Some(expired) = self.retry_nodes_queue.next() => { + let node = expired.into_inner(); + trace!(node=%node.fmt_short(), "tick: retry node"); + self.on_retry_wait_elapsed(node); } Some(expired) = self.goodbye_nodes_queue.next() => { let node = expired.into_inner(); - self.nodes.remove(&node); - trace!(%node, "tick: goodbye node"); + trace!(node=%node.fmt_short(), "tick: goodbye node"); + self.disconnect_idle_node(node, "idle expired"); } } + + self.process_head(); + #[cfg(any(test, debug_assertions))] self.check_invariants(); } } /// Handle receiving a [`Message`]. - fn handle_message(&mut self, msg: Message) { + /// + // This is called in the actor loop, and only async because subscribing to an existing transfer + // sends the initial state. + async fn handle_message(&mut self, msg: Message) { match msg { Message::Queue { - kind, - id, - sender, - nodes, - } => self.handle_queue_new_download(kind, id, sender, nodes), - Message::Cancel { id, kind } => self.handle_cancel_download(id, kind), - Message::PeersHave { hash, nodes } => self.handle_nodes_have(hash, nodes), + request, + on_finish, + intent_id, + } => { + self.handle_queue_new_download(request, intent_id, on_finish) + .await + } + Message::CancelIntent { id, kind } => self.handle_cancel_download(id, kind).await, + Message::NodesHave { hash, nodes } => { + let updated = self + .providers + .add_nodes_if_hash_exists(hash, nodes.iter().cloned()); + if updated { + self.queue.unpark_hash(hash); + } + } } } @@ -557,437 +672,468 @@ impl, D: Dialer> Service { /// /// If this intent maps to a request that already exists, it will be registered with it. If the /// request is new it will be scheduled. - fn handle_queue_new_download( + async fn handle_queue_new_download( &mut self, - kind: DownloadKind, - id: Id, - sender: oneshot::Sender, - nodes: Vec, + request: DownloadRequest, + intent_id: IntentId, + on_finish: oneshot::Sender, ) { - self.providers.add_nodes(*kind.hash(), &nodes); - if let Some(info) = self.current_requests.get_mut(&kind) { - // this intent maps to a download that already exists, simply register it - info.intents.insert(id, sender); - // increasing the retries by one accounts for multiple intents for the same request in - // a conservative way - info.remaining_retries += 1; - return trace!(?kind, ?info, "intent registered with active request"); + let DownloadRequest { + kind, + nodes, + tag, + progress, + } = request; + debug!(%kind, nodes=?nodes.iter().map(|n| n.node_id.fmt_short()).collect::>(), "queue intent"); + + // store the download intent + let intent_handlers = IntentHandlers { + on_finish, + on_progress: progress, + }; + + // early exit if no providers. + if nodes.is_empty() && self.providers.get_candidates(&kind.hash()).next().is_none() { + self.finalize_download( + kind, + [(intent_id, intent_handlers)].into(), + Err(DownloadError::NoProviders), + ); + return; } - let needs_node = self - .scheduled_requests - .get(&kind) - .map(|info| info.next_node.is_none()) - .unwrap_or(true); - - let next_node = needs_node - .then(|| self.get_best_candidate(kind.hash())) - .flatten(); - - // if we are here this request is not active, check if it needs to be scheduled - match self.scheduled_requests.get_mut(&kind) { - Some(info) => { - info.intents.insert(id, sender); - // pre-emptively get a node if we don't already have one - match (info.next_node, next_node) { - // We did not yet have next node, but have a node now. - (None, Some(next_node)) => { - info.next_node = Some(next_node); - } - (Some(_old_next_node), Some(_next_node)) => { - unreachable!("invariant: info.next_node must be none because checked above with needs_node") - } - _ => {} + // add the nodes to the provider map + let updated = self + .providers + .add_hash_with_nodes(kind.hash(), nodes.iter().map(|n| n.node_id)); + + // queue the transfer (if not running) or attach to transfer progress (if already running) + if self.active_requests.contains_key(&kind) { + // the transfer is already running, so attach the progress sender + if let Some(on_progress) = &intent_handlers.on_progress { + // this is async because it sends the current state over the progress channel + if let Err(err) = self + .progress_tracker + .subscribe(kind, on_progress.clone()) + .await + { + debug!(?err, %kind, "failed to subscribe progress sender to transfer"); } - - // increasing the retries by one accounts for multiple intents for the same request in - // a conservative way - info.remaining_retries += 1; - trace!(?kind, ?info, "intent registered with scheduled request"); } - None => { - let intents = HashMap::from([(id, sender)]); - self.schedule_request(kind, INITIAL_RETRY_COUNT, next_node, intents) + } else { + // the transfer is not running. + if updated && self.queue.is_parked(&kind) { + // the transfer is on hold for pending retries, and we added new nodes, so move back to queue. + self.queue.unpark(&kind); + } else if !self.queue.contains(&kind) { + // the transfer is not yet queued: add to queue. + self.queue.insert(kind); } } + + // store the request info + let request_info = self.requests.entry(kind).or_default(); + request_info.intents.insert(intent_id, intent_handlers); + if let Some(tag) = &tag { + request_info.tags.insert(tag.clone()); + } } - /// Gets the best candidate for a download. + /// Cancels a download intent. /// - /// Peers are selected prioritizing those with an open connection and with capacity for another - /// request, followed by nodes we are currently dialing with capacity for another request. - /// Lastly, nodes not connected and not dialing are considered. + /// This removes the intent from the list of intents for the `kind`. If the removed intent was + /// the last one for the `kind`, this means that the download is no longer needed. In this + /// case, the `kind` will be removed from the list of pending downloads - and, if the download was + /// already started, the download task will be cancelled. /// - /// If the selected candidate is not connected and we have capacity for another connection, a - /// dial is queued. - fn get_best_candidate(&mut self, hash: &Hash) -> Option { - /// Model the state of nodes found in the candidates - #[derive(PartialEq, Eq, Clone, Copy)] - enum ConnState { - Dialing, - Connected(usize), - NotConnected, - } - - impl Ord for ConnState { - fn cmp(&self, other: &Self) -> std::cmp::Ordering { - // define the order of preference between candidates as follows: - // - prefer connected nodes to dialing ones - // - prefer dialing nodes to not connected ones - // - prefer nodes with less active requests when connected - use std::cmp::Ordering::*; - match (self, other) { - (ConnState::Dialing, ConnState::Dialing) => Equal, - (ConnState::Dialing, ConnState::Connected(_)) => Less, - (ConnState::Dialing, ConnState::NotConnected) => Greater, - (ConnState::NotConnected, ConnState::Dialing) => Less, - (ConnState::NotConnected, ConnState::Connected(_)) => Less, - (ConnState::NotConnected, ConnState::NotConnected) => Equal, - (ConnState::Connected(_), ConnState::Dialing) => Greater, - (ConnState::Connected(_), ConnState::NotConnected) => Greater, - (ConnState::Connected(a), ConnState::Connected(b)) => match a.cmp(b) { - Less => Greater, // less preferable if greater number of requests - Equal => Equal, // no preference - Greater => Less, // more preferable if less number of requests - }, - } - } - } - - impl PartialOrd for ConnState { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } - } + /// The method is async because it will send a final abort event on the progress sender. + async fn handle_cancel_download(&mut self, intent_id: IntentId, kind: DownloadKind) { + let Entry::Occupied(mut occupied_entry) = self.requests.entry(kind) else { + warn!(%kind, %intent_id, "cancel download called for unknown download"); + return; + }; - // first collect suitable candidates - let mut candidates = self - .providers - .get_candidates(hash) - .filter_map(|(node_id, role)| { - let node = NodeInfo::new(*node_id, *role); - if let Some(info) = self.nodes.get(node_id) { - info.conn.as_ref()?; - let req_count = info.active_requests(); - // filter out nodes at capacity - let has_capacity = !self.concurrency_limits.node_at_request_capacity(req_count); - has_capacity.then_some((node, ConnState::Connected(req_count))) - } else if self.dialer.is_pending(node_id) { - Some((node, ConnState::Dialing)) - } else { - Some((node, ConnState::NotConnected)) - } - }) - .collect::>(); - - // Sort candidates by: - // * Role (Providers > Candidates) - // * ConnState (Connected > Dialing > NotConnected) - candidates.sort_unstable_by_key(|(NodeInfo { role, .. }, state)| (*role, *state)); - - // this is our best node, check if we need to dial it - let (node, state) = candidates.pop()?; - - if let ConnState::NotConnected = state { - if !self.at_connections_capacity() { - // node is not connected, not dialing and concurrency limits allow another connection - debug!(node = %node.node_id, "dialing node"); - self.dialer.queue_dial(node.node_id); - Some(node.node_id) - } else { - trace!(node = %node.node_id, "required node not dialed to maintain concurrency limits"); - None + let request_info = occupied_entry.get_mut(); + if let Some(handlers) = request_info.intents.remove(&intent_id) { + handlers.on_finish.send(Err(DownloadError::Cancelled)).ok(); + + if let Some(sender) = handlers.on_progress { + self.progress_tracker.unsubscribe(&kind, &sender); + sender + .send(DownloadProgress::Abort( + anyhow::Error::from(DownloadError::Cancelled).into(), + )) + .await + .ok(); } - } else { - Some(node.node_id) } - } - /// Cancels the download request. - /// - /// This removes the registered download intent and, depending on its state, it will either - /// remove it from the scheduled requests, or cancel the future. - fn handle_cancel_download(&mut self, id: Id, kind: DownloadKind) { - let hash = *kind.hash(); - let mut download_removed = false; - if let Entry::Occupied(mut occupied_entry) = self.current_requests.entry(kind.clone()) { - // remove the intent from the associated request - let intents = &mut occupied_entry.get_mut().intents; - intents.remove(&id); - // if this was the last intent associated with the request cancel it - if intents.is_empty() { - download_removed = true; + if request_info.intents.is_empty() { + occupied_entry.remove(); + if let Entry::Occupied(occupied_entry) = self.active_requests.entry(kind) { occupied_entry.remove().cancellation.cancel(); + } else { + self.queue.remove(&kind); } - } else if let Entry::Occupied(mut occupied_entry) = self.scheduled_requests.entry(kind) { - // remove the intent from the associated request - let intents = &mut occupied_entry.get_mut().intents; - intents.remove(&id); - // if this was the last intent associated with the request remove it from the schedule - // queue - if intents.is_empty() { - let delay_key = occupied_entry.remove().delay_key; - self.scheduled_request_queue.remove(&delay_key); - download_removed = true; - } - } - - if download_removed && !self.is_needed(hash) { - self.providers.remove(hash) - } - } - - /// Handle a [`Message::PeersHave`]. - fn handle_nodes_have(&mut self, hash: Hash, nodes: Vec) { - // check if this still needed - if self.is_needed(hash) { - self.providers.add_nodes(hash, &nodes); - } - } - - /// Checks if this hash is needed. - fn is_needed(&self, hash: Hash) -> bool { - let as_blob = DownloadKind::Blob { hash }; - let as_hash_seq = DownloadKind::HashSeq { hash }; - self.current_requests.contains_key(&as_blob) - || self.scheduled_requests.contains_key(&as_blob) - || self.current_requests.contains_key(&as_hash_seq) - || self.scheduled_requests.contains_key(&as_hash_seq) - } - - /// Check if this hash is currently being downloaded. - fn is_current_request(&self, hash: Hash) -> bool { - let as_blob = DownloadKind::Blob { hash }; - let as_hash_seq = DownloadKind::HashSeq { hash }; - self.current_requests.contains_key(&as_blob) - || self.current_requests.contains_key(&as_hash_seq) - } - - /// Remove a hash from the scheduled queue. - fn unschedule(&mut self, hash: Hash) -> Option<(DownloadKind, PendingRequestInfo)> { - let as_blob = DownloadKind::Blob { hash }; - let as_hash_seq = DownloadKind::HashSeq { hash }; - let info = match self.scheduled_requests.remove(&as_blob) { - Some(req) => Some(req), - None => self.scheduled_requests.remove(&as_hash_seq), - }; - if let Some(info) = info { - let kind = self.scheduled_request_queue.remove(&info.delay_key); - let kind = kind.into_inner(); - Some((kind, info)) - } else { - None + self.remove_hash_if_not_queued(&kind.hash()); } } /// Handle receiving a new connection. fn on_connection_ready(&mut self, node: NodeId, result: anyhow::Result) { + debug_assert!( + !self.connected_nodes.contains_key(&node), + "newly connected node is not yet connected" + ); match result { Ok(connection) => { - trace!(%node, "connected to node"); + trace!(node=%node.fmt_short(), "connected to node"); let drop_key = self.goodbye_nodes_queue.insert(node, IDLE_PEER_TIMEOUT); - self.nodes + self.connected_nodes .insert(node, ConnectionInfo::new_idle(connection, drop_key)); - self.on_node_ready(node); } Err(err) => { - debug!(%node, %err, "connection to node failed") + debug!(%node, %err, "connection to node failed"); + self.disconnect_and_retry(node); } } } - /// Called after the connection to a node is established, and after finishing a download. - /// - /// Starts the next provider hash download, if there is one. - fn on_node_ready(&mut self, node: NodeId) { - // Get the next provider hash for this node. - let Some(hash) = self.providers.get_next_provider_hash_for_node(&node) else { - return; - }; - - if self.is_current_request(hash) { - return; - } - - let Some(conn) = self.get_node_connection_for_download(&node) else { - return; - }; - - let Some((kind, info)) = self.unschedule(hash) else { - debug_assert!( - false, - "invalid state: expected {hash:?} to be scheduled, but it wasn't" - ); - return; - }; - - let PendingRequestInfo { - intents, - remaining_retries, - .. - } = info; - - self.start_download(kind, node, conn, remaining_retries, intents); - } - - fn on_download_completed(&mut self, kind: DownloadKind, result: Result<(), FailureAction>) { + async fn on_download_completed(&mut self, kind: DownloadKind, result: InternalDownloadResult) { // first remove the request - let info = self - .current_requests + let active_request_info = self + .active_requests .remove(&kind) .expect("request was active"); - // update the active requests for this node - let ActiveRequestInfo { - intents, - node, - mut remaining_retries, - .. - } = info; + // get general request info + let request_info = self.requests.remove(&kind).expect("request was active"); + + let ActiveRequestInfo { node, temp_tag, .. } = active_request_info; + // get node info let node_info = self - .nodes + .connected_nodes .get_mut(&node) .expect("node exists in the mapping"); - node_info.state = match &node_info.state { - PeerState::Busy { active_requests } => { - match NonZeroUsize::new(active_requests.get() - 1) { - Some(active_requests) => PeerState::Busy { active_requests }, - None => { - // last request of the node was this one - let drop_key = self.goodbye_nodes_queue.insert(node, IDLE_PEER_TIMEOUT); - PeerState::Idle { drop_key } - } - } + + // update node busy/idle state + node_info.state = match NonZeroUsize::new(node_info.active_requests() - 1) { + None => { + // last request of the node was this one, switch to idle + let drop_key = self.goodbye_nodes_queue.insert(node, IDLE_PEER_TIMEOUT); + ConnectedState::Idle { drop_key } } - PeerState::Idle { .. } => unreachable!("node was busy"), + Some(active_requests) => ConnectedState::Busy { active_requests }, }; - let hash = *kind.hash(); - - let node_ready = match result { + match &result { Ok(_) => { - debug!(%node, ?kind, "download completed"); - for sender in intents.into_values() { - let _ = sender.send(Ok(())); - } - true + debug!(%kind, node=%node.fmt_short(), "download successful"); + // clear retry state if operation was successful + self.retry_node_state.remove(&node); + } + Err(FailureAction::AllIntentsDropped) => { + debug!(%kind, node=%node.fmt_short(), "download cancelled"); } Err(FailureAction::AbortRequest(reason)) => { - debug!(%node, ?kind, %reason, "aborting request"); - for sender in intents.into_values() { - let _ = sender.send(Err(anyhow::anyhow!("request aborted"))); - } - true + debug!(%kind, node=%node.fmt_short(), %reason, "download failed: abort request"); + // do not try to download the hash from this node again + self.providers.remove_hash_from_node(&kind.hash(), &node); } Err(FailureAction::DropPeer(reason)) => { - debug!(%node, ?kind, %reason, "node will be dropped"); - if let Some(_connection) = node_info.conn.take() { - // TODO(@divma): this will fail open streams, do we want this? - // connection.close(..) + debug!(%kind, node=%node.fmt_short(), %reason, "download failed: drop node"); + if node_info.is_idle() { + // remove the node + self.remove_node(node, "explicit drop"); + } else { + // do not try to download the hash from this node again + self.providers.remove_hash_from_node(&kind.hash(), &node); } - false } Err(FailureAction::RetryLater(reason)) => { - // check if the download can be retried - if remaining_retries > 0 { - debug!(%node, ?kind, %reason, "download attempt failed"); - remaining_retries -= 1; - let next_node = self.get_best_candidate(kind.hash()); - self.schedule_request(kind, remaining_retries, next_node, intents); - } else { - warn!(%node, ?kind, %reason, "download failed"); - for sender in intents.into_values() { - let _ = sender.send(Err(anyhow::anyhow!("download ran out of attempts"))); - } + debug!(%kind, node=%node.fmt_short(), %reason, "download failed: retry later"); + if node_info.is_idle() { + self.disconnect_and_retry(node); } - false } }; - if !self.is_needed(hash) { - self.providers.remove(hash) + // we finalize the download if either the download was successful, + // or if it should never proceed because all intents were dropped, + // or if we don't have any candidates to proceed with anymore. + let finalize = match &result { + Ok(_) | Err(FailureAction::AllIntentsDropped) => true, + _ => !self.providers.has_candidates(&kind.hash()), + }; + + if finalize { + let result = result.map_err(|_| DownloadError::DownloadFailed); + if result.is_ok() { + request_info.tags.apply(&self.db, kind.0).await.ok(); + } + drop(temp_tag); + self.finalize_download(kind, request_info.intents, result); + } else { + // reinsert the download at the front of the queue to try from the next node + self.requests.insert(kind, request_info); + self.queue.insert_front(kind); + } + } + + /// Finalize a download. + /// + /// This triggers the intent return channels, and removes the download from the progress tracker + /// and provider map. + fn finalize_download( + &mut self, + kind: DownloadKind, + intents: HashMap, + result: ExternalDownloadResult, + ) { + self.progress_tracker.remove(&kind); + self.remove_hash_if_not_queued(&kind.hash()); + let result = result.map_err(|_| DownloadError::DownloadFailed); + for (_id, handlers) in intents.into_iter() { + handlers.on_finish.send(result.clone()).ok(); + } + } + + fn on_retry_wait_elapsed(&mut self, node: NodeId) { + // check if the node is still needed + let Some(hashes) = self.providers.node_hash.get(&node) else { + self.retry_node_state.remove(&node); + return; + }; + let Some(state) = self.retry_node_state.get_mut(&node) else { + warn!(node=%node.fmt_short(), "missing retry state for node ready for retry"); + return; + }; + state.retry_is_queued = false; + for hash in hashes { + self.queue.unpark_hash(*hash); + } + } + + /// Start the next downloads, or dial nodes, if limits permit and the queue is non-empty. + /// + /// This is called after all actions. If there is nothing to do, it will return cheaply. + /// Otherwise, we will check the next hash in the queue, and: + /// * start the transfer if we are connected to a provider and limits are ok + /// * or, connect to a provider, if there is one we are not dialing yet and limits are ok + /// * or, disconnect an idle node if it would allow us to connect to a provider, + /// * or, if all providers are waiting for retry, park the download + /// * or, if our limits are reached, do nothing for now + /// + /// The download requests will only be popped from the queue once we either start the transfer + /// from a connected node [`NextStep::StartTransfer`], or if we abort the download on + /// [`NextStep::OutOfProviders`]. In all other cases, the request is kept at the top of the + /// queue, so the next call to [`Self::process_head`] will evaluate the situation again - and + /// so forth, until either [`NextStep::StartTransfer`] or [`NextStep::OutOfProviders`] is + /// reached. + fn process_head(&mut self) { + // start as many queued downloads as allowed by the request limits. + loop { + let Some(kind) = self.queue.front().cloned() else { + break; + }; + + let next_step = self.next_step(&kind); + trace!(%kind, ?next_step, "process_head"); + + match next_step { + NextStep::Wait => break, + NextStep::StartTransfer(node) => { + let _ = self.queue.pop_front(); + debug!(%kind, node=%node.fmt_short(), "start transfer"); + self.start_download(kind, node); + } + NextStep::Dial(node) => { + debug!(%kind, node=%node.fmt_short(), "dial node"); + self.dialer.queue_dial(node); + } + NextStep::DialQueuedDisconnect(node, key) => { + let idle_node = self.goodbye_nodes_queue.remove(&key).into_inner(); + self.disconnect_idle_node(idle_node, "drop idle for new dial"); + debug!(%kind, node=%node.fmt_short(), idle_node=%idle_node.fmt_short(), "dial node, disconnect idle node)"); + self.dialer.queue_dial(node); + } + NextStep::Park => { + debug!(%kind, "park download: all providers waiting for retry"); + self.queue.park_front(); + } + NextStep::OutOfProviders => { + debug!(%kind, "abort download: out of providers"); + let _ = self.queue.pop_front(); + let info = self.requests.remove(&kind).expect("queued downloads exist"); + self.finalize_download(kind, info.intents, Err(DownloadError::NoProviders)); + } + } } - if node_ready { - self.on_node_ready(node); + } + + /// Drop the connection to a node and insert it into the the retry queue. + fn disconnect_and_retry(&mut self, node: NodeId) { + self.disconnect_idle_node(node, "queue retry"); + let retry_state = self.retry_node_state.entry(node).or_default(); + retry_state.retry_count += 1; + if retry_state.retry_count <= self.retry_config.max_retries_per_node { + // node can be retried + debug!(node=%node.fmt_short(), retry_count=retry_state.retry_count, "queue retry"); + let timeout = self.retry_config.initial_retry_delay * retry_state.retry_count; + self.retry_nodes_queue.insert(node, timeout); + retry_state.retry_is_queued = true; + } else { + // node is dead + self.remove_node(node, "retries exceeded"); } } - /// A scheduled request is ready to be processed. + /// Calculate the next step needed to proceed the download for `kind`. /// - /// The node that was initially selected is used if possible. Otherwise we try to get a new - /// node - fn on_scheduled_request_ready(&mut self, kind: DownloadKind, info: PendingRequestInfo) { - let PendingRequestInfo { - intents, - mut remaining_retries, - next_node, - .. - } = info; - - // first try with the node that was initially assigned - if let Some((node_id, conn)) = next_node.and_then(|node_id| { - self.get_node_connection_for_download(&node_id) - .map(|conn| (node_id, conn)) - }) { - return self.start_download(kind, node_id, conn, remaining_retries, intents); + /// This is called once `kind` has reached the head of the queue, see [`Self::process_head`]. + /// It can be called repeatedly, and does nothing on itself, only calculate what *should* be + /// done next. + /// + /// See [`NextStep`] for details on the potential next steps returned from this method. + fn next_step(&self, kind: &DownloadKind) -> NextStep { + // If the total requests capacity is reached, we have to wait until an active request + // completes. + if self + .concurrency_limits + .at_requests_capacity(self.active_requests.len()) + { + return NextStep::Wait; + }; + + let mut candidates = self.providers.get_candidates(&kind.hash()).peekable(); + // If we have no provider candidates for this download, there's nothing else we can do. + if candidates.peek().is_none() { + return NextStep::OutOfProviders; } - // we either didn't have a node or the node is busy or dialing. In any case try to get - // another node - let next_node = match self.get_best_candidate(kind.hash()) { - None => None, - Some(node_id) => { - // optimistically check if the node could do the request right away - match self.get_node_connection_for_download(&node_id) { - Some(conn) => { - return self.start_download(kind, node_id, conn, remaining_retries, intents) + // Track if there is provider node to which we are connected and which is not at its request capacity. + // If there are more than one, take the one with the least amount of running transfers. + let mut best_connected: Option<(NodeId, usize)> = None; + // Track if there is a disconnected provider node to which we can potentially connect. + let mut next_to_dial = None; + // Track the number of provider nodes that are currently being dialed. + let mut currently_dialing = 0; + // Track if we have at least one provider node which is currently at its request capacity. + // If this is the case, we will never return [`NextStep::OutOfProviders`] but [`NextStep::Wait`] + // instead, because we can still try that node once it has finished its work. + let mut has_exhausted_provider = false; + // Track if we have at least one provider node that is currently in the retry queue. + let mut has_retrying_provider = false; + + for node in candidates { + match self.node_state(node) { + NodeState::Connected(info) => { + let active_requests = info.active_requests(); + if self + .concurrency_limits + .node_at_request_capacity(active_requests) + { + has_exhausted_provider = true; + } else { + best_connected = Some(match best_connected.take() { + Some(old) if old.1 <= active_requests => old, + _ => (*node, active_requests), + }); + } + } + NodeState::Dialing => { + currently_dialing += 1; + } + NodeState::WaitForRetry => { + has_retrying_provider = true; + } + NodeState::Disconnected => { + if next_to_dial.is_none() { + next_to_dial = Some(node); } - None => Some(node_id), } } - }; + } - // we tried to get a node to perform this request but didn't get one, so now this attempt - // is failed - if remaining_retries > 0 { - remaining_retries -= 1; - self.schedule_request(kind, remaining_retries, next_node, intents); - } else { - // check if this hash is needed in some form, otherwise remove it from providers - let hash = *kind.hash(); - if !self.is_needed(hash) { - self.providers.remove(hash) + let has_dialing = currently_dialing > 0; + + // If we have a connected provider node with free slots, use it! + if let Some((node, _active_requests)) = best_connected { + NextStep::StartTransfer(node) + } + // If we have a node which could be dialed: Check capacity and act accordingly. + else if let Some(node) = next_to_dial { + // We check if the dial capacity for this hash is exceeded: We only start new dials for + // the hash if we are below the limit. + // + // If other requests trigger dials for providers of this hash, the limit may be + // exceeded, but then we just don't start further dials and wait until one completes. + let at_dial_capacity = has_dialing + && self + .concurrency_limits + .at_dials_per_hash_capacity(currently_dialing); + // Check if we reached the global connection limit. + let at_connections_capacity = self.at_connections_capacity(); + + // All slots are free: We can dial our candidate. + if !at_connections_capacity && !at_dial_capacity { + NextStep::Dial(*node) + } + // The hash has free dial capacity, but the global connection capacity is reached. + // But if we have idle nodes, we will disconnect the longest idling node, and then dial our + // candidate. + else if at_connections_capacity + && !at_dial_capacity + && !self.goodbye_nodes_queue.is_empty() + { + let key = self.goodbye_nodes_queue.peek().expect("just checked"); + NextStep::DialQueuedDisconnect(*node, key) } - // request can't be retried - for sender in intents.into_values() { - let _ = sender.send(Err(anyhow::anyhow!("download ran out of attempts"))); + // No dial capacity, and no idling nodes: We have to wait until capacity is freed up. + else { + NextStep::Wait } - debug!(?kind, "download ran out of attempts") + } + // If we have pending dials to candidates, or connected candidates which are busy + // with other work: Wait for one of these to become available. + else if has_exhausted_provider || has_dialing { + NextStep::Wait + } + // All providers are in the retry queue: Park this request until they can be tried again. + else if has_retrying_provider { + NextStep::Park + } + // We have no candidates left: Nothing more to do. + else { + NextStep::OutOfProviders } } /// Start downloading from the given node. - fn start_download( - &mut self, - kind: DownloadKind, - node: NodeId, - conn: D::Connection, - remaining_retries: u8, - intents: HashMap>, - ) { - debug!(%node, ?kind, "starting download"); + /// + /// Panics if hash is not in self.requests or node is not in self.nodes. + fn start_download(&mut self, kind: DownloadKind, node: NodeId) { + let node_info = self.connected_nodes.get_mut(&node).expect("node exists"); + let request_info = self.requests.get(&kind).expect("hash exists"); + + // create a progress sender and subscribe all intents to the progress sender + let subscribers = request_info + .intents + .values() + .flat_map(|state| state.on_progress.clone()); + let progress_sender = self.progress_tracker.track(kind, subscribers); + + // create the active request state let cancellation = CancellationToken::new(); - let info = ActiveRequestInfo { - intents, - remaining_retries, - cancellation, + let temp_tag = self.db.temp_tag(kind.0); + let state = ActiveRequestInfo { + cancellation: cancellation.clone(), node, + temp_tag, }; - let cancellation = info.cancellation.clone(); - self.current_requests.insert(kind.clone(), info); - - let get = self.getter.get(kind.clone(), conn); + let conn = node_info.conn.clone(); + let get_fut = self.getter.get(kind, conn, progress_sender); let fut = async move { // NOTE: it's an open question if we should do timeouts at this point. Considerations from @Frando: // > at this stage we do not know the size of the download, so the timeout would have @@ -996,66 +1142,64 @@ impl, D: Dialer> Service { // > time, while faster nodes could be readily available. // As a conclusion, timeouts should be added only after downloads are known to be bounded let res = tokio::select! { - _ = cancellation.cancelled() => Err(FailureAction::AbortRequest(anyhow::anyhow!("cancelled"))), - res = get => res + _ = cancellation.cancelled() => Err(FailureAction::AllIntentsDropped), + res = get_fut => res }; + trace!("transfer finished"); - (kind, res.map(|_stats| ())) + (kind, res) + } + .instrument(error_span!("transfer", %kind, node=%node.fmt_short())); + node_info.state = match &node_info.state { + ConnectedState::Busy { active_requests } => ConnectedState::Busy { + active_requests: active_requests.saturating_add(1), + }, + ConnectedState::Idle { drop_key } => { + self.goodbye_nodes_queue.remove(drop_key); + ConnectedState::Busy { + active_requests: NonZeroUsize::new(1).expect("clearly non zero"), + } + } }; - + self.active_requests.insert(kind, state); self.in_progress_downloads.spawn_local(fut); } - /// Schedule a request for later processing. - fn schedule_request( - &mut self, - kind: DownloadKind, - remaining_retries: u8, - next_node: Option, - intents: HashMap>, - ) { - // this is simply INITIAL_REQUEST_DELAY * attempt_num where attempt_num (as an ordinal - // number) is maxed at INITIAL_RETRY_COUNT - let delay = INITIAL_REQUEST_DELAY - * (INITIAL_RETRY_COUNT.saturating_sub(remaining_retries) as u32 + 1); - - let delay_key = self.scheduled_request_queue.insert(kind.clone(), delay); - - let info = PendingRequestInfo { - intents, - remaining_retries, - delay_key, - next_node, - }; - debug!(?kind, ?info, "request scheduled"); - self.scheduled_requests.insert(kind, info); - } - - /// Gets the [`Dialer::Connection`] for a node if it's connected and has capacity for another - /// request. In this case, the count of active requests for the node is incremented. - fn get_node_connection_for_download(&mut self, node: &NodeId) -> Option { - let info = self.nodes.get_mut(node)?; - let connection = info.conn.as_ref()?; - // check if the node can be sent another request - match &mut info.state { - PeerState::Busy { active_requests } => { - if !self - .concurrency_limits - .node_at_request_capacity(active_requests.get()) - { - *active_requests = active_requests.saturating_add(1); - Some(connection.clone()) - } else { - None + fn disconnect_idle_node(&mut self, node: NodeId, reason: &'static str) -> bool { + if let Some(info) = self.connected_nodes.remove(&node) { + match info.state { + ConnectedState::Idle { drop_key } => { + self.goodbye_nodes_queue.try_remove(&drop_key); + true + } + ConnectedState::Busy { .. } => { + warn!("expected removed node to be idle, but is busy (removal reason: {reason:?})"); + self.connected_nodes.insert(node, info); + false } } - PeerState::Idle { drop_key } => { - // node is no longer idle - self.goodbye_nodes_queue.remove(drop_key); - info.state = PeerState::Busy { - active_requests: NonZeroUsize::new(1).expect("clearly non zero"), - }; - Some(connection.clone()) + } else { + true + } + } + + fn remove_node(&mut self, node: NodeId, reason: &'static str) { + debug!(node = %node.fmt_short(), %reason, "remove node"); + if self.disconnect_idle_node(node, reason) { + self.providers.remove_node(&node); + self.retry_node_state.remove(&node); + } + } + + fn node_state<'a>(&'a self, node: &NodeId) -> NodeState<'a, D::Connection> { + if let Some(info) = self.connected_nodes.get(node) { + NodeState::Connected(info) + } else if self.dialer.is_pending(node) { + NodeState::Dialing + } else { + match self.retry_node_state.get(node) { + Some(state) if state.retry_is_queued => NodeState::WaitForRetry, + _ => NodeState::Disconnected, } } } @@ -1068,15 +1212,19 @@ impl, D: Dialer> Service { /// Get the total number of connected and dialing nodes. fn connections_count(&self) -> usize { - let connected_nodes = self - .nodes - .values() - .filter(|info| info.conn.is_some()) - .count(); + let connected_nodes = self.connected_nodes.values().count(); let dialing_nodes = self.dialer.pending_count(); connected_nodes + dialing_nodes } + /// Remove a `hash` from the [`ProviderMap`], but only if [`Self::queue`] does not contain the + /// hash at all, even with the other [`BlobFormat`]. + fn remove_hash_if_not_queued(&mut self, hash: &Hash) { + if !self.queue.contains_hash(*hash) { + self.providers.remove_hash(hash); + } + } + #[allow(clippy::unused_async)] async fn shutdown(self) { debug!("shutting down"); @@ -1084,88 +1232,224 @@ impl, D: Dialer> Service { } } -/// Map of potential providers for a hash. -#[derive(Default, Debug)] -pub struct ProviderMap { - /// Candidates to download a hash. - candidates: HashMap>, - /// Ordered list of provider hashes per node. +/// The next step needed to continue a download. +/// +/// See [`Service::next_step`] for details. +#[derive(Debug)] +enum NextStep { + /// Provider connection is ready, initiate the transfer. + StartTransfer(NodeId), + /// Start to dial `NodeId`. /// - /// I.e. blobs we assume the node can provide. - provider_hashes_by_node: HashMap>, -} - -struct ProviderIter<'a> { - inner: Option>, + /// This means: We have no non-exhausted connection to a provider node, but a free connection slot + /// and a provider node we are not yet connected to. + Dial(NodeId), + /// Start to dial `NodeId`, but first disconnect the idle node behind [`delay_queue::Key`] in + /// [`Service::goodbye_nodes_queue`] to free up a connection slot. + DialQueuedDisconnect(NodeId, delay_queue::Key), + /// Resource limits are exhausted, do nothing for now and wait until a slot frees up. + Wait, + /// All providers are currently in a retry timeout. Park the download aside, and move + /// to the next download in the queue. + Park, + /// We have tried all available providers. There is nothing else to do. + OutOfProviders, } -impl<'a> Iterator for ProviderIter<'a> { - type Item = (&'a NodeId, &'a Role); - - fn next(&mut self) -> Option { - self.inner.as_mut().and_then(|iter| iter.next()) - } +/// Map of potential providers for a hash. +#[derive(Default, Debug)] +struct ProviderMap { + hash_node: HashMap>, + node_hash: HashMap>, } impl ProviderMap { /// Get candidates to download this hash. - fn get_candidates(&self, hash: &Hash) -> impl Iterator { - let inner = self.candidates.get(hash).map(|nodes| nodes.iter()); - ProviderIter { inner } + pub fn get_candidates(&self, hash: &Hash) -> impl Iterator { + self.hash_node + .get(hash) + .map(|nodes| nodes.iter()) + .into_iter() + .flatten() + } + + /// Whether we have any candidates to download this hash. + pub fn has_candidates(&self, hash: &Hash) -> bool { + self.hash_node + .get(hash) + .map(|nodes| !nodes.is_empty()) + .unwrap_or(false) } /// Register nodes for a hash. Should only be done for hashes we care to download. - fn add_nodes(&mut self, hash: Hash, nodes: &[NodeInfo]) { - let entry = self.candidates.entry(hash).or_default(); + /// + /// Returns `true` if new providers were added. + fn add_hash_with_nodes(&mut self, hash: Hash, nodes: impl Iterator) -> bool { + let mut updated = false; + let hash_entry = self.hash_node.entry(hash).or_default(); for node in nodes { - entry - .entry(node.node_id) - .and_modify(|role| *role = (*role).max(node.role)) - .or_insert(node.role); - if let Role::Provider = node.role { - self.provider_hashes_by_node - .entry(node.node_id) - .or_default() - .push_back(hash); - } + updated |= hash_entry.insert(node); + let node_entry = self.node_hash.entry(node).or_default(); + node_entry.insert(hash); } + updated } - /// Get the next provider hash for a node. + /// Register nodes for a hash, but only if the hash is already in our queue. /// - /// I.e. get the next hash that was added with [`Role::Provider`] for this node. - fn get_next_provider_hash_for_node(&mut self, node: &NodeId) -> Option { - let hash = self - .provider_hashes_by_node - .get(node) - .and_then(|hashes| hashes.front()) - .copied(); - if let Some(hash) = hash { - self.move_hash_to_back(node, hash); + /// Returns `true` if a new node was added. + fn add_nodes_if_hash_exists( + &mut self, + hash: Hash, + nodes: impl Iterator, + ) -> bool { + let mut updated = false; + if let Some(hash_entry) = self.hash_node.get_mut(&hash) { + for node in nodes { + updated |= hash_entry.insert(node); + let node_entry = self.node_hash.entry(node).or_default(); + node_entry.insert(hash); + } } - hash + updated } /// Signal the registry that this hash is no longer of interest. - fn remove(&mut self, hash: Hash) { - if let Some(nodes) = self.candidates.remove(&hash) { - for node in nodes.keys() { - if let Some(hashes) = self.provider_hashes_by_node.get_mut(node) { - hashes.retain(|h| *h != hash); + fn remove_hash(&mut self, hash: &Hash) { + if let Some(nodes) = self.hash_node.remove(hash) { + for node in nodes { + if let Some(hashes) = self.node_hash.get_mut(&node) { + hashes.remove(hash); + if hashes.is_empty() { + self.node_hash.remove(&node); + } } } } } - /// Move a hash to the back of the provider queue for a node. - fn move_hash_to_back(&mut self, node: &NodeId, hash: Hash) { - let hashes = self.provider_hashes_by_node.get_mut(node); - if let Some(hashes) = hashes { - debug_assert_eq!(hashes.front(), Some(&hash)); - if !hashes.is_empty() { - hashes.rotate_left(1); + fn remove_node(&mut self, node: &NodeId) { + if let Some(hashes) = self.node_hash.remove(node) { + for hash in hashes { + if let Some(nodes) = self.hash_node.get_mut(&hash) { + nodes.remove(node); + if nodes.is_empty() { + self.hash_node.remove(&hash); + } + } + } + } + } + + fn remove_hash_from_node(&mut self, hash: &Hash, node: &NodeId) { + if let Some(nodes) = self.hash_node.get_mut(hash) { + nodes.remove(node); + if nodes.is_empty() { + self.remove_hash(hash); } } + if let Some(hashes) = self.node_hash.get_mut(node) { + hashes.remove(hash); + if hashes.is_empty() { + self.remove_node(node); + } + } + } +} + +/// The queue of requested downloads. +/// +/// This manages two datastructures: +/// * The main queue, a FIFO queue where each item can only appear once. +/// New downloads are pushed to the back of the queue, and the next download to process is popped +/// from the front. +/// * The parked set, a hash set. Items can be moved from the main queue into the parked set. +/// Parked items will not be popped unless they are moved back into the main queue. +#[derive(Debug, Default)] +struct Queue { + main: LinkedHashSet, + parked: HashSet, +} + +impl Queue { + /// Peek at the front element of the main queue. + pub fn front(&self) -> Option<&DownloadKind> { + self.main.front() + } + + #[cfg(any(test, debug_assertions))] + pub fn iter_parked(&self) -> impl Iterator { + self.parked.iter() + } + + #[cfg(any(test, debug_assertions))] + pub fn iter(&self) -> impl Iterator { + self.main.iter().chain(self.parked.iter()) + } + + /// Returns `true` if either the main queue or the parked set contain a download. + pub fn contains(&self, kind: &DownloadKind) -> bool { + self.main.contains(kind) || self.parked.contains(kind) + } + + /// Returns `true` if either the main queue or the parked set contain a download for a hash. + pub fn contains_hash(&self, hash: Hash) -> bool { + let as_raw = HashAndFormat::raw(hash).into(); + let as_hash_seq = HashAndFormat::hash_seq(hash).into(); + self.contains(&as_raw) || self.contains(&as_hash_seq) + } + + /// Returns `true` if a download is in the parked set. + pub fn is_parked(&self, kind: &DownloadKind) -> bool { + self.parked.contains(kind) + } + + /// Insert an element at the back of the main queue. + pub fn insert(&mut self, kind: DownloadKind) { + if !self.main.contains(&kind) { + self.main.insert(kind); + } + } + + /// Insert an element at the front of the main queue. + pub fn insert_front(&mut self, kind: DownloadKind) { + if !self.main.contains(&kind) { + self.main.insert(kind); + } + self.main.to_front(&kind); + } + + /// Dequeue the first download of the main queue. + pub fn pop_front(&mut self) -> Option { + self.main.pop_front() + } + + /// Move the front item of the main queue into the parked set. + pub fn park_front(&mut self) { + if let Some(item) = self.pop_front() { + self.parked.insert(item); + } + } + + /// Move a download from the parked set to the front of the main queue. + pub fn unpark(&mut self, kind: &DownloadKind) { + if self.parked.remove(kind) { + self.main.insert(*kind); + self.main.to_front(kind); + } + } + + /// Move any download for a hash from the parked set to the main queue. + pub fn unpark_hash(&mut self, hash: Hash) { + let as_raw = HashAndFormat::raw(hash).into(); + let as_hash_seq = HashAndFormat::hash_seq(hash).into(); + self.unpark(&as_raw); + self.unpark(&as_hash_seq); + } + + /// Remove a download from both the main queue and the parked set. + pub fn remove(&mut self, kind: &DownloadKind) -> bool { + self.main.remove(kind) || self.parked.remove(kind) } } diff --git a/iroh-bytes/src/downloader/get.rs b/iroh-bytes/src/downloader/get.rs index 334064bdee..2fb39c2900 100644 --- a/iroh-bytes/src/downloader/get.rs +++ b/iroh-bytes/src/downloader/get.rs @@ -3,7 +3,6 @@ use crate::{ get::{db::get_to_db, error::GetError}, store::Store, - util::progress::IgnoreProgressSender, }; use futures::FutureExt; #[cfg(feature = "metrics")] @@ -12,7 +11,7 @@ use iroh_metrics::{inc, inc_by}; #[cfg(feature = "metrics")] use crate::metrics::Metrics; -use super::{DownloadKind, FailureAction, GetFut, Getter}; +use super::{progress::BroadcastProgressSender, DownloadKind, FailureAction, GetFut, Getter}; impl From for FailureAction { fn from(e: GetError) -> Self { @@ -36,9 +35,13 @@ pub(crate) struct IoGetter { impl Getter for IoGetter { type Connection = quinn::Connection; - fn get(&mut self, kind: DownloadKind, conn: Self::Connection) -> GetFut { + fn get( + &mut self, + kind: DownloadKind, + conn: Self::Connection, + progress_sender: BroadcastProgressSender, + ) -> GetFut { let store = self.store.clone(); - let progress_sender = IgnoreProgressSender::default(); let fut = async move { let get_conn = || async move { Ok(conn) }; let res = get_to_db(&store, get_conn, &kind.hash_and_format(), progress_sender).await; diff --git a/iroh-bytes/src/downloader/invariants.rs b/iroh-bytes/src/downloader/invariants.rs index b24a6a125e..f26c333707 100644 --- a/iroh-bytes/src/downloader/invariants.rs +++ b/iroh-bytes/src/downloader/invariants.rs @@ -5,12 +5,12 @@ use super::*; /// invariants for the service. -impl, D: Dialer> Service { +impl, D: Dialer, S: Store> Service { /// Checks the various invariants the service must maintain #[track_caller] pub(in crate::downloader) fn check_invariants(&self) { self.check_active_request_count(); - self.check_scheduled_requests_consistency(); + self.check_queued_requests_consistency(); self.check_idle_peer_consistency(); self.check_concurrency_limits(); self.check_provider_map_prunning(); @@ -21,8 +21,9 @@ impl, D: Dialer> Service { fn check_concurrency_limits(&self) { let ConcurrencyLimits { max_concurrent_requests, - max_concurrent_requests_per_node: max_concurrent_requests_per_peer, + max_concurrent_requests_per_node, max_open_connections, + max_concurrent_dials_per_hash, } = &self.concurrency_limits; // check the total number of active requests to ensure it stays within the limit @@ -32,16 +33,39 @@ impl, D: Dialer> Service { ); // check that the open and dialing peers don't exceed the connection capacity + tracing::trace!( + "limits: conns: {}/{} | reqs: {}/{}", + self.connections_count(), + max_open_connections, + self.in_progress_downloads.len(), + max_concurrent_requests + ); assert!( self.connections_count() <= *max_open_connections, "max_open_connections exceeded" ); // check the active requests per peer don't exceed the limit - for (peer, info) in self.nodes.iter() { + for (node, info) in self.connected_nodes.iter() { assert!( - info.active_requests() <= *max_concurrent_requests_per_peer, - "max_concurrent_requests_per_peer exceeded for {peer}" + info.active_requests() <= *max_concurrent_requests_per_node, + "max_concurrent_requests_per_node exceeded for {node}" + ) + } + + // check that we do not dial more nodes than allowed for the next pending hashes + if let Some(kind) = self.queue.front() { + let hash = kind.hash(); + let nodes = self.providers.get_candidates(&hash); + let mut dialing = 0; + for node in nodes { + if self.dialer.is_pending(node) { + dialing += 1; + } + } + assert!( + dialing <= *max_concurrent_dials_per_hash, + "max_concurrent_dials_per_hash exceeded for {hash}" ) } } @@ -54,17 +78,18 @@ impl, D: Dialer> Service { // number of requests assert_eq!( self.in_progress_downloads.len(), - self.current_requests.len(), - "current_requests and in_progress_downloads are out of sync" + self.active_requests.len(), + "active_requests and in_progress_downloads are out of sync" ); // check that the count of requests per peer matches the number of requests that have that // peer as active - let mut real_count: HashMap = HashMap::with_capacity(self.nodes.len()); - for req_info in self.current_requests.values() { + let mut real_count: HashMap = + HashMap::with_capacity(self.connected_nodes.len()); + for req_info in self.active_requests.values() { // nothing like some classic word count *real_count.entry(req_info.node).or_default() += 1; } - for (peer, info) in self.nodes.iter() { + for (peer, info) in self.connected_nodes.iter() { assert_eq!( info.active_requests(), real_count.get(peer).copied().unwrap_or_default(), @@ -73,21 +98,44 @@ impl, D: Dialer> Service { } } - /// Checks that the scheduled requests match the queue that handles their delays. + /// Checks that the queued requests all appear in the provider map and request map. #[track_caller] - fn check_scheduled_requests_consistency(&self) { - assert_eq!( - self.scheduled_requests.len(), - self.scheduled_request_queue.len(), - "scheduled_request_queue and scheduled_requests are out of sync" - ); + fn check_queued_requests_consistency(&self) { + // check that all hashes in the queue have candidates + for entry in self.queue.iter() { + assert!( + self.providers + .get_candidates(&entry.hash()) + .next() + .is_some(), + "all queued requests have providers" + ); + assert!( + self.requests.get(entry).is_some(), + "all queued requests have request info" + ); + } + + // check that all parked hashes should be parked + for entry in self.queue.iter_parked() { + assert!( + matches!(self.next_step(entry), NextStep::Park), + "all parked downloads evaluate to the correct next step" + ); + assert!( + self.providers + .get_candidates(&entry.hash()) + .all(|node| matches!(self.node_state(node), NodeState::WaitForRetry)), + "all parked downloads have only retrying nodes" + ); + } } /// Check that peers queued to be disconnected are consistent with peers considered idle. #[track_caller] fn check_idle_peer_consistency(&self) { let idle_peers = self - .nodes + .connected_nodes .values() .filter(|info| info.active_requests() == 0) .count(); @@ -101,11 +149,15 @@ impl, D: Dialer> Service { /// Check that every hash in the provider map is needed. #[track_caller] fn check_provider_map_prunning(&self) { - for hash in self.providers.candidates.keys() { + for hash in self.providers.hash_node.keys() { + let as_raw = DownloadKind(HashAndFormat::raw(*hash)); + let as_hash_seq = DownloadKind(HashAndFormat::hash_seq(*hash)); assert!( - self.is_needed(*hash), - "provider map contains {hash:?} which should have been prunned" - ); + self.queue.contains_hash(*hash) + || self.active_requests.contains_key(&as_raw) + || self.active_requests.contains_key(&as_hash_seq), + "all hashes in the provider map are in the queue or active" + ) } } } diff --git a/iroh-bytes/src/downloader/progress.rs b/iroh-bytes/src/downloader/progress.rs new file mode 100644 index 0000000000..47ec74154c --- /dev/null +++ b/iroh-bytes/src/downloader/progress.rs @@ -0,0 +1,195 @@ +use std::{ + collections::HashMap, + sync::{ + atomic::{AtomicU64, Ordering}, + Arc, + }, +}; + +use anyhow::anyhow; +use parking_lot::Mutex; + +use crate::{ + get::{db::DownloadProgress, progress::TransferState}, + util::progress::{FlumeProgressSender, IdGenerator, ProgressSendError, ProgressSender}, +}; + +use super::DownloadKind; + +/// The channel that can be used to subscribe to progress updates. +pub type ProgressSubscriber = FlumeProgressSender; + +/// Track the progress of downloads. +/// +/// This struct allows to create [`ProgressSender`] structs to be passed to +/// [`crate::get::db::get_to_db`]. Each progress sender can be subscribed to by any number of +/// [`ProgressSubscriber`] channel senders, which will receive each progress update (if they have +/// capacity). Additionally, the [`ProgressTracker`] maintains a [`TransferState`] for each +/// transfer, applying each progress update to update this state. When subscribing to an already +/// running transfer, the subscriber will receive a [`DownloadProgress::InitialState`] message +/// containing the state at the time of the subscription, and then receive all further progress +/// events directly. +#[derive(Debug, Default)] +pub struct ProgressTracker { + /// Map of shared state for each tracked download. + running: HashMap, + /// Shared [`IdGenerator`] for all progress senders created by the tracker. + id_gen: Arc, +} + +impl ProgressTracker { + pub fn new() -> Self { + Self::default() + } + + /// Track a new download with a list of initial subscribers. + /// + /// Note that this should only be called for *new* downloads. If a download for the `kind` is + /// already tracked in this [`ProgressTracker`], calling `track` will replace all existing + /// state and subscribers (equal to calling [`Self::remove`] first). + pub fn track( + &mut self, + kind: DownloadKind, + subscribers: impl IntoIterator, + ) -> BroadcastProgressSender { + let inner = Inner { + subscribers: subscribers.into_iter().collect(), + state: TransferState::new(kind.hash()), + }; + let shared = Arc::new(Mutex::new(inner)); + self.running.insert(kind, Arc::clone(&shared)); + let id_gen = Arc::clone(&self.id_gen); + BroadcastProgressSender { shared, id_gen } + } + + /// Subscribe to a tracked download. + /// + /// Will return an error if `kind` is not yet tracked. + pub async fn subscribe( + &mut self, + kind: DownloadKind, + sender: ProgressSubscriber, + ) -> anyhow::Result<()> { + let initial_msg = self + .running + .get_mut(&kind) + .ok_or_else(|| anyhow!("state for download {kind:?} not found"))? + .lock() + .subscribe(sender.clone()); + sender.send(initial_msg).await?; + Ok(()) + } + + /// Unsubscribe `sender` from `kind`. + pub fn unsubscribe(&mut self, kind: &DownloadKind, sender: &ProgressSubscriber) { + if let Some(shared) = self.running.get_mut(kind) { + shared.lock().unsubscribe(sender) + } + } + + /// Remove all state for a download. + pub fn remove(&mut self, kind: &DownloadKind) { + self.running.remove(kind); + } +} + +type Shared = Arc>; + +#[derive(Debug)] +struct Inner { + subscribers: Vec, + state: TransferState, +} + +impl Inner { + fn subscribe(&mut self, subscriber: ProgressSubscriber) -> DownloadProgress { + let msg = DownloadProgress::InitialState(self.state.clone()); + self.subscribers.push(subscriber); + msg + } + + fn unsubscribe(&mut self, sender: &ProgressSubscriber) { + self.subscribers.retain(|s| !s.same_channel(sender)); + } + + fn on_progress(&mut self, progress: DownloadProgress) { + self.state.on_progress(progress); + } +} + +#[derive(Debug, Clone)] +pub struct BroadcastProgressSender { + shared: Shared, + id_gen: Arc, +} + +impl IdGenerator for BroadcastProgressSender { + fn new_id(&self) -> u64 { + self.id_gen.fetch_add(1, Ordering::SeqCst) + } +} + +impl ProgressSender for BroadcastProgressSender { + type Msg = DownloadProgress; + + async fn send(&self, msg: Self::Msg) -> Result<(), ProgressSendError> { + // making sure that the lock is not held across an await point. + let futs = { + let mut inner = self.shared.lock(); + inner.on_progress(msg.clone()); + let futs = inner + .subscribers + .iter_mut() + .map(|sender| { + let sender = sender.clone(); + let msg = msg.clone(); + async move { + match sender.send(msg).await { + Ok(()) => None, + Err(ProgressSendError::ReceiverDropped) => Some(sender), + } + } + }) + .collect::>(); + drop(inner); + futs + }; + + let failed_senders = futures::future::join_all(futs).await; + // remove senders where the receiver is dropped + if failed_senders.iter().any(|s| s.is_some()) { + let mut inner = self.shared.lock(); + for sender in failed_senders.into_iter().flatten() { + inner.unsubscribe(&sender); + } + drop(inner); + } + Ok(()) + } + + fn try_send(&self, msg: Self::Msg) -> Result<(), ProgressSendError> { + let mut inner = self.shared.lock(); + inner.on_progress(msg.clone()); + // remove senders where the receiver is dropped + inner + .subscribers + .retain_mut(|sender| match sender.try_send(msg.clone()) { + Err(ProgressSendError::ReceiverDropped) => false, + Ok(()) => true, + }); + Ok(()) + } + + fn blocking_send(&self, msg: Self::Msg) -> Result<(), ProgressSendError> { + let mut inner = self.shared.lock(); + inner.on_progress(msg.clone()); + // remove senders where the receiver is dropped + inner + .subscribers + .retain_mut(|sender| match sender.blocking_send(msg.clone()) { + Err(ProgressSendError::ReceiverDropped) => false, + Ok(()) => true, + }); + Ok(()) + } +} diff --git a/iroh-bytes/src/downloader/test.rs b/iroh-bytes/src/downloader/test.rs index 6f7029ad2f..b18b185ebc 100644 --- a/iroh-bytes/src/downloader/test.rs +++ b/iroh-bytes/src/downloader/test.rs @@ -1,8 +1,18 @@ #![cfg(test)] -use std::time::Duration; +use anyhow::anyhow; +use futures::FutureExt; +use std::{ + sync::atomic::AtomicUsize, + time::{Duration, Instant}, +}; use iroh_net::key::SecretKey; +use crate::{ + get::{db::BlobId, progress::TransferState}, + util::progress::{FlumeProgressSender, IdGenerator, ProgressSender}, +}; + use super::*; mod dialer; @@ -13,14 +23,30 @@ impl Downloader { dialer: dialer::TestingDialer, getter: getter::TestingGetter, concurrency_limits: ConcurrencyLimits, + ) -> Self { + Self::spawn_for_test_with_retry_config( + dialer, + getter, + concurrency_limits, + Default::default(), + ) + } + + fn spawn_for_test_with_retry_config( + dialer: dialer::TestingDialer, + getter: getter::TestingGetter, + concurrency_limits: ConcurrencyLimits, + retry_config: RetryConfig, ) -> Self { let (msg_tx, msg_rx) = mpsc::channel(super::SERVICE_CHANNEL_CAPACITY); + let db = crate::store::mem::Store::default(); LocalPoolHandle::new(1).spawn_pinned(move || async move { // we want to see the logs of the service let _guard = iroh_test::logging::setup(); - let service = Service::new(getter, dialer, concurrency_limits, msg_rx); + let service = + Service::new(db, getter, dialer, concurrency_limits, retry_config, msg_rx); service.run().await }); @@ -34,21 +60,18 @@ impl Downloader { /// Tests that receiving a download request and performing it doesn't explode. #[tokio::test] async fn smoke_test() { + let _guard = iroh_test::logging::setup(); let dialer = dialer::TestingDialer::default(); let getter = getter::TestingGetter::default(); let concurrency_limits = ConcurrencyLimits::default(); - let mut downloader = - Downloader::spawn_for_test(dialer.clone(), getter.clone(), concurrency_limits); + let downloader = Downloader::spawn_for_test(dialer.clone(), getter.clone(), concurrency_limits); // send a request and make sure the peer is requested the corresponding download let peer = SecretKey::generate().public(); - let kind = DownloadKind::Blob { - hash: Hash::new([0u8; 32]), - }; - let handle = downloader - .queue(kind.clone(), vec![(peer, Role::Candidate).into()]) - .await; + let kind: DownloadKind = HashAndFormat::raw(Hash::new([0u8; 32])).into(); + let req = DownloadRequest::new(kind, vec![peer]); + let handle = downloader.queue(req).await; // wait for the download result to be reported handle.await.expect("should report success"); // verify that the peer was dialed @@ -60,24 +83,21 @@ async fn smoke_test() { /// Tests that multiple intents produce a single request. #[tokio::test] async fn deduplication() { + let _guard = iroh_test::logging::setup(); let dialer = dialer::TestingDialer::default(); let getter = getter::TestingGetter::default(); // make request take some time to ensure the intents are received before completion getter.set_request_duration(Duration::from_secs(1)); let concurrency_limits = ConcurrencyLimits::default(); - let mut downloader = - Downloader::spawn_for_test(dialer.clone(), getter.clone(), concurrency_limits); + let downloader = Downloader::spawn_for_test(dialer.clone(), getter.clone(), concurrency_limits); let peer = SecretKey::generate().public(); - let kind = DownloadKind::Blob { - hash: Hash::new([0u8; 32]), - }; + let kind: DownloadKind = HashAndFormat::raw(Hash::new([0u8; 32])).into(); let mut handles = Vec::with_capacity(10); for _ in 0..10 { - let h = downloader - .queue(kind.clone(), vec![(peer, Role::Candidate).into()]) - .await; + let req = DownloadRequest::new(kind, vec![peer]); + let h = downloader.queue(req).await; handles.push(h); } assert!( @@ -94,37 +114,27 @@ async fn deduplication() { /// Tests that the request is cancelled only when all intents are cancelled. #[tokio::test] async fn cancellation() { + let _guard = iroh_test::logging::setup(); let dialer = dialer::TestingDialer::default(); let getter = getter::TestingGetter::default(); // make request take some time to ensure cancellations are received on time getter.set_request_duration(Duration::from_millis(500)); let concurrency_limits = ConcurrencyLimits::default(); - let mut downloader = - Downloader::spawn_for_test(dialer.clone(), getter.clone(), concurrency_limits); + let downloader = Downloader::spawn_for_test(dialer.clone(), getter.clone(), concurrency_limits); let peer = SecretKey::generate().public(); - let kind_1 = DownloadKind::Blob { - hash: Hash::new([0u8; 32]), - }; - let handle_a = downloader - .queue(kind_1.clone(), vec![(peer, Role::Candidate).into()]) - .await; - let handle_b = downloader - .queue(kind_1.clone(), vec![(peer, Role::Candidate).into()]) - .await; + let kind_1: DownloadKind = HashAndFormat::raw(Hash::new([0u8; 32])).into(); + let req = DownloadRequest::new(kind_1, vec![peer]); + let handle_a = downloader.queue(req.clone()).await; + let handle_b = downloader.queue(req).await; downloader.cancel(handle_a).await; // create a request with two intents and cancel them both - let kind_2 = DownloadKind::Blob { - hash: Hash::new([1u8; 32]), - }; - let handle_c = downloader - .queue(kind_2.clone(), vec![(peer, Role::Candidate).into()]) - .await; - let handle_d = downloader - .queue(kind_2.clone(), vec![(peer, Role::Candidate).into()]) - .await; + let kind_2 = HashAndFormat::raw(Hash::new([1u8; 32])); + let req = DownloadRequest::new(kind_2, vec![peer]); + let handle_c = downloader.queue(req.clone()).await; + let handle_d = downloader.queue(req).await; downloader.cancel(handle_c).await; downloader.cancel(handle_d).await; @@ -138,7 +148,8 @@ async fn cancellation() { /// maximum number of concurrent requests is not exceed. /// NOTE: This is internally tested by [`Service::check_invariants`]. #[tokio::test] -async fn max_concurrent_requests() { +async fn max_concurrent_requests_total() { + let _guard = iroh_test::logging::setup(); let dialer = dialer::TestingDialer::default(); let getter = getter::TestingGetter::default(); // make request take some time to ensure concurreny limits are hit @@ -149,20 +160,16 @@ async fn max_concurrent_requests() { ..Default::default() }; - let mut downloader = - Downloader::spawn_for_test(dialer.clone(), getter.clone(), concurrency_limits); + let downloader = Downloader::spawn_for_test(dialer.clone(), getter.clone(), concurrency_limits); // send the downloads let peer = SecretKey::generate().public(); let mut handles = Vec::with_capacity(5); let mut expected_history = Vec::with_capacity(5); for i in 0..5 { - let kind = DownloadKind::Blob { - hash: Hash::new([i; 32]), - }; - let h = downloader - .queue(kind.clone(), vec![(peer, Role::Candidate).into()]) - .await; + let kind: DownloadKind = HashAndFormat::raw(Hash::new([i; 32])).into(); + let req = DownloadRequest::new(kind, vec![peer]); + let h = downloader.queue(req).await; expected_history.push((kind, peer)); handles.push(h); } @@ -184,6 +191,7 @@ async fn max_concurrent_requests() { /// NOTE: This is internally tested by [`Service::check_invariants`]. #[tokio::test] async fn max_concurrent_requests_per_peer() { + let _guard = iroh_test::logging::setup(); let dialer = dialer::TestingDialer::default(); let getter = getter::TestingGetter::default(); // make request take some time to ensure concurreny limits are hit @@ -195,60 +203,302 @@ async fn max_concurrent_requests_per_peer() { ..Default::default() }; - let mut downloader = - Downloader::spawn_for_test(dialer.clone(), getter.clone(), concurrency_limits); + let downloader = Downloader::spawn_for_test(dialer.clone(), getter.clone(), concurrency_limits); // send the downloads let peer = SecretKey::generate().public(); let mut handles = Vec::with_capacity(5); for i in 0..5 { - let kind = DownloadKind::Blob { - hash: Hash::new([i; 32]), - }; - let h = downloader - .queue(kind.clone(), vec![(peer, Role::Candidate).into()]) - .await; + let kind = HashAndFormat::raw(Hash::new([i; 32])); + let req = DownloadRequest::new(kind, vec![peer]); + let h = downloader.queue(req).await; handles.push(h); } futures::future::join_all(handles).await; } -/// Tests that providers are preferred over candidates. +/// Tests concurrent progress reporting for multiple intents. +/// +/// This first registers two intents for a download, and then proceeds until the `Found` event is +/// emitted, and verifies that both intents received the event. +/// It then registers a third intent mid-download, and makes sure it receives a correct ìnitial +/// state. The download then finishes, and we make sure that all events are emitted properly, and +/// the progress state of the handles converges. #[tokio::test] -async fn peer_role_provider() { +async fn concurrent_progress() { + let _guard = iroh_test::logging::setup(); let dialer = dialer::TestingDialer::default(); - dialer.set_dial_duration(Duration::from_millis(100)); let getter = getter::TestingGetter::default(); - let concurrency_limits = ConcurrencyLimits::default(); - let mut downloader = - Downloader::spawn_for_test(dialer.clone(), getter.clone(), concurrency_limits); + let (start_tx, start_rx) = oneshot::channel(); + let start_rx = start_rx.shared(); + + let (done_tx, done_rx) = oneshot::channel(); + let done_rx = done_rx.shared(); + + getter.set_handler(Arc::new(move |hash, _peer, progress, _duration| { + let start_rx = start_rx.clone(); + let done_rx = done_rx.clone(); + async move { + let hash = hash.hash(); + start_rx.await.unwrap(); + let id = progress.new_id(); + progress + .send(DownloadProgress::Found { + id, + child: BlobId::Root, + hash, + size: 100, + }) + .await + .unwrap(); + done_rx.await.unwrap(); + progress.send(DownloadProgress::Done { id }).await.unwrap(); + Ok(Stats::default()) + } + .boxed() + })); + let downloader = Downloader::spawn_for_test(dialer.clone(), getter.clone(), Default::default()); + + let peer = SecretKey::generate().public(); + let hash = Hash::new([0u8; 32]); + let kind_1 = HashAndFormat::raw(hash); + + let (prog_a_tx, prog_a_rx) = flume::bounded(64); + let prog_a_tx = FlumeProgressSender::new(prog_a_tx); + let req = DownloadRequest::new(kind_1, vec![peer]).progress_sender(prog_a_tx); + let handle_a = downloader.queue(req).await; + + let (prog_b_tx, prog_b_rx) = flume::bounded(64); + let prog_b_tx = FlumeProgressSender::new(prog_b_tx); + let req = DownloadRequest::new(kind_1, vec![peer]).progress_sender(prog_b_tx); + let handle_b = downloader.queue(req).await; + + start_tx.send(()).unwrap(); + + let mut state_a = TransferState::new(hash); + let mut state_b = TransferState::new(hash); + let mut state_c = TransferState::new(hash); + + let prog1_a = prog_a_rx.recv_async().await.unwrap(); + let prog1_b = prog_b_rx.recv_async().await.unwrap(); + assert!(matches!(prog1_a, DownloadProgress::Found { hash, size: 100, ..} if hash == hash)); + assert!(matches!(prog1_b, DownloadProgress::Found { hash, size: 100, ..} if hash == hash)); + + state_a.on_progress(prog1_a); + state_b.on_progress(prog1_b); + assert_eq!(state_a, state_b); - let peer_candidate1 = SecretKey::from_bytes(&[0u8; 32]).public(); - let peer_candidate2 = SecretKey::from_bytes(&[1u8; 32]).public(); - let peer_provider = SecretKey::from_bytes(&[2u8; 32]).public(); - let kind = DownloadKind::Blob { - hash: Hash::new([0u8; 32]), + let (prog_c_tx, prog_c_rx) = flume::bounded(64); + let prog_c_tx = FlumeProgressSender::new(prog_c_tx); + let req = DownloadRequest::new(kind_1, vec![peer]).progress_sender(prog_c_tx); + let handle_c = downloader.queue(req).await; + + let prog1_c = prog_c_rx.recv_async().await.unwrap(); + assert!(matches!(&prog1_c, DownloadProgress::InitialState(state) if state == &state_a)); + state_c.on_progress(prog1_c); + + done_tx.send(()).unwrap(); + + let (res_a, res_b, res_c) = futures::future::join3(handle_a, handle_b, handle_c).await; + res_a.unwrap(); + res_b.unwrap(); + res_c.unwrap(); + + let prog_a: Vec<_> = prog_a_rx.into_stream().collect().await; + let prog_b: Vec<_> = prog_b_rx.into_stream().collect().await; + let prog_c: Vec<_> = prog_c_rx.into_stream().collect().await; + + assert_eq!(prog_a.len(), 1); + assert_eq!(prog_b.len(), 1); + assert_eq!(prog_c.len(), 1); + + assert!(matches!(prog_a[0], DownloadProgress::Done { .. })); + assert!(matches!(prog_b[0], DownloadProgress::Done { .. })); + assert!(matches!(prog_c[0], DownloadProgress::Done { .. })); + + for p in prog_a { + state_a.on_progress(p); + } + for p in prog_b { + state_b.on_progress(p); + } + for p in prog_c { + state_c.on_progress(p); + } + assert_eq!(state_a, state_b); + assert_eq!(state_a, state_c); +} + +#[tokio::test] +async fn long_queue() { + let _guard = iroh_test::logging::setup(); + let dialer = dialer::TestingDialer::default(); + let getter = getter::TestingGetter::default(); + let concurrency_limits = ConcurrencyLimits { + max_open_connections: 2, + max_concurrent_requests_per_node: 2, + max_concurrent_requests: 4, // all requests can be performed at the same time + ..Default::default() }; - let handle = downloader - .queue( - kind.clone(), - vec![ - (peer_candidate1, Role::Candidate).into(), - (peer_provider, Role::Provider).into(), - (peer_candidate2, Role::Candidate).into(), - ], - ) - .await; - let now = std::time::Instant::now(); - assert!(handle.await.is_ok(), "download succeeded"); - // this is, I think, currently the best way to test that no delay was performed. It should be - // safe enough to assume that test runtime is not longer than the delay of 500ms. - assert!( - now.elapsed() < INITIAL_REQUEST_DELAY, - "no initial delay was added to fetching from a provider" + + let downloader = Downloader::spawn_for_test(dialer.clone(), getter.clone(), concurrency_limits); + // send the downloads + let nodes = [ + SecretKey::generate().public(), + SecretKey::generate().public(), + SecretKey::generate().public(), + ]; + let mut handles = vec![]; + for i in 0..100usize { + let kind = HashAndFormat::raw(Hash::new(i.to_be_bytes())); + let peer = nodes[i % 3]; + let req = DownloadRequest::new(kind, vec![peer]); + let h = downloader.queue(req).await; + handles.push(h); + } + + let res = futures::future::join_all(handles).await; + for res in res { + res.expect("all downloads to succeed"); + } +} + +/// If a download errors with [`FailureAction::DropPeer`], make sure that the peer is not dropped +/// while other transfers are still running. +#[tokio::test] +async fn fail_while_running() { + let _guard = iroh_test::logging::setup(); + let dialer = dialer::TestingDialer::default(); + let getter = getter::TestingGetter::default(); + let downloader = Downloader::spawn_for_test(dialer.clone(), getter.clone(), Default::default()); + let blob_fail = HashAndFormat::raw(Hash::new([1u8; 32])); + let blob_success = HashAndFormat::raw(Hash::new([2u8; 32])); + + getter.set_handler(Arc::new(move |kind, _node, _progress_sender, _duration| { + async move { + if kind == blob_fail.into() { + tokio::time::sleep(Duration::from_millis(10)).await; + Err(FailureAction::DropPeer(anyhow!("bad!"))) + } else if kind == blob_success.into() { + tokio::time::sleep(Duration::from_millis(20)).await; + Ok(Default::default()) + } else { + unreachable!("invalid blob") + } + } + .boxed() + })); + + let node = SecretKey::generate().public(); + let req_success = DownloadRequest::new(blob_success, vec![node]); + let req_fail = DownloadRequest::new(blob_fail, vec![node]); + let handle_success = downloader.queue(req_success).await; + let handle_fail = downloader.queue(req_fail).await; + + let res_fail = handle_fail.await; + let res_success = handle_success.await; + + assert!(res_fail.is_err()); + assert!(res_success.is_ok()); +} + +#[tokio::test] +async fn retry_nodes_simple() { + let _guard = iroh_test::logging::setup(); + let dialer = dialer::TestingDialer::default(); + let getter = getter::TestingGetter::default(); + let downloader = Downloader::spawn_for_test(dialer.clone(), getter.clone(), Default::default()); + let node = SecretKey::generate().public(); + let dial_attempts = Arc::new(AtomicUsize::new(0)); + let dial_attempts2 = dial_attempts.clone(); + // fail on first dial, then succeed + dialer.set_dial_outcome(move |_node| dial_attempts2.fetch_add(1, Ordering::SeqCst) != 0); + let kind = HashAndFormat::raw(Hash::EMPTY); + let req = DownloadRequest::new(kind, vec![node]); + let handle = downloader.queue(req).await; + + assert!(handle.await.is_ok()); + assert_eq!(dial_attempts.load(Ordering::SeqCst), 2); + dialer.assert_history(&[node, node]); +} + +#[tokio::test] +async fn retry_nodes_fail() { + let _guard = iroh_test::logging::setup(); + let dialer = dialer::TestingDialer::default(); + let getter = getter::TestingGetter::default(); + let config = RetryConfig { + initial_retry_delay: Duration::from_millis(10), + max_retries_per_node: 3, + }; + + let downloader = Downloader::spawn_for_test_with_retry_config( + dialer.clone(), + getter.clone(), + Default::default(), + config, ); - getter.assert_history(&[(kind, peer_provider)]); - dialer.assert_history(&[peer_provider]); + let node = SecretKey::generate().public(); + // fail always + dialer.set_dial_outcome(move |_node| false); + + // queue a download + let kind = HashAndFormat::raw(Hash::EMPTY); + let req = DownloadRequest::new(kind, vec![node]); + let now = Instant::now(); + let handle = downloader.queue(req).await; + + // assert that the download failed + assert!(handle.await.is_err()); + + // assert the dial history: we dialed 4 times + dialer.assert_history(&[node, node, node, node]); + + // assert that the retry timeouts were uphold + let expected_dial_duration = Duration::from_millis(10 * 4); + let expected_retry_wait_duration = Duration::from_millis(10 + 2 * 10 + 3 * 10); + assert!(now.elapsed() >= expected_dial_duration + expected_retry_wait_duration); +} + +#[tokio::test] +async fn retry_nodes_jump_queue() { + let _guard = iroh_test::logging::setup(); + let dialer = dialer::TestingDialer::default(); + let getter = getter::TestingGetter::default(); + let concurrency_limits = ConcurrencyLimits { + max_open_connections: 2, + max_concurrent_requests_per_node: 2, + max_concurrent_requests: 4, // all requests can be performed at the same time + ..Default::default() + }; + + let downloader = Downloader::spawn_for_test(dialer.clone(), getter.clone(), concurrency_limits); + + let good_node = SecretKey::generate().public(); + let bad_node = SecretKey::generate().public(); + + dialer.set_dial_outcome(move |node| node == good_node); + let kind1 = HashAndFormat::raw(Hash::new([0u8; 32])); + let kind2 = HashAndFormat::raw(Hash::new([2u8; 32])); + + let req1 = DownloadRequest::new(kind1, vec![bad_node]); + let h1 = downloader.queue(req1).await; + + let req2 = DownloadRequest::new(kind2, vec![bad_node, good_node]); + let h2 = downloader.queue(req2).await; + + // wait for req2 to complete - this tests that the "queue is jumped" and we are not + // waiting for req1 to elapse all retries + assert!(h2.await.is_ok()); + + dialer.assert_history(&[bad_node, good_node]); + + // now we make download1 succeed! + dialer.set_dial_outcome(move |_node| true); + assert!(h1.await.is_ok()); + + // assert history + dialer.assert_history(&[bad_node, good_node, bad_node]); } diff --git a/iroh-bytes/src/downloader/test/dialer.rs b/iroh-bytes/src/downloader/test/dialer.rs index a112464b5d..a68ba575fb 100644 --- a/iroh-bytes/src/downloader/test/dialer.rs +++ b/iroh-bytes/src/downloader/test/dialer.rs @@ -23,7 +23,7 @@ struct TestingDialerInner { /// How long does a dial last. dial_duration: Duration, /// Fn deciding if a dial is successful. - dial_outcome: Box bool>, + dial_outcome: Box bool + Send + Sync + 'static>, } impl Default for TestingDialerInner { @@ -32,7 +32,7 @@ impl Default for TestingDialerInner { dialing: HashSet::default(), dial_futs: delay_queue::DelayQueue::default(), dial_history: Vec::default(), - dial_duration: Duration::ZERO, + dial_duration: Duration::from_millis(10), dial_outcome: Box::new(|_| true), } } @@ -68,10 +68,11 @@ impl futures::Stream for TestingDialer { match inner.dial_futs.poll_expired(cx) { Poll::Ready(Some(expired)) => { let node = expired.into_inner(); - let report_ok = (inner.dial_outcome)(&node); + let report_ok = (inner.dial_outcome)(node); let result = report_ok .then_some(node) .ok_or_else(|| anyhow::anyhow!("dialing test set to fail")); + inner.dialing.remove(&node); Poll::Ready(Some((node, result))) } _ => Poll::Pending, @@ -85,8 +86,11 @@ impl TestingDialer { assert_eq!(self.0.read().dial_history, history) } - pub(super) fn set_dial_duration(&self, duration: Duration) { + pub(super) fn set_dial_outcome( + &self, + dial_outcome: impl Fn(NodeId) -> bool + Send + Sync + 'static, + ) { let mut inner = self.0.write(); - inner.dial_duration = duration; + inner.dial_outcome = Box::new(dial_outcome); } } diff --git a/iroh-bytes/src/downloader/test/getter.rs b/iroh-bytes/src/downloader/test/getter.rs index b8a7f44a35..1581d84af6 100644 --- a/iroh-bytes/src/downloader/test/getter.rs +++ b/iroh-bytes/src/downloader/test/getter.rs @@ -1,5 +1,8 @@ //! Implementation of [`super::Getter`] used for testing. +use std::{sync::Arc, time::Duration}; + +use futures::future::BoxFuture; use parking_lot::RwLock; use super::*; @@ -7,12 +10,26 @@ use super::*; #[derive(Default, Clone)] pub(super) struct TestingGetter(Arc>); +pub(super) type RequestHandlerFn = Arc< + dyn Fn( + DownloadKind, + NodeId, + BroadcastProgressSender, + Duration, + ) -> BoxFuture<'static, InternalDownloadResult> + + Send + + Sync + + 'static, +>; + #[derive(Default)] struct TestingGetterInner { /// How long requests take. request_duration: Duration, /// History of requests performed by the [`Getter`] and if they were successful. request_history: Vec<(DownloadKind, NodeId)>, + /// Set a handler function which actually handles the requests. + request_handler: Option, } impl Getter for TestingGetter { @@ -20,19 +37,32 @@ impl Getter for TestingGetter { // request being sent to type Connection = NodeId; - fn get(&mut self, kind: DownloadKind, peer: NodeId) -> GetFut { + fn get( + &mut self, + kind: DownloadKind, + peer: NodeId, + progress_sender: BroadcastProgressSender, + ) -> GetFut { let mut inner = self.0.write(); inner.request_history.push((kind, peer)); let request_duration = inner.request_duration; + let handler = inner.request_handler.clone(); async move { - tokio::time::sleep(request_duration).await; - Ok(Stats::default()) + if let Some(f) = handler { + f(kind, peer, progress_sender, request_duration).await + } else { + tokio::time::sleep(request_duration).await; + Ok(Stats::default()) + } } .boxed_local() } } impl TestingGetter { + pub(super) fn set_handler(&self, handler: RequestHandlerFn) { + self.0.write().request_handler = Some(handler); + } pub(super) fn set_request_duration(&self, request_duration: Duration) { self.0.write().request_duration = request_duration; } diff --git a/iroh-bytes/src/get.rs b/iroh-bytes/src/get.rs index 47cd017932..9975007d39 100644 --- a/iroh-bytes/src/get.rs +++ b/iroh-bytes/src/get.rs @@ -30,6 +30,7 @@ use crate::IROH_BLOCK_SIZE; pub mod db; pub mod error; +pub mod progress; pub mod request; /// Stats about the transfer. diff --git a/iroh-bytes/src/get/db.rs b/iroh-bytes/src/get/db.rs index 11103d5b5e..69510d336e 100644 --- a/iroh-bytes/src/get/db.rs +++ b/iroh-bytes/src/get/db.rs @@ -9,6 +9,7 @@ use crate::protocol::RangeSpec; use crate::store::BaoBlobSize; use crate::store::FallibleProgressBatchWriter; use std::io; +use std::num::NonZeroU64; use crate::hashseq::parse_hash_seq; use crate::store::BaoBatchWriter; @@ -18,6 +19,7 @@ use crate::{ self, error::GetError, fsm::{AtBlobHeader, AtEndBlob, ConnectedNext, EndBlobNext}, + progress::TransferState, Stats, }, protocol::{GetRequest, RangeSpecSeq}, @@ -74,7 +76,7 @@ async fn get_blob< tracing::info!("already got entire blob"); progress .send(DownloadProgress::FoundLocal { - child: 0, + child: BlobId::Root, hash: *hash, size: entry.size(), valid_ranges: RangeSpec::all(), @@ -90,7 +92,7 @@ async fn get_blob< .unwrap_or_else(ChunkRanges::all); progress .send(DownloadProgress::FoundLocal { - child: 0, + child: BlobId::Root, hash: *hash, size: entry.size(), valid_ranges: RangeSpec::new(&valid_ranges), @@ -186,7 +188,7 @@ async fn get_blob_inner( id, hash, size, - child: child_offset, + child: BlobId::from_offset(child_offset), }) .await?; let sender2 = sender.clone(); @@ -237,7 +239,7 @@ async fn get_blob_inner_partial( id, hash, size, - child: child_offset, + child: BlobId::from_offset(child_offset), }) .await?; let sender2 = sender.clone(); @@ -316,7 +318,7 @@ async fn get_hash_seq< // send info that we have the hashseq itself entirely sender .send(DownloadProgress::FoundLocal { - child: 0, + child: BlobId::Root, hash: *root_hash, size: entry.size(), valid_ranges: RangeSpec::all(), @@ -343,7 +345,7 @@ async fn get_hash_seq< if let Some(size) = info.size() { sender .send(DownloadProgress::FoundLocal { - child: (i as u64) + 1, + child: BlobId::from_offset((i as u64) + 1), hash: children[i], size, valid_ranges: RangeSpec::new(&info.valid_ranges()), @@ -521,12 +523,15 @@ impl BlobInfo { } /// Progress updates for the get operation. +// TODO: Move to super::progress #[derive(Debug, Clone, Serialize, Deserialize)] pub enum DownloadProgress { + /// Initial state if subscribing to a running or queued transfer. + InitialState(TransferState), /// Data was found locally. FoundLocal { /// child offset - child: u64, + child: BlobId, /// The hash of the entry. hash: Hash, /// The size of the entry in bytes. @@ -538,10 +543,13 @@ pub enum DownloadProgress { Connected, /// An item was found with hash `hash`, from now on referred to via `id`. Found { - /// A new unique id for this entry. + /// A new unique progress id for this entry. id: u64, - /// child offset - child: u64, + /// Identifier for this blob within this download. + /// + /// Will always be [`BlobId::Root`] unless a hashseq is downloaded, in which case this + /// allows to identify the children by their offset in the hashseq. + child: BlobId, /// The hash of the entry. hash: Hash, /// The size of the entry in bytes. @@ -575,3 +583,29 @@ pub enum DownloadProgress { /// This will be the last message in the stream. Abort(RpcError), } + +/// The id of a blob in a transfer +#[derive( + Debug, Copy, Clone, Ord, PartialOrd, Eq, PartialEq, std::hash::Hash, Serialize, Deserialize, +)] +pub enum BlobId { + /// The root blob (child id 0) + Root, + /// A child blob (child id > 0) + Child(NonZeroU64), +} + +impl BlobId { + fn from_offset(id: u64) -> Self { + NonZeroU64::new(id).map(Self::Child).unwrap_or(Self::Root) + } +} + +impl From for u64 { + fn from(value: BlobId) -> Self { + match value { + BlobId::Root => 0, + BlobId::Child(id) => id.into(), + } + } +} diff --git a/iroh-bytes/src/get/progress.rs b/iroh-bytes/src/get/progress.rs new file mode 100644 index 0000000000..5fcac04b2f --- /dev/null +++ b/iroh-bytes/src/get/progress.rs @@ -0,0 +1,186 @@ +//! Types for get progress state management. + +use std::{collections::HashMap, num::NonZeroU64}; + +use serde::{Deserialize, Serialize}; +use tracing::warn; + +use crate::{protocol::RangeSpec, store::BaoBlobSize, Hash}; + +use super::db::{BlobId, DownloadProgress}; + +/// The identifier for progress events. +pub type ProgressId = u64; + +/// Accumulated progress state of a transfer. +#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)] +pub struct TransferState { + /// The root blob of this transfer (may be a hash seq), + pub root: BlobState, + /// Whether we are connected to a node + pub connected: bool, + /// Children if the root blob is a hash seq, empty for raw blobs + pub children: HashMap, + /// Child being transferred at the moment. + pub current: Option, + /// Progress ids for individual blobs. + pub progress_id_to_blob: HashMap, +} + +impl TransferState { + /// Create a new, empty transfer state. + pub fn new(root_hash: Hash) -> Self { + Self { + root: BlobState::new(root_hash), + connected: false, + children: Default::default(), + current: None, + progress_id_to_blob: Default::default(), + } + } +} + +/// State of a single blob in transfer +#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)] +pub struct BlobState { + /// The hash of this blob. + pub hash: Hash, + /// The size of this blob. Only known if the blob is partially present locally, or after having + /// received the size from the remote. + pub size: Option, + /// The current state of the blob transfer. + pub progress: BlobProgress, + /// Ranges already available locally at the time of starting the transfer. + pub local_ranges: Option, + /// Number of children (only applies to hashseqs, None for raw blobs). + pub child_count: Option, +} + +/// Progress state for a single blob +#[derive(Debug, Default, Clone, Serialize, Deserialize, Eq, PartialEq)] +pub enum BlobProgress { + /// Download is pending + #[default] + Pending, + /// Download is in progress + Progressing(u64), + /// Download has finished + Done, +} + +impl BlobState { + /// Create a new [`BlobState`]. + pub fn new(hash: Hash) -> Self { + Self { + hash, + size: None, + local_ranges: None, + child_count: None, + progress: BlobProgress::default(), + } + } +} + +impl TransferState { + /// Get state of the root blob of this transfer. + pub fn root(&self) -> &BlobState { + &self.root + } + + /// Get a blob state by its [`BlobId`] in this transfer. + pub fn get_blob(&self, blob_id: &BlobId) -> Option<&BlobState> { + match blob_id { + BlobId::Root => Some(&self.root), + BlobId::Child(id) => self.children.get(id), + } + } + + /// Get the blob state currently being transferred. + pub fn get_current(&self) -> Option<&BlobState> { + self.current.as_ref().and_then(|id| self.get_blob(id)) + } + + fn get_or_insert_blob(&mut self, blob_id: BlobId, hash: Hash) -> &mut BlobState { + match blob_id { + BlobId::Root => &mut self.root, + BlobId::Child(id) => self + .children + .entry(id) + .or_insert_with(|| BlobState::new(hash)), + } + } + fn get_blob_mut(&mut self, blob_id: &BlobId) -> Option<&mut BlobState> { + match blob_id { + BlobId::Root => Some(&mut self.root), + BlobId::Child(id) => self.children.get_mut(id), + } + } + + fn get_by_progress_id(&mut self, progress_id: ProgressId) -> Option<&mut BlobState> { + let blob_id = *self.progress_id_to_blob.get(&progress_id)?; + self.get_blob_mut(&blob_id) + } + + /// Update the state with a new [`DownloadProgress`] event for this transfer. + pub fn on_progress(&mut self, event: DownloadProgress) { + match event { + DownloadProgress::InitialState(s) => { + *self = s; + } + DownloadProgress::FoundLocal { + child, + hash, + size, + valid_ranges, + } => { + let blob = self.get_or_insert_blob(child, hash); + blob.size = Some(size); + blob.local_ranges = Some(valid_ranges); + } + DownloadProgress::Connected => self.connected = true, + DownloadProgress::Found { + id: progress_id, + child: blob_id, + hash, + size, + } => { + let blob = self.get_or_insert_blob(blob_id, hash); + blob.size = match blob.size { + // If we don't have a verified size for this blob yet: Use the size as reported + // by the remote. + None | Some(BaoBlobSize::Unverified(_)) => Some(BaoBlobSize::Unverified(size)), + // Otherwise, keep the existing verified size. + value @ Some(BaoBlobSize::Verified(_)) => value, + }; + blob.progress = BlobProgress::Progressing(0); + self.progress_id_to_blob.insert(progress_id, blob_id); + self.current = Some(blob_id); + } + DownloadProgress::FoundHashSeq { hash, children } => { + if hash == self.root.hash { + self.root.child_count = Some(children); + } else { + // I think it is an invariant of the protocol that `FoundHashSeq` is only + // triggered for the root hash. + warn!("Received `FoundHashSeq` event for a hash which is not the download's root hash.") + } + } + DownloadProgress::Progress { id, offset } => { + if let Some(blob) = self.get_by_progress_id(id) { + blob.progress = BlobProgress::Progressing(offset); + } else { + warn!(%id, "Received `Progress` event for unknown progress id.") + } + } + DownloadProgress::Done { id } => { + if let Some(blob) = self.get_by_progress_id(id) { + blob.progress = BlobProgress::Done; + self.progress_id_to_blob.remove(&id); + } else { + warn!(%id, "Received `Done` event for unknown progress id.") + } + } + _ => {} + } + } +} diff --git a/iroh-bytes/src/store/traits.rs b/iroh-bytes/src/store/traits.rs index 6c5dd5aaca..f8a88b0784 100644 --- a/iroh-bytes/src/store/traits.rs +++ b/iroh-bytes/src/store/traits.rs @@ -45,7 +45,7 @@ pub enum EntryStatus { } /// The size of a bao file -#[derive(Debug, Clone, Copy, Serialize, Deserialize)] +#[derive(Debug, Clone, Copy, Serialize, Deserialize, Eq, PartialEq)] pub enum BaoBlobSize { /// A remote side told us the size, but we have insufficient data to verify it. Unverified(u64), diff --git a/iroh-bytes/src/util.rs b/iroh-bytes/src/util.rs index 57812cb58d..b540b88562 100644 --- a/iroh-bytes/src/util.rs +++ b/iroh-bytes/src/util.rs @@ -6,7 +6,7 @@ use range_collections::range_set::RangeSetRange; use serde::{Deserialize, Serialize}; use std::{borrow::Borrow, fmt, sync::Arc, time::SystemTime}; -use crate::{BlobFormat, Hash, HashAndFormat, IROH_BLOCK_SIZE}; +use crate::{store::Store, BlobFormat, Hash, HashAndFormat, IROH_BLOCK_SIZE}; pub mod io; mod mem_or_file; @@ -121,6 +121,64 @@ impl Tag { } } +/// A set of merged [`SetTagOption`]s for a blob. +#[derive(Debug, Default)] +pub struct TagSet { + auto: bool, + named: Vec, +} + +impl TagSet { + /// Insert a new tag into the set. + pub fn insert(&mut self, tag: SetTagOption) { + match tag { + SetTagOption::Auto => self.auto = true, + SetTagOption::Named(tag) => { + if !self.named.iter().any(|t| t == &tag) { + self.named.push(tag) + } + } + } + } + + /// Convert the [`TagSet`] into a list of [`SetTagOption`]. + pub fn into_tags(self) -> impl Iterator { + self.auto + .then_some(SetTagOption::Auto) + .into_iter() + .chain(self.named.into_iter().map(SetTagOption::Named)) + } + + /// Apply the tags in the [`TagSet`] to the database. + pub async fn apply( + self, + db: &D, + hash_and_format: HashAndFormat, + ) -> std::io::Result<()> { + let tags = self.into_tags(); + for tag in tags { + match tag { + SetTagOption::Named(tag) => { + db.set_tag(tag, Some(hash_and_format)).await?; + } + SetTagOption::Auto => { + db.create_tag(hash_and_format).await?; + } + } + } + Ok(()) + } +} + +/// Option for commands that allow setting a tag +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] +pub enum SetTagOption { + /// A tag will be automatically generated + Auto, + /// The tag is explicitly named + Named(Tag), +} + /// A trait for things that can track liveness of blobs and collections. /// /// This trait works together with [TempTag] to keep track of the liveness of a diff --git a/iroh-bytes/src/util/progress.rs b/iroh-bytes/src/util/progress.rs index d80ec3533e..308f4aab7a 100644 --- a/iroh-bytes/src/util/progress.rs +++ b/iroh-bytes/src/util/progress.rs @@ -471,13 +471,18 @@ impl Clone for FlumeProgressSender { } impl FlumeProgressSender { - /// Create a new progress sender from a tokio mpsc sender. + /// Create a new progress sender from a flume sender. pub fn new(sender: flume::Sender) -> Self { Self { sender, id: std::sync::Arc::new(std::sync::atomic::AtomicU64::new(0)), } } + + /// Returns true if `other` sends on the same `flume` channel as `self`. + pub fn same_channel(&self, other: &FlumeProgressSender) -> bool { + self.sender.same_channel(&other.sender) + } } impl IdGenerator for FlumeProgressSender { diff --git a/iroh-cli/src/commands/blob.rs b/iroh-cli/src/commands/blob.rs index 6fab9ea4e7..fcd2597965 100644 --- a/iroh-cli/src/commands/blob.rs +++ b/iroh-cli/src/commands/blob.rs @@ -14,7 +14,7 @@ use indicatif::{ ProgressStyle, }; use iroh::bytes::{ - get::{db::DownloadProgress, Stats}, + get::{db::DownloadProgress, progress::BlobProgress, Stats}, provider::AddProgress, store::{ConsistencyCheckProgress, ExportFormat, ExportMode, ReportLevel, ValidateProgress}, BlobFormat, Hash, HashAndFormat, Tag, @@ -24,7 +24,7 @@ use iroh::{ client::{BlobStatus, Iroh, ShareTicketOptions}, rpc_protocol::{ BlobDownloadRequest, BlobListCollectionsResponse, BlobListIncompleteResponse, - BlobListResponse, ProviderService, SetTagOption, WrapOption, + BlobListResponse, DownloadMode, ProviderService, SetTagOption, WrapOption, }, ticket::BlobTicket, }; @@ -81,6 +81,12 @@ pub enum BlobCommands { /// Tag to tag the data with. #[clap(long)] tag: Option, + /// If set, will queue the download in the download queue. + /// + /// Use this if you are doing many downloads in parallel and want to limit the number of + /// downloads running concurrently. + #[clap(long)] + queued: bool, }, /// Export a blob from the internal blob store to the local filesystem. Export { @@ -183,6 +189,7 @@ impl BlobCommands { out, stable, tag, + queued, } => { let (node_addr, hash, format) = match ticket { TicketOrHash::Ticket(ticket) => { @@ -241,13 +248,19 @@ impl BlobCommands { None => SetTagOption::Auto, }; + let mode = match queued { + true => DownloadMode::Queued, + false => DownloadMode::Direct, + }; + let mut stream = iroh .blobs .download(BlobDownloadRequest { hash, format, - peer: node_addr, + nodes: vec![node_addr], tag, + mode, }) .await?; @@ -277,6 +290,7 @@ impl BlobCommands { }; tracing::info!("exporting to {} -> {}", path.display(), absolute.display()); let stream = iroh.blobs.export(hash, absolute, format, mode).await?; + // TODO: report export progress stream.await?; } @@ -1009,6 +1023,36 @@ pub async fn show_download_progress( let mut seq = false; while let Some(x) = stream.next().await { match x? { + DownloadProgress::InitialState(state) => { + if state.connected { + op.set_message(format!("{} Requesting ...\n", style("[2/3]").bold().dim())); + } + if let Some(count) = state.root.child_count { + op.set_message(format!( + "{} Downloading {} blob(s)\n", + style("[3/3]").bold().dim(), + count + 1, + )); + op.set_length(count + 1); + op.reset(); + op.set_position(state.current.map(u64::from).unwrap_or(0)); + seq = true; + } + if let Some(blob) = state.get_current() { + if let Some(size) = blob.size { + ip.set_length(size.value()); + ip.reset(); + match blob.progress { + BlobProgress::Pending => {} + BlobProgress::Progressing(offset) => ip.set_position(offset), + BlobProgress::Done => ip.finish_and_clear(), + } + if !seq { + op.finish_and_clear(); + } + } + } + } DownloadProgress::FoundLocal { .. } => {} DownloadProgress::Connected => { op.set_message(format!("{} Requesting ...\n", style("[2/3]").bold().dim())); @@ -1025,7 +1069,7 @@ pub async fn show_download_progress( } DownloadProgress::Found { size, child, .. } => { if seq { - op.set_position(child); + op.set_position(child.into()); } else { op.finish_and_clear(); } diff --git a/iroh/examples/collection-fetch.rs b/iroh/examples/collection-fetch.rs index e21ac30139..53fac68f52 100644 --- a/iroh/examples/collection-fetch.rs +++ b/iroh/examples/collection-fetch.rs @@ -4,7 +4,7 @@ //! This is using an in memory database and a random node id. //! Run the `collection-provide` example, which will give you instructions on how to run this example. use anyhow::{bail, ensure, Context, Result}; -use iroh::rpc_protocol::BlobDownloadRequest; +use iroh::rpc_protocol::{BlobDownloadRequest, DownloadMode}; use iroh_bytes::BlobFormat; use std::env; use std::str::FromStr; @@ -60,15 +60,18 @@ async fn main() -> Result<()> { // When interacting with the iroh API, you will most likely be using blobs and collections. format: ticket.format(), - // The `peer` field is a `NodeAddr`, which combines all of the known address information we have for the remote node. + // The `nodes` field is a list of `NodeAddr`, where each combines all of the known address information we have for the remote node. // This includes the `node_id` (or `PublicKey` of the node), any direct UDP addresses we know about for that node, as well as the relay url of that node. The relay url is the url of the relay server that that node is connected to. // If the direct UDP addresses to that node do not work, than we can use the relay node to attempt to holepunch between your current node and the remote node. // If holepunching fails, iroh will use the relay node to proxy a connection to the remote node over HTTPS. // Thankfully, the ticket contains all of this information - peer: ticket.node_addr().clone(), + nodes: vec![ticket.node_addr().clone()], // You can create a special tag name (`SetTagOption::Named`), or create an automatic tag that is derived from the timestamp. tag: iroh::rpc_protocol::SetTagOption::Auto, + + // Whether to use the download queue, or do a direct download. + mode: DownloadMode::Direct, }; // `download` returns a stream of `DownloadProgress` events. You can iterate through these updates to get progress on the state of your download. diff --git a/iroh/examples/hello-world-fetch.rs b/iroh/examples/hello-world-fetch.rs index 7294c67093..fcc75cdbca 100644 --- a/iroh/examples/hello-world-fetch.rs +++ b/iroh/examples/hello-world-fetch.rs @@ -4,7 +4,7 @@ //! This is using an in memory database and a random node id. //! Run the `provide` example, which will give you instructions on how to run this example. use anyhow::{bail, ensure, Context, Result}; -use iroh::rpc_protocol::BlobDownloadRequest; +use iroh::rpc_protocol::{BlobDownloadRequest, DownloadMode}; use iroh_bytes::BlobFormat; use std::env; use std::str::FromStr; @@ -60,15 +60,18 @@ async fn main() -> Result<()> { // When interacting with the iroh API, you will most likely be using blobs and collections. format: ticket.format(), - // The `peer` field is a `NodeAddr`, which combines all of the known address information we have for the remote node. + // The `nodes` field is a list of `NodeAddr`, where each combines all of the known address information we have for the remote node. // This includes the `node_id` (or `PublicKey` of the node), any direct UDP addresses we know about for that node, as well as the relay url of that node. The relay url is the url of the relay server that that node is connected to. // If the direct UDP addresses to that node do not work, than we can use the relay node to attempt to holepunch between your current node and the remote node. // If holepunching fails, iroh will use the relay node to proxy a connection to the remote node over HTTPS. // Thankfully, the ticket contains all of this information - peer: ticket.node_addr().clone(), + nodes: vec![ticket.node_addr().clone()], // You can create a special tag name (`SetTagOption::Named`), or create an automatic tag that is derived from the timestamp. tag: iroh::rpc_protocol::SetTagOption::Auto, + + // Whether to use the download queue, or do a direct download. + mode: DownloadMode::Direct, }; // `download` returns a stream of `DownloadProgress` events. You can iterate through these updates to get progress on the state of your download. diff --git a/iroh/src/node.rs b/iroh/src/node.rs index 663c8bf19f..8bd8f01ca6 100644 --- a/iroh/src/node.rs +++ b/iroh/src/node.rs @@ -16,6 +16,7 @@ use std::task::Poll; use anyhow::{anyhow, Result}; use futures::future::{BoxFuture, Shared}; use futures::{FutureExt, StreamExt}; +use iroh_bytes::downloader::Downloader; use iroh_bytes::store::Store as BaoStore; use iroh_bytes::BlobFormat; use iroh_bytes::Hash; @@ -108,6 +109,7 @@ struct NodeInner { #[debug("rt")] rt: LocalPoolHandle, pub(crate) sync: SyncEngine, + downloader: Downloader, } /// Events emitted by the [`Node`] informing about the current status. @@ -289,7 +291,8 @@ mod tests { use crate::{ client::BlobAddOutcome, rpc_protocol::{ - BlobAddPathRequest, BlobAddPathResponse, BlobDownloadRequest, SetTagOption, WrapOption, + BlobAddPathRequest, BlobAddPathResponse, BlobDownloadRequest, DownloadMode, + SetTagOption, WrapOption, }, }; @@ -437,7 +440,8 @@ mod tests { hash, tag: SetTagOption::Auto, format: BlobFormat::Raw, - peer: addr, + mode: DownloadMode::Direct, + nodes: vec![addr], }; node2.blobs.download(req).await?.await?; assert_eq!( diff --git a/iroh/src/node/builder.rs b/iroh/src/node/builder.rs index 5ef192e719..564ca87f01 100644 --- a/iroh/src/node/builder.rs +++ b/iroh/src/node/builder.rs @@ -397,7 +397,7 @@ where gossip.clone(), self.docs_store, self.blobs_store.clone(), - downloader, + downloader.clone(), ); let sync_db = sync.sync.clone(); @@ -425,6 +425,7 @@ where gc_task, rt: lp.clone(), sync, + downloader, }); let task = { let gossip = gossip.clone(); diff --git a/iroh/src/node/rpc.rs b/iroh/src/node/rpc.rs index 99bf28ce5e..870620ab02 100644 --- a/iroh/src/node/rpc.rs +++ b/iroh/src/node/rpc.rs @@ -3,15 +3,17 @@ use std::io; use std::sync::{Arc, Mutex}; use std::time::Duration; -use anyhow::{anyhow, Result}; -use futures::{Future, FutureExt, Stream, StreamExt}; +use anyhow::{anyhow, ensure, Result}; +use futures::{FutureExt, Stream, StreamExt}; use genawaiter::sync::{Co, Gen}; use iroh_base::rpc::RpcResult; +use iroh_bytes::downloader::{DownloadRequest, Downloader}; use iroh_bytes::export::ExportProgress; use iroh_bytes::format::collection::Collection; use iroh_bytes::get::db::DownloadProgress; +use iroh_bytes::get::Stats; use iroh_bytes::store::{ConsistencyCheckProgress, ExportFormat, ImportProgress, MapEntry}; -use iroh_bytes::util::progress::{IdGenerator, ProgressSender}; +use iroh_bytes::util::progress::ProgressSender; use iroh_bytes::BlobFormat; use iroh_bytes::{ hashseq::parse_hash_seq, @@ -21,6 +23,7 @@ use iroh_bytes::{ HashAndFormat, }; use iroh_io::AsyncSliceReader; +use iroh_net::{MagicEndpoint, NodeAddr}; use quic_rpc::{ server::{RpcChannel, RpcServerError}, ServiceEndpoint, @@ -37,10 +40,11 @@ use crate::rpc_protocol::{ BlobReadAtRequest, BlobReadAtResponse, BlobValidateRequest, CreateCollectionRequest, CreateCollectionResponse, DeleteTagRequest, DocExportFileRequest, DocExportFileResponse, DocImportFileRequest, DocImportFileResponse, DocImportProgress, DocSetHashRequest, - ListTagsRequest, ListTagsResponse, NodeConnectionInfoRequest, NodeConnectionInfoResponse, - NodeConnectionsRequest, NodeConnectionsResponse, NodeShutdownRequest, NodeStatsRequest, - NodeStatsResponse, NodeStatusRequest, NodeStatusResponse, NodeWatchRequest, NodeWatchResponse, - ProviderRequest, ProviderService, SetTagOption, + DownloadMode, ListTagsRequest, ListTagsResponse, NodeConnectionInfoRequest, + NodeConnectionInfoResponse, NodeConnectionsRequest, NodeConnectionsResponse, + NodeShutdownRequest, NodeStatsRequest, NodeStatsResponse, NodeStatusRequest, + NodeStatusResponse, NodeWatchRequest, NodeWatchResponse, ProviderRequest, ProviderService, + SetTagOption, }; use super::{Event, NodeInner}; @@ -602,38 +606,17 @@ impl Handler { fn blob_download(self, msg: BlobDownloadRequest) -> impl Stream { let (sender, receiver) = flume::bounded(1024); - let progress = FlumeProgressSender::new(sender); - - let BlobDownloadRequest { - hash, - format, - peer, - tag, - } = msg; - let db = self.inner.db.clone(); - let hash_and_format = HashAndFormat { hash, format }; - let temp_pin = self.inner.db.temp_tag(hash_and_format); - let get_conn = { - let progress = progress.clone(); - let ep = self.inner.endpoint.clone(); - move || async move { - let conn = ep.connect(peer, iroh_bytes::protocol::ALPN).await?; - progress.send(DownloadProgress::Connected).await?; - Ok(conn) - } - }; - + let downloader = self.inner.downloader.clone(); + let endpoint = self.inner.endpoint.clone(); + let progress = FlumeProgressSender::new(sender); self.inner.rt.spawn_pinned(move || async move { - if let Err(err) = - download_blob(db, get_conn, hash_and_format, tag, progress.clone()).await - { + if let Err(err) = download(&db, endpoint, &downloader, msg, progress.clone()).await { progress .send(DownloadProgress::Abort(err.into())) .await .ok(); } - drop(temp_pin); }); receiver.into_stream().map(BlobDownloadResponse) @@ -1052,31 +1035,138 @@ impl Handler { } } -async fn download_blob( - db: D, - get_conn: C, +async fn download( + db: &D, + endpoint: MagicEndpoint, + downloader: &Downloader, + req: BlobDownloadRequest, + progress: FlumeProgressSender, +) -> Result<()> +where + D: iroh_bytes::store::Store, +{ + let BlobDownloadRequest { + hash, + format, + nodes, + tag, + mode, + } = req; + let hash_and_format = HashAndFormat { hash, format }; + let stats = match mode { + DownloadMode::Queued => { + download_queued( + endpoint, + downloader, + hash_and_format, + nodes, + tag, + progress.clone(), + ) + .await? + } + DownloadMode::Direct => { + download_direct_from_nodes(db, endpoint, hash_and_format, nodes, tag, progress.clone()) + .await? + } + }; + + progress.send(DownloadProgress::AllDone(stats)).await.ok(); + + Ok(()) +} + +async fn download_queued( + endpoint: MagicEndpoint, + downloader: &Downloader, + hash_and_format: HashAndFormat, + nodes: Vec, + tag: SetTagOption, + progress: FlumeProgressSender, +) -> Result { + let mut node_ids = Vec::with_capacity(nodes.len()); + for node in nodes { + node_ids.push(node.node_id); + endpoint.add_node_addr(node)?; + } + let req = DownloadRequest::new(hash_and_format, node_ids) + .progress_sender(progress) + .tag(tag); + let handle = downloader.queue(req).await; + let stats = handle.await?; + Ok(stats) +} + +async fn download_direct_from_nodes( + db: &D, + endpoint: MagicEndpoint, hash_and_format: HashAndFormat, + nodes: Vec, tag: SetTagOption, - progress: impl ProgressSender + IdGenerator, -) -> Result<()> + progress: FlumeProgressSender, +) -> Result where D: BaoStore, - C: FnOnce() -> F, - F: Future>, { - let stats = - iroh_bytes::get::db::get_to_db(&db, get_conn, &hash_and_format, progress.clone()).await?; + ensure!(!nodes.is_empty(), "No nodes to download from provided."); + let mut last_err = None; + for node in nodes { + let node_id = node.node_id; + match download_direct( + db, + endpoint.clone(), + hash_and_format, + node, + tag.clone(), + progress.clone(), + ) + .await + { + Ok(stats) => return Ok(stats), + Err(err) => { + debug!(?err, node = &node_id.fmt_short(), "Download failed"); + last_err = Some(err) + } + } + } + Err(last_err.unwrap()) +} - match tag { - SetTagOption::Named(tag) => { - db.set_tag(tag, Some(hash_and_format)).await?; +async fn download_direct( + db: &D, + endpoint: MagicEndpoint, + hash_and_format: HashAndFormat, + node: NodeAddr, + tag: SetTagOption, + progress: FlumeProgressSender, +) -> Result +where + D: BaoStore, +{ + let temp_pin = db.temp_tag(hash_and_format); + let get_conn = { + let progress = progress.clone(); + move || async move { + let conn = endpoint.connect(node, iroh_bytes::protocol::ALPN).await?; + progress.send(DownloadProgress::Connected).await?; + Ok(conn) } - SetTagOption::Auto => { - db.create_tag(hash_and_format).await?; + }; + + let res = iroh_bytes::get::db::get_to_db(db, get_conn, &hash_and_format, progress).await; + + if res.is_ok() { + match tag { + SetTagOption::Named(tag) => { + db.set_tag(tag, Some(hash_and_format)).await?; + } + SetTagOption::Auto => { + db.create_tag(hash_and_format).await?; + } } } - progress.send(DownloadProgress::AllDone(stats)).await.ok(); + drop(temp_pin); - Ok(()) + res.map_err(Into::into) } diff --git a/iroh/src/rpc_protocol.rs b/iroh/src/rpc_protocol.rs index 053150f51e..ce71da856d 100644 --- a/iroh/src/rpc_protocol.rs +++ b/iroh/src/rpc_protocol.rs @@ -39,19 +39,11 @@ pub use iroh_bytes::{provider::AddProgress, store::ValidateProgress}; use crate::sync_engine::LiveEvent; pub use crate::ticket::DocTicket; +pub use iroh_bytes::util::SetTagOption; /// A 32-byte key or token pub type KeyBytes = [u8; 32]; -/// Option for commands that allow setting a tag -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] -pub enum SetTagOption { - /// A tag will be automatically generated - Auto, - /// The tag is explicitly named - Named(Tag), -} - /// A request to the node to provide the data at the given path /// /// Will produce a stream of [`AddProgress`] messages. @@ -104,10 +96,31 @@ pub struct BlobDownloadRequest { /// If the format is [`BlobFormat::HashSeq`], all children are downloaded and shared as /// well. pub format: BlobFormat, - /// This mandatory field specifies the peer to download the data from. - pub peer: NodeAddr, + /// This mandatory field specifies the nodes to download the data from. + /// + /// If set to more than a single node, they will all be tried. If `mode` is set to + /// [`DownloadMode::Direct`], they will be tried sequentially until a download succeeds. + /// If `mode` is set to [`DownloadMode::Queued`], the nodes may be dialed in parallel, + /// if the concurrency limits permit. + pub nodes: Vec, /// Optional tag to tag the data with. pub tag: SetTagOption, + /// Whether to directly start the download or add it to the downlod queue. + pub mode: DownloadMode, +} + +/// Set the mode for whether to directly start the download or add it to the download queue. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum DownloadMode { + /// Start the download right away. + /// + /// No concurrency limits or queuing will be applied. It is up to the user to manage download + /// concurrency. + Direct, + /// Queue the download. + /// + /// The download queue will be processed in-order, while respecting the downloader concurrency limits. + Queued, } impl Msg for BlobDownloadRequest { diff --git a/iroh/src/sync_engine/gossip.rs b/iroh/src/sync_engine/gossip.rs index 8daa8781f8..8912887a9f 100644 --- a/iroh/src/sync_engine/gossip.rs +++ b/iroh/src/sync_engine/gossip.rs @@ -15,7 +15,7 @@ use tokio::{ use tracing::{debug, error, trace}; use super::live::{Op, ToLiveActor}; -use iroh_bytes::downloader::{Downloader, Role}; +use iroh_bytes::downloader::Downloader; #[derive(strum::Display, Debug)] pub enum ToGossipActor { @@ -179,7 +179,7 @@ impl GossipActor { // Inform the downloader that we now know that this peer has the content // for this hash. self.downloader - .nodes_have(hash, vec![(msg.delivered_from, Role::Provider).into()]) + .nodes_have(hash, vec![msg.delivered_from]) .await; } Op::SyncReport(report) => { diff --git a/iroh/src/sync_engine/live.rs b/iroh/src/sync_engine/live.rs index bad616e6af..ac39a2aa89 100644 --- a/iroh/src/sync_engine/live.rs +++ b/iroh/src/sync_engine/live.rs @@ -4,7 +4,8 @@ use std::{collections::HashMap, time::SystemTime}; use anyhow::{Context, Result}; use futures::FutureExt; -use iroh_bytes::downloader::{DownloadKind, Downloader, Role}; +use iroh_bytes::downloader::{DownloadRequest, Downloader}; +use iroh_bytes::HashAndFormat; use iroh_bytes::{store::EntryStatus, Hash}; use iroh_gossip::{net::Gossip, proto::TopicId}; use iroh_net::{key::PublicKey, MagicEndpoint, NodeAddr}; @@ -635,15 +636,13 @@ impl LiveActor { if matches!(entry_status, EntryStatus::NotFound | EntryStatus::Partial) && should_download { - let from = PublicKey::from_bytes(&from)?; - let role = match remote_content_status { - ContentStatus::Complete => Role::Provider, - _ => Role::Candidate, + let mut nodes = vec![]; + if let ContentStatus::Complete = remote_content_status { + let node_id = PublicKey::from_bytes(&from)?; + nodes.push(node_id); }; - let handle = self - .downloader - .queue(DownloadKind::Blob { hash }, vec![(from, role).into()]) - .await; + let req = DownloadRequest::untagged(HashAndFormat::raw(hash), nodes); + let handle = self.downloader.queue(req).await; self.pending_downloads.spawn(async move { // NOTE: this ignores the result for now, simply keeping the option