diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 074a3bea4a..d027f93ccf 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -197,6 +197,8 @@ jobs: steps: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@stable + with: + components: clippy - name: Install sccache uses: mozilla-actions/sccache-action@v0.0.4 diff --git a/CHANGELOG.md b/CHANGELOG.md index 689f157eb8..27a6234b3c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,7 +2,40 @@ All notable changes to iroh will be documented in this file. -## [0.17.0](https://github.com/n0-computer/iroh/compare/v0.16.0..0.17.0) - 2024-05-24 +## [0.18.0](https://github.com/n0-computer/iroh/compare/v0.17.0..0.18.0) - 2024-06-07 + +### โ›ฐ๏ธ Features + +- *(iroh-gossip)* Configure the max message size ([#2340](https://github.com/n0-computer/iroh/issues/2340)) - ([7153a38](https://github.com/n0-computer/iroh/commit/7153a38bc52a8cec877c8b874f37a37658b99370)) + +### ๐Ÿ› Bug Fixes + +- *(docs)* Prevent deadlocks with streams returned from docs actor ([#2346](https://github.com/n0-computer/iroh/issues/2346)) - ([98914ee](https://github.com/n0-computer/iroh/commit/98914ee4dcdb78f7477311f933d84f4f2478e168)) +- *(iroh-net)* Fix extra delay ([#2330](https://github.com/n0-computer/iroh/issues/2330)) - ([77f92ef](https://github.com/n0-computer/iroh/commit/77f92efd16e523c41b0e01aa5a7e11e9aae3e795)) +- *(iroh-net)* Return `Poll::Read(Ok(n))` when we have no relay URL or direct addresses in `poll_send` ([#2322](https://github.com/n0-computer/iroh/issues/2322)) - ([b2f0b0e](https://github.com/n0-computer/iroh/commit/b2f0b0eb84ef8f4a9962d540805a148a103d1e2b)) + +### ๐Ÿšœ Refactor + +- *(iroh)* [**breaking**] Replace public fields in iroh client with accessors and use ref-cast to eliminate them entirely ([#2350](https://github.com/n0-computer/iroh/issues/2350)) - ([35ce780](https://github.com/n0-computer/iroh/commit/35ce7805230ac7732a1bf3213be5424a1e019a44)) +- *(iroh)* [**breaking**] Remove tags from downloader ([#2348](https://github.com/n0-computer/iroh/issues/2348)) - ([82aa93f](https://github.com/n0-computer/iroh/commit/82aa93fc5e2f55499ab7d29b18029ae47c519c3a)) +- *(iroh-blobs)* [**breaking**] Make TempTag non-Clone ([#2338](https://github.com/n0-computer/iroh/issues/2338)) - ([d0662c2](https://github.com/n0-computer/iroh/commit/d0662c2d980b9fe28c669f2e6262c446d08bf7bf)) +- *(iroh-blobs)* [**breaking**] Implement some collection related things on the client side ([#2349](https://github.com/n0-computer/iroh/issues/2349)) - ([b047b28](https://github.com/n0-computer/iroh/commit/b047b28ddead8f357cb22c67c6e7ada23db5deb8)) +- Move docs engine into iroh-docs ([#2343](https://github.com/n0-computer/iroh/issues/2343)) - ([3772889](https://github.com/n0-computer/iroh/commit/3772889cd0a8e02731e5dc9c2a1e2f638ab2691a)) + +### ๐Ÿ“š Documentation + +- *(iroh-net)* Update toplevel module documentation ([#2329](https://github.com/n0-computer/iroh/issues/2329)) - ([4dd69f4](https://github.com/n0-computer/iroh/commit/4dd69f44d62e3b671339ce586a2f7e97a47559ff)) +- *(iroh-net)* Update endpoint docs ([#2334](https://github.com/n0-computer/iroh/issues/2334)) - ([8d91b10](https://github.com/n0-computer/iroh/commit/8d91b10e25e5a8363edde3c41a1bce4f9dc7455a)) + +### ๐Ÿงช Testing + +- Disable a flaky tests ([#2332](https://github.com/n0-computer/iroh/issues/2332)) - ([23e8c7b](https://github.com/n0-computer/iroh/commit/23e8c7b3d5cdc83783822e3fa10b09e798d24f22)) + +### โš™๏ธ Miscellaneous Tasks + +- *(ci)* Update clippy ([#2351](https://github.com/n0-computer/iroh/issues/2351)) - ([7198cd0](https://github.com/n0-computer/iroh/commit/7198cd0f69cd0a178db3b71b7ee58ea5f285b95e)) + +## [0.17.0](https://github.com/n0-computer/iroh/compare/v0.16.0..v0.17.0) - 2024-05-24 ### โ›ฐ๏ธ Features @@ -42,6 +75,7 @@ All notable changes to iroh will be documented in this file. ### โš™๏ธ Miscellaneous Tasks - Minimize use of raw base32 in examples ([#2304](https://github.com/n0-computer/iroh/issues/2304)) - ([1fafc9e](https://github.com/n0-computer/iroh/commit/1fafc9ea8c8eb085f1c51ce8314d5f62f8d1b260)) +- Release - ([5ad15c8](https://github.com/n0-computer/iroh/commit/5ad15c8accc547fc33dd9e66839bd371834a3e35)) ## [0.16.0](https://github.com/n0-computer/iroh/compare/v0.15.0..v0.16.0) - 2024-05-13 diff --git a/Cargo.lock b/Cargo.lock index ba3e0e389d..4827a91fbd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1011,16 +1011,15 @@ dependencies = [ [[package]] name = "curve25519-dalek" -version = "4.1.2" +version = "4.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0a677b8922c94e01bdbb12126b0bc852f00447528dee1782229af9c720c3f348" +checksum = "97fb8b7c4503de7d6ae7b42ab72a5a59857b4c937ec27a3d4539dba95b5ab2be" dependencies = [ "cfg-if", "cpufeatures", "curve25519-dalek-derive", "digest", "fiat-crypto", - "platforms", "rustc_version", "subtle", "zeroize", @@ -2412,7 +2411,7 @@ checksum = "8f518f335dce6725a761382244631d86cf0ccb2863413590b31338feb467f9c3" [[package]] name = "iroh" -version = "0.17.0" +version = "0.18.0" dependencies = [ "anyhow", "bao-tree", @@ -2444,6 +2443,7 @@ dependencies = [ "quic-rpc", "rand", "rand_chacha", + "ref-cast", "regex", "serde", "serde_json", @@ -2461,7 +2461,7 @@ dependencies = [ [[package]] name = "iroh-base" -version = "0.17.0" +version = "0.18.0" dependencies = [ "aead", "anyhow", @@ -2505,7 +2505,7 @@ dependencies = [ [[package]] name = "iroh-blobs" -version = "0.17.0" +version = "0.18.0" dependencies = [ "anyhow", "bao-tree", @@ -2554,7 +2554,7 @@ dependencies = [ [[package]] name = "iroh-cli" -version = "0.17.0" +version = "0.18.0" dependencies = [ "anyhow", "bao-tree", @@ -2608,7 +2608,7 @@ dependencies = [ [[package]] name = "iroh-dns-server" -version = "0.17.0" +version = "0.18.0" dependencies = [ "anyhow", "async-trait", @@ -2657,7 +2657,7 @@ dependencies = [ [[package]] name = "iroh-docs" -version = "0.17.0" +version = "0.18.0" dependencies = [ "anyhow", "bytes", @@ -2698,7 +2698,7 @@ dependencies = [ [[package]] name = "iroh-gossip" -version = "0.17.0" +version = "0.18.0" dependencies = [ "anyhow", "bytes", @@ -2742,7 +2742,7 @@ dependencies = [ [[package]] name = "iroh-metrics" -version = "0.17.0" +version = "0.18.0" dependencies = [ "anyhow", "erased_set", @@ -2761,7 +2761,7 @@ dependencies = [ [[package]] name = "iroh-net" -version = "0.17.0" +version = "0.18.0" dependencies = [ "aead", "anyhow", @@ -2851,7 +2851,7 @@ dependencies = [ [[package]] name = "iroh-net-bench" -version = "0.17.0" +version = "0.18.0" dependencies = [ "anyhow", "bytes", @@ -2917,7 +2917,7 @@ dependencies = [ [[package]] name = "iroh-test" -version = "0.17.0" +version = "0.18.0" dependencies = [ "anyhow", "tokio", @@ -3748,12 +3748,6 @@ dependencies = [ "spki", ] -[[package]] -name = "platforms" -version = "3.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "db23d408679286588f4d4644f965003d056e3dd5abcaaa938116871d7ce2fee7" - [[package]] name = "plotters" version = "0.3.6" diff --git a/deny.toml b/deny.toml index f65fd56cb2..12a7d569e0 100644 --- a/deny.toml +++ b/deny.toml @@ -1,5 +1,9 @@ [bans] multiple-versions = "allow" +deny = [ + "openssl", + "native-tls", +] [licenses] allow = [ diff --git a/iroh-base/Cargo.toml b/iroh-base/Cargo.toml index 787e539bcb..bf67dcb454 100644 --- a/iroh-base/Cargo.toml +++ b/iroh-base/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "iroh-base" -version = "0.17.0" +version = "0.18.0" edition = "2021" readme = "README.md" description = "base type and utilities for Iroh" diff --git a/iroh-base/src/key.rs b/iroh-base/src/key.rs index a08e52f629..8032d2de04 100644 --- a/iroh-base/src/key.rs +++ b/iroh-base/src/key.rs @@ -94,6 +94,10 @@ pub struct PublicKey([u8; 32]); /// The identifier for a node in the (iroh) network. /// +/// Each node in iroh has a unique identifier created as a cryptographic key. This can be +/// used to globally identify a node. Since it is also a cryptographic key it is also the +/// mechanism by which all traffic is always encrypted for a specific node only. +/// /// This is equivalent to [`PublicKey`]. By convention we will (or should) use `PublicKey` /// as type name when performing cryptographic operations, but use `NodeId` when referencing /// a node. E.g.: diff --git a/iroh-base/src/node_addr.rs b/iroh-base/src/node_addr.rs index b926b3c2fc..a2b7c39f33 100644 --- a/iroh-base/src/node_addr.rs +++ b/iroh-base/src/node_addr.rs @@ -1,3 +1,11 @@ +//! Addressing for iroh nodes. +//! +//! This module contains some common addressing types for iroh. A node is uniquely +//! identified by the [`NodeId`] but that does not make it addressable on the network layer. +//! For this the addition of a [`RelayUrl`] and/or direct addresses are required. +//! +//! The primary way of addressing a node is by using the [`NodeAddr`]. + use std::{collections::BTreeSet, fmt, net::SocketAddr, ops::Deref, str::FromStr}; use anyhow::Context; @@ -6,17 +14,40 @@ use url::Url; use crate::key::{NodeId, PublicKey}; -/// A peer and it's addressing information. +/// Network-level addressing information for an iroh-net node. +/// +/// This combines a node's identifier with network-level addressing information of how to +/// contact the node. +/// +/// To establish a network connection to a node both the [`NodeId`] and one or more network +/// paths are needed. The network paths can come from various sources: +/// +/// - A [discovery] service which can provide routing information for a given [`NodeId`]. +/// +/// - A [`RelayUrl`] of the node's [home relay], this allows establishing the connection via +/// the Relay server and is very reliable. +/// +/// - One or more *direct addresses* on which the node might be reachable. Depending on the +/// network location of both nodes it might not be possible to establish a direct +/// connection without the help of a [Relay server]. +/// +/// This structure will always contain the required [`NodeId`] and will contain an optional +/// number of network-level addressing information. It is a generic addressing type used +/// whenever a connection to other nodes needs to be established. +/// +/// [discovery]: https://docs.rs/iroh_net/*/iroh_net/index.html#node-discovery +/// [home relay]: https://docs.rs/iroh_net/*/iroh_net/relay/index.html +/// [Relay server]: https://docs.rs/iroh_net/*/iroh_net/index.html#relay-servers #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] pub struct NodeAddr { - /// The node's public key. - pub node_id: PublicKey, + /// The node's identifier. + pub node_id: NodeId, /// Addressing information to connect to [`Self::node_id`]. pub info: AddrInfo, } impl NodeAddr { - /// Create a new [`NodeAddr`] with empty [`AddrInfo`]. + /// Creates a new [`NodeAddr`] with empty [`AddrInfo`]. pub fn new(node_id: PublicKey) -> Self { NodeAddr { node_id, @@ -24,13 +55,13 @@ impl NodeAddr { } } - /// Add a relay url to the peer's [`AddrInfo`]. + /// Adds a relay url to the node's [`AddrInfo`]. pub fn with_relay_url(mut self, relay_url: RelayUrl) -> Self { self.info.relay_url = Some(relay_url); self } - /// Add the given direct addresses to the peer's [`AddrInfo`]. + /// Adds the given direct addresses to the peer's [`AddrInfo`]. pub fn with_direct_addresses( mut self, addresses: impl IntoIterator, @@ -39,17 +70,38 @@ impl NodeAddr { self } - /// Apply the options to `self`. + /// Creates a new [`NodeAddr`] from its parts. + pub fn from_parts( + node_id: PublicKey, + relay_url: Option, + direct_addresses: Vec, + ) -> Self { + Self { + node_id, + info: AddrInfo { + relay_url, + direct_addresses: direct_addresses.into_iter().collect(), + }, + } + } + + /// Applies the options to `self`. + /// + /// This is used to more tightly control the information stored in a [`NodeAddr`] + /// received from another API. E.g. to ensure a [discovery] service is used the + /// `AddrInfoOptions::Id`] option could be used to remove all other addressing details. + /// + /// [discovery]: https://docs.rs/iroh_net/*/iroh_net/index.html#node-discovery pub fn apply_options(&mut self, opts: AddrInfoOptions) { self.info.apply_options(opts); } - /// Get the direct addresses of this peer. + /// Returns the direct addresses of this peer. pub fn direct_addresses(&self) -> impl Iterator { self.info.direct_addresses.iter() } - /// Get the relay url of this peer. + /// Returns the relay url of this peer. pub fn relay_url(&self) -> Option<&RelayUrl> { self.info.relay_url.as_ref() } @@ -74,22 +126,34 @@ impl From for NodeAddr { } } -/// Addressing information to connect to a peer. +/// Network paths to contact an iroh-net node. +/// +/// This contains zero or more network paths to establish a connection to an iroh-net node. +/// Unless a [discovery service] is used at least one path is required to connect to an +/// other node, see [`NodeAddr`] for details. +/// +/// [discovery]: https://docs.rs/iroh_net/*/iroh_net/index.html#node-discovery #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)] pub struct AddrInfo { - /// The peer's home relay url. + /// The node's home relay url. pub relay_url: Option, /// Socket addresses where the peer might be reached directly. pub direct_addresses: BTreeSet, } impl AddrInfo { - /// Return whether this addressing information is empty. + /// Returns whether this addressing information is empty. pub fn is_empty(&self) -> bool { self.relay_url.is_none() && self.direct_addresses.is_empty() } - /// Apply the options to `self`. + /// Applies the options to `self`. + /// + /// This is used to more tightly control the information stored in ab [`AddrInfo`] + /// received from another API. E.g. to ensure a [discovery] service is used the + /// `AddrInfoOptions::Id`] option could be used to remove all other addressing details. + /// + /// [discovery]: https://docs.rs/iroh_net/*/iroh_net/index.html#node-discovery pub fn apply_options(&mut self, opts: AddrInfoOptions) { match opts { AddrInfoOptions::Id => { @@ -109,24 +173,7 @@ impl AddrInfo { } } -impl NodeAddr { - /// Create a new [`NodeAddr`] from its parts. - pub fn from_parts( - node_id: PublicKey, - relay_url: Option, - direct_addresses: Vec, - ) -> Self { - Self { - node_id, - info: AddrInfo { - relay_url, - direct_addresses: direct_addresses.into_iter().collect(), - }, - } - } -} - -/// Options to configure what is included in a `NodeAddr`. +/// Options to configure what is included in a [`NodeAddr`] and [`AddrInfo`]. #[derive( Copy, Clone, @@ -145,11 +192,11 @@ pub enum AddrInfoOptions { /// This usually means that iroh-dns discovery is used to find address information. #[default] Id, - /// Include both the relay URL and the direct addresses. + /// Includes both the relay URL and the direct addresses. RelayAndAddresses, - /// Only include the relay URL. + /// Only includes the relay URL. Relay, - /// Only include the direct addresses. + /// Only includes the direct addresses. Addresses, } @@ -186,7 +233,7 @@ impl From for RelayUrl { } } -/// This is a convenience only to directly parse strings. +/// Support for parsing strings directly. /// /// If you need more control over the error first create a [`Url`] and use [`RelayUrl::from`] /// instead. @@ -205,7 +252,7 @@ impl From for Url { } } -/// Dereference to the wrapped [`Url`]. +/// Dereferences to the wrapped [`Url`]. /// /// Note that [`DerefMut`] is not implemented on purpose, so this type has more flexibility /// to change the inner later. diff --git a/iroh-blobs/Cargo.toml b/iroh-blobs/Cargo.toml index 0d23446e2f..30d66c45f6 100644 --- a/iroh-blobs/Cargo.toml +++ b/iroh-blobs/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "iroh-blobs" -version = "0.17.0" +version = "0.18.0" edition = "2021" readme = "README.md" description = "blob and collection transfer support for iroh" @@ -27,10 +27,10 @@ futures-lite = "2.3" genawaiter = { version = "0.99.1", features = ["futures03"] } hashlink = { version = "0.9.0", optional = true } hex = "0.4.3" -iroh-base = { version = "0.17.0", features = ["redb"], path = "../iroh-base" } +iroh-base = { version = "0.18.0", features = ["redb"], path = "../iroh-base" } iroh-io = { version = "0.6.0", features = ["stats"] } -iroh-metrics = { version = "0.17.0", path = "../iroh-metrics", optional = true } -iroh-net = { version = "0.17.0", path = "../iroh-net" } +iroh-metrics = { version = "0.18.0", path = "../iroh-metrics", optional = true } +iroh-net = { version = "0.18.0", path = "../iroh-net" } num_cpus = "1.15.0" parking_lot = { version = "0.12.1", optional = true } postcard = { version = "1", default-features = false, features = ["alloc", "use-std", "experimental-derive"] } diff --git a/iroh-blobs/src/downloader.rs b/iroh-blobs/src/downloader.rs index 53c70f4fe1..b0f605fe4e 100644 --- a/iroh-blobs/src/downloader.rs +++ b/iroh-blobs/src/downloader.rs @@ -51,8 +51,7 @@ use tracing::{debug, error_span, trace, warn, Instrument}; use crate::{ get::{db::DownloadProgress, Stats}, store::Store, - util::{progress::ProgressSender, SetTagOption, TagSet}, - TempTag, + util::progress::ProgressSender, }; mod get; @@ -188,14 +187,18 @@ impl Default for RetryConfig { pub struct DownloadRequest { kind: DownloadKind, nodes: Vec, - tag: Option, progress: Option, } impl DownloadRequest { /// Create a new download request. /// - /// The blob will be auto-tagged after the download to prevent it from being garbage collected. + /// It is the responsibility of the caller to ensure that the data is tagged either with a + /// temp tag or with a persistent tag to make sure the data is not garbage collected during + /// the download. + /// + /// If this is not done, there download will proceed as normal, but there is no guarantee + /// that the data is still available when the download is complete. pub fn new( resource: impl Into, nodes: impl IntoIterator>, @@ -203,30 +206,10 @@ impl DownloadRequest { Self { kind: resource.into(), nodes: nodes.into_iter().map(|n| n.into()).collect(), - tag: Some(SetTagOption::Auto), progress: None, } } - /// Create a new untagged download request. - /// - /// The blob will not be tagged, so only use this if the blob is already protected from garbage - /// collection through other means. - pub fn untagged( - resource: HashAndFormat, - nodes: impl IntoIterator>, - ) -> Self { - let mut r = Self::new(resource, nodes); - r.tag = None; - r - } - - /// Set a tag to apply to the blob after download. - pub fn tag(mut self, tag: SetTagOption) -> Self { - self.tag = Some(tag); - self - } - /// Pass a progress sender to receive progress updates. pub fn progress_sender(mut self, sender: ProgressSubscriber) -> Self { self.progress = Some(sender); @@ -351,14 +334,7 @@ impl Downloader { store: store.clone(), }; - let service = Service::new( - store, - getter, - dialer, - concurrency_limits, - retry_config, - msg_rx, - ); + let service = Service::new(getter, dialer, concurrency_limits, retry_config, msg_rx); service.run().instrument(error_span!("downloader", %me)) }; @@ -450,8 +426,6 @@ struct IntentHandlers { struct RequestInfo { /// Registered intents with progress senders and result callbacks. intents: HashMap, - /// Tags requested for the blob to be created once the download finishes. - tags: TagSet, } /// Information about a request in progress. @@ -462,8 +436,6 @@ struct ActiveRequestInfo { cancellation: CancellationToken, /// Peer doing this request attempt. node: NodeId, - /// Temporary tag to protect the partial blob from being garbage collected. - temp_tag: TempTag, } #[derive(Debug, Default)] @@ -531,7 +503,7 @@ enum NodeState<'a, Conn> { } #[derive(Debug)] -struct Service { +struct Service { /// The getter performs individual requests. getter: G, /// Map to query for nodes that we believe have the data we are looking for. @@ -562,12 +534,9 @@ struct Service { in_progress_downloads: JoinSet<(DownloadKind, InternalDownloadResult)>, /// Progress tracker progress_tracker: ProgressTracker, - /// The [`Store`] where tags are saved after a download completes. - db: DB, } -impl, D: Dialer> Service { +impl, D: Dialer> Service { fn new( - db: DB, getter: G, dialer: D, concurrency_limits: ConcurrencyLimits, @@ -590,7 +559,6 @@ impl, D: Dialer> Service, D: Dialer> Service { trace!(%kind, "tick: transfer completed"); - self.on_download_completed(kind, result).await; + self.on_download_completed(kind, result); } Err(err) => { warn!(?err, "transfer task panicked"); @@ -679,7 +647,6 @@ impl, D: Dialer> Service>(), "queue intent"); @@ -732,9 +699,6 @@ impl, D: Dialer> Service, D: Dialer> Service, D: Dialer> Service, D: Dialer> Service, D: Dialer> Service, D: Dialer, S: Store> Service { +impl, D: Dialer> Service { /// Checks the various invariants the service must maintain #[track_caller] pub(in crate::downloader) fn check_invariants(&self) { diff --git a/iroh-blobs/src/downloader/test.rs b/iroh-blobs/src/downloader/test.rs index bdf55cc423..9901cdf2e4 100644 --- a/iroh-blobs/src/downloader/test.rs +++ b/iroh-blobs/src/downloader/test.rs @@ -39,14 +39,12 @@ impl Downloader { retry_config: RetryConfig, ) -> Self { let (msg_tx, msg_rx) = mpsc::channel(super::SERVICE_CHANNEL_CAPACITY); - let db = crate::store::mem::Store::default(); LocalPoolHandle::new(1).spawn_pinned(move || async move { // we want to see the logs of the service let _guard = iroh_test::logging::setup(); - let service = - Service::new(db, getter, dialer, concurrency_limits, retry_config, msg_rx); + let service = Service::new(getter, dialer, concurrency_limits, retry_config, msg_rx); service.run().await }); diff --git a/iroh-blobs/src/export.rs b/iroh-blobs/src/export.rs index 75b282fd6c..cdbda28881 100644 --- a/iroh-blobs/src/export.rs +++ b/iroh-blobs/src/export.rs @@ -46,7 +46,7 @@ pub async fn export_collection( progress: impl ProgressSender + IdGenerator, ) -> anyhow::Result<()> { tokio::fs::create_dir_all(&outpath).await?; - let collection = Collection::load(db, &hash).await?; + let collection = Collection::load_db(db, &hash).await?; for (name, hash) in collection.into_iter() { #[allow(clippy::needless_borrow)] let path = outpath.join(pathbuf_from_name(&name)); diff --git a/iroh-blobs/src/format/collection.rs b/iroh-blobs/src/format/collection.rs index ab13572cc1..cdf4448e98 100644 --- a/iroh-blobs/src/format/collection.rs +++ b/iroh-blobs/src/format/collection.rs @@ -1,5 +1,5 @@ //! The collection type used by iroh -use std::collections::BTreeMap; +use std::{collections::BTreeMap, future::Future}; use anyhow::Context; use bao_tree::blake3; @@ -64,6 +64,12 @@ impl IntoIterator for Collection { } } +/// A simple store trait for loading blobs +pub trait SimpleStore { + /// Load a blob from the store + fn load(&self, hash: Hash) -> impl Future> + Send + '_; +} + /// Metadata for a collection /// /// This is the wire format for the metadata blob. @@ -84,7 +90,7 @@ impl Collection { /// /// To persist the collection, write all the blobs to storage, and use the /// hash of the last blob as the collection hash. - pub fn to_blobs(&self) -> impl Iterator { + pub fn to_blobs(&self) -> impl DoubleEndedIterator { let meta = CollectionMeta { header: *Self::HEADER, names: self.names(), @@ -160,11 +166,25 @@ impl Collection { Ok((collection, res, stats)) } + /// Create a new collection from a hash sequence and metadata. + pub async fn load(root: Hash, store: &impl SimpleStore) -> anyhow::Result { + let hs = store.load(root).await?; + let hs = HashSeq::try_from(hs)?; + let meta_hash = hs.iter().next().context("empty hash seq")?; + let meta = store.load(meta_hash).await?; + let meta: CollectionMeta = postcard::from_bytes(&meta)?; + anyhow::ensure!( + meta.names.len() + 1 == hs.len(), + "names and links length mismatch" + ); + Ok(Self::from_parts(hs.into_iter(), meta)) + } + /// Load a collection from a store given a root hash /// /// This assumes that both the links and the metadata of the collection is stored in the store. /// It does not require that all child blobs are stored in the store. - pub async fn load(db: &D, root: &Hash) -> anyhow::Result + pub async fn load_db(db: &D, root: &Hash) -> anyhow::Result where D: crate::store::Map, { diff --git a/iroh-blobs/src/store/fs.rs b/iroh-blobs/src/store/fs.rs index 5febe54457..e9e113a603 100644 --- a/iroh-blobs/src/store/fs.rs +++ b/iroh-blobs/src/store/fs.rs @@ -1486,6 +1486,8 @@ impl Actor { let mut msgs = PeekableFlumeReceiver::new(self.state.msgs.clone()); while let Some(msg) = msgs.recv() { if let ActorMessage::Shutdown { tx } = msg { + // Make sure the database is dropped before we send the reply. + drop(self); if let Some(tx) = tx { tx.send(()).ok(); } diff --git a/iroh-blobs/src/store/traits.rs b/iroh-blobs/src/store/traits.rs index e0ec3e6b39..2a91d1c0f3 100644 --- a/iroh-blobs/src/store/traits.rs +++ b/iroh-blobs/src/store/traits.rs @@ -295,7 +295,7 @@ pub trait ReadableStore: Map { } /// The mutable part of a Bao store. -pub trait Store: ReadableStore + MapMut { +pub trait Store: ReadableStore + MapMut + std::fmt::Debug { /// This trait method imports a file from a local path. /// /// `data` is the path to the file. diff --git a/iroh-blobs/src/util.rs b/iroh-blobs/src/util.rs index 751886492c..be43dfaaff 100644 --- a/iroh-blobs/src/util.rs +++ b/iroh-blobs/src/util.rs @@ -11,7 +11,7 @@ use std::{ time::SystemTime, }; -use crate::{store::Store, BlobFormat, Hash, HashAndFormat, IROH_BLOCK_SIZE}; +use crate::{BlobFormat, Hash, HashAndFormat, IROH_BLOCK_SIZE}; pub mod io; mod mem_or_file; @@ -126,55 +126,6 @@ impl Tag { } } -/// A set of merged [`SetTagOption`]s for a blob. -#[derive(Debug, Default)] -pub struct TagSet { - auto: bool, - named: Vec, -} - -impl TagSet { - /// Insert a new tag into the set. - pub fn insert(&mut self, tag: SetTagOption) { - match tag { - SetTagOption::Auto => self.auto = true, - SetTagOption::Named(tag) => { - if !self.named.iter().any(|t| t == &tag) { - self.named.push(tag) - } - } - } - } - - /// Convert the [`TagSet`] into a list of [`SetTagOption`]. - pub fn into_tags(self) -> impl Iterator { - self.auto - .then_some(SetTagOption::Auto) - .into_iter() - .chain(self.named.into_iter().map(SetTagOption::Named)) - } - - /// Apply the tags in the [`TagSet`] to the database. - pub async fn apply( - self, - db: &D, - hash_and_format: HashAndFormat, - ) -> std::io::Result<()> { - let tags = self.into_tags(); - for tag in tags { - match tag { - SetTagOption::Named(tag) => { - db.set_tag(tag, Some(hash_and_format)).await?; - } - SetTagOption::Auto => { - db.create_tag(hash_and_format).await?; - } - } - } - Ok(()) - } -} - /// Option for commands that allow setting a tag #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] pub enum SetTagOption { diff --git a/iroh-cli/Cargo.toml b/iroh-cli/Cargo.toml index 2a1f6c9d1e..9b0a30c306 100644 --- a/iroh-cli/Cargo.toml +++ b/iroh-cli/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "iroh-cli" -version = "0.17.0" +version = "0.18.0" edition = "2021" readme = "README.md" description = "Bytes. Distributed." @@ -40,9 +40,9 @@ futures-util = { version = "0.3.30", features = ["futures-sink"] } hex = "0.4.3" human-time = "0.1.6" indicatif = { version = "0.17", features = ["tokio"] } -iroh = { version = "0.17.0", path = "../iroh", features = ["metrics"] } -iroh-gossip = { version = "0.17.0", path = "../iroh-gossip" } -iroh-metrics = { version = "0.17.0", path = "../iroh-metrics" } +iroh = { version = "0.18.0", path = "../iroh", features = ["metrics"] } +iroh-gossip = { version = "0.18.0", path = "../iroh-gossip" } +iroh-metrics = { version = "0.18.0", path = "../iroh-metrics" } parking_lot = "0.12.1" pkarr = { version = "1.1.5", default-features = false } portable-atomic = "1" diff --git a/iroh-cli/src/commands/author.rs b/iroh-cli/src/commands/author.rs index 8da797845c..2ad98a48b6 100644 --- a/iroh-cli/src/commands/author.rs +++ b/iroh-cli/src/commands/author.rs @@ -48,7 +48,7 @@ impl AuthorCommands { println!("Active author is now {}", fmt_short(author.as_bytes())); } Self::List => { - let mut stream = iroh.authors.list().await?; + let mut stream = iroh.authors().list().await?; while let Some(author_id) = stream.try_next().await? { println!("{}", author_id); } @@ -57,7 +57,7 @@ impl AuthorCommands { if switch && !env.is_console() { bail!("The --switch flag is only supported within the Iroh console."); } - let author_id = iroh.authors.default().await?; + let author_id = iroh.authors().default().await?; println!("{}", author_id); if switch { env.set_author(author_id)?; @@ -69,7 +69,7 @@ impl AuthorCommands { bail!("The --switch flag is only supported within the Iroh console."); } - let author_id = iroh.authors.create().await?; + let author_id = iroh.authors().create().await?; println!("{}", author_id); if switch { @@ -78,10 +78,10 @@ impl AuthorCommands { } } Self::Delete { author } => { - iroh.authors.delete(author).await?; + iroh.authors().delete(author).await?; println!("Deleted author {}", fmt_short(author.as_bytes())); } - Self::Export { author } => match iroh.authors.export(author).await? { + Self::Export { author } => match iroh.authors().export(author).await? { Some(author) => { println!("{}", author); } @@ -92,7 +92,7 @@ impl AuthorCommands { Self::Import { author } => match Author::from_str(&author) { Ok(author) => { let id = author.id(); - iroh.authors.import(author).await?; + iroh.authors().import(author).await?; println!("Imported {}", fmt_short(id)); } Err(err) => { diff --git a/iroh-cli/src/commands/blob.rs b/iroh-cli/src/commands/blob.rs index 82ea5bd4e9..9e2c7208c6 100644 --- a/iroh-cli/src/commands/blob.rs +++ b/iroh-cli/src/commands/blob.rs @@ -262,7 +262,7 @@ impl BlobCommands { }; let mut stream = iroh - .blobs + .blobs() .download_with_opts( hash, DownloadOptions { @@ -281,7 +281,7 @@ impl BlobCommands { Some(OutputTarget::Stdout) => { // we asserted above that `OutputTarget::Stdout` is only permitted if getting a // single hash and not a hashseq. - let mut blob_read = iroh.blobs.read(hash).await?; + let mut blob_read = iroh.blobs().read(hash).await?; tokio::io::copy(&mut blob_read, &mut tokio::io::stdout()).await?; } Some(OutputTarget::Path(path)) => { @@ -299,7 +299,7 @@ impl BlobCommands { false => ExportFormat::Blob, }; tracing::info!("exporting to {} -> {}", path.display(), absolute.display()); - let stream = iroh.blobs.export(hash, absolute, format, mode).await?; + let stream = iroh.blobs().export(hash, absolute, format, mode).await?; // TODO: report export progress stream.await?; @@ -320,7 +320,7 @@ impl BlobCommands { !recursive, "Recursive option is not supported when exporting to STDOUT" ); - let mut blob_read = iroh.blobs.read(hash).await?; + let mut blob_read = iroh.blobs().read(hash).await?; tokio::io::copy(&mut blob_read, &mut tokio::io::stdout()).await?; } OutputTarget::Path(path) => { @@ -341,7 +341,7 @@ impl BlobCommands { path.display(), absolute.display() ); - let stream = iroh.blobs.export(hash, absolute, format, mode).await?; + let stream = iroh.blobs().export(hash, absolute, format, mode).await?; // TODO: report export progress stream.await?; } @@ -369,8 +369,8 @@ impl BlobCommands { } else { BlobFormat::Raw }; - let status = iroh.blobs.status(hash).await?; - let ticket = iroh.blobs.share(hash, format, addr_options).await?; + let status = iroh.blobs().status(hash).await?; + let ticket = iroh.blobs().share(hash, format, addr_options).await?; let (blob_status, size) = match (status, format) { (BlobStatus::Complete { size }, BlobFormat::Raw) => ("blob", size), @@ -453,21 +453,21 @@ impl ListCommands { { match self { Self::Blobs => { - let mut response = iroh.blobs.list().await?; + let mut response = iroh.blobs().list().await?; while let Some(item) = response.next().await { let BlobInfo { path, hash, size } = item?; println!("{} {} ({})", path, hash, HumanBytes(size)); } } Self::IncompleteBlobs => { - let mut response = iroh.blobs.list_incomplete().await?; + let mut response = iroh.blobs().list_incomplete().await?; while let Some(item) = response.next().await { let IncompleteBlobInfo { hash, size, .. } = item?; println!("{} ({})", hash, HumanBytes(size)); } } Self::Collections => { - let mut response = iroh.blobs.list_collections().await?; + let mut response = iroh.blobs().list_collections()?; while let Some(item) = response.next().await { let CollectionInfo { tag, @@ -513,7 +513,7 @@ impl DeleteCommands { { match self { Self::Blob { hash } => { - let response = iroh.blobs.delete_blob(hash).await; + let response = iroh.blobs().delete_blob(hash).await; if let Err(e) = response { eprintln!("Error: {}", e); } @@ -544,7 +544,7 @@ pub async fn consistency_check(iroh: &Iroh, verbose: u8, repair: bool) -> where C: ServiceConnection, { - let mut response = iroh.blobs.consistency_check(repair).await?; + let mut response = iroh.blobs().consistency_check(repair).await?; let verbosity = get_report_level(verbose); let print = |level: ReportLevel, entry: Option, message: String| { if level < verbosity { @@ -589,7 +589,7 @@ where C: ServiceConnection, { let mut state = ValidateProgressState::new(); - let mut response = iroh.blobs.validate(repair).await?; + let mut response = iroh.blobs().validate(repair).await?; let verbosity = get_report_level(verbose); let print = |level: ReportLevel, entry: Option, message: String| { if level < verbosity { @@ -854,7 +854,7 @@ pub async fn add>( // tell the node to add the data let stream = client - .blobs + .blobs() .add_from_path(absolute, in_place, tag, wrap) .await?; aggregate_add_response(stream).await? @@ -872,7 +872,7 @@ pub async fn add>( // tell the node to add the data let stream = client - .blobs + .blobs() .add_from_path(path_buf, false, tag, wrap) .await?; aggregate_add_response(stream).await? diff --git a/iroh-cli/src/commands/doc.rs b/iroh-cli/src/commands/doc.rs index 7c6465b592..b2a13b3596 100644 --- a/iroh-cli/src/commands/doc.rs +++ b/iroh-cli/src/commands/doc.rs @@ -317,7 +317,7 @@ impl DocCommands { bail!("The --switch flag is only supported within the Iroh console."); } - let doc = iroh.docs.create().await?; + let doc = iroh.docs().create().await?; println!("{}", doc.id()); if switch { @@ -330,7 +330,7 @@ impl DocCommands { bail!("The --switch flag is only supported within the Iroh console."); } - let doc = iroh.docs.import(ticket).await?; + let doc = iroh.docs().import(ticket).await?; println!("{}", doc.id()); if switch { @@ -339,7 +339,7 @@ impl DocCommands { } } Self::List => { - let mut stream = iroh.docs.list().await?; + let mut stream = iroh.docs().list().await?; while let Some((id, kind)) = stream.try_next().await? { println!("{id} {kind}") } @@ -483,7 +483,7 @@ impl DocCommands { } let stream = iroh - .blobs + .blobs() .add_from_path( root.clone(), in_place, @@ -627,7 +627,7 @@ impl DocCommands { .interact() .unwrap_or(false) { - iroh.docs.drop_doc(doc.id()).await?; + iroh.docs().drop_doc(doc.id()).await?; println!("Doc {} has been deleted.", fmt_short(doc.id())); } else { println!("Aborted.") @@ -681,7 +681,7 @@ async fn get_doc( where C: ServiceConnection, { - iroh.docs + iroh.docs() .open(env.doc(id)?) .await? .context("Document not found") @@ -975,8 +975,8 @@ mod tests { let node = crate::commands::start::start_node(data_dir.path(), None).await?; let client = node.client(); - let doc = client.docs.create().await.context("doc create")?; - let author = client.authors.create().await.context("author create")?; + let doc = client.docs().create().await.context("doc create")?; + let author = client.authors().create().await.context("author create")?; // set up command, getting iroh node let cli = ConsoleEnv::for_console(data_dir.path().to_owned(), &node) diff --git a/iroh-cli/src/commands/doctor.rs b/iroh-cli/src/commands/doctor.rs index 156bb4dd9d..a28f749cf6 100644 --- a/iroh-cli/src/commands/doctor.rs +++ b/iroh-cli/src/commands/doctor.rs @@ -27,7 +27,7 @@ use iroh::{ }, docs::{Capability, DocTicket}, net::{ - defaults::DEFAULT_RELAY_STUN_PORT, + defaults::DEFAULT_STUN_PORT, discovery::{ dns::DnsDiscovery, pkarr_publish::PkarrPublisher, ConcurrentDiscovery, Discovery, }, @@ -93,7 +93,7 @@ pub enum Commands { #[clap(long)] stun_host: Option, /// The port of the STUN server. - #[clap(long, default_value_t = DEFAULT_RELAY_STUN_PORT)] + #[clap(long, default_value_t = DEFAULT_STUN_PORT)] stun_port: u16, }, /// Wait for incoming requests from iroh doctor connect @@ -631,7 +631,7 @@ async fn passive_side(gui: Gui, connection: Connection) -> anyhow::Result<()> { } fn configure_local_relay_map() -> RelayMap { - let stun_port = DEFAULT_RELAY_STUN_PORT; + let stun_port = DEFAULT_STUN_PORT; let url = "http://localhost:3340".parse().unwrap(); RelayMap::default_from_node(url, stun_port) } @@ -669,7 +669,7 @@ async fn make_endpoint( }; let endpoint = endpoint.bind(0).await?; - tokio::time::timeout(Duration::from_secs(10), endpoint.local_endpoints().next()) + tokio::time::timeout(Duration::from_secs(10), endpoint.direct_addresses().next()) .await .context("wait for relay connection")? .context("no endpoints")?; @@ -692,7 +692,7 @@ async fn connect( let conn = endpoint.connect(node_addr, &DR_RELAY_ALPN).await; match conn { Ok(connection) => { - let maybe_stream = endpoint.conn_type_stream(&node_id); + let maybe_stream = endpoint.conn_type_stream(node_id); let gui = Gui::new(endpoint, node_id); if let Ok(stream) = maybe_stream { log_connection_changes(gui.mp.clone(), node_id, stream); @@ -727,7 +727,7 @@ async fn accept( ) -> anyhow::Result<()> { let endpoint = make_endpoint(secret_key.clone(), relay_map, discovery).await?; let endpoints = endpoint - .local_endpoints() + .direct_addresses() .next() .await .context("no endpoints")?; @@ -742,7 +742,7 @@ async fn accept( secret_key.public(), remote_addrs, ); - if let Some(relay_url) = endpoint.my_relay() { + if let Some(relay_url) = endpoint.home_relay() { println!( "\tUsing just the relay url:\niroh doctor connect {} --relay-url {}\n", secret_key.public(), @@ -770,7 +770,7 @@ async fn accept( println!("Accepted connection from {}", remote_peer_id); let t0 = Instant::now(); let gui = Gui::new(endpoint.clone(), remote_peer_id); - if let Ok(stream) = endpoint.conn_type_stream(&remote_peer_id) { + if let Ok(stream) = endpoint.conn_type_stream(remote_peer_id) { log_connection_changes(gui.mp.clone(), remote_peer_id, stream); } let res = active_side(connection, &config, Some(&gui)).await; diff --git a/iroh-cli/src/commands/gossip.rs b/iroh-cli/src/commands/gossip.rs index 858ef5c9c5..67a486ce8a 100644 --- a/iroh-cli/src/commands/gossip.rs +++ b/iroh-cli/src/commands/gossip.rs @@ -62,7 +62,7 @@ impl GossipCommands { subscription_capacity: 1024, }; - let (mut sink, mut stream) = iroh.gossip.subscribe_with_opts(topic, opts).await?; + let (mut sink, mut stream) = iroh.gossip().subscribe_with_opts(topic, opts).await?; let mut input_lines = tokio::io::BufReader::new(tokio::io::stdin()).lines(); loop { tokio::select! { diff --git a/iroh-cli/src/commands/tag.rs b/iroh-cli/src/commands/tag.rs index 3d995d5a52..42c228266b 100644 --- a/iroh-cli/src/commands/tag.rs +++ b/iroh-cli/src/commands/tag.rs @@ -26,7 +26,7 @@ impl TagCommands { { match self { Self::List => { - let mut response = iroh.tags.list().await?; + let mut response = iroh.tags().list().await?; while let Some(res) = response.next().await { let res = res?; println!("{}: {} ({:?})", res.name, res.hash, res.format); @@ -38,7 +38,7 @@ impl TagCommands { } else { Tag::from(tag) }; - iroh.tags.delete(tag).await?; + iroh.tags().delete(tag).await?; } } Ok(()) diff --git a/iroh-cli/src/config.rs b/iroh-cli/src/config.rs index 861c2ec5ad..249b91af10 100644 --- a/iroh-cli/src/config.rs +++ b/iroh-cli/src/config.rs @@ -293,7 +293,7 @@ async fn env_author>( { Ok(author) } else { - iroh.authors.default().await + iroh.authors().default().await } } diff --git a/iroh-dns-server/Cargo.toml b/iroh-dns-server/Cargo.toml index 7b1de803c2..e42a9a8038 100644 --- a/iroh-dns-server/Cargo.toml +++ b/iroh-dns-server/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "iroh-dns-server" -version = "0.17.0" +version = "0.18.0" edition = "2021" description = "A pkarr relay and DNS server" license = "MIT OR Apache-2.0" @@ -24,7 +24,7 @@ governor = "0.6.3" hickory-proto = "0.24.0" hickory-server = { version = "0.24.0", features = ["dns-over-rustls"] } http = "1.0.0" -iroh-metrics = { version = "0.17.0", path = "../iroh-metrics" } +iroh-metrics = { version = "0.18.0", path = "../iroh-metrics" } lru = "0.12.3" parking_lot = "0.12.1" pkarr = { version = "1.1.4", features = [ "async", "relay", "dht"], default-features = false } @@ -52,7 +52,7 @@ z32 = "1.1.1" [dev-dependencies] hickory-resolver = "0.24.0" -iroh-net = { version = "0.17.0", path = "../iroh-net" } +iroh-net = { version = "0.18.0", path = "../iroh-net" } iroh-test = { path = "../iroh-test" } mainline = "<1.5.0" diff --git a/iroh-docs/Cargo.toml b/iroh-docs/Cargo.toml index 005d2e1ea3..28e7c3505e 100644 --- a/iroh-docs/Cargo.toml +++ b/iroh-docs/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "iroh-docs" -version = "0.17.0" +version = "0.18.0" edition = "2021" readme = "README.md" description = "Iroh sync" @@ -23,13 +23,13 @@ ed25519-dalek = { version = "2.0.0", features = ["serde", "rand_core"] } flume = "0.11" futures-buffered = "0.2.4" futures-lite = "2.3.0" -futures-util = { version = "0.3.25", optional = true } +futures-util = { version = "0.3.25" } hex = "0.4" -iroh-base = { version = "0.17.0", path = "../iroh-base" } -iroh-blobs = { version = "0.17.0", path = "../iroh-blobs", optional = true, features = ["downloader"] } -iroh-gossip = { version = "0.17.0", path = "../iroh-gossip", optional = true } -iroh-metrics = { version = "0.17.0", path = "../iroh-metrics", optional = true } -iroh-net = { version = "0.17.0", optional = true, path = "../iroh-net" } +iroh-base = { version = "0.18.0", path = "../iroh-base" } +iroh-blobs = { version = "0.18.0", path = "../iroh-blobs", optional = true, features = ["downloader"] } +iroh-gossip = { version = "0.18.0", path = "../iroh-gossip", optional = true } +iroh-metrics = { version = "0.18.0", path = "../iroh-metrics", optional = true } +iroh-net = { version = "0.18.0", optional = true, path = "../iroh-net" } lru = "0.12" num_enum = "0.7" postcard = { version = "1", default-features = false, features = ["alloc", "use-std", "experimental-derive"] } @@ -42,7 +42,7 @@ serde = { version = "1.0.164", features = ["derive"] } strum = { version = "0.25", features = ["derive"] } tempfile = { version = "3.4" } thiserror = "1" -tokio = { version = "1", features = ["sync"] } +tokio = { version = "1", features = ["sync", "rt", "time", "macros"] } tokio-stream = { version = "0.1", optional = true, features = ["sync"]} tokio-util = { version = "0.7", optional = true, features = ["codec", "io-util", "io"] } tracing = "0.1" @@ -57,7 +57,7 @@ test-strategy = "0.3.1" [features] default = ["net", "metrics", "engine"] -net = ["dep:iroh-net", "tokio/io-util", "dep:tokio-stream", "dep:tokio-util", "dep:futures-util"] +net = ["dep:iroh-net", "tokio/io-util", "dep:tokio-stream", "dep:tokio-util"] metrics = ["dep:iroh-metrics"] engine = ["net", "dep:iroh-gossip", "dep:iroh-blobs"] diff --git a/iroh-docs/src/actor.rs b/iroh-docs/src/actor.rs index bbe91181cb..a48e8f55b3 100644 --- a/iroh-docs/src/actor.rs +++ b/iroh-docs/src/actor.rs @@ -10,9 +10,10 @@ use std::{ use anyhow::{anyhow, Context, Result}; use bytes::Bytes; +use futures_util::FutureExt; use iroh_base::hash::Hash; use serde::{Deserialize, Serialize}; -use tokio::sync::oneshot; +use tokio::{sync::oneshot, task::JoinSet}; use tracing::{debug, error, error_span, trace, warn}; use crate::{ @@ -253,6 +254,7 @@ impl SyncHandle { states: Default::default(), action_rx, content_status_callback, + tasks: Default::default(), }; let join_handle = std::thread::Builder::new() .name("sync-actor".to_string()) @@ -570,22 +572,37 @@ struct Actor { states: OpenReplicas, action_rx: flume::Receiver, content_status_callback: Option, + tasks: JoinSet<()>, } impl Actor { - fn run(mut self) -> Result<()> { + fn run(self) -> Result<()> { + let rt = tokio::runtime::Builder::new_current_thread() + .enable_time() + .build()?; + let local_set = tokio::task::LocalSet::new(); + local_set.block_on(&rt, async move { self.run_async().await }) + } + async fn run_async(mut self) -> Result<()> { loop { - let action = match self.action_rx.recv_timeout(MAX_COMMIT_DELAY) { - Ok(action) => action, - Err(flume::RecvTimeoutError::Timeout) => { + let timeout = tokio::time::sleep(MAX_COMMIT_DELAY); + tokio::pin!(timeout); + let action = tokio::select! { + _ = &mut timeout => { if let Err(cause) = self.store.flush() { error!(?cause, "failed to flush store"); } continue; } - Err(flume::RecvTimeoutError::Disconnected) => { - debug!("action channel disconnected"); - break; + action = self.action_rx.recv_async() => { + match action { + Ok(action) => action, + Err(flume::RecvError::Disconnected) => { + debug!("action channel disconnected"); + break; + } + + } } }; trace!(%action, "tick"); @@ -607,6 +624,7 @@ impl Actor { } } } + self.tasks.abort_all(); debug!("shutdown"); Ok(()) } @@ -636,13 +654,21 @@ impl Actor { } Ok(id) }), - Action::ListAuthors { reply } => iter_to_channel( - reply, - self.store + Action::ListAuthors { reply } => { + let iter = self + .store .list_authors() - .map(|a| a.map(|a| a.map(|a| a.id()))), - ), - Action::ListReplicas { reply } => iter_to_channel(reply, self.store.list_namespaces()), + .map(|a| a.map(|a| a.map(|a| a.id()))); + self.tasks + .spawn_local(iter_to_channel_async(reply, iter).map(|_| ())); + Ok(()) + } + Action::ListReplicas { reply } => { + let iter = self.store.list_namespaces(); + self.tasks + .spawn_local(iter_to_channel_async(reply, iter).map(|_| ())); + Ok(()) + } Action::ContentHashes { reply } => { send_reply_with(reply, self, |this| this.store.content_hashes()) } @@ -657,7 +683,9 @@ impl Actor { ) -> Result<(), SendReplyError> { match action { ReplicaAction::Open { reply, opts } => { + tracing::trace!("open in"); let res = self.open(namespace, opts); + tracing::trace!("open out"); send_reply(reply, res) } ReplicaAction::Close { reply } => { @@ -759,7 +787,9 @@ impl Actor { .states .ensure_open(&namespace) .and_then(|_| self.store.get_many(namespace, query)); - iter_to_channel(reply, iter) + self.tasks + .spawn_local(iter_to_channel_async(reply, iter).map(|_| ())); + Ok(()) } ReplicaAction::DropReplica { reply } => send_reply_with(reply, self, |this| { this.close(namespace); @@ -921,15 +951,18 @@ impl OpenReplicas { } } -fn iter_to_channel( +async fn iter_to_channel_async( channel: flume::Sender>, iter: Result>>, ) -> Result<(), SendReplyError> { match iter { - Err(err) => channel.send(Err(err)).map_err(send_reply_error)?, + Err(err) => channel + .send_async(Err(err)) + .await + .map_err(send_reply_error)?, Ok(iter) => { for item in iter { - channel.send(item).map_err(send_reply_error)?; + channel.send_async(item).await.map_err(send_reply_error)?; } } } diff --git a/iroh-docs/src/engine/live.rs b/iroh-docs/src/engine/live.rs index 5c7608722b..e7f77549b0 100644 --- a/iroh-docs/src/engine/live.rs +++ b/iroh-docs/src/engine/live.rs @@ -31,6 +31,9 @@ use crate::{ use super::gossip::{GossipActor, ToGossipActor}; use super::state::{NamespaceStates, Origin, SyncReason}; +/// Name used for logging when new node addresses are added from the docs engine. +const SOURCE_NAME: &str = "docs_engine"; + /// An iroh-docs operation /// /// This is the message that is broadcast over iroh-gossip. @@ -437,7 +440,7 @@ impl LiveActor { // add addresses of peers to our endpoint address book for peer in peers.into_iter() { let peer_id = peer.node_id; - if let Err(err) = self.endpoint.add_node_addr(peer) { + if let Err(err) = self.endpoint.add_node_addr_with_source(peer, SOURCE_NAME) { warn!(peer = %peer_id.fmt_short(), "failed to add known addrs: {err:?}"); } } @@ -543,7 +546,7 @@ impl LiveActor { match details .outcome .heads_received - .encode(Some(iroh_gossip::net::MAX_MESSAGE_SIZE)) + .encode(Some(self.gossip.max_message_size())) { Err(err) => warn!(?err, "Failed to encode author heads for sync report"), Ok(heads) => { @@ -738,7 +741,7 @@ impl LiveActor { self.queued_hashes.insert(hash, namespace); self.downloader.nodes_have(hash, vec![node]).await; } else if !only_if_missing || self.missing_hashes.contains(&hash) { - let req = DownloadRequest::untagged(HashAndFormat::raw(hash), vec![node]); + let req = DownloadRequest::new(HashAndFormat::raw(hash), vec![node]); let handle = self.downloader.queue(req).await; self.queued_hashes.insert(hash, namespace); diff --git a/iroh-docs/src/store/fs.rs b/iroh-docs/src/store/fs.rs index ab1171b756..981143ca86 100644 --- a/iroh-docs/src/store/fs.rs +++ b/iroh-docs/src/store/fs.rs @@ -154,6 +154,22 @@ impl Store { } } + /// Get an owned read-only snapshot of the database. + /// + /// This will open a new read transaction. The read transaction won't be reused for other + /// reads. + /// + /// This has the side effect of committing any open write transaction, + /// so it can be used as a way to ensure that the data is persisted. + pub fn snapshot_owned(&mut self) -> Result { + // make sure the current transaction is committed + self.flush()?; + assert!(matches!(self.transaction, CurrentTransaction::None)); + let tx = self.db.begin_read()?; + let tables = ReadOnlyTables::new(tx)?; + Ok(tables) + } + /// Get access to the tables to read from them. /// /// The underlying transaction is a write transaction, but with a non-mut @@ -223,8 +239,6 @@ impl Store { } } -type AuthorsIter = std::vec::IntoIter>; -type NamespaceIter = std::vec::IntoIter>; type PeersIter = std::vec::IntoIter; impl Store { @@ -297,18 +311,16 @@ impl Store { } /// List all replica namespaces in this store. - pub fn list_namespaces(&mut self) -> Result { - // TODO: avoid collect - let tables = self.tables()?; - let namespaces: Vec<_> = tables - .namespaces - .iter()? - .map(|res| { - let capability = parse_capability(res?.1.value())?; - Ok((capability.id(), capability.kind())) - }) - .collect(); - Ok(namespaces.into_iter()) + pub fn list_namespaces( + &mut self, + ) -> Result>> { + let snapshot = self.snapshot()?; + let iter = snapshot.namespaces.range::<&'static [u8; 32]>(..)?; + let iter = iter.map(|res| { + let capability = parse_capability(res?.1.value())?; + Ok((capability.id(), capability.kind())) + }); + Ok(iter) } /// Get an author key from the store. @@ -340,19 +352,16 @@ impl Store { } /// List all author keys in this store. - pub fn list_authors(&mut self) -> Result { - // TODO: avoid collect - let tables = self.tables()?; - let authors: Vec<_> = tables + pub fn list_authors(&mut self) -> Result>> { + let tables = self.snapshot()?; + let iter = tables .authors - .iter()? + .range::<&'static [u8; 32]>(..)? .map(|res| match res { Ok((_key, value)) => Ok(Author::from_bytes(value.value())), Err(err) => Err(err.into()), - }) - .collect(); - - Ok(authors.into_iter()) + }); + Ok(iter) } /// Import a new replica namespace. @@ -413,7 +422,8 @@ impl Store { namespace: NamespaceId, query: impl Into, ) -> Result { - QueryIterator::new(self.tables()?, namespace, query.into()) + let tables = self.snapshot_owned()?; + QueryIterator::new(tables, namespace, query.into()) } /// Get an entry by key and author. @@ -435,13 +445,8 @@ impl Store { /// Get all content hashes of all replicas in the store. pub fn content_hashes(&mut self) -> Result { - // make sure the current transaction is committed - self.flush()?; - assert!(matches!(self.transaction, CurrentTransaction::None)); - let tx = self.db.begin_read()?; - let tables = ReadOnlyTables::new(tx)?; - let records = tables.records; - ContentHashesIterator::all(records) + let tables = self.snapshot_owned()?; + ContentHashesIterator::all(&tables.records) } /// Get the latest entry for each author in a namespace. @@ -870,14 +875,6 @@ impl Iterator for ParentIterator { } } -self_cell::self_cell!( - struct ContentHashesIteratorInner { - owner: RecordsTable, - #[covariant] - dependent: RecordsRange, - } -); - /// Iterator for all content hashes /// /// Note that you might get duplicate hashes. Also, the iterator will keep @@ -886,13 +883,16 @@ self_cell::self_cell!( /// Also, this represents a snapshot of the database at the time of creation. /// It nees a copy of a redb::ReadOnlyTable to be self-contained. #[derive(derive_more::Debug)] -pub struct ContentHashesIterator(#[debug(skip)] ContentHashesIteratorInner); +pub struct ContentHashesIterator { + #[debug(skip)] + range: RecordsRange<'static>, +} impl ContentHashesIterator { /// Create a new iterator over all content hashes. - pub fn all(owner: RecordsTable) -> anyhow::Result { - let inner = ContentHashesIteratorInner::try_new(owner, |owner| RecordsRange::all(owner))?; - Ok(Self(inner)) + pub fn all(table: &RecordsTable) -> anyhow::Result { + let range = RecordsRange::all_static(table)?; + Ok(Self { range }) } } @@ -900,7 +900,7 @@ impl Iterator for ContentHashesIterator { type Item = Result; fn next(&mut self) -> Option { - let v = self.0.with_dependent_mut(|_, d| d.next())?; + let v = self.range.next()?; Some(v.map(|e| e.content_hash())) } } diff --git a/iroh-docs/src/store/fs/query.rs b/iroh-docs/src/store/fs/query.rs index a73dbcd8e7..f05b4ecfb3 100644 --- a/iroh-docs/src/store/fs/query.rs +++ b/iroh-docs/src/store/fs/query.rs @@ -3,6 +3,7 @@ use iroh_base::hash::Hash; use crate::{ store::{ + fs::tables::ReadOnlyTables, util::{IndexKind, LatestPerKeySelector, SelectorRes}, AuthorFilter, KeyFilter, Query, }, @@ -12,34 +13,33 @@ use crate::{ use super::{ bounds::{ByKeyBounds, RecordsBounds}, ranges::{RecordsByKeyRange, RecordsRange}, - tables::Tables, RecordsValue, }; /// A query iterator for entry queries. #[derive(Debug)] -pub struct QueryIterator<'a> { - range: QueryRange<'a>, +pub struct QueryIterator { + range: QueryRange, query: Query, offset: u64, count: u64, } #[derive(Debug)] -enum QueryRange<'a> { +enum QueryRange { AuthorKey { - range: RecordsRange<'a>, + range: RecordsRange<'static>, key_filter: KeyFilter, }, KeyAuthor { - range: RecordsByKeyRange<'a>, + range: RecordsByKeyRange, author_filter: AuthorFilter, selector: Option, }, } -impl<'a> QueryIterator<'a> { - pub fn new(tables: &'a Tables<'a>, namespace: NamespaceId, query: Query) -> Result { +impl QueryIterator { + pub fn new(tables: ReadOnlyTables, namespace: NamespaceId, query: Query) -> Result { let index_kind = IndexKind::from(&query); let range = match index_kind { IndexKind::AuthorKey { range, key_filter } => { @@ -53,7 +53,7 @@ impl<'a> QueryIterator<'a> { // no author set => full table scan with the provided key filter AuthorFilter::Any => (RecordsBounds::namespace(namespace), key_filter), }; - let range = RecordsRange::with_bounds(&tables.records, bounds)?; + let range = RecordsRange::with_bounds_static(&tables.records, bounds)?; QueryRange::AuthorKey { range, key_filter: filter, @@ -65,11 +65,8 @@ impl<'a> QueryIterator<'a> { latest_per_key, } => { let bounds = ByKeyBounds::new(namespace, &range); - let range = RecordsByKeyRange::with_bounds( - &tables.records_by_key, - &tables.records, - bounds, - )?; + let range = + RecordsByKeyRange::with_bounds(tables.records_by_key, tables.records, bounds)?; let selector = latest_per_key.then(LatestPerKeySelector::default); QueryRange::KeyAuthor { author_filter, @@ -88,7 +85,7 @@ impl<'a> QueryIterator<'a> { } } -impl<'a> Iterator for QueryIterator<'a> { +impl Iterator for QueryIterator { type Item = Result; fn next(&mut self) -> Option> { diff --git a/iroh-docs/src/store/fs/ranges.rs b/iroh-docs/src/store/fs/ranges.rs index 9219c620ac..f28d95ae63 100644 --- a/iroh-docs/src/store/fs/ranges.rs +++ b/iroh-docs/src/store/fs/ranges.rs @@ -1,6 +1,6 @@ //! Ranges and helpers for working with [`redb`] tables -use redb::{Key, Range, ReadableTable, Table, Value}; +use redb::{Key, Range, ReadOnlyTable, ReadableTable, Value}; use crate::{store::SortDirection, SignedEntry}; @@ -74,14 +74,9 @@ impl<'a, K: Key + 'static, V: Value + 'static> RangeExt for Range<'a, K, V #[debug("RecordsRange")] pub struct RecordsRange<'a>(Range<'a, RecordsId<'static>, RecordsValue<'static>>); -impl<'a> RecordsRange<'a> { - pub(super) fn all( - records: &'a impl ReadableTable, RecordsValue<'static>>, - ) -> anyhow::Result { - let range = records.range::>(..)?; - Ok(Self(range)) - } +// pub type RecordsRange<'a> = Range<'a, RecordsId<'static>, RecordsValue<'static>>; +impl<'a> RecordsRange<'a> { pub(super) fn with_bounds( records: &'a impl ReadableTable, RecordsValue<'static>>, bounds: RecordsBounds, @@ -90,6 +85,7 @@ impl<'a> RecordsRange<'a> { Ok(Self(range)) } + // /// Get the next item in the range. /// /// Omit items for which the `matcher` function returns false. @@ -103,6 +99,22 @@ impl<'a> RecordsRange<'a> { } } +impl RecordsRange<'static> { + pub(super) fn all_static( + records: &ReadOnlyTable, RecordsValue<'static>>, + ) -> anyhow::Result { + let range = records.range::>(..)?; + Ok(Self(range)) + } + pub(super) fn with_bounds_static( + records: &ReadOnlyTable, RecordsValue<'static>>, + bounds: RecordsBounds, + ) -> anyhow::Result { + let range = records.range(bounds.as_ref())?; + Ok(Self(range)) + } +} + impl<'a> Iterator for RecordsRange<'a> { type Item = anyhow::Result; fn next(&mut self) -> Option { @@ -112,15 +124,15 @@ impl<'a> Iterator for RecordsRange<'a> { #[derive(derive_more::Debug)] #[debug("RecordsByKeyRange")] -pub struct RecordsByKeyRange<'a> { - records_table: &'a Table<'a, RecordsId<'static>, RecordsValue<'static>>, - by_key_range: Range<'a, RecordsByKeyId<'static>, ()>, +pub struct RecordsByKeyRange { + records_table: ReadOnlyTable, RecordsValue<'static>>, + by_key_range: Range<'static, RecordsByKeyId<'static>, ()>, } -impl<'a> RecordsByKeyRange<'a> { +impl RecordsByKeyRange { pub fn with_bounds( - records_by_key_table: &'a impl ReadableTable, ()>, - records_table: &'a Table<'a, RecordsId<'static>, RecordsValue<'static>>, + records_by_key_table: ReadOnlyTable, ()>, + records_table: ReadOnlyTable, RecordsValue<'static>>, bounds: ByKeyBounds, ) -> anyhow::Result { let by_key_range = records_by_key_table.range(bounds.as_ref())?; diff --git a/iroh-gossip/Cargo.toml b/iroh-gossip/Cargo.toml index cb6fa8567b..6d5c42f027 100644 --- a/iroh-gossip/Cargo.toml +++ b/iroh-gossip/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "iroh-gossip" -version = "0.17.0" +version = "0.18.0" edition = "2021" readme = "README.md" description = "gossip messages over broadcast trees" @@ -27,12 +27,12 @@ rand = { version = "0.8.5", features = ["std_rng"] } rand_core = "0.6.4" serde = { version = "1.0.164", features = ["derive"] } tracing = "0.1" -iroh-metrics = { version = "0.17.0", path = "../iroh-metrics" } -iroh-base = { version = "0.17.0", path = "../iroh-base" } +iroh-metrics = { version = "0.18.0", path = "../iroh-metrics" } +iroh-base = { version = "0.18.0", path = "../iroh-base" } # net dependencies (optional) futures-lite = { version = "2.3", optional = true } -iroh-net = { path = "../iroh-net", version = "0.17.0", optional = true, default-features = false, features = ["test-utils"] } +iroh-net = { path = "../iroh-net", version = "0.18.0", optional = true, default-features = false, features = ["test-utils"] } tokio = { version = "1", optional = true, features = ["io-util", "sync", "rt", "macros", "net", "fs"] } tokio-util = { version = "0.7.8", optional = true, features = ["codec"] } genawaiter = { version = "0.99.1", default-features = false, features = ["futures03"] } diff --git a/iroh-gossip/examples/chat.rs b/iroh-gossip/examples/chat.rs index 0bd9a0e1a3..f9bf38863f 100644 --- a/iroh-gossip/examples/chat.rs +++ b/iroh-gossip/examples/chat.rs @@ -108,13 +108,13 @@ async fn main() -> anyhow::Result<()> { .await?; println!("> our node id: {}", endpoint.node_id()); - let my_addr = endpoint.my_addr().await?; + let my_addr = endpoint.node_addr().await?; // create the gossip protocol let gossip = Gossip::from_endpoint(endpoint.clone(), Default::default(), &my_addr.info); // print a ticket that includes our own node id and endpoint addresses let ticket = { - let me = endpoint.my_addr().await?; + let me = endpoint.node_addr().await?; let peers = peers.iter().cloned().chain([me]).collect(); Ticket { topic, peers } }; @@ -206,11 +206,11 @@ async fn handle_connection( let alpn = conn.alpn().await?; let conn = conn.await?; let peer_id = iroh_net::endpoint::get_remote_node_id(&conn)?; - match alpn.as_bytes() { - GOSSIP_ALPN => gossip - .handle_connection(conn) - .await - .context(format!("connection to {peer_id} with ALPN {alpn} failed"))?, + match alpn.as_ref() { + GOSSIP_ALPN => gossip.handle_connection(conn).await.context(format!( + "connection to {peer_id} with ALPN {} failed", + String::from_utf8_lossy(&alpn) + ))?, _ => println!("> ignoring connection from {peer_id}: unsupported ALPN protocol"), } Ok(()) diff --git a/iroh-gossip/src/net.rs b/iroh-gossip/src/net.rs index 4083e3a113..13d5940703 100644 --- a/iroh-gossip/src/net.rs +++ b/iroh-gossip/src/net.rs @@ -26,10 +26,6 @@ pub mod util; /// ALPN protocol name pub const GOSSIP_ALPN: &[u8] = b"/iroh-gossip/0"; -/// Maximum message size is limited currently. The limit is more-or-less arbitrary. -// TODO: Make the limit configurable. -pub const MAX_MESSAGE_SIZE: usize = 4096; - /// Channel capacity for all subscription broadcast channels (single) const SUBSCRIBE_ALL_CAP: usize = 2048; /// Channel capacity for topic subscription broadcast channels (one per topic) @@ -42,6 +38,8 @@ const TO_ACTOR_CAP: usize = 64; const IN_EVENT_CAP: usize = 1024; /// Channel capacity for endpoint change message queue (single) const ON_ENDPOINTS_CAP: usize = 64; +/// Name used for logging when new node addresses are added from gossip. +const SOURCE_NAME: &str = "gossip"; /// Events emitted from the gossip protocol pub type Event = proto::Event; @@ -74,8 +72,9 @@ type ProtoMessage = proto::Message; #[derive(Debug, Clone)] pub struct Gossip { to_actor_tx: mpsc::Sender, - on_endpoints_tx: mpsc::Sender>, + on_direct_addrs_tx: mpsc::Sender>, _actor_handle: Arc>>, + max_message_size: usize, } impl Gossip { @@ -94,6 +93,7 @@ impl Gossip { let (on_endpoints_tx, on_endpoints_rx) = mpsc::channel(ON_ENDPOINTS_CAP); let me = endpoint.node_id().fmt_short(); + let max_message_size = state.max_message_size(); let actor = Actor { endpoint, state, @@ -101,7 +101,7 @@ impl Gossip { to_actor_rx, in_event_rx, in_event_tx, - on_endpoints_rx, + on_direct_addr_rx: on_endpoints_rx, conns: Default::default(), conn_send_tx: Default::default(), pending_sends: Default::default(), @@ -123,11 +123,17 @@ impl Gossip { ); Self { to_actor_tx, - on_endpoints_tx, + on_direct_addrs_tx: on_endpoints_tx, _actor_handle: Arc::new(actor_handle), + max_message_size, } } + /// Get the maximum message size configured for this gossip actor. + pub fn max_message_size(&self) -> usize { + self.max_message_size + } + /// Join a topic and connect to peers. /// /// @@ -237,16 +243,19 @@ impl Gossip { Ok(()) } - /// Set info on our local endpoints. + /// Set info on our direct addresses. /// /// This will be sent to peers on Neighbor and Join requests so that they can connect directly /// to us. /// /// This is only best effort, and will drop new events if backed up. - pub fn update_endpoints(&self, endpoints: &[iroh_net::config::Endpoint]) -> anyhow::Result<()> { - let endpoints = endpoints.to_vec(); - self.on_endpoints_tx - .try_send(endpoints) + pub fn update_direct_addresses( + &self, + addrs: &[iroh_net::endpoint::DirectAddr], + ) -> anyhow::Result<()> { + let addrs = addrs.to_vec(); + self.on_direct_addrs_tx + .try_send(addrs) .map_err(|_| anyhow!("endpoints channel dropped"))?; Ok(()) } @@ -338,7 +347,7 @@ struct Actor { /// Input events to the state (emitted from the connection loops) in_event_rx: mpsc::Receiver, /// Updates of discovered endpoint addresses - on_endpoints_rx: mpsc::Receiver>, + on_direct_addr_rx: mpsc::Receiver>, /// Queued timers timers: Timers, /// Currently opened quinn connections to peers @@ -371,10 +380,14 @@ impl Actor { } } }, - new_endpoints = self.on_endpoints_rx.recv() => { + new_endpoints = self.on_direct_addr_rx.recv() => { match new_endpoints { Some(endpoints) => { - let addr = self.endpoint.my_addr_with_endpoints(endpoints)?; + let addr = NodeAddr::from_parts( + self.endpoint.node_id(), + self.endpoint.home_relay(), + endpoints.into_iter().map(|x| x.addr).collect(), + ); let peer_data = encode_peer_data(&addr.info)?; self.handle_in_event(InEvent::UpdatePeerData(peer_data), Instant::now()).await?; } @@ -427,12 +440,23 @@ impl Actor { let (send_tx, send_rx) = mpsc::channel(SEND_QUEUE_CAP); self.conn_send_tx.insert(peer_id, send_tx.clone()); + let max_message_size = self.state.max_message_size(); + // Spawn a task for this connection let in_event_tx = self.in_event_tx.clone(); tokio::spawn( async move { debug!("connection established"); - match connection_loop(peer_id, conn, origin, send_rx, &in_event_tx).await { + match connection_loop( + peer_id, + conn, + origin, + send_rx, + &in_event_tx, + max_message_size, + ) + .await + { Ok(()) => { debug!("connection closed without error") } @@ -556,7 +580,10 @@ impl Actor { Ok(info) => { debug!(peer = ?node_id, "add known addrs: {info:?}"); let node_addr = NodeAddr { node_id, info }; - if let Err(err) = self.endpoint.add_node_addr(node_addr) { + if let Err(err) = self + .endpoint + .add_node_addr_with_source(node_addr, SOURCE_NAME) + { debug!(peer = ?node_id, "add known failed: {err:?}"); } } @@ -605,6 +632,7 @@ async fn connection_loop( origin: ConnOrigin, mut send_rx: mpsc::Receiver, in_event_tx: &mpsc::Sender, + max_message_size: usize, ) -> anyhow::Result<()> { let (mut send, mut recv) = match origin { ConnOrigin::Accept => conn.accept_bi().await?, @@ -621,10 +649,10 @@ async fn connection_loop( // but the other side may still want to use it to // send data to us. Some(msg) = send_rx.recv(), if !send_rx.is_closed() => { - write_message(&mut send, &mut send_buf, &msg).await? + write_message(&mut send, &mut send_buf, &msg, max_message_size).await? } - msg = read_message(&mut recv, &mut recv_buf) => { + msg = read_message(&mut recv, &mut recv_buf, max_message_size) => { let msg = msg?; match msg { None => break, diff --git a/iroh-gossip/src/net/util.rs b/iroh-gossip/src/net/util.rs index 1101300292..2a45fa4961 100644 --- a/iroh-gossip/src/net/util.rs +++ b/iroh-gossip/src/net/util.rs @@ -11,16 +11,17 @@ use tokio::{ use crate::proto::util::TimerMap; -use super::{ProtoMessage, MAX_MESSAGE_SIZE}; +use super::ProtoMessage; /// Write a `ProtoMessage` as a length-prefixed, postcard-encoded message. pub async fn write_message( writer: &mut W, buffer: &mut BytesMut, frame: &ProtoMessage, + max_message_size: usize, ) -> Result<()> { let len = postcard::experimental::serialized_size(&frame)?; - ensure!(len < MAX_MESSAGE_SIZE); + ensure!(len < max_message_size); buffer.clear(); buffer.resize(len, 0u8); let slice = postcard::to_slice(&frame, buffer)?; @@ -33,8 +34,9 @@ pub async fn write_message( pub async fn read_message( reader: impl AsyncRead + Unpin, buffer: &mut BytesMut, + max_message_size: usize, ) -> Result> { - match read_lp(reader, buffer).await? { + match read_lp(reader, buffer, max_message_size).await? { None => Ok(None), Some(data) => { let message = postcard::from_bytes(&data)?; @@ -52,6 +54,7 @@ pub async fn read_message( pub async fn read_lp( mut reader: impl AsyncRead + Unpin, buffer: &mut BytesMut, + max_message_size: usize, ) -> Result> { let size = match reader.read_u32().await { Ok(size) => size, @@ -60,8 +63,8 @@ pub async fn read_lp( }; let mut reader = reader.take(size as u64); let size = usize::try_from(size).context("frame larger than usize")?; - if size > MAX_MESSAGE_SIZE { - bail!("Incoming message exceeds MAX_MESSAGE_SIZE"); + if size > max_message_size { + bail!("Incoming message exceeds the maximum message size of {max_message_size} bytes"); } buffer.reserve(size); loop { diff --git a/iroh-gossip/src/proto/state.rs b/iroh-gossip/src/proto/state.rs index f8b1ebd1e3..a841342014 100644 --- a/iroh-gossip/src/proto/state.rs +++ b/iroh-gossip/src/proto/state.rs @@ -196,6 +196,11 @@ impl State { .unwrap_or(false) } + /// Returns the maximum message size configured in the gossip protocol. + pub fn max_message_size(&self) -> usize { + self.config.max_message_size + } + /// Handle an [`InEvent`] /// /// This returns an iterator of [`OutEvent`]s that must be processed. diff --git a/iroh-gossip/src/proto/topic.rs b/iroh-gossip/src/proto/topic.rs index dc573fae45..64cffb783d 100644 --- a/iroh-gossip/src/proto/topic.rs +++ b/iroh-gossip/src/proto/topic.rs @@ -18,6 +18,10 @@ use super::{ }; use super::{PeerData, PeerIdentity}; +/// The default maximum size in bytes for a gossip message. +/// This is a sane but arbitrary default and can be changed in the [`Config`]. +pub const DEFAULT_MAX_MESSAGE_SIZE: usize = 4096; + /// Input event to the topic state handler. #[derive(Clone, Debug)] pub enum InEvent { @@ -170,13 +174,32 @@ impl IO for VecDeque> { self.push_back(event.into()) } } + /// Protocol configuration -#[derive(Clone, Default, Debug)] +#[derive(Clone, Debug)] pub struct Config { /// Configuration for the swarm membership layer pub membership: hyparview::Config, /// Configuration for the gossip broadcast layer pub broadcast: plumtree::Config, + /// Max message size in bytes. + /// + /// This size should be the same across a network to ensure all nodes can transmit and read large messages. + /// + /// At minimum, this size should be large enough to send gossip control messages. This can vary, depending on the size of the [`PeerIdentity`] you use and the size of the [`PeerData`] you transmit in your messages. + /// + /// The default is [`DEFAULT_MAX_MESSAGE_SIZE`]. + pub max_message_size: usize, +} + +impl Default for Config { + fn default() -> Self { + Self { + membership: Default::default(), + broadcast: Default::default(), + max_message_size: DEFAULT_MAX_MESSAGE_SIZE, + } + } } /// The topic state maintains the swarm membership and broadcast tree for a particular topic. diff --git a/iroh-metrics/Cargo.toml b/iroh-metrics/Cargo.toml index 01ebfde25e..df85062175 100644 --- a/iroh-metrics/Cargo.toml +++ b/iroh-metrics/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "iroh-metrics" -version = "0.17.0" +version = "0.18.0" edition = "2021" readme = "README.md" description = "metrics for iroh" diff --git a/iroh-net/Cargo.toml b/iroh-net/Cargo.toml index 53b97a0377..6481dbba27 100644 --- a/iroh-net/Cargo.toml +++ b/iroh-net/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "iroh-net" -version = "0.17.0" +version = "0.18.0" edition = "2021" readme = "README.md" description = "networking support for iroh" @@ -41,7 +41,7 @@ http-body-util = "0.1.0" hyper = { version = "1", features = ["server", "client", "http1"] } hyper-util = "0.1.1" igd-next = { version = "0.14.3", features = ["aio_tokio"] } -iroh-base = { version = "0.17.0", path = "../iroh-base", features = ["key"] } +iroh-base = { version = "0.18.0", path = "../iroh-base", features = ["key"] } libc = "0.2.139" num_enum = "0.7" once_cell = "1.18.0" @@ -86,7 +86,7 @@ toml = { version = "0.8", optional = true } tracing-subscriber = { version = "0.3", features = ["env-filter"], optional = true } # metrics -iroh-metrics = { version = "0.17.0", path = "../iroh-metrics", default-features = false } +iroh-metrics = { version = "0.18.0", path = "../iroh-metrics", default-features = false } strum = { version = "0.26.2", features = ["derive"] } [target.'cfg(any(target_os = "linux", target_os = "android"))'.dependencies] diff --git a/iroh-net/bench/Cargo.toml b/iroh-net/bench/Cargo.toml index 8a16ce32c5..8d075337f8 100644 --- a/iroh-net/bench/Cargo.toml +++ b/iroh-net/bench/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "iroh-net-bench" -version = "0.17.0" +version = "0.18.0" edition = "2021" license = "MIT OR Apache-2.0" publish = false diff --git a/iroh-net/bench/src/iroh.rs b/iroh-net/bench/src/iroh.rs index a359be35b2..5a85952b7f 100644 --- a/iroh-net/bench/src/iroh.rs +++ b/iroh-net/bench/src/iroh.rs @@ -29,7 +29,7 @@ pub fn server_endpoint(rt: &tokio::runtime::Runtime, opt: &Opt) -> (NodeAddr, En .bind(0) .await .unwrap(); - let addr = ep.local_addr(); + let addr = ep.bound_sockets(); let addr = SocketAddr::new("127.0.0.1".parse().unwrap(), addr.0.port()); let addr = NodeAddr::new(ep.node_id()).with_direct_addresses([addr]); (addr, ep) diff --git a/iroh-net/examples/connect-unreliable.rs b/iroh-net/examples/connect-unreliable.rs index b42ae683c3..3be041353c 100644 --- a/iroh-net/examples/connect-unreliable.rs +++ b/iroh-net/examples/connect-unreliable.rs @@ -60,7 +60,7 @@ async fn main() -> anyhow::Result<()> { println!("node id: {me}"); println!("node listening addresses:"); for local_endpoint in endpoint - .local_endpoints() + .direct_addresses() .next() .await .context("no endpoints")? @@ -69,7 +69,7 @@ async fn main() -> anyhow::Result<()> { } let relay_url = endpoint - .my_relay() + .home_relay() .expect("should be connected to a relay server, try calling `endpoint.local_endpoints()` or `endpoint.connect()` first, to ensure the endpoint has actually attempted a connection before checking for the connected relay server"); println!("node relay server url: {relay_url}\n"); // Build a `NodeAddr` from the node_id, relay url, and UDP addresses. diff --git a/iroh-net/examples/connect.rs b/iroh-net/examples/connect.rs index 68740e9040..216a4e42eb 100644 --- a/iroh-net/examples/connect.rs +++ b/iroh-net/examples/connect.rs @@ -57,7 +57,7 @@ async fn main() -> anyhow::Result<()> { println!("node id: {me}"); println!("node listening addresses:"); for local_endpoint in endpoint - .local_endpoints() + .direct_addresses() .next() .await .context("no endpoints")? @@ -66,7 +66,7 @@ async fn main() -> anyhow::Result<()> { } let relay_url = endpoint - .my_relay() + .home_relay() .expect("should be connected to a relay server, try calling `endpoint.local_endpoints()` or `endpoint.connect()` first, to ensure the endpoint has actually attempted a connection before checking for the connected relay server"); println!("node relay server url: {relay_url}\n"); // Build a `NodeAddr` from the node_id, relay url, and UDP addresses. diff --git a/iroh-net/examples/listen-unreliable.rs b/iroh-net/examples/listen-unreliable.rs index 43fe12f81d..7dbc5e246d 100644 --- a/iroh-net/examples/listen-unreliable.rs +++ b/iroh-net/examples/listen-unreliable.rs @@ -38,7 +38,7 @@ async fn main() -> anyhow::Result<()> { println!("node listening addresses:"); let local_addrs = endpoint - .local_endpoints() + .direct_addresses() .next() .await .context("no endpoints")? @@ -52,7 +52,7 @@ async fn main() -> anyhow::Result<()> { .join(" "); let relay_url = endpoint - .my_relay() + .home_relay() .expect("should be connected to a relay server, try calling `endpoint.local_endpoints()` or `endpoint.connect()` first, to ensure the endpoint has actually attempted a connection before checking for the connected relay server"); println!("node relay server url: {relay_url}"); println!("\nin a separate terminal run:"); @@ -67,7 +67,8 @@ async fn main() -> anyhow::Result<()> { let conn = conn.await?; let node_id = iroh_net::endpoint::get_remote_node_id(&conn)?; info!( - "new (unreliable) connection from {node_id} with ALPN {alpn} (coming from {})", + "new (unreliable) connection from {node_id} with ALPN {} (coming from {})", + String::from_utf8_lossy(&alpn), conn.remote_address() ); // spawn a task to handle reading and writing off of the connection diff --git a/iroh-net/examples/listen.rs b/iroh-net/examples/listen.rs index 4d59472584..6f538534a4 100644 --- a/iroh-net/examples/listen.rs +++ b/iroh-net/examples/listen.rs @@ -38,7 +38,7 @@ async fn main() -> anyhow::Result<()> { println!("node listening addresses:"); let local_addrs = endpoint - .local_endpoints() + .direct_addresses() .next() .await .context("no endpoints")? @@ -52,7 +52,7 @@ async fn main() -> anyhow::Result<()> { .join(" "); let relay_url = endpoint - .my_relay() + .home_relay() .expect("should be connected to a relay server, try calling `endpoint.local_endpoints()` or `endpoint.connect()` first, to ensure the endpoint has actually attempted a connection before checking for the connected relay server"); println!("node relay server url: {relay_url}"); println!("\nin a separate terminal run:"); @@ -66,7 +66,8 @@ async fn main() -> anyhow::Result<()> { let conn = conn.await?; let node_id = iroh_net::endpoint::get_remote_node_id(&conn)?; info!( - "new connection from {node_id} with ALPN {alpn} (coming from {})", + "new connection from {node_id} with ALPN {} (coming from {})", + String::from_utf8_lossy(&alpn), conn.remote_address() ); diff --git a/iroh-net/src/bin/iroh-relay.rs b/iroh-net/src/bin/iroh-relay.rs index f9717a46a1..45e076e66c 100644 --- a/iroh-net/src/bin/iroh-relay.rs +++ b/iroh-net/src/bin/iroh-relay.rs @@ -1,49 +1,28 @@ -//! A simple relay server. +//! A simple relay server for iroh-net. //! -//! Based on /tailscale/cmd/derper +//! This handles only the CLI and config file loading, the server implementation lives in +//! [`iroh_net::relay::iroh_relay`]. -use std::{ - borrow::Cow, - future::Future, - net::{IpAddr, Ipv6Addr, SocketAddr}, - path::{Path, PathBuf}, - pin::Pin, - sync::Arc, -}; +use std::net::{Ipv6Addr, SocketAddr}; +use std::path::{Path, PathBuf}; use anyhow::{anyhow, bail, Context as _, Result}; use clap::Parser; -use futures_lite::StreamExt; -use http::{response::Builder as ResponseBuilder, HeaderMap}; -use hyper::body::Incoming; -use hyper::{Method, Request, Response, StatusCode}; -use iroh_metrics::inc; -use iroh_net::defaults::{DEFAULT_RELAY_STUN_PORT, NA_RELAY_HOSTNAME}; -use iroh_net::key::SecretKey; -use iroh_net::relay::http::{ - ServerBuilder as RelayServerBuilder, TlsAcceptor, TlsConfig as RelayTlsConfig, +use iroh_net::defaults::{ + DEFAULT_HTTPS_PORT, DEFAULT_HTTP_PORT, DEFAULT_METRICS_PORT, DEFAULT_STUN_PORT, }; -use iroh_net::relay::{self}; -use iroh_net::stun; +use iroh_net::key::SecretKey; +use iroh_net::relay::iroh_relay; use serde::{Deserialize, Serialize}; use serde_with::{serde_as, DisplayFromStr}; -use tokio::net::{TcpListener, UdpSocket}; use tokio_rustls_acme::{caches::DirCache, AcmeConfig}; -use tracing::{debug, debug_span, error, info, info_span, trace, warn, Instrument}; +use tracing::{debug, info}; use tracing_subscriber::{prelude::*, EnvFilter}; -use metrics::StunMetrics; - -type BytesBody = http_body_util::Full; -type HyperError = Box; -type HyperResult = std::result::Result; +/// The default `http_bind_port` when using `--dev`. +const DEV_MODE_HTTP_PORT: u16 = 3340; -/// Creates a new [`BytesBody`] with no content. -fn body_empty() -> BytesBody { - http_body_util::Full::new(hyper::body::Bytes::new()) -} - -/// A simple relay server. +/// A relay server for iroh-net. #[derive(Parser, Debug, Clone)] #[clap(version, about, long_about = None)] struct Cli { @@ -54,7 +33,10 @@ struct Cli { /// Running in dev mode will ignore any config file fields pertaining to TLS. #[clap(long, default_value_t = false)] dev: bool, - /// Config file path. Generate a default configuration file by supplying a path. + /// Path to the configuration file. + /// + /// If provided and no configuration file exists the default configuration will be + /// written to the file. #[clap(long, short)] config_path: Option, } @@ -65,73 +47,6 @@ enum CertMode { LetsEncrypt, } -impl CertMode { - async fn gen_server_config( - &self, - hostname: String, - contact: String, - is_production: bool, - dir: PathBuf, - ) -> Result<(Arc, TlsAcceptor)> { - let config = rustls::ServerConfig::builder() - .with_safe_defaults() - .with_no_client_auth(); - - match self { - CertMode::LetsEncrypt => { - let mut state = AcmeConfig::new(vec![hostname]) - .contact([format!("mailto:{contact}")]) - .cache_option(Some(DirCache::new(dir))) - .directory_lets_encrypt(is_production) - .state(); - - let config = config.with_cert_resolver(state.resolver()); - let acceptor = state.acceptor(); - - tokio::spawn( - async move { - while let Some(event) = state.next().await { - match event { - Ok(ok) => debug!("acme event: {:?}", ok), - Err(err) => error!("error: {:?}", err), - } - } - debug!("event stream finished"); - } - .instrument(info_span!("acme")), - ); - - Ok((Arc::new(config), TlsAcceptor::LetsEncrypt(acceptor))) - } - CertMode::Manual => { - // load certificates manually - let keyname = escape_hostname(&hostname); - let cert_path = dir.join(format!("{keyname}.crt")); - let key_path = dir.join(format!("{keyname}.key")); - - let (certs, secret_key) = tokio::task::spawn_blocking(move || { - let certs = load_certs(cert_path)?; - let key = load_secret_key(key_path)?; - anyhow::Ok((certs, key)) - }) - .await??; - - let config = config.with_single_cert(certs, secret_key)?; - let config = Arc::new(config); - let acceptor = tokio_rustls::TlsAcceptor::from(config.clone()); - - Ok((config, TlsAcceptor::Manual(acceptor))) - } - } - } -} - -fn escape_hostname(hostname: &str) -> Cow<'_, str> { - let unsafe_hostname_characters = - regex::Regex::new(r"[^a-zA-Z0-9-\.]").expect("regex manually checked"); - unsafe_hostname_characters.replace_all(hostname, "") -} - fn load_certs(filename: impl AsRef) -> Result> { let certfile = std::fs::File::open(filename).context("cannot open certificate file")?; let mut reader = std::io::BufReader::new(certfile); @@ -164,72 +79,203 @@ fn load_secret_key(filename: impl AsRef) -> Result { ); } +/// Configuration for the relay-server. +/// +/// This is (de)serialised to/from a TOML config file. #[serde_as] -#[derive(Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize)] struct Config { - /// [`SecretKey`] for this relay server. + /// The iroh [`SecretKey`] for this relay server. + /// + /// If not specified a new key will be generated and the config file will be re-written + /// using it. #[serde_as(as = "DisplayFromStr")] #[serde(default = "SecretKey::generate")] secret_key: SecretKey, - /// Server listen address. + /// Whether to enable the Relay server. /// - /// Defaults to `[::]:443`. + /// Defaults to `true`. /// - /// If the port address is 443, the relay server will issue a warning if it is started - /// without a `tls` config. - addr: SocketAddr, - - /// The UDP port on which to serve STUN. The listener is bound to the same IP (if any) as - /// specified in the `addr` field. Defaults to [`DEFAULT_RELAY_STUN_PORT`]. - stun_port: u16, - /// Certificate hostname. Defaults to [`NA_RELAY_HOSTNAME`]. - hostname: String, + /// Disabling will leave only the STUN server. The `http_bind_addr` and `tls` + /// configuration options will be ignored. + #[serde(default = "cfg_defaults::enable_relay")] + enable_relay: bool, + /// The socket address to bind the Relay HTTP server on. + /// + /// Defaults to `[::]:80`. + /// + /// When running with `--dev` defaults to [::]:3340`. If specified overrides these + /// defaults. + /// + /// The Relay server always starts an HTTP server, this specifies the socket this will + /// be bound on. If there is no `tls` configuration set all the HTTP relay services + /// will be bound on this socket. Otherwise most Relay HTTP services will run on the + /// `https_bind_addr` of the `tls` configuration section and only the captive portal + /// will be served from the HTTP socket. + http_bind_addr: Option, + /// TLS specific configuration. + /// + /// TLS is disabled if not present and the Relay server will serve all services over + /// plain HTTP. + /// + /// If disabled all services will run on plain HTTP. The `--dev` option disables this, + /// regardless of what is in the configuration file. + tls: Option, /// Whether to run a STUN server. It will bind to the same IP as the `addr` field. /// /// Defaults to `true`. + #[serde(default = "cfg_defaults::enable_stun")] enable_stun: bool, - /// Whether to run a relay server. The only reason to set this false is if you're decommissioning a - /// server but want to keep its bootstrap DNS functionality still running. + /// The socket address to bind the STUN server on. /// - /// Defaults to `true` - enable_relay: bool, - /// TLS specific configuration - tls: Option, - /// Rate limiting configuration + /// Defaults to using the `http_bind_addr` with the port set to [`DEFAULT_STUN_PORT`]. + stun_bind_addr: Option, + /// Rate limiting configuration. + /// + /// Disabled if not present. limits: Option, - #[cfg(feature = "metrics")] - /// Metrics serve address. If not set, metrics are not served. - metrics_addr: Option, + /// Whether to run the metrics server. + /// + /// Defaults to `true`, when the metrics feature is enabled. + #[serde(default = "cfg_defaults::enable_metrics")] + enable_metrics: bool, + /// Metrics serve address. + /// + /// Defaults to `http_bind_addr` with the port set to [`DEFAULT_METRICS_PORT`] + /// (`[::]:9090` when `http_bind_addr` is set to the default). + metrics_bind_addr: Option, +} + +impl Config { + fn http_bind_addr(&self) -> SocketAddr { + self.http_bind_addr + .unwrap_or((Ipv6Addr::UNSPECIFIED, DEFAULT_HTTP_PORT).into()) + } + + fn stun_bind_addr(&self) -> SocketAddr { + self.stun_bind_addr + .unwrap_or_else(|| SocketAddr::new(self.http_bind_addr().ip(), DEFAULT_STUN_PORT)) + } + + fn metrics_bind_addr(&self) -> SocketAddr { + self.metrics_bind_addr + .unwrap_or_else(|| SocketAddr::new(self.http_bind_addr().ip(), DEFAULT_METRICS_PORT)) + } +} + +impl Default for Config { + fn default() -> Self { + Self { + secret_key: SecretKey::generate(), + enable_relay: true, + http_bind_addr: None, + tls: None, + enable_stun: true, + stun_bind_addr: None, + limits: None, + enable_metrics: true, + metrics_bind_addr: None, + } + } } -#[derive(Serialize, Deserialize)] +/// Defaults for fields from [`Config`]. +/// +/// These are the defaults that serde will fill in. Other defaults depends on each other +/// and can not immediately be substituded by serde. +mod cfg_defaults { + pub(crate) fn enable_relay() -> bool { + true + } + + pub(crate) fn enable_stun() -> bool { + true + } + + pub(crate) fn enable_metrics() -> bool { + true + } + + pub(crate) mod tls_config { + pub(crate) fn prod_tls() -> bool { + true + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] struct TlsConfig { - /// Mode for getting a cert. possible options: 'Manual', 'LetsEncrypt' - /// When using manual mode, a certificate will be read from `.crt` and a secret key from - /// `.key`, with the `` being the escaped hostname. + /// The socket address to bind the Relay HTTPS server on. + /// + /// Defaults to the `http_bind_addr` with the port set to `443`. + https_bind_addr: Option, + /// Certificate hostname when using LetsEncrypt. + hostname: Option, + /// Mode for getting a cert. + /// + /// Possible options: 'Manual', 'LetsEncrypt'. cert_mode: CertMode, + /// Directory to store LetsEncrypt certs or read manual certificates from. + /// + /// Defaults to the servers' current working directory. + cert_dir: Option, + /// Path of where to read the certificate from for the `Manual` `cert_mode`. + /// + /// Defaults to `/default.crt`. + /// + /// Only used when `cert_mode` is `Manual`. + manual_cert_path: Option, + /// Path of where to read the private key from for the `Manual` `cert_mode`. + /// + /// Defaults to `/default.key`. + /// + /// Only used when `cert_mode` is `Manual`. + manual_key_path: Option, /// Whether to use the LetsEncrypt production or staging server. /// - /// While in development, LetsEncrypt prefers you to use the staging server. However, the staging server seems to - /// only use `ECDSA` keys. In their current set up, you can only get intermediate certificates - /// for `ECDSA` keys if you are on their "allowlist". The production server uses `RSA` keys, - /// which allow for issuing intermediate certificates in all normal circumstances. - /// So, to have valid certificates, we must use the LetsEncrypt production server. - /// Read more here: - /// Default is true. This field is ignored if we are not using `cert_mode: CertMode::LetsEncrypt`. + /// Default is `true`. + /// + /// Only used when `cert_mode` is `LetsEncrypt`. + /// + /// While in development, LetsEncrypt prefers you to use the staging server. However, + /// the staging server seems to only use `ECDSA` keys. In their current set up, you can + /// only get intermediate certificates for `ECDSA` keys if you are on their + /// "allowlist". The production server uses `RSA` keys, which allow for issuing + /// intermediate certificates in all normal circumstances. So, to have valid + /// certificates, we must use the LetsEncrypt production server. Read more here: + /// . + #[serde(default = "cfg_defaults::tls_config::prod_tls")] prod_tls: bool, /// The contact email for the tls certificate. - contact: String, - /// Directory to store LetsEncrypt certs or read certificates from, if TLS is used. - cert_dir: Option, - /// The port on which to serve a response for the captive portal probe over HTTP. /// - /// The listener is bound to the same IP as specified in the `addr` field. Defaults to 80. - /// This field is only read in we are serving the relay server over HTTPS. In that case, we must listen for requests for the `/generate_204` over a non-TLS connection. - captive_portal_port: Option, + /// Used when `cert_mode` is `LetsEncrypt`. + contact: Option, } -#[derive(Serialize, Deserialize)] +impl TlsConfig { + fn https_bind_addr(&self, cfg: &Config) -> SocketAddr { + self.https_bind_addr + .unwrap_or_else(|| SocketAddr::new(cfg.http_bind_addr().ip(), DEFAULT_HTTPS_PORT)) + } + + fn cert_dir(&self) -> PathBuf { + self.cert_dir.clone().unwrap_or_else(|| PathBuf::from(".")) + } + + fn cert_path(&self) -> PathBuf { + self.manual_cert_path + .clone() + .unwrap_or_else(|| self.cert_dir().join("default.crt")) + } + + fn key_path(&self) -> PathBuf { + self.manual_key_path + .clone() + .unwrap_or_else(|| self.cert_dir().join("default.key")) + } +} + +#[derive(Debug, Clone, Default, Serialize, Deserialize)] struct Limits { /// Rate limit for accepting new connection. Unlimited if not set. accept_conn_limit: Option, @@ -237,23 +283,6 @@ struct Limits { accept_conn_burst: Option, } -impl Default for Config { - fn default() -> Self { - Self { - secret_key: SecretKey::generate(), - addr: (Ipv6Addr::UNSPECIFIED, 443).into(), - stun_port: DEFAULT_RELAY_STUN_PORT, - hostname: NA_RELAY_HOSTNAME.into(), - enable_stun: true, - enable_relay: true, - tls: None, - limits: None, - #[cfg(feature = "metrics")] - metrics_addr: None, - } - } -} - impl Config { async fn load(opts: &Cli) -> Result { let config_path = if let Some(config_path) = &opts.config_path { @@ -274,12 +303,12 @@ impl Config { async fn read_from_file(path: impl AsRef) -> Result { if !path.as_ref().is_file() { - bail!("config-path must be a valid toml file"); + bail!("config-path must be a file"); } let config_ser = tokio::fs::read_to_string(&path) .await .context("unable to read config")?; - let config: Self = toml::from_str(&config_ser).context("unable to decode config")?; + let config: Self = toml::from_str(&config_ser).context("config file must be valid toml")?; if !config_ser.contains("secret_key") { info!("generating new secret key and updating config file"); config.write_to_file(path).await?; @@ -307,36 +336,6 @@ impl Config { } } -#[cfg(feature = "metrics")] -pub fn init_metrics_collection( - metrics_addr: Option, -) -> Option> { - use iroh_metrics::core::Metric; - - let rt = tokio::runtime::Handle::current(); - - // doesn't start the server if the address is None - if let Some(metrics_addr) = metrics_addr { - iroh_metrics::core::Core::init(|reg, metrics| { - metrics.insert(iroh_net::metrics::RelayMetrics::new(reg)); - metrics.insert(StunMetrics::new(reg)); - }); - - return Some(rt.spawn(async move { - if let Err(e) = iroh_metrics::metrics::start_metrics_server(metrics_addr).await { - eprintln!("Failed to start metrics server: {e}"); - } - })); - } - tracing::info!("Metrics server not started, no address provided"); - None -} - -/// Only used when in `dev` mode & the given port is `443` -const DEV_PORT: u16 = 3340; -/// Only used when tls is enabled & a captive protal port is not given -const DEFAULT_CAPTIVE_PORTAL_PORT: u16 = 80; - #[tokio::main] async fn main() -> Result<()> { tracing_subscriber::registry() @@ -345,470 +344,100 @@ async fn main() -> Result<()> { .init(); let cli = Cli::parse(); - let cfg = Config::load(&cli).await?; - - #[cfg(feature = "metrics")] - let metrics_fut = init_metrics_collection(cfg.metrics_addr); - - let r = run(cli.dev, cfg, None).await; - - #[cfg(feature = "metrics")] - if let Some(metrics_fut) = metrics_fut { - metrics_fut.abort(); - drop(metrics_fut); - } - r -} - -async fn run( - dev_mode: bool, - cfg: Config, - addr_sender: Option>, -) -> Result<()> { - let (addr, tls_config) = if dev_mode { - let port = if cfg.addr.port() != 443 { - cfg.addr.port() - } else { - DEV_PORT - }; - - let addr = SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), port); - info!(%addr, "Running in dev mode."); - (addr, None) - } else { - (cfg.addr, cfg.tls) - }; - - if let Some(tls_config) = &tls_config { - if let Some(captive_portal_port) = tls_config.captive_portal_port { - if addr.port() == captive_portal_port { - bail!("The main listening address {addr:?} and the `captive_portal_port` have the same port number."); - } + let mut cfg = Config::load(&cli).await?; + if cli.dev { + cfg.tls = None; + if cfg.http_bind_addr.is_none() { + cfg.http_bind_addr = Some((Ipv6Addr::UNSPECIFIED, DEV_MODE_HTTP_PORT).into()); } - } else if addr.port() == 443 { - // no tls config, but the port is 443 - warn!("The address port is 443, which is typically the expected tls port, but you have not supplied any tls configuration.\nIf you meant to run the relay server with tls enabled, adjust the config file to include tls configuration."); } + let relay_config = build_relay_config(cfg).await?; + debug!("{relay_config:#?}"); - // set up relay configuration details - let secret_key = if cfg.enable_relay { - Some(cfg.secret_key) - } else { - None - }; + let mut relay = iroh_relay::Server::spawn(relay_config).await?; - // run stun - let stun_task = if cfg.enable_stun { - Some(tokio::task::spawn(async move { - serve_stun(addr.ip(), cfg.stun_port).await - })) - } else { - None - }; - - // set up tls configuration details - let (tls_config, headers, captive_portal_port) = if let Some(tls_config) = tls_config { - let contact = tls_config.contact; - let is_production = tls_config.prod_tls; - let (config, acceptor) = tls_config - .cert_mode - .gen_server_config( - cfg.hostname.clone(), - contact, - is_production, - tls_config.cert_dir.unwrap_or_else(|| PathBuf::from(".")), - ) - .await?; - let mut headers = HeaderMap::new(); - for (name, value) in TLS_HEADERS.iter() { - headers.insert(*name, value.parse()?); - } - ( - Some(RelayTlsConfig { config, acceptor }), - headers, - tls_config - .captive_portal_port - .unwrap_or(DEFAULT_CAPTIVE_PORTAL_PORT), - ) - } else { - (None, HeaderMap::new(), 0) - }; - - let mut builder = RelayServerBuilder::new(addr) - .secret_key(secret_key.map(Into::into)) - .headers(headers) - .tls_config(tls_config.clone()) - .relay_override(Box::new(relay_disabled_handler)) - .request_handler(Method::GET, "/", Box::new(root_handler)) - .request_handler(Method::GET, "/index.html", Box::new(root_handler)) - .request_handler(Method::GET, "/derp/probe", Box::new(probe_handler)) - .request_handler(Method::GET, "/robots.txt", Box::new(robots_handler)); - // if tls is enabled, we need to serve this endpoint from a non-tls connection - // which we check for below - if tls_config.is_none() { - builder = builder.request_handler( - Method::GET, - "/generate_204", - Box::new(serve_no_content_handler), - ); + tokio::select! { + biased; + _ = tokio::signal::ctrl_c() => (), + _ = relay.task_handle() => (), } - let relay_server = builder.spawn().await?; - // captive portal detections must be served over HTTP - let captive_portal_task = if tls_config.is_some() { - let http_addr = SocketAddr::new(addr.ip(), captive_portal_port); - let task = serve_captive_portal_service(http_addr).await?; - Some(task) - } else { - None - }; - - if let Some(addr_sender) = addr_sender { - if let Err(e) = addr_sender.send(relay_server.addr()) { - bail!("Unable to send the local SocketAddr, the Sender was dropped - {e:?}"); - } - } - - tokio::signal::ctrl_c().await?; - // Shutdown all tasks - if let Some(task) = stun_task { - task.abort(); - } - if let Some(task) = captive_portal_task { - task.abort() - } - relay_server.shutdown().await; - - Ok(()) + relay.shutdown().await } -const NO_CONTENT_CHALLENGE_HEADER: &str = "X-Tailscale-Challenge"; -const NO_CONTENT_RESPONSE_HEADER: &str = "X-Tailscale-Response"; - -const NOTFOUND: &[u8] = b"Not Found"; -const RELAY_DISABLED: &[u8] = b"relay server disabled"; -const ROBOTS_TXT: &[u8] = b"User-agent: *\nDisallow: /\n"; -const INDEX: &[u8] = br#" -

RELAY

-

- This is an - Iroh Relay - server. -

-"#; - -const TLS_HEADERS: [(&str, &str); 2] = [ - ("Strict-Transport-Security", "max-age=63072000; includeSubDomains"), - ("Content-Security-Policy", "default-src 'none'; frame-ancestors 'none'; form-action 'none'; base-uri 'self'; block-all-mixed-content; plugin-types 'none'") -]; - -async fn serve_captive_portal_service(addr: SocketAddr) -> Result> { - let http_listener = TcpListener::bind(&addr) - .await - .context("failed to bind http")?; - let http_addr = http_listener.local_addr()?; - info!("[CaptivePortalService]: serving on {}", http_addr); - - let task = tokio::spawn( - async move { - loop { - match http_listener.accept().await { - Ok((stream, peer_addr)) => { - debug!( - "[CaptivePortalService] Connection opened from {}", - peer_addr - ); - let handler = CaptivePortalService; - - tokio::task::spawn(async move { - let stream = relay::MaybeTlsStreamServer::Plain(stream); - let stream = hyper_util::rt::TokioIo::new(stream); - if let Err(err) = hyper::server::conn::http1::Builder::new() - .serve_connection(stream, handler) - .with_upgrades() - .await - { - error!( - "[CaptivePortalService] Failed to serve connection: {:?}", - err - ); - } - }); - } - Err(err) => { - error!( - "[CaptivePortalService] failed to accept connection: {:#?}", - err - ); - } +/// Convert the TOML-loaded config to the [`iroh_relay::RelayConfig`] format. +async fn build_relay_config(cfg: Config) -> Result> { + let tls = match cfg.tls { + Some(ref tls) => { + let cert_config = match tls.cert_mode { + CertMode::Manual => { + let cert_path = tls.cert_path(); + let key_path = tls.key_path(); + // Could probably just do this blocking, we're only starting up. + let (private_key, certs) = tokio::task::spawn_blocking(move || { + let key = load_secret_key(key_path)?; + let certs = load_certs(cert_path)?; + anyhow::Ok((key, certs)) + }) + .await??; + iroh_relay::CertConfig::Manual { private_key, certs } } - } - } - .instrument(info_span!("captive-portal.service")), - ); - Ok(task) -} - -#[derive(Clone)] -struct CaptivePortalService; - -impl hyper::service::Service> for CaptivePortalService { - type Response = Response; - type Error = HyperError; - type Future = Pin> + Send>>; - - fn call(&self, req: Request) -> Self::Future { - match (req.method(), req.uri().path()) { - // Captive Portal checker - (&Method::GET, "/generate_204") => { - Box::pin(async move { serve_no_content_handler(req, Response::builder()) }) - } - _ => { - // Return 404 not found response. - let r = Response::builder() - .status(StatusCode::NOT_FOUND) - .body(NOTFOUND.into()) - .map_err(|err| Box::new(err) as HyperError); - Box::pin(async move { r }) - } - } - } -} - -fn relay_disabled_handler( - _r: Request, - response: ResponseBuilder, -) -> HyperResult> { - response - .status(StatusCode::NOT_FOUND) - .body(RELAY_DISABLED.into()) - .map_err(|err| Box::new(err) as HyperError) -} - -fn root_handler( - _r: Request, - response: ResponseBuilder, -) -> HyperResult> { - response - .status(StatusCode::OK) - .header("Content-Type", "text/html; charset=utf-8") - .body(INDEX.into()) - .map_err(|err| Box::new(err) as HyperError) -} - -/// HTTP latency queries -fn probe_handler( - _r: Request, - response: ResponseBuilder, -) -> HyperResult> { - response - .status(StatusCode::OK) - .header("Access-Control-Allow-Origin", "*") - .body(body_empty()) - .map_err(|err| Box::new(err) as HyperError) -} - -fn robots_handler( - _r: Request, - response: ResponseBuilder, -) -> HyperResult> { - response - .status(StatusCode::OK) - .body(ROBOTS_TXT.into()) - .map_err(|err| Box::new(err) as HyperError) -} - -/// For captive portal detection. -fn serve_no_content_handler( - r: Request, - mut response: ResponseBuilder, -) -> HyperResult> { - if let Some(challenge) = r.headers().get(NO_CONTENT_CHALLENGE_HEADER) { - if !challenge.is_empty() - && challenge.len() < 64 - && challenge - .as_bytes() - .iter() - .all(|c| is_challenge_char(*c as char)) - { - response = response.header( - NO_CONTENT_RESPONSE_HEADER, - format!("response {}", challenge.to_str()?), - ); - } - } - - response - .status(StatusCode::NO_CONTENT) - .body(body_empty()) - .map_err(|err| Box::new(err) as HyperError) -} - -fn is_challenge_char(c: char) -> bool { - // Semi-randomly chosen as a limited set of valid characters - c.is_ascii_lowercase() - || c.is_ascii_uppercase() - || c.is_ascii_digit() - || c == '.' - || c == '-' - || c == '_' -} - -async fn serve_stun(host: IpAddr, port: u16) { - match UdpSocket::bind((host, port)).await { - Ok(sock) => { - let addr = sock.local_addr().expect("socket just bound"); - info!(%addr, "running STUN server"); - server_stun_listener(sock) - .instrument(debug_span!("stun_server", %addr)) - .await; - } - Err(err) => { - error!( - "failed to open STUN listener at host {host} and port {port}: {:#?}", - err - ); - } - } -} - -async fn server_stun_listener(sock: UdpSocket) { - let sock = Arc::new(sock); - let mut buffer = vec![0u8; 64 << 10]; - loop { - match sock.recv_from(&mut buffer).await { - Ok((n, src_addr)) => { - inc!(StunMetrics, requests); - let pkt = buffer[..n].to_vec(); - let sock = sock.clone(); - tokio::task::spawn(async move { - if !stun::is(&pkt) { - debug!(%src_addr, "STUN: ignoring non stun packet"); - inc!(StunMetrics, bad_requests); - return; - } - match tokio::task::spawn_blocking(move || stun::parse_binding_request(&pkt)) - .await - { - Ok(Ok(txid)) => { - debug!(%src_addr, %txid, "STUN: received binding request"); - let res = match tokio::task::spawn_blocking(move || { - stun::response(txid, src_addr) - }) - .await - { - Ok(res) => res, - Err(err) => { - error!("JoinError: {err:#}"); - return; - } - }; - match sock.send_to(&res, src_addr).await { - Ok(len) => { - if len != res.len() { - warn!(%src_addr, %txid, "STUN: failed to write response sent: {}, but expected {}", len, res.len()); - } - match src_addr { - SocketAddr::V4(_) => { - inc!(StunMetrics, ipv4_success); - } - SocketAddr::V6(_) => { - inc!(StunMetrics, ipv6_success); - } - } - trace!(%src_addr, %txid, "STUN: sent {} bytes", len); - } - Err(err) => { - inc!(StunMetrics, failures); - warn!(%src_addr, %txid, "STUN: failed to write response: {:?}", err); - } - } - } - Ok(Err(err)) => { - inc!(StunMetrics, bad_requests); - warn!(%src_addr, "STUN: invalid binding request: {:?}", err); - } - Err(err) => error!("JoinError parsing STUN binding: {err:#}"), - } - }); - } - Err(err) => { - inc!(StunMetrics, failures); - warn!("STUN: failed to recv: {:?}", err); - } + CertMode::LetsEncrypt => { + let hostname = tls + .hostname + .clone() + .context("LetsEncrypt needs a hostname")?; + let contact = tls + .contact + .clone() + .context("LetsEncrypt needs a contact email")?; + let config = AcmeConfig::new(vec![hostname.clone()]) + .contact([format!("mailto:{}", contact)]) + .cache_option(Some(DirCache::new(tls.cert_dir()))) + .directory_lets_encrypt(tls.prod_tls); + iroh_relay::CertConfig::LetsEncrypt { config } + } + }; + Some(iroh_relay::TlsConfig { + https_bind_addr: tls.https_bind_addr(&cfg), + cert: cert_config, + }) } - } + None => None, + }; + let limits = iroh_relay::Limits { + accept_conn_limit: cfg + .limits + .as_ref() + .map(|l| l.accept_conn_limit) + .unwrap_or_default(), + accept_conn_burst: cfg + .limits + .as_ref() + .map(|l| l.accept_conn_burst) + .unwrap_or_default(), + }; + let relay_config = iroh_relay::RelayConfig { + secret_key: cfg.secret_key.clone(), + http_bind_addr: cfg.http_bind_addr(), + tls, + limits, + }; + let stun_config = iroh_relay::StunConfig { + bind_addr: cfg.stun_bind_addr(), + }; + Ok(iroh_relay::ServerConfig { + relay: Some(relay_config), + stun: Some(stun_config), + #[cfg(feature = "metrics")] + metrics_addr: if cfg.enable_metrics { + Some(cfg.metrics_bind_addr()) + } else { + None + }, + }) } -// var validProdHostname = regexp.MustCompile(`^relay([^.]*)\.tailscale\.com\.?$`) - -// func prodAutocertHostPolicy(_ context.Context, host string) error { -// if validProdHostname.MatchString(host) { -// return nil -// } -// return errors.New("invalid hostname") -// } - -// func rateLimitedListenAndServeTLS(srv *http.Server) error { -// addr := srv.Addr -// if addr == "" { -// addr = ":https" -// } -// ln, err := net.Listen("tcp", addr) -// if err != nil { -// return err -// } -// rln := newRateLimitedListener(ln, rate.Limit(*acceptConnLimit), *acceptConnBurst) -// expvar.Publish("tls_listener", rln.ExpVar()) -// defer rln.Close() -// return srv.ServeTLS(rln, "", "") -// } - -// type rateLimitedListener struct { -// // These are at the start of the struct to ensure 64-bit alignment -// // on 32-bit architecture regardless of what other fields may exist -// // in this package. -// numAccepts expvar.Int // does not include number of rejects -// numRejects expvar.Int - -// net.Listener - -// lim *rate.Limiter -// } - -// func newRateLimitedListener(ln net.Listener, limit rate.Limit, burst int) *rateLimitedListener { -// return &rateLimitedListener{Listener: ln, lim: rate.NewLimiter(limit, burst)} -// } - -// func (l *rateLimitedListener) ExpVar() expvar.Var { -// m := new(metrics.Set) -// m.Set("counter_accepted_connections", &l.numAccepts) -// m.Set("counter_rejected_connections", &l.numRejects) -// return m -// } - -// var errLimitedConn = errors.New("cannot accept connection; rate limited") - -// func (l *rateLimitedListener) Accept() (net.Conn, error) { -// // Even under a rate limited situation, we accept the connection immediately -// // and close it, rather than being slow at accepting new connections. -// // This provides two benefits: 1) it signals to the client that something -// // is going on on the server, and 2) it prevents new connections from -// // piling up and occupying resources in the OS kernel. -// // The client will retry as needing (with backoffs in place). -// cn, err := l.Listener.Accept() -// if err != nil { -// return nil, err -// } -// if !l.lim.Allow() { -// l.numRejects.Add(1) -// cn.Close() -// return nil, errLimitedConn -// } -// l.numAccepts.Add(1) -// return cn, nil -// } -// mod metrics { use iroh_metrics::{ core::{Counter, Metric}, @@ -856,206 +485,3 @@ mod metrics { } } } - -#[cfg(test)] -mod tests { - use super::*; - - use std::net::Ipv4Addr; - use std::time::Duration; - - use bytes::Bytes; - use http_body_util::BodyExt; - use iroh_base::node_addr::RelayUrl; - use iroh_net::relay::http::ClientBuilder; - use iroh_net::relay::ReceivedMessage; - use tokio::task::JoinHandle; - - #[tokio::test] - async fn test_serve_no_content_handler() { - let challenge = "123az__."; - let req = Request::builder() - .header(NO_CONTENT_CHALLENGE_HEADER, challenge) - .body(body_empty()) - .unwrap(); - - let res = serve_no_content_handler(req, Response::builder()).unwrap(); - assert_eq!(res.status(), StatusCode::NO_CONTENT); - - let header = res - .headers() - .get(NO_CONTENT_RESPONSE_HEADER) - .unwrap() - .to_str() - .unwrap(); - assert_eq!(header, format!("response {challenge}")); - assert!(res - .into_body() - .collect() - .await - .unwrap() - .to_bytes() - .is_empty()); - } - - #[test] - fn test_escape_hostname() { - assert_eq!( - escape_hostname("hello.host.name_foo-bar%baz"), - "hello.host.namefoo-barbaz" - ); - } - - struct DropServer { - server_task: JoinHandle<()>, - } - - impl Drop for DropServer { - fn drop(&mut self) { - self.server_task.abort(); - } - } - - #[tokio::test] - async fn test_relay_server_basic() -> Result<()> { - tracing_subscriber::registry() - .with(tracing_subscriber::fmt::layer().with_writer(std::io::stderr)) - .with(EnvFilter::from_default_env()) - .try_init() - .ok(); - // Binding to LOCALHOST to satisfy issues when binding to UNSPECIFIED in Windows for tests - // Binding to Ipv4 because, when binding to `IPv6::UNSPECIFIED`, it will also listen for - // IPv4 connections, but will not automatically do the same for `LOCALHOST`. In order to - // test STUN, which only listens on Ipv4, we must bind the whole relay server to Ipv4::LOCALHOST. - let cfg = Config { - addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0), - ..Default::default() - }; - let (addr_send, addr_recv) = tokio::sync::oneshot::channel(); - let relay_server_task = tokio::spawn( - async move { - // dev mode will bind to IPv6::UNSPECIFIED, so setting it `false` - let res = run(false, cfg, Some(addr_send)).await; - if let Err(e) = res { - eprintln!("error starting relay server {e}"); - } - } - .instrument(debug_span!("relay server")), - ); - let _drop_server = DropServer { - server_task: relay_server_task, - }; - - let relay_server_addr = addr_recv.await?; - let relay_server_str_url = format!("http://{}", relay_server_addr); - let relay_server_url: RelayUrl = relay_server_str_url.parse().unwrap(); - - // set up clients - let a_secret_key = SecretKey::generate(); - let a_key = a_secret_key.public(); - let resolver = iroh_net::dns::default_resolver().clone(); - let (client_a, mut client_a_receiver) = - ClientBuilder::new(relay_server_url.clone()).build(a_secret_key, resolver); - let connect_client = client_a.clone(); - - // give the relay server some time to set up - if let Err(e) = tokio::time::timeout(Duration::from_secs(10), async move { - loop { - match connect_client.connect().await { - Ok(_) => break, - Err(e) => { - tracing::warn!("client a unable to connect to relay server: {e:?}. Attempting to dial again in 10ms"); - tokio::time::sleep(Duration::from_millis(100)).await - } - } - } - }) - .await - { - bail!("error connecting client a to relay server: {e:?}"); - } - - let b_secret_key = SecretKey::generate(); - let b_key = b_secret_key.public(); - let resolver = iroh_net::dns::default_resolver().clone(); - let (client_b, mut client_b_receiver) = - ClientBuilder::new(relay_server_url.clone()).build(b_secret_key, resolver); - client_b.connect().await?; - - let msg = Bytes::from("hello, b"); - client_a.send(b_key, msg.clone()).await?; - - let (res, _) = client_b_receiver.recv().await.unwrap()?; - if let ReceivedMessage::ReceivedPacket { source, data } = res { - assert_eq!(a_key, source); - assert_eq!(msg, data); - } else { - bail!("client_b received unexpected message {res:?}"); - } - - let msg = Bytes::from("howdy, a"); - client_b.send(a_key, msg.clone()).await?; - - let (res, _) = client_a_receiver.recv().await.unwrap()?; - if let ReceivedMessage::ReceivedPacket { source, data } = res { - assert_eq!(b_key, source); - assert_eq!(msg, data); - } else { - bail!("client_a received unexpected message {res:?}"); - } - - // run stun check - let stun_addr: SocketAddr = - SocketAddr::new(IpAddr::V4(std::net::Ipv4Addr::LOCALHOST), 3478); - - let txid = stun::TransactionId::default(); - let req = stun::request(txid); - let socket = Arc::new(UdpSocket::bind("0.0.0.0:0").await.unwrap()); - - let server_socket = socket.clone(); - let server_task = tokio::task::spawn(async move { - let mut buf = vec![0u8; 64000]; - let len = server_socket.recv(&mut buf).await.unwrap(); - dbg!(len); - buf.truncate(len); - buf - }); - - tracing::info!("sending stun request to {stun_addr}"); - if let Err(e) = socket.send_to(&req, stun_addr).await { - bail!("socket.send_to error: {e:?}"); - } - - let response = server_task.await.unwrap(); - let (txid_back, response_addr) = stun::parse_response(&response).unwrap(); - assert_eq!(txid, txid_back); - tracing::info!("got {response_addr}"); - - // get 200 home page response - tracing::info!("send request for homepage"); - let res = reqwest::get(relay_server_str_url).await?; - assert!(res.status().is_success()); - tracing::info!("got OK"); - - // test captive portal - tracing::info!("test captive portal response"); - - let url = relay_server_url.join("/generate_204")?; - let challenge = "123az__."; - let client = reqwest::Client::new(); - let res = client - .get(url) - .header(NO_CONTENT_CHALLENGE_HEADER, challenge) - .send() - .await?; - assert_eq!(StatusCode::NO_CONTENT.as_u16(), res.status().as_u16()); - let header = res.headers().get(NO_CONTENT_RESPONSE_HEADER).unwrap(); - assert_eq!(header.to_str().unwrap(), format!("response {challenge}")); - let body = res.bytes().await?; - assert!(body.is_empty()); - - tracing::info!("got successful captive portal response"); - - Ok(()) - } -} diff --git a/iroh-net/src/config.rs b/iroh-net/src/config.rs deleted file mode 100644 index 8c98749810..0000000000 --- a/iroh-net/src/config.rs +++ /dev/null @@ -1,128 +0,0 @@ -//! Configuration types. - -use std::{collections::BTreeMap, fmt::Display, net::SocketAddr}; - -use crate::relay::RelayUrl; - -use super::portmapper; - -// TODO: This re-uses "Endpoint" again, a term that already means "a quic endpoint" and "a -// magicsock endpoint". this time it means "an IP address on which our local magicsock -// endpoint is listening". Name this better. -/// An endpoint IPPort and an associated type. -#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] -pub struct Endpoint { - /// The address of the endpoint. - pub addr: SocketAddr, - /// The kind of endpoint. - pub typ: EndpointType, -} - -/// Type of endpoint. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] -pub enum EndpointType { - /// Endpoint kind has not been determined yet. - Unknown, - /// Endpoint is bound to a local address. - Local, - /// Endpoint has a publicly reachable address found via STUN. - Stun, - /// Endpoint uses a port mapping in the router. - Portmapped, - /// Hard NAT: STUN'ed IPv4 address + local fixed port. - Stun4LocalPort, -} - -impl Display for EndpointType { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - EndpointType::Unknown => write!(f, "?"), - EndpointType::Local => write!(f, "local"), - EndpointType::Stun => write!(f, "stun"), - EndpointType::Portmapped => write!(f, "portmap"), - EndpointType::Stun4LocalPort => write!(f, "stun4localport"), - } - } -} - -/// Contains information about the host's network state. -#[derive(Debug, Clone, PartialEq)] -pub struct NetInfo { - /// Says whether the host's NAT mappings vary based on the destination IP. - pub mapping_varies_by_dest_ip: Option, - - /// If their router does hairpinning. It reports true even if there's no NAT involved. - pub hair_pinning: Option, - - /// Whether the host has IPv6 internet connectivity. - pub working_ipv6: Option, - - /// Whether the OS supports IPv6 at all, regardless of whether IPv6 internet connectivity is available. - pub os_has_ipv6: Option, - - /// Whether the host has UDP internet connectivity. - pub working_udp: Option, - - /// Whether ICMPv4 works, `None` means not checked. - pub working_icmp_v4: Option, - - /// Whether ICMPv6 works, `None` means not checked. - pub working_icmp_v6: Option, - - /// Whether we have an existing portmap open (UPnP, PMP, or PCP). - pub have_port_map: bool, - - /// Probe indicating the presence of port mapping protocols on the LAN. - pub portmap_probe: Option, - - /// This node's preferred relay server for incoming traffic. The node might be be temporarily - /// connected to multiple relay servers (to send to other nodes) - /// but PreferredRelay is the instance number that the node - /// subscribes to traffic at. Zero means disconnected or unknown. - pub preferred_relay: Option, - - /// LinkType is the current link type, if known. - pub link_type: Option, - - /// The fastest recent time to reach various relay STUN servers, in seconds. - /// - /// This should only be updated rarely, or when there's a - /// material change, as any change here also gets uploaded to the control plane. - pub relay_latency: BTreeMap, -} - -impl NetInfo { - /// reports whether `self` and `other` are basically equal, ignoring changes in relay ServerLatency & RelayLatency. - pub fn basically_equal(&self, other: &Self) -> bool { - let eq_icmp_v4 = match (self.working_icmp_v4, other.working_icmp_v4) { - (Some(slf), Some(other)) => slf == other, - _ => true, // ignore for comparison if only one report had this info - }; - let eq_icmp_v6 = match (self.working_icmp_v6, other.working_icmp_v6) { - (Some(slf), Some(other)) => slf == other, - _ => true, // ignore for comparison if only one report had this info - }; - self.mapping_varies_by_dest_ip == other.mapping_varies_by_dest_ip - && self.hair_pinning == other.hair_pinning - && self.working_ipv6 == other.working_ipv6 - && self.os_has_ipv6 == other.os_has_ipv6 - && self.working_udp == other.working_udp - && eq_icmp_v4 - && eq_icmp_v6 - && self.have_port_map == other.have_port_map - && self.portmap_probe == other.portmap_probe - && self.preferred_relay == other.preferred_relay - && self.link_type == other.link_type - } -} - -/// The type of link. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum LinkType { - /// A wired link (ethernet, fiber, etc). - Wired, - /// A WiFi link. - Wifi, - /// LTE, 4G, 3G, etc. - Mobile, -} diff --git a/iroh-net/src/defaults.rs b/iroh-net/src/defaults.rs index 22e6b1d6e9..7d0237a629 100644 --- a/iroh-net/src/defaults.rs +++ b/iroh-net/src/defaults.rs @@ -9,8 +9,20 @@ pub const NA_RELAY_HOSTNAME: &str = "use1-1.relay.iroh.network."; /// Hostname of the default EU relay. pub const EU_RELAY_HOSTNAME: &str = "euw1-1.relay.iroh.network."; -/// STUN port as defined by [RFC 8489]() -pub const DEFAULT_RELAY_STUN_PORT: u16 = 3478; +/// The default STUN port used by the Relay server. +/// +/// The STUN port as defined by [RFC +/// 8489]() +pub const DEFAULT_STUN_PORT: u16 = 3478; + +/// The default HTTP port used by the Relay server. +pub const DEFAULT_HTTP_PORT: u16 = 80; + +/// The default HTTPS port used by the Relay server. +pub const DEFAULT_HTTPS_PORT: u16 = 443; + +/// The default metrics port used by the Relay server. +pub const DEFAULT_METRICS_PORT: u16 = 9090; /// Get the default [`RelayMap`]. pub fn default_relay_map() -> RelayMap { @@ -27,7 +39,7 @@ pub fn default_na_relay_node() -> RelayNode { RelayNode { url: url.into(), stun_only: false, - stun_port: DEFAULT_RELAY_STUN_PORT, + stun_port: DEFAULT_STUN_PORT, } } @@ -40,6 +52,6 @@ pub fn default_eu_relay_node() -> RelayNode { RelayNode { url: url.into(), stun_only: false, - stun_port: DEFAULT_RELAY_STUN_PORT, + stun_port: DEFAULT_STUN_PORT, } } diff --git a/iroh-net/src/discovery.rs b/iroh-net/src/discovery.rs index 0b935621c2..60be8bd3ea 100644 --- a/iroh-net/src/discovery.rs +++ b/iroh-net/src/discovery.rs @@ -13,6 +13,9 @@ use crate::{AddrInfo, Endpoint, NodeId}; pub mod dns; pub mod pkarr_publish; +/// Name used for logging when new node addresses are added from discovery. +const SOURCE_NAME: &str = "discovery"; + /// Node discovery for [`super::Endpoint`]. /// /// The purpose of this trait is to hook up a node discovery mechanism that @@ -252,7 +255,7 @@ impl DiscoveryTask { info: r.addr_info, node_id, }; - ep.add_node_addr(addr).ok(); + ep.add_node_addr_with_source(addr, SOURCE_NAME).ok(); if let Some(tx) = on_first_tx.take() { tx.send(Ok(())).ok(); } @@ -416,7 +419,7 @@ mod tests { }; let ep1_addr = NodeAddr::new(ep1.node_id()); // wait for out address to be updated and thus published at least once - ep1.my_addr().await?; + ep1.node_addr().await?; let _conn = ep2.connect(ep1_addr, TEST_ALPN).await?; Ok(()) } @@ -442,7 +445,7 @@ mod tests { }; let ep1_addr = NodeAddr::new(ep1.node_id()); // wait for out address to be updated and thus published at least once - ep1.my_addr().await?; + ep1.node_addr().await?; let _conn = ep2.connect(ep1_addr, TEST_ALPN).await?; Ok(()) } @@ -472,7 +475,7 @@ mod tests { }; let ep1_addr = NodeAddr::new(ep1.node_id()); // wait for out address to be updated and thus published at least once - ep1.my_addr().await?; + ep1.node_addr().await?; let _conn = ep2.connect(ep1_addr, TEST_ALPN).await?; Ok(()) } @@ -495,7 +498,7 @@ mod tests { }; let ep1_addr = NodeAddr::new(ep1.node_id()); // wait for out address to be updated and thus published at least once - ep1.my_addr().await?; + ep1.node_addr().await?; let res = ep2.connect(ep1_addr, TEST_ALPN).await; assert!(res.is_err()); Ok(()) @@ -518,7 +521,7 @@ mod tests { new_endpoint(secret, disco).await }; // wait for out address to be updated and thus published at least once - ep1.my_addr().await?; + ep1.node_addr().await?; let ep1_wrong_addr = NodeAddr { node_id: ep1.node_id(), info: AddrInfo { @@ -669,11 +672,6 @@ mod test_dns_pkarr { // wait until our shared state received the update from pkarr publishing dns_pkarr_server.on_node(&ep1.node_id(), timeout).await?; - let node_addr = NodeAddr::new(ep1.node_id()); - - // add empty node address. We *should* launch discovery before attempting to dial. - ep2.add_node_addr(node_addr)?; - // we connect only by node id! let res = ep2.connect(ep1.node_id().into(), TEST_ALPN).await; assert!(res.is_ok(), "connection established"); diff --git a/iroh-net/src/dns.rs b/iroh-net/src/dns.rs index 1ac64c2f7f..bcd5ebc15a 100644 --- a/iroh-net/src/dns.rs +++ b/iroh-net/src/dns.rs @@ -387,17 +387,6 @@ pub(crate) mod tests { const TIMEOUT: Duration = Duration::from_secs(5); const STAGGERING_DELAYS: &[u64] = &[200, 300]; - #[tokio::test] - #[cfg_attr(target_os = "windows", ignore = "flaky")] - async fn test_dns_lookup_basic() { - let _logging = iroh_test::logging::setup(); - let resolver = default_resolver(); - let res = resolver.lookup_ip(NA_RELAY_HOSTNAME).await.unwrap(); - let res: Vec<_> = res.iter().collect(); - assert!(!res.is_empty()); - dbg!(res); - } - #[tokio::test] async fn test_dns_lookup_ipv4_ipv6() { let _logging = iroh_test::logging::setup(); diff --git a/iroh-net/src/endpoint.rs b/iroh-net/src/endpoint.rs index dd1a56d274..b741f47178 100644 --- a/iroh-net/src/endpoint.rs +++ b/iroh-net/src/endpoint.rs @@ -1,6 +1,15 @@ -//! An endpoint that leverages a [`quinn::Endpoint`] and transparently routes packages via direct -//! conenctions or a relay when necessary, optimizing the path to target nodes to ensure maximum -//! connectivity. +//! The [`Endpoint`] allows establishing connections to other iroh-net nodes. +//! +//! The [`Endpoint`] is the main API interface to manage a local iroh-net node. It allows +//! connecting to and accepting connections from other nodes. See the [module docs] for +//! more details on how iroh-net connections work. +//! +//! The main items in this module are: +//! +//! - [`Endpoint`] to establish iroh-net connections with other nodes. +//! - [`Builder`] to create an [`Endpoint`]. +//! +//! [module docs]: crate use std::any::Any; use std::future::Future; @@ -19,7 +28,6 @@ use tracing::{debug, info_span, trace, warn}; use url::Url; use crate::{ - config, defaults::default_relay_map, discovery::{Discovery, DiscoveryTask}, dns::{default_resolver, DnsResolver}, @@ -39,17 +47,25 @@ pub use quinn::{ }; pub use super::magicsock::{ - ConnectionInfo, ConnectionType, ConnectionTypeStream, ControlMsg, DirectAddrInfo, - LocalEndpointsStream, + ConnectionInfo, ConnectionType, ConnectionTypeStream, ControlMsg, DirectAddr, DirectAddrInfo, + DirectAddrType, DirectAddrsStream, }; pub use iroh_base::node_addr::{AddrInfo, NodeAddr}; -/// The delay we add before starting a discovery in [`Endpoint::connect`] if the user provided -/// new direct addresses (to try these addresses before starting the discovery). +/// The delay to fall back to discovery when direct addresses fail. +/// +/// When a connection is attempted with a [`NodeAddr`] containing direct addresses the +/// [`Endpoint`] assumes one of those addresses probably works. If after this delay there +/// is still no connection the configured [`Discovery`] will be used however. const DISCOVERY_WAIT_PERIOD: Duration = Duration::from_millis(500); -/// Builder for [Endpoint] +/// Builder for [`Endpoint`]. +/// +/// By default the endpoint will generate a new random [`SecretKey`], which will result in a +/// new [`NodeId`]. +/// +/// To create the [`Endpoint`] call [`Builder::bind`]. #[derive(Debug)] pub struct Builder { secret_key: Option, @@ -87,117 +103,136 @@ impl Default for Builder { } impl Builder { - /// Set a secret key to authenticate with other peers. + // The ordering of public methods is reflected directly in the documentation. This is + // roughly ordered by what is most commonly needed by users. + + // # The final constructor that everyone needs. + + /// Binds the magic endpoint on the specified socket address. /// - /// This secret key's public key will be the [PublicKey] of this endpoint. + /// The *bind_port* is the port that should be bound locally. + /// The port will be used to bind an IPv4 and, if supported, and IPv6 socket. + /// You can pass `0` to let the operating system choose a free port for you. /// - /// If not set, a new secret key will be generated. - pub fn secret_key(mut self, secret_key: SecretKey) -> Self { - self.secret_key = Some(secret_key); - self - } + /// NOTE: This will be improved soon to add support for binding on specific addresses. + pub async fn bind(self, bind_port: u16) -> Result { + let relay_map = match self.relay_mode { + RelayMode::Disabled => RelayMap::empty(), + RelayMode::Default => default_relay_map(), + RelayMode::Custom(relay_map) => { + ensure!(!relay_map.is_empty(), "Empty custom relay server map",); + relay_map + } + }; + let secret_key = self.secret_key.unwrap_or_else(SecretKey::generate); + let static_config = StaticConfig { + transport_config: Arc::new(self.transport_config.unwrap_or_default()), + keylog: self.keylog, + concurrent_connections: self.concurrent_connections, + secret_key: secret_key.clone(), + }; + let dns_resolver = self + .dns_resolver + .unwrap_or_else(|| default_resolver().clone()); - /// Set the ALPN protocols that this endpoint will accept on incoming connections. - pub fn alpns(mut self, alpn_protocols: Vec>) -> Self { - self.alpn_protocols = alpn_protocols; - self + let msock_opts = magicsock::Options { + port: bind_port, + secret_key, + relay_map, + nodes_path: self.peers_path, + discovery: self.discovery, + proxy_url: self.proxy_url, + dns_resolver, + #[cfg(any(test, feature = "test-utils"))] + insecure_skip_relay_cert_verify: self.insecure_skip_relay_cert_verify, + }; + Endpoint::bind(static_config, msock_opts, self.alpn_protocols).await } - /// Set an explicit proxy url to proxy all HTTP(S) traffic through. - pub fn proxy_url(mut self, url: Url) -> Self { - self.proxy_url.replace(url); - self - } + // # The very common methods everyone basically needs. - /// Set the proxy url from the environment, in this order: + /// Sets a secret key to authenticate with other peers. /// - /// - `HTTP_PROXY` - /// - `http_proxy` - /// - `HTTPS_PROXY` - /// - `https_proxy` - pub fn proxy_from_env(mut self) -> Self { - self.proxy_url = proxy_url_from_env(); - self - } - - /// If *keylog* is `true` and the KEYLOGFILE environment variable is present it will be - /// considered a filename to which the TLS pre-master keys are logged. This can be useful - /// to be able to decrypt captured traffic for debugging purposes. - pub fn keylog(mut self, keylog: bool) -> Self { - self.keylog = keylog; + /// This secret key's public key will be the [`PublicKey`] of this endpoint and thus + /// also its [`NodeId`] + /// + /// If not set, a new secret key will be generated. + pub fn secret_key(mut self, secret_key: SecretKey) -> Self { + self.secret_key = Some(secret_key); self } - /// Skip verification of SSL certificates from relay servers + /// Sets the [ALPN] protocols that this endpoint will accept on incoming connections. /// - /// May only be used in tests. - #[cfg(any(test, feature = "test-utils"))] - pub fn insecure_skip_relay_cert_verify(mut self, skip_verify: bool) -> Self { - self.insecure_skip_relay_cert_verify = skip_verify; + /// Not setting this will still allow creating connections, but to accept incoming + /// connections the [ALPN] must be set. + /// + /// [ALPN]: https://en.wikipedia.org/wiki/Application-Layer_Protocol_Negotiation + pub fn alpns(mut self, alpn_protocols: Vec>) -> Self { + self.alpn_protocols = alpn_protocols; self } + // # Methods for common customisation items. + /// Sets the relay servers to assist in establishing connectivity. /// - /// relay servers are used to discover other peers by [`PublicKey`] and also help - /// establish connections between peers by being an initial relay for traffic while - /// assisting in holepunching to establish a direct connection between peers. + /// Relay servers are used to establish initial connection with another iroh-net node. + /// They also perform various functions related to hole punching, see the [crate docs] + /// for more details. + /// + /// By default the Number0 relay servers are used. /// /// When using [RelayMode::Custom], the provided `relay_map` must contain at least one /// configured relay node. If an invalid [`RelayMap`] is provided [`bind`] /// will result in an error. /// /// [`bind`]: Builder::bind + /// [crate docs]: crate pub fn relay_mode(mut self, relay_mode: RelayMode) -> Self { self.relay_mode = relay_mode; self } - /// Set a custom [quinn::TransportConfig] for this endpoint. + /// Optionally sets a discovery mechanism for this endpoint. /// - /// The transport config contains parameters governing the QUIC state machine. + /// If you want to combine multiple discovery services, you can pass a + /// [`crate::discovery::ConcurrentDiscovery`]. /// - /// If unset, the default config is used. Default values should be suitable for most internet - /// applications. Applications protocols which forbid remotely-initiated streams should set - /// `max_concurrent_bidi_streams` and `max_concurrent_uni_streams` to zero. - pub fn transport_config(mut self, transport_config: quinn::TransportConfig) -> Self { - self.transport_config = Some(transport_config); - self - } - - /// Maximum number of simultaneous connections to accept. + /// If no discovery service is set, connecting to a node without providing its + /// direct addresses or relay URLs will fail. /// - /// New incoming connections are only accepted if the total number of incoming or outgoing - /// connections is less than this. Outgoing connections are unaffected. - pub fn concurrent_connections(mut self, concurrent_connections: u32) -> Self { - self.concurrent_connections = Some(concurrent_connections); + /// See the documentation of the [`Discovery`] trait for details. + pub fn discovery(mut self, discovery: Box) -> Self { + self.discovery = Some(discovery); self } - /// Optionally set the path where peer info should be stored. + /// Optionally sets the path where peer info should be stored. /// - /// If the file exists, it will be used to populate an initial set of peers. Peers will be - /// saved periodically and on shutdown to this path. + /// If the file exists, it will be used to populate an initial set of peers. Peers will + /// be saved periodically and on shutdown to this path. pub fn peers_data_path(mut self, path: PathBuf) -> Self { self.peers_path = Some(path); self } - /// Optionally set a discovery mechanism for this endpoint. - /// - /// If you want to combine multiple discovery services, you can pass a - /// [`crate::discovery::ConcurrentDiscovery`]. + // # Methods for more specialist customisation. + + /// Sets a custom [`quinn::TransportConfig`] for this endpoint. /// - /// If no discovery service is set, connecting to a node without providing its - /// direct addresses or relay URLs will fail. + /// The transport config contains parameters governing the QUIC state machine. /// - /// See the documentation of the [`Discovery`] trait for details. - pub fn discovery(mut self, discovery: Box) -> Self { - self.discovery = Some(discovery); + /// If unset, the default config is used. Default values should be suitable for most + /// internet applications. Applications protocols which forbid remotely-initiated + /// streams should set `max_concurrent_bidi_streams` and `max_concurrent_uni_streams` to + /// zero. + pub fn transport_config(mut self, transport_config: quinn::TransportConfig) -> Self { + self.transport_config = Some(transport_config); self } - /// Optionally set a custom DNS resolver to use for this endpoint. + /// Optionally sets a custom DNS resolver to use for this endpoint. /// /// The DNS resolver is used to resolve relay hostnames, and node addresses if /// [`crate::discovery::dns::DnsDiscovery`] is configured. @@ -210,104 +245,151 @@ impl Builder { self } - /// Bind the magic endpoint on the specified socket address. + /// Sets an explicit proxy url to proxy all HTTP(S) traffic through. + pub fn proxy_url(mut self, url: Url) -> Self { + self.proxy_url.replace(url); + self + } + + /// Sets the proxy url from the environment, in this order: /// - /// The *bind_port* is the port that should be bound locally. - /// The port will be used to bind an IPv4 and, if supported, and IPv6 socket. - /// You can pass `0` to let the operating system choose a free port for you. - /// NOTE: This will be improved soon to add support for binding on specific addresses. - pub async fn bind(self, bind_port: u16) -> Result { - let relay_map = match self.relay_mode { - RelayMode::Disabled => RelayMap::empty(), - RelayMode::Default => default_relay_map(), - RelayMode::Custom(relay_map) => { - ensure!(!relay_map.is_empty(), "Empty custom relay server map",); - relay_map - } - }; - let secret_key = self.secret_key.unwrap_or_else(SecretKey::generate); + /// - `HTTP_PROXY` + /// - `http_proxy` + /// - `HTTPS_PROXY` + /// - `https_proxy` + pub fn proxy_from_env(mut self) -> Self { + self.proxy_url = proxy_url_from_env(); + self + } + + /// Enables saving the TLS pre-master key for connections. + /// + /// This key should normally remain secret but can be useful to debug networking issues + /// by decrypting captured traffic. + /// + /// If *keylog* is `true` then setting the `KEYLOGFILE` environment variable to a + /// filename will result in this file being used to log the TLS pre-master keys. + pub fn keylog(mut self, keylog: bool) -> Self { + self.keylog = keylog; + self + } + + /// Skip verification of SSL certificates from relay servers + /// + /// May only be used in tests. + #[cfg(any(test, feature = "test-utils"))] + pub fn insecure_skip_relay_cert_verify(mut self, skip_verify: bool) -> Self { + self.insecure_skip_relay_cert_verify = skip_verify; + self + } + + /// Maximum number of simultaneous connections to accept. + /// + /// New incoming connections are only accepted if the total number of incoming or + /// outgoing connections is less than this. Outgoing connections are unaffected. + pub fn concurrent_connections(mut self, concurrent_connections: u32) -> Self { + self.concurrent_connections = Some(concurrent_connections); + self + } +} + +/// Configuration for a [`quinn::Endpoint`] that cannot be changed at runtime. +#[derive(Debug)] +struct StaticConfig { + secret_key: SecretKey, + transport_config: Arc, + keylog: bool, + concurrent_connections: Option, +} + +impl StaticConfig { + /// Create a [`quinn::ServerConfig`] with the specified ALPN protocols. + fn create_server_config(&self, alpn_protocols: Vec>) -> Result { let mut server_config = make_server_config( - &secret_key, - self.alpn_protocols, - self.transport_config, + &self.secret_key, + alpn_protocols, + self.transport_config.clone(), self.keylog, )?; if let Some(c) = self.concurrent_connections { server_config.concurrent_connections(c); } - let dns_resolver = self - .dns_resolver - .unwrap_or_else(|| default_resolver().clone()); - - let msock_opts = magicsock::Options { - port: bind_port, - secret_key, - relay_map, - nodes_path: self.peers_path, - discovery: self.discovery, - proxy_url: self.proxy_url, - dns_resolver, - #[cfg(any(test, feature = "test-utils"))] - insecure_skip_relay_cert_verify: self.insecure_skip_relay_cert_verify, - }; - Endpoint::bind(Some(server_config), msock_opts, self.keylog).await + Ok(server_config) } } -/// Create a [`quinn::ServerConfig`] with the given secret key and limits. +/// Creates a [`quinn::ServerConfig`] with the given secret key and limits. pub fn make_server_config( secret_key: &SecretKey, alpn_protocols: Vec>, - transport_config: Option, + transport_config: Arc, keylog: bool, ) -> Result { let tls_server_config = tls::make_server_config(secret_key, alpn_protocols, keylog)?; let mut server_config = quinn::ServerConfig::with_crypto(Arc::new(tls_server_config)); - server_config.transport_config(Arc::new(transport_config.unwrap_or_default())); - + server_config.transport_config(transport_config); Ok(server_config) } -/// Iroh connectivity layer. +/// Controls an iroh-net node, establishing connections with other nodes. +/// +/// This is the main API interface to create connections to, and accept connections from +/// other iroh-net nodes. The connections are peer-to-peer and encrypted, a Relay server is +/// used to make the connections reliable. See the [crate docs] for a more detailed +/// overview of iroh-net. +/// +/// It is recommended to only create a single instance per application. This ensures all +/// the connections made share the same peer-to-peer connections to other iroh-net nodes, +/// while still remaining independent connections. This will result in more optimal network +/// behaviour. +/// +/// New connections are typically created using the [`Endpoint::connect`] and +/// [`Endpoint::accept`] methods. Once established, the [`Connection`] gives access to most +/// [QUIC] features. Individual streams to send data to the peer are created using the +/// [`Connection::open_bi`], [`Connection::accept_bi`], [`Connection::open_uni`] and +/// [`Connection::open_bi`] functions. /// -/// This is responsible for routing packets to nodes based on node IDs, it will initially route -/// packets via a relay and transparently try and establish a node-to-node connection and upgrade -/// to it. It will also keep looking for better connections as the network details of both nodes -/// change. +/// Note that due to the light-weight properties of streams a stream will only be accepted +/// once the initiating peer has sent some data on it. /// -/// It is usually only necessary to use a single [`Endpoint`] instance in an application, it -/// means any QUIC endpoints on top will be sharing as much information about nodes as possible. +/// [QUIC]: https://quicwg.org #[derive(Clone, Debug)] pub struct Endpoint { - secret_key: Arc, msock: Handle, endpoint: quinn::Endpoint, rtt_actor: Arc, - keylog: bool, cancel_token: CancellationToken, + static_config: Arc, } impl Endpoint { - /// Build an [`Endpoint`] + // The ordering of public methods is reflected directly in the documentation. This is + // roughly ordered by what is most commonly needed by users, but grouped in similar + // items. + + // # Methods relating to construction. + + /// Returns the builder for an [`Endpoint`]. pub fn builder() -> Builder { Builder::default() } - /// Create a quinn endpoint backed by a magicsock. + /// Creates a quinn endpoint backed by a magicsock. /// /// This is for internal use, the public interface is the [`Builder`] obtained from /// [Self::builder]. See the methods on the builder for documentation of the parameters. async fn bind( - server_config: Option, + static_config: StaticConfig, msock_opts: magicsock::Options, - keylog: bool, + initial_alpns: Vec>, ) -> Result { - let secret_key = msock_opts.secret_key.clone(); - let span = info_span!("magic_ep", me = %secret_key.public().fmt_short()); + let span = info_span!("magic_ep", me = %static_config.secret_key.public().fmt_short()); let _guard = span.enter(); let msock = magicsock::MagicSock::spawn(msock_opts).await?; trace!("created magicsock"); + let server_config = static_config.create_server_config(initial_alpns)?; + let mut endpoint_config = quinn::EndpointConfig::default(); // Setting this to false means that quinn will ignore packets that have the QUIC fixed bit // set to 0. The fixed bit is the 3rd bit of the first byte of a packet. @@ -318,180 +400,53 @@ impl Endpoint { let endpoint = quinn::Endpoint::new_with_abstract_socket( endpoint_config, - server_config, + Some(server_config), msock.clone(), Arc::new(quinn::TokioRuntime), )?; trace!("created quinn endpoint"); Ok(Self { - secret_key: Arc::new(secret_key), msock, endpoint, rtt_actor: Arc::new(rtt_actor::RttHandle::new()), - keylog, cancel_token: CancellationToken::new(), + static_config: Arc::new(static_config), }) } - /// Accept an incoming connection on the socket. - pub fn accept(&self) -> Accept<'_> { - Accept { - inner: self.endpoint.accept(), - magic_ep: self.clone(), - } - } - - /// Get the node id of this endpoint. - pub fn node_id(&self) -> NodeId { - self.secret_key.public() - } - - /// Get the secret_key of this endpoint. - pub fn secret_key(&self) -> &SecretKey { - &self.secret_key - } - - /// Optional reference to the discovery mechanism. - pub fn discovery(&self) -> Option<&dyn Discovery> { - self.msock.discovery() - } - - /// Get the local endpoint addresses on which the underlying magic socket is bound. - /// - /// Returns a tuple of the IPv4 and the optional IPv6 address. - pub fn local_addr(&self) -> (SocketAddr, Option) { - self.msock.local_addr() - } - - /// Returns the local endpoints as a stream. - /// - /// The [`Endpoint`] continuously monitors the local endpoints, the network - /// addresses it can listen on, for changes. Whenever changes are detected this stream - /// will yield a new list of endpoints. - /// - /// Upon the first creation, the first local endpoint discovery might still be underway, in - /// this case the first item of the stream will not be immediately available. Once this first - /// set of local endpoints are discovered the stream will always return the first set of - /// endpoints immediately, which are the most recently discovered endpoints. - /// - /// The list of endpoints yielded contains both the locally-bound addresses and the - /// endpoint's publicly-reachable addresses, if they could be discovered through STUN or - /// port mapping. - /// - /// # Examples - /// - /// To get the current endpoints, drop the stream after the first item was received: - /// ``` - /// use futures_lite::StreamExt; - /// use iroh_net::Endpoint; - /// - /// # let rt = tokio::runtime::Builder::new_current_thread().enable_all().build().unwrap(); - /// # rt.block_on(async move { - /// let mep = Endpoint::builder().bind(0).await.unwrap(); - /// let _endpoints = mep.local_endpoints().next().await; - /// # }); - /// ``` - pub fn local_endpoints(&self) -> LocalEndpointsStream { - self.msock.local_endpoints() - } - - /// Get the relay url we are connected to with the lowest latency. - /// - /// Returns `None` if we are not connected to any relayer. - pub fn my_relay(&self) -> Option { - self.msock.my_relay() - } - - /// Get the [`NodeAddr`] for this endpoint. - pub async fn my_addr(&self) -> Result { - let addrs = self - .local_endpoints() - .next() - .await - .ok_or(anyhow!("No endpoints found"))?; - let relay = self.my_relay(); - let addrs = addrs.into_iter().map(|x| x.addr).collect(); - Ok(NodeAddr::from_parts(self.node_id(), relay, addrs)) - } - - /// Get the [`NodeAddr`] for this endpoint, while providing the endpoints. - pub fn my_addr_with_endpoints(&self, eps: Vec) -> Result { - let relay = self.my_relay(); - let addrs = eps.into_iter().map(|x| x.addr).collect(); - Ok(NodeAddr::from_parts(self.node_id(), relay, addrs)) - } - - /// Watch for changes to the home relay. - /// - /// Note that this can be used to wait for the initial home relay to be known. If the home - /// relay is known at this point, it will be the first item in the stream. - pub fn watch_home_relay(&self) -> impl Stream { - self.msock.watch_home_relay() - } - - /// Get information on all the nodes we have connection information about. - /// - /// Includes the node's [`PublicKey`], potential relay Url, its addresses with any known - /// latency, and its [`ConnectionType`], which let's us know if we are currently communicating - /// with that node over a `Direct` (UDP) or `Relay` (relay) connection. - /// - /// Connections are currently only pruned on user action (when we explicitly add a new address - /// to the internal addressbook through [`Endpoint::add_node_addr`]), so these connections - /// are not necessarily active connections. - pub fn connection_infos(&self) -> Vec { - self.msock.connection_infos() - } - - /// Get connection information about a specific node. - /// - /// Includes the node's [`PublicKey`], potential relay Url, its addresses with any known - /// latency, and its [`ConnectionType`], which let's us know if we are currently communicating - /// with that node over a `Direct` (UDP) or `Relay` (relay) connection. - pub fn connection_info(&self, node_id: PublicKey) -> Option { - self.msock.connection_info(node_id) - } - - pub(crate) fn cancelled(&self) -> WaitForCancellationFuture<'_> { - self.cancel_token.cancelled() - } - - /// Connect to a remote endpoint, using just the nodes's [`PublicKey`]. - pub async fn connect_by_node_id( - &self, - node_id: &PublicKey, - alpn: &[u8], - ) -> Result { - let addr = NodeAddr::new(*node_id); - self.connect(addr, alpn).await - } - - /// Returns a stream that reports changes in the [`ConnectionType`] for the given `node_id`. + /// Set the list of accepted ALPN protocols. /// - /// # Errors - /// - /// Will error if we do not have any address information for the given `node_id` - pub fn conn_type_stream(&self, node_id: &PublicKey) -> Result { - self.msock.conn_type_stream(node_id) + /// This will only affect new incoming connections. + /// Note that this *overrides* the current list of ALPNs. + pub fn set_alpns(&self, alpns: Vec>) -> Result<()> { + let server_config = self.static_config.create_server_config(alpns)?; + self.endpoint.set_server_config(Some(server_config)); + Ok(()) } - /// Connect to a remote endpoint. + // # Methods for establishing connectivity. + + /// Connects to a remote [`Endpoint`]. /// - /// A [`NodeAddr`] is required. It must contain the [`NodeId`] to dial and may also contain a - /// relay URL and direct addresses. If direct addresses are provided, they will be used to - /// try and establish a direct connection without involving a relay server. + /// A [`NodeAddr`] is required. It must contain the [`NodeId`] to dial and may also + /// contain a [`RelayUrl`] and direct addresses. If direct addresses are provided, they + /// will be used to try and establish a direct connection without involving a relay + /// server. /// - /// The `alpn`, or application-level protocol identifier, is also required. The remote endpoint - /// must support this `alpn`, otherwise the connection attempt will fail with an error. + /// If neither a [`RelayUrl`] or direct addresses are configured in the [`NodeAddr`] it + /// may still be possible a connection can be established. This depends on other calls + /// to [`Endpoint::add_node_addr`] which may provide contact information, or via the + /// [`Discovery`] service configured using [`Builder::discovery`]. The discovery + /// service will also be used if the remote node is not reachable on the provided direct + /// addresses and there is no [`RelayUrl`]. /// - /// If the [`NodeAddr`] contains only [`NodeId`] and no direct addresses and no relay servers, - /// a discovery service will be invoked, if configured, to try and discover the node's - /// addressing information. The discovery services must be configured globally per [`Endpoint`] - /// with [`Builder::discovery`]. The discovery service will also be invoked if - /// none of the existing or provided direct addresses are reachable. + /// If addresses or relay servers are neither provided nor can be discovered, the + /// connection attempt will fail with an error. /// - /// If addresses or relay servers are neither provided nor can be discovered, the connection - /// attempt will fail with an error. + /// The `alpn`, or application-level protocol identifier, is also required. The remote + /// endpoint must support this `alpn`, otherwise the connection attempt will fail with + /// an error. pub async fn connect(&self, node_addr: NodeAddr, alpn: &[u8]) -> Result { // Connecting to ourselves is not supported. if node_addr.node_id == self.node_id() { @@ -531,6 +486,21 @@ impl Endpoint { conn } + /// Connects to a remote endpoint, using just the nodes's [`NodeId`]. + /// + /// This is a convenience function for [`Endpoint::connect`]. It relies on addressing + /// information being provided by either the discovery service or using + /// [`Endpoint::add_node_addr`]. See [`Endpoint::connect`] for the details of how it + /// uses the discovery service to establish a connection to a remote node. + pub async fn connect_by_node_id( + &self, + node_id: &NodeId, + alpn: &[u8], + ) -> Result { + let addr = NodeAddr::new(*node_id); + self.connect(addr, alpn).await + } + async fn connect_quinn( &self, node_id: &PublicKey, @@ -540,10 +510,10 @@ impl Endpoint { let client_config = { let alpn_protocols = vec![alpn.to_vec()]; let tls_client_config = tls::make_client_config( - &self.secret_key, + &self.static_config.secret_key, Some(*node_id), alpn_protocols, - self.keylog, + self.static_config.keylog, )?; let mut client_config = quinn::ClientConfig::new(Arc::new(tls_client_config)); let mut transport_config = quinn::TransportConfig::default(); @@ -561,7 +531,7 @@ impl Endpoint { let rtt_msg = RttMessage::NewConnection { connection: connection.weak_handle(), - conn_type_changes: self.conn_type_stream(node_id)?, + conn_type_changes: self.conn_type_stream(*node_id)?, node_id: *node_id, }; if let Err(err) = self.rtt_actor.msg_tx.send(rtt_msg).await { @@ -572,6 +542,293 @@ impl Endpoint { Ok(connection) } + /// Accepts an incoming connection on the endpoint. + /// + /// Only connections with the ALPNs configured in [`Builder::alpns`] will be accepted. + /// If multiple ALPNs have been configured the ALPN can be inspected before accepting + /// the connection using [`Connecting::alpn`]. + pub fn accept(&self) -> Accept<'_> { + Accept { + inner: self.endpoint.accept(), + magic_ep: self.clone(), + } + } + + // # Methods for manipulating the internal state about other nodes. + + /// Informs this [`Endpoint`] about addresses of the iroh-net node. + /// + /// This updates the local state for the remote node. If the provided [`NodeAddr`] + /// contains a [`RelayUrl`] this will be used as the new relay server for this node. If + /// it contains any new IP endpoints they will also be stored and tried when next + /// connecting to this node. Any address that matches this node's direct addresses will be + /// silently ignored. + /// + /// See also [`Endpoint::add_node_addr_with_source`]. + /// + /// # Errors + /// + /// Will return an error if we attempt to add our own [`PublicKey`] to the node map or if the + /// direct addresses are a subset of ours. + pub fn add_node_addr(&self, node_addr: NodeAddr) -> Result<()> { + self.add_node_addr_inner(node_addr, magicsock::Source::App) + } + + /// Informs this [`Endpoint`] about addresses of the iroh-net node, noting the source. + /// + /// This updates the local state for the remote node. If the provided [`NodeAddr`] contains a + /// [`RelayUrl`] this will be used as the new relay server for this node. If it contains any + /// new IP endpoints they will also be stored and tried when next connecting to this node. Any + /// address that matches this node's direct addresses will be silently ignored. The *source* is + /// used for logging exclusively and will not be stored. + /// + /// # Errors + /// + /// Will return an error if we attempt to add our own [`PublicKey`] to the node map or if the + /// direct addresses are a subset of ours. + pub fn add_node_addr_with_source( + &self, + node_addr: NodeAddr, + source: &'static str, + ) -> Result<()> { + self.add_node_addr_inner(node_addr, magicsock::Source::NamedApp { name: source }) + } + + fn add_node_addr_inner(&self, node_addr: NodeAddr, source: magicsock::Source) -> Result<()> { + // Connecting to ourselves is not supported. + if node_addr.node_id == self.node_id() { + bail!( + "Adding our own address is not supported ({} is the node id of this node)", + node_addr.node_id.fmt_short() + ); + } + self.msock.add_node_addr(node_addr, source) + } + + // # Getter methods for properties of this Endpoint itself. + + /// Returns the secret_key of this endpoint. + pub fn secret_key(&self) -> &SecretKey { + &self.static_config.secret_key + } + + /// Returns the node id of this endpoint. + /// + /// This ID is the unique addressing information of this node and other peers must know + /// it to be able to connect to this node. + pub fn node_id(&self) -> NodeId { + self.static_config.secret_key.public() + } + + /// Returns the current [`NodeAddr`] for this endpoint. + /// + /// The returned [`NodeAddr`] will have the current [`RelayUrl`] and local IP endpoints + /// as they would be returned by [`Endpoint::home_relay`] and + /// [`Endpoint::direct_addresses`]. + pub async fn node_addr(&self) -> Result { + let addrs = self + .direct_addresses() + .next() + .await + .ok_or(anyhow!("No IP endpoints found"))?; + let relay = self.home_relay(); + let addrs = addrs.into_iter().map(|x| x.addr).collect(); + Ok(NodeAddr::from_parts(self.node_id(), relay, addrs)) + } + + /// Returns the [`RelayUrl`] of the Relay server used as home relay. + /// + /// Every endpoint has a home Relay server which it chooses as the server with the + /// lowest latency out of the configured servers provided by [`Builder::relay_mode`]. + /// This is the server other iroh-net nodes can use to reliably establish a connection + /// to this node. + /// + /// Returns `None` if we are not connected to any Relay server. + /// + /// Note that this will be `None` right after the [`Endpoint`] is created since it takes + /// some time to connect to find and connect to the home relay server. Use + /// [`Endpoint::watch_home_relay`] to wait until the home relay server is available. + pub fn home_relay(&self) -> Option { + self.msock.my_relay() + } + + /// Watches for changes to the home relay. + /// + /// If there is currently a home relay it will be yielded immediately as the first item + /// in the stream. This makes it possible to use this function to wait for the initial + /// home relay to be known. + /// + /// Note that it is not guaranteed that a home relay will ever become available. If no + /// servers are configured with [`Builder::relay_mode`] this stream will never yield an + /// item. + pub fn watch_home_relay(&self) -> impl Stream { + self.msock.watch_home_relay() + } + + /// Returns the direct addresses of this [`Endpoint`]. + /// + /// The direct addresses of the [`Endpoint`] are those that could be used by other + /// iroh-net nodes to establish direct connectivity, depending on the network + /// situation. The yielded lists of direct addresses contain both the locally-bound + /// addresses and the [`Endpoint`]'s publicly reachable addresses discovered through + /// mechanisms such as [STUN] and port mapping. Hence usually only a subset of these + /// will be applicable to a certain remote iroh-net node. + /// + /// The [`Endpoint`] continuously monitors the direct addresses for changes as its own + /// location in the network might change. Whenever changes are detected this stream + /// will yield a new list of direct addresses. + /// + /// When issuing the first call to this method the first direct address discovery might + /// still be underway, in this case the first item of the returned stream will not be + /// immediately available. Once this first set of local IP endpoints are discovered the + /// stream will always return the first set of IP endpoints immediately, which are the + /// most recently discovered IP endpoints. + /// + /// # Examples + /// + /// To get the current endpoints, drop the stream after the first item was received: + /// ``` + /// use futures_lite::StreamExt; + /// use iroh_net::Endpoint; + /// + /// # let rt = tokio::runtime::Builder::new_current_thread().enable_all().build().unwrap(); + /// # rt.block_on(async move { + /// let mep = Endpoint::builder().bind(0).await.unwrap(); + /// let _addrs = mep.direct_addresses().next().await; + /// # }); + /// ``` + /// + /// [STUN]: https://en.wikipedia.org/wiki/STUN + pub fn direct_addresses(&self) -> DirectAddrsStream { + self.msock.direct_addresses() + } + + /// Returns the local socket addresses on which the underlying sockets are bound. + /// + /// The [`Endpoint`] always binds on an IPv4 address and also tries to bind on an IPv6 + /// address if available. + pub fn bound_sockets(&self) -> (SocketAddr, Option) { + self.msock.local_addr() + } + + // # Getter methods for information about other nodes. + + /// Returns connection information about a specific node. + /// + /// Then [`Endpoint`] stores some information about all the other iroh-net nodes it has + /// information about. This includes information about the relay server in use, any + /// known direct addresses, when there was last any conact with this node and what kind + /// of connection this was. + pub fn connection_info(&self, node_id: NodeId) -> Option { + self.msock.connection_info(node_id) + } + + /// Returns information on all the nodes we have connection information about. + /// + /// This returns the same information as [`Endpoint::connection_info`] for each node + /// known to this [`Endpoint`]. + /// + /// Connections are currently only pruned on user action when using + /// [`Endpoint::add_node_addr`] so these connections are not necessarily active + /// connections. + pub fn connection_infos(&self) -> Vec { + self.msock.connection_infos() + } + + // # Methods for less common getters. + // + // Partially they return things passed into the builder. + + /// Returns a stream that reports connection type changes for the remote node. + /// + /// This returns a stream of [`ConnectionType`] items, each time the underlying + /// connection to a remote node changes it yields an item. These connection changes are + /// when the connection switches between using the Relay server and a direct connection. + /// + /// If there is currently a connection with the remote node the first item in the stream + /// will yield immediately returning the current connection type. + /// + /// Note that this does not guarantee each connection change is yielded in the stream. + /// If the connection type changes several times before this stream is polled only the + /// last recorded state is returned. This can be observed e.g. right at the start of a + /// connection when the switch from a relayed to a direct connection can be so fast that + /// the relayed state is never exposed. + /// + /// # Errors + /// + /// Will error if we do not have any address information for the given `node_id`. + pub fn conn_type_stream(&self, node_id: NodeId) -> Result { + self.msock.conn_type_stream(node_id) + } + + /// Returns the DNS resolver used in this [`Endpoint`]. + /// + /// See [`Builder::discovery`]. + pub fn dns_resolver(&self) -> &DnsResolver { + self.msock.dns_resolver() + } + + /// Returns the discovery mechanism, if configured. + /// + /// See [`Builder::dns_resolver`]. + pub fn discovery(&self) -> Option<&dyn Discovery> { + self.msock.discovery() + } + + // # Methods for less common state updates. + + /// Notifies the system of potential network changes. + /// + /// On many systems iroh is able to detect network changes by itself, however + /// some systems like android do not expose this functionality to native code. + /// Android does however provide this functionality to Java code. This + /// function allows for notifying iroh of any potential network changes like + /// this. + /// + /// Even when the network did not change, or iroh was already able to detect + /// the network change itself, there is no harm in calling this function. + pub async fn network_change(&self) { + self.msock.network_change().await; + } + + // # Methods for terminating the endpoint. + + /// Closes the QUIC endpoint and the magic socket. + /// + /// This will close all open QUIC connections with the provided error_code and + /// reason. See [`quinn::Connection`] for details on how these are interpreted. + /// + /// It will then wait for all connections to actually be shutdown, and afterwards + /// close the magic socket. + /// + /// Returns an error if closing the magic socket failed. + /// TODO: Document error cases. + pub async fn close(self, error_code: VarInt, reason: &[u8]) -> Result<()> { + let Endpoint { + msock, + endpoint, + cancel_token, + .. + } = self; + cancel_token.cancel(); + tracing::debug!("Closing connections"); + endpoint.close(error_code, reason); + endpoint.wait_idle().await; + // In case this is the last clone of `Endpoint`, dropping the `quinn::Endpoint` will + // make it more likely that the underlying socket is not polled by quinn anymore after this + drop(endpoint); + tracing::debug!("Connections closed"); + + msock.close().await?; + Ok(()) + } + + // # Remaining private methods + + pub(crate) fn cancelled(&self) -> WaitForCancellationFuture<'_> { + self.cancel_token.cancelled() + } + /// Return the quic mapped address for this `node_id` and possibly start discovery /// services if discovery is enabled on this magic endpoint. /// @@ -594,7 +851,7 @@ impl Endpoint { // Only return a mapped addr if we have some way of dialing this node, in other // words, we have either a relay URL or at least one direct address. let addr = if self.msock.has_send_address(node_id) { - self.msock.get_mapping_addr(&node_id) + self.msock.get_mapping_addr(node_id) } else { None }; @@ -622,7 +879,7 @@ impl Endpoint { let mut discovery = DiscoveryTask::start(self.clone(), node_id)?; discovery.first_arrived().await?; if self.msock.has_send_address(node_id) { - let addr = self.msock.get_mapping_addr(&node_id).expect("checked"); + let addr = self.msock.get_mapping_addr(node_id).expect("checked"); Ok((addr, Some(discovery))) } else { bail!("Failed to retrieve the mapped address from the magic socket. Unable to dial node {node_id:?}"); @@ -631,77 +888,6 @@ impl Endpoint { } } - /// Inform the magic socket about addresses of the peer. - /// - /// This updates the magic socket's *netmap* with these addresses, which are used as candidates - /// when connecting to this peer (in addition to addresses obtained from a relay server). - /// - /// Note: updating the magic socket's *netmap* will also prune any connections that are *not* - /// present in the netmap. - /// - /// # Errors - /// Will return an error if we attempt to add our own [`PublicKey`] to the node map. - pub fn add_node_addr(&self, node_addr: NodeAddr) -> Result<()> { - // Connecting to ourselves is not supported. - if node_addr.node_id == self.node_id() { - bail!( - "Adding our own address is not supported ({} is the node id of this node)", - node_addr.node_id.fmt_short() - ); - } - self.msock.add_node_addr(node_addr); - Ok(()) - } - - /// Get a reference to the DNS resolver used in this [`Endpoint`]. - pub fn dns_resolver(&self) -> &DnsResolver { - self.msock.dns_resolver() - } - - /// Close the QUIC endpoint and the magic socket. - /// - /// This will close all open QUIC connections with the provided error_code and reason. See - /// [quinn::Connection] for details on how these are interpreted. - /// - /// It will then wait for all connections to actually be shutdown, and afterwards - /// close the magic socket. - /// - /// Returns an error if closing the magic socket failed. - /// TODO: Document error cases. - pub async fn close(self, error_code: VarInt, reason: &[u8]) -> Result<()> { - let Endpoint { - msock, - endpoint, - cancel_token, - .. - } = self; - cancel_token.cancel(); - tracing::debug!("Closing connections"); - endpoint.close(error_code, reason); - endpoint.wait_idle().await; - // In case this is the last clone of `Endpoint`, dropping the `quinn::Endpoint` will - // make it more likely that the underlying socket is not polled by quinn anymore after this - drop(endpoint); - tracing::debug!("Connections closed"); - - msock.close().await?; - Ok(()) - } - - /// Call to notify the system of potential network changes. - /// - /// On many systems iroh is able to detect network changes by itself, however - /// some systems like android do not expose this functionality to native code. - /// Android does however provide this functionality to Java code. This - /// function allows for notifying iroh of any potential network changes like - /// this. - /// - /// Even when the network did not change, or iroh was already able to detect - /// the network change itself, there is no harm in calling this function. - pub async fn network_change(&self) { - self.msock.network_change().await; - } - #[cfg(test)] pub(crate) fn magic_sock(&self) -> Handle { self.msock.clone() @@ -779,11 +965,11 @@ impl Connecting { /// Extracts the ALPN protocol from the peer's handshake data. // Note, we could totally provide this method to be on a Connection as well. But we'd // need to wrap Connection too. - pub async fn alpn(&mut self) -> Result { + pub async fn alpn(&mut self) -> Result> { let data = self.handshake_data().await?; match data.downcast::() { Ok(data) => match data.protocol { - Some(protocol) => std::string::String::from_utf8(protocol).map_err(Into::into), + Some(protocol) => Ok(protocol), None => bail!("no ALPN protocol available"), }, Err(_) => bail!("unknown handshake type"), @@ -838,7 +1024,7 @@ fn try_send_rtt_msg(conn: &quinn::Connection, magic_ep: &Endpoint) { warn!(?conn, "failed to get remote node id"); return; }; - let Ok(conn_type_changes) = magic_ep.conn_type_stream(&peer_id) else { + let Ok(conn_type_changes) = magic_ep.conn_type_stream(peer_id) else { warn!(?conn, "failed to create conn_type_stream"); return; }; @@ -938,7 +1124,7 @@ mod tests { .bind(0) .await .unwrap(); - let my_addr = ep.my_addr().await.unwrap(); + let my_addr = ep.node_addr().await.unwrap(); let res = ep.connect(my_addr.clone(), TEST_ALPN).await; assert!(res.is_err()); let err = res.err().unwrap(); @@ -1117,7 +1303,7 @@ mod tests { .bind(0) .await .unwrap(); - let eps = ep.local_addr(); + let eps = ep.bound_sockets(); info!(me = %ep.node_id().fmt_short(), ipv4=%eps.0, ipv6=?eps.1, "server bound"); for i in 0..n_clients { let now = Instant::now(); @@ -1162,7 +1348,7 @@ mod tests { .bind(0) .await .unwrap(); - let eps = ep.local_addr(); + let eps = ep.bound_sockets(); info!(me = %ep.node_id().fmt_short(), ipv4=%eps.0, ipv6=?eps.1, "client bound"); let node_addr = NodeAddr::new(server_node_id).with_relay_url(relay_url); info!(to = ?node_addr, "client connecting"); @@ -1212,8 +1398,8 @@ mod tests { .bind(0) .await .unwrap(); - let ep1_nodeaddr = ep1.my_addr().await.unwrap(); - let ep2_nodeaddr = ep2.my_addr().await.unwrap(); + let ep1_nodeaddr = ep1.node_addr().await.unwrap(); + let ep2_nodeaddr = ep2.node_addr().await.unwrap(); ep1.add_node_addr(ep2_nodeaddr.clone()).unwrap(); ep2.add_node_addr(ep1_nodeaddr.clone()).unwrap(); let ep1_nodeid = ep1.node_id(); @@ -1236,7 +1422,7 @@ mod tests { let conn = incoming.await.unwrap(); let node_id = get_remote_node_id(&conn).unwrap(); assert_eq!(node_id, src); - assert_eq!(alpn.as_bytes(), TEST_ALPN); + assert_eq!(alpn, TEST_ALPN); let (mut send, mut recv) = conn.accept_bi().await.unwrap(); let m = recv.read_to_end(100).await.unwrap(); assert_eq!(m, b"hello"); @@ -1257,8 +1443,9 @@ mod tests { #[tokio::test] async fn endpoint_conn_type_stream() { + const TIMEOUT: Duration = std::time::Duration::from_secs(15); let _logging_guard = iroh_test::logging::setup(); - let (relay_map, relay_url, _relay_guard) = run_relay_server().await.unwrap(); + let (relay_map, _relay_url, _relay_guard) = run_relay_server().await.unwrap(); let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(42); let ep1_secret_key = SecretKey::generate_with_rng(&mut rng); let ep2_secret_key = SecretKey::generate_with_rng(&mut rng); @@ -1279,76 +1466,62 @@ mod tests { .await .unwrap(); - async fn handle_direct_conn(ep: Endpoint, node_id: PublicKey) -> Result<()> { - let node_addr = NodeAddr::new(node_id); - ep.add_node_addr(node_addr)?; - let stream = ep.conn_type_stream(&node_id)?; - async fn get_direct_event( - src: &PublicKey, - dst: &PublicKey, - mut stream: ConnectionTypeStream, - ) -> Result<()> { - let src = src.fmt_short(); - let dst = dst.fmt_short(); - while let Some(conn_type) = stream.next().await { - tracing::info!(me = %src, dst = %dst, conn_type = ?conn_type); - if matches!(conn_type, ConnectionType::Direct(_)) { - return Ok(()); - } + async fn handle_direct_conn(ep: &Endpoint, node_id: PublicKey) -> Result<()> { + let mut stream = ep.conn_type_stream(node_id)?; + let src = ep.node_id().fmt_short(); + let dst = node_id.fmt_short(); + while let Some(conn_type) = stream.next().await { + tracing::info!(me = %src, dst = %dst, conn_type = ?conn_type); + if matches!(conn_type, ConnectionType::Direct(_)) { + return Ok(()); } - anyhow::bail!("conn_type stream ended before `ConnectionType::Direct`"); } - tokio::time::timeout( - Duration::from_secs(15), - get_direct_event(&ep.node_id(), &node_id, stream), - ) - .await??; - Ok(()) + anyhow::bail!("conn_type stream ended before `ConnectionType::Direct`"); + } + + async fn accept(ep: &Endpoint) -> NodeId { + let incoming = ep.accept().await.unwrap(); + let conn = incoming.await.unwrap(); + let node_id = get_remote_node_id(&conn).unwrap(); + tracing::info!(node_id=%node_id.fmt_short(), "accepted connection"); + node_id } let ep1_nodeid = ep1.node_id(); let ep2_nodeid = ep2.node_id(); - let ep1_nodeaddr = ep1.my_addr().await.unwrap(); + let ep1_nodeaddr = ep1.node_addr().await.unwrap(); tracing::info!( "node id 1 {ep1_nodeid}, relay URL {:?}", ep1_nodeaddr.relay_url() ); tracing::info!("node id 2 {ep2_nodeid}"); - let res_ep1 = tokio::spawn(handle_direct_conn(ep1.clone(), ep2_nodeid)); + let ep1_side = async move { + accept(&ep1).await; + handle_direct_conn(&ep1, ep2_nodeid).await + }; + + let ep2_side = async move { + ep2.connect(ep1_nodeaddr, TEST_ALPN).await.unwrap(); + handle_direct_conn(&ep2, ep1_nodeid).await + }; + + let res_ep1 = tokio::spawn(tokio::time::timeout(TIMEOUT, ep1_side)); let ep1_abort_handle = res_ep1.abort_handle(); let _ep1_guard = CallOnDrop::new(move || { ep1_abort_handle.abort(); }); - let res_ep2 = tokio::spawn(handle_direct_conn(ep2.clone(), ep1_nodeid)); + let res_ep2 = tokio::spawn(tokio::time::timeout(TIMEOUT, ep2_side)); let ep2_abort_handle = res_ep2.abort_handle(); let _ep2_guard = CallOnDrop::new(move || { ep2_abort_handle.abort(); }); - async fn accept(ep: Endpoint) -> NodeId { - let incoming = ep.accept().await.unwrap(); - let conn = incoming.await.unwrap(); - get_remote_node_id(&conn).unwrap() - } - - // create a node addr with no direct connections - let ep1_nodeaddr = NodeAddr::from_parts(ep1_nodeid, Some(relay_url), vec![]); - - let accept_res = tokio::spawn(accept(ep1.clone())); - let accept_abort_handle = accept_res.abort_handle(); - let _accept_guard = CallOnDrop::new(move || { - accept_abort_handle.abort(); - }); - - let _conn_2 = ep2.connect(ep1_nodeaddr, TEST_ALPN).await.unwrap(); - - let got_id = accept_res.await.unwrap(); - assert_eq!(ep2_nodeid, got_id); - res_ep1.await.unwrap().unwrap(); - res_ep2.await.unwrap().unwrap(); + let (r1, r2) = tokio::try_join!(res_ep1, res_ep2).unwrap(); + r1.expect("ep1 timeout").unwrap(); + r2.expect("ep2 timeout").unwrap(); } } diff --git a/iroh-net/src/lib.rs b/iroh-net/src/lib.rs index 5cba9c3892..8e54ac70e2 100644 --- a/iroh-net/src/lib.rs +++ b/iroh-net/src/lib.rs @@ -117,7 +117,6 @@ #![recursion_limit = "256"] #![deny(missing_docs, rustdoc::broken_intra_doc_links)] -pub mod config; pub mod defaults; pub mod dialer; mod disco; diff --git a/iroh-net/src/magicsock.rs b/iroh-net/src/magicsock.rs index 080e99b985..37e6da695a 100644 --- a/iroh-net/src/magicsock.rs +++ b/iroh-net/src/magicsock.rs @@ -16,7 +16,7 @@ //! however, read any packets that come off the UDP sockets. use std::{ - collections::HashMap, + collections::{BTreeMap, HashMap}, fmt::Display, io, net::{IpAddr, Ipv6Addr, SocketAddr}, @@ -33,6 +33,7 @@ use std::{ use anyhow::{anyhow, Context as _, Result}; use bytes::Bytes; use futures_lite::{FutureExt, Stream, StreamExt}; +use iroh_base::key::NodeId; use iroh_metrics::{inc, inc_by}; use quinn::AsyncUdpSocket; use rand::{seq::SliceRandom, Rng, SeedableRng}; @@ -50,7 +51,6 @@ use url::Url; use watchable::Watchable; use crate::{ - config, disco::{self, SendAddr}, discovery::Discovery, dns::DnsResolver, @@ -80,6 +80,7 @@ pub use self::node_map::{ ConnectionType, ConnectionTypeStream, ControlMsg, DirectAddrInfo, NodeInfo as ConnectionInfo, }; pub(super) use self::timer::Timer; +pub(crate) use node_map::Source; /// How long we consider a STUN-derived endpoint valid for. UDP NAT mappings typically /// expire at 30 seconds, so this is a few seconds shy of that. @@ -95,37 +96,37 @@ const NETCHECK_REPORT_TIMEOUT: Duration = Duration::from_secs(10); /// Contains options for `MagicSock::listen`. #[derive(derive_more::Debug)] -pub(super) struct Options { +pub(crate) struct Options { /// The port to listen on. /// Zero means to pick one automatically. - pub port: u16, + pub(crate) port: u16, /// Secret key for this node. - pub secret_key: SecretKey, + pub(crate) secret_key: SecretKey, /// The [`RelayMap`] to use, leave empty to not use a relay server. - pub relay_map: RelayMap, + pub(crate) relay_map: RelayMap, /// Path to store known nodes. - pub nodes_path: Option, + pub(crate) nodes_path: Option, /// Optional node discovery mechanism. - pub discovery: Option>, + pub(crate) discovery: Option>, /// A DNS resolver to use for resolving relay URLs. /// /// You can use [`crate::dns::default_resolver`] for a resolver that uses the system's DNS /// configuration. - pub dns_resolver: DnsResolver, + pub(crate) dns_resolver: DnsResolver, /// Proxy configuration. - pub proxy_url: Option, + pub(crate) proxy_url: Option, /// Skip verification of SSL certificates from relay servers /// /// May only be used in tests. #[cfg(any(test, feature = "test-utils"))] - pub insecure_skip_relay_cert_verify: bool, + pub(crate) insecure_skip_relay_cert_verify: bool, } impl Default for Options { @@ -146,13 +147,13 @@ impl Default for Options { /// Contents of a relay message. Use a SmallVec to avoid allocations for the very /// common case of a single packet. -pub(super) type RelayContents = SmallVec<[Bytes; 1]>; +type RelayContents = SmallVec<[Bytes; 1]>; /// Handle for [`MagicSock`]. /// /// Dereferences to [`MagicSock`], and handles closing. #[derive(Clone, Debug, derive_more::Deref)] -pub(super) struct Handle { +pub(crate) struct Handle { #[deref(forward)] msock: Arc, // Empty when closed @@ -170,7 +171,7 @@ pub(super) struct Handle { /// means any QUIC endpoints on top will be sharing as much information about nodes as /// possible. #[derive(derive_more::Debug)] -pub(super) struct MagicSock { +pub(crate) struct MagicSock { actor_sender: mpsc::Sender, relay_actor_sender: mpsc::Sender, /// String representation of the node_id of this node. @@ -246,19 +247,19 @@ pub(super) struct MagicSock { impl MagicSock { /// Creates a magic [`MagicSock`] listening on [`Options::port`]. - pub async fn spawn(opts: Options) -> Result { + pub(crate) async fn spawn(opts: Options) -> Result { Handle::new(opts).await } /// Returns the relay node we are connected to, that has the best latency. /// /// If `None`, then we are not connected to any relay nodes. - pub fn my_relay(&self) -> Option { + pub(crate) fn my_relay(&self) -> Option { self.my_relay.get() } /// Get the current proxy configuration. - pub fn proxy_url(&self) -> Option<&Url> { + pub(crate) fn proxy_url(&self) -> Option<&Url> { self.proxy_url.as_ref() } @@ -282,42 +283,43 @@ impl MagicSock { } /// Get the cached version of the Ipv4 and Ipv6 addrs of the current connection. - pub fn local_addr(&self) -> (SocketAddr, Option) { + pub(crate) fn local_addr(&self) -> (SocketAddr, Option) { *self.local_addrs.read().expect("not poisoned") } /// Returns `true` if we have at least one candidate address where we can send packets to. - pub fn has_send_address(&self, node_key: PublicKey) -> bool { + pub(crate) fn has_send_address(&self, node_key: PublicKey) -> bool { self.connection_info(node_key) .map(|info| info.has_send_address()) .unwrap_or(false) } /// Retrieve connection information about nodes in the network. - pub fn connection_infos(&self) -> Vec { + pub(crate) fn connection_infos(&self) -> Vec { self.node_map.node_infos(Instant::now()) } /// Retrieve connection information about a node in the network. - pub fn connection_info(&self, node_key: PublicKey) -> Option { - self.node_map.node_info(&node_key) + pub(crate) fn connection_info(&self, node_id: NodeId) -> Option { + self.node_map.node_info(node_id) } - /// Returns the local endpoints as a stream. + /// Returns the direct addresses as a stream. /// - /// The [`MagicSock`] continuously monitors the local endpoints, the network addresses - /// it can listen on, for changes. Whenever changes are detected this stream will yield - /// a new list of endpoints. + /// The [`MagicSock`] continuously monitors the direct addresses, the network addresses + /// it might be able to be contacted on, for changes. Whenever changes are detected + /// this stream will yield a new list of addresses. /// /// Upon the first creation on the [`MagicSock`] it may not yet have completed a first - /// local endpoint discovery, in this case the first item of the stream will not be - /// immediately available. Once this first set of local endpoints are discovered the - /// stream will always return the first set of endpoints immediately, which are the most - /// recently discovered endpoints. + /// direct addresses discovery, in this case the first item of the stream will not be + /// immediately available. Once this first set of direct addresses are discovered the + /// stream will always return the first set of addresses immediately, which are the most + /// recently discovered addresses. /// - /// To get the current endpoints, drop the stream after the first item was received. - pub fn local_endpoints(&self) -> LocalEndpointsStream { - LocalEndpointsStream { + /// To get the current direct addresses, drop the stream after the first item was + /// received. + pub(crate) fn direct_addresses(&self) -> DirectAddrsStream { + DirectAddrsStream { initial: Some(self.endpoints.get()), inner: self.endpoints.watch().into_stream(), } @@ -327,7 +329,7 @@ impl MagicSock { /// /// Note that this can be used to wait for the initial home relay to be known. If the home /// relay is known at this point, it will be the first item in the stream. - pub fn watch_home_relay(&self) -> impl Stream { + pub(crate) fn watch_home_relay(&self) -> impl Stream { let current = futures_lite::stream::iter(self.my_relay()); let changes = self .my_relay @@ -350,7 +352,7 @@ impl MagicSock { /// /// Will return an error if there is no address information known about the /// given `node_id`. - pub fn conn_type_stream(&self, node_id: &PublicKey) -> Result { + pub(crate) fn conn_type_stream(&self, node_id: NodeId) -> Result { self.node_map.conn_type_stream(node_id) } @@ -358,30 +360,61 @@ impl MagicSock { /// /// Note this is a user-facing API and does not wrap the [`SocketAddr`] in a /// [`QuicMappedAddr`] as we do internally. - pub fn get_mapping_addr(&self, node_key: &PublicKey) -> Option { + pub(crate) fn get_mapping_addr(&self, node_id: NodeId) -> Option { self.node_map - .get_quic_mapped_addr_for_node_key(node_key) + .get_quic_mapped_addr_for_node_key(node_id) .map(|a| a.0) } /// Add addresses for a node to the magic socket's addresbook. #[instrument(skip_all, fields(me = %self.me))] - pub fn add_node_addr(&self, addr: NodeAddr) { - self.node_map.add_node_addr(addr); + pub fn add_node_addr(&self, mut addr: NodeAddr, source: node_map::Source) -> Result<()> { + let my_addresses = self.endpoints.get().last_endpoints; + let mut pruned = 0; + for my_addr in my_addresses.into_iter().map(|ep| ep.addr) { + if addr.info.direct_addresses.remove(&my_addr) { + warn!(node_id=addr.node_id.fmt_short(), %my_addr, %source, "not adding our addr for node"); + pruned += 1; + } + } + if !addr.info.is_empty() { + self.node_map.add_node_addr(addr, source); + Ok(()) + } else if pruned != 0 { + Err(anyhow::anyhow!( + "empty addressing info, {pruned} direct addresses have been pruned" + )) + } else { + Err(anyhow::anyhow!("empty addressing info")) + } + } + + /// Updates our direct addresses. + /// + /// On a successful update, our address is published to discovery. + pub(super) fn update_direct_addresses(&self, eps: Vec) { + let updated = self.endpoints.update(DiscoveredEndpoints::new(eps)).is_ok(); + if updated { + let eps = self.endpoints.read(); + eps.log_endpoint_change(); + self.node_map + .on_direct_addr_discovered(eps.iter().map(|ep| ep.addr)); + self.publish_my_addr(); + } } /// Get a reference to the DNS resolver used in this [`MagicSock`]. - pub fn dns_resolver(&self) -> &DnsResolver { + pub(crate) fn dns_resolver(&self) -> &DnsResolver { &self.dns_resolver } /// Reference to optional discovery service - pub fn discovery(&self) -> Option<&dyn Discovery> { + pub(crate) fn discovery(&self) -> Option<&dyn Discovery> { self.discovery.as_ref().map(Box::as_ref) } /// Call to notify the system of potential network changes. - pub async fn network_change(&self) { + pub(crate) async fn network_change(&self) { self.actor_sender .send(ActorMessage::NetworkChange) .await @@ -468,7 +501,7 @@ impl MagicSock { let mut transmits_sent = 0; match self .node_map - .get_send_addrs(&dest, self.ipv6_reported.load(Ordering::Relaxed)) + .get_send_addrs(dest, self.ipv6_reported.load(Ordering::Relaxed)) { Some((public_key, udp_addr, relay_url, mut msgs)) => { let mut pings_sent = false; @@ -531,12 +564,20 @@ impl MagicSock { } if udp_addr.is_none() && relay_url.is_none() { - // Handle no addresses being available - warn!(node = %public_key.fmt_short(), "failed to send: no UDP or relay addr"); - return Poll::Ready(Err(io::Error::new( - io::ErrorKind::NotConnected, - "no UDP or relay address available for node", - ))); + // Returning an error here would lock up the entire `Endpoint`. + // + // If we returned `Poll::Pending`, the waker driving the `poll_send` will never get woken up. + // + // Our best bet here is to log an error and return `Poll::Ready(Ok(n))`. + // + // `n` is the number of consecutive transmits in this batch that are meant for the same destination (a destination that we have no addresses for, and so we can never actually send). + // + // When we return `Poll::Ready(Ok(n))`, we are effectively dropping those n messages, by lying to QUIC and saying they were sent. + // (If we returned `Poll::Ready(Ok(0))` instead, QUIC would loop to attempt to re-send those messages, blocking other traffic.) + // + // When `QUIC` gets no `ACK`s for those messages, the connection will eventually timeout. + error!(node = %public_key.fmt_short(), "failed to send: no UDP or relay addr"); + return Poll::Ready(Ok(n)); } if (udp_addr.is_none() || udp_pending) && (relay_url.is_none() || relay_pending) { @@ -549,14 +590,16 @@ impl MagicSock { } if !relay_sent && !udp_sent && !pings_sent { - warn!(node = %public_key.fmt_short(), "failed to send: no UDP or relay addr"); + // Returning an error here would lock up the entire `Endpoint`. + // Instead, log an error and return `Poll::Pending`, the connection will timeout. let err = udp_error.unwrap_or_else(|| { io::Error::new( io::ErrorKind::NotConnected, "no UDP or relay address available for node", ) }); - return Poll::Ready(Err(err)); + error!(node = %public_key.fmt_short(), "{err:?}"); + return Poll::Pending; } trace!( @@ -1447,7 +1490,7 @@ impl Handle { /// Polling the socket ([`AsyncUdpSocket::poll_recv`]) will return [`Poll::Pending`] /// indefinitely after this call. #[instrument(skip_all, fields(me = %self.msock.me))] - pub async fn close(&self) -> Result<()> { + pub(crate) async fn close(&self) -> Result<()> { if self.msock.is_closed() { return Ok(()); } @@ -1482,13 +1525,13 @@ impl Handle { /// Stream returning local endpoints as they change. #[derive(Debug)] -pub struct LocalEndpointsStream { +pub struct DirectAddrsStream { initial: Option, inner: watchable::WatcherStream, } -impl Stream for LocalEndpointsStream { - type Item = Vec; +impl Stream for DirectAddrsStream { + type Item = Vec; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = &mut *self; @@ -1571,7 +1614,7 @@ enum DiscoBoxError { type RelayRecvResult = Result<(PublicKey, quinn_udp::RecvMeta, Bytes), io::Error>; /// Reports whether x and y represent the same set of endpoints. The order doesn't matter. -fn endpoint_sets_equal(xs: &[config::Endpoint], ys: &[config::Endpoint]) -> bool { +fn endpoint_sets_equal(xs: &[DirectAddr], ys: &[DirectAddr]) -> bool { if xs.is_empty() && ys.is_empty() { return true; } @@ -1587,7 +1630,7 @@ fn endpoint_sets_equal(xs: &[config::Endpoint], ys: &[config::Endpoint]) -> bool return true; } } - let mut m: HashMap<&config::Endpoint, usize> = HashMap::new(); + let mut m: HashMap<&DirectAddr, usize> = HashMap::new(); for x in xs { *m.entry(x).or_default() |= 1; } @@ -1656,7 +1699,7 @@ struct Actor { /// When set, is an AfterFunc timer that will call MagicSock::do_periodic_stun. periodic_re_stun_timer: time::Interval, /// The `NetInfo` provided in the last call to `net_info_func`. It's used to deduplicate calls to netInfoFunc. - net_info_last: Option, + net_info_last: Option, /// Path where connection info from [`MagicSock::node_map`] is persisted. nodes_path: Option, @@ -1951,7 +1994,7 @@ impl Actor { #[allow(clippy::map_entry)] if !$already.contains_key(&$ipp) { $already.insert($ipp, $et); - $eps.push(config::Endpoint { + $eps.push(DirectAddr { addr: $ipp, typ: $et, }); @@ -1962,13 +2005,13 @@ impl Actor { let maybe_port_mapped = *portmap_watcher.borrow(); if let Some(portmap_ext) = maybe_port_mapped.map(SocketAddr::V4) { - add_addr!(already, eps, portmap_ext, config::EndpointType::Portmapped); + add_addr!(already, eps, portmap_ext, DirectAddrType::Portmapped); self.set_net_info_have_port_map().await; } if let Some(nr) = nr { if let Some(global_v4) = nr.global_v4 { - add_addr!(already, eps, global_v4.into(), config::EndpointType::Stun); + add_addr!(already, eps, global_v4.into(), DirectAddrType::Stun); // If they're behind a hard NAT and are using a fixed // port locally, assume they might've added a static @@ -1978,16 +2021,11 @@ impl Actor { if nr.mapping_varies_by_dest_ip.unwrap_or_default() && port != 0 { let mut addr = global_v4; addr.set_port(port); - add_addr!( - already, - eps, - addr.into(), - config::EndpointType::Stun4LocalPort - ); + add_addr!(already, eps, addr.into(), DirectAddrType::Stun4LocalPort); } } if let Some(global_v6) = nr.global_v6 { - add_addr!(already, eps, global_v6.into(), config::EndpointType::Stun); + add_addr!(already, eps, global_v6.into(), DirectAddrType::Stun); } } let local_addr_v4 = self.pconn4.local_addr().ok(); @@ -2045,7 +2083,7 @@ impl Actor { already, eps, SocketAddr::new(ip, port), - config::EndpointType::Local + DirectAddrType::Local ); } } @@ -2055,7 +2093,7 @@ impl Actor { already, eps, SocketAddr::new(ip, port), - config::EndpointType::Local + DirectAddrType::Local ); } } @@ -2067,7 +2105,7 @@ impl Actor { if let Some(addr) = local_addr_v4 { // Our local endpoint is bound to a particular address. // Do not offer addresses on other local interfaces. - add_addr!(already, eps, addr, config::EndpointType::Local); + add_addr!(already, eps, addr, DirectAddrType::Local); } } @@ -2075,7 +2113,7 @@ impl Actor { if let Some(addr) = local_addr_v6 { // Our local endpoint is bound to a particular address. // Do not offer addresses on other local interfaces. - add_addr!(already, eps, addr, config::EndpointType::Local); + add_addr!(already, eps, addr, DirectAddrType::Local); } } @@ -2089,15 +2127,7 @@ impl Actor { // The STUN address(es) are always first. // Despite this sorting, clients are not relying on this sorting for decisions; - let updated = msock - .endpoints - .update(DiscoveredEndpoints::new(eps)) - .is_ok(); - if updated { - let eps = msock.endpoints.read(); - eps.log_endpoint_change(); - msock.publish_my_addr(); - } + msock.update_direct_addresses(eps); // Regardless of whether our local endpoints changed, we now want to send any queued // call-me-maybe messages. @@ -2134,7 +2164,7 @@ impl Actor { } #[instrument(level = "debug", skip_all)] - async fn call_net_info_callback(&mut self, ni: config::NetInfo) { + async fn call_net_info_callback(&mut self, ni: NetInfo) { if let Some(ref net_info_last) = self.net_info_last { if ni.basically_equal(net_info_last) { return; @@ -2209,7 +2239,7 @@ impl Actor { self.no_v4_send = !r.ipv4_can_send; let have_port_map = self.port_mapper.watch_external_address().borrow().is_some(); - let mut ni = config::NetInfo { + let mut ni = NetInfo { relay_latency: Default::default(), mapping_varies_by_dest_ip: r.mapping_varies_by_dest_ip, hair_pinning: r.hair_pinning, @@ -2221,7 +2251,6 @@ impl Actor { working_icmp_v4: r.icmpv4, working_icmp_v6: r.icmpv6, preferred_relay: r.preferred_relay.clone(), - link_type: None, }; for (rid, d) in r.relay_v4_latency.iter() { ni.relay_latency @@ -2397,7 +2426,7 @@ fn bind(port: u16) -> Result<(UdpConn, Option)> { struct DiscoveredEndpoints { /// Records the endpoints found during the previous /// endpoint discovery. It's used to avoid duplicate endpoint change notifications. - last_endpoints: Vec, + last_endpoints: Vec, /// The last time the endpoints were updated, even if there was no change. last_endpoints_time: Option, @@ -2410,18 +2439,18 @@ impl PartialEq for DiscoveredEndpoints { } impl DiscoveredEndpoints { - fn new(endpoints: Vec) -> Self { + fn new(endpoints: Vec) -> Self { Self { last_endpoints: endpoints, last_endpoints_time: Some(Instant::now()), } } - fn into_iter(self) -> impl Iterator { + fn into_iter(self) -> impl Iterator { self.last_endpoints.into_iter() } - fn iter(&self) -> impl Iterator + '_ { + fn iter(&self) -> impl Iterator + '_ { self.last_endpoints.iter() } @@ -2477,7 +2506,7 @@ fn split_packets(transmits: &[quinn_udp::Transmit]) -> RelayContents { /// Splits a packet into its component items. #[derive(Debug)] -pub(super) struct PacketSplitIter { +struct PacketSplitIter { bytes: Bytes, } @@ -2485,7 +2514,7 @@ impl PacketSplitIter { /// Create a new PacketSplitIter from a packet. /// /// Returns an error if the packet is too big. - pub fn new(bytes: Bytes) -> Self { + fn new(bytes: Bytes) -> Self { Self { bytes } } @@ -2581,8 +2610,133 @@ fn disco_message_sent(msg: &disco::Message) { } } +/// A *direct address* on which an iroh-node might be contactable. +/// +/// Direct addresses are UDP socket addresses on which an iroh-net node could potentially be +/// contacted. These can come from various sources depending on the network topology of the +/// iroh-net node, see [`DirectAddrType`] for the several kinds of sources. +#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub struct DirectAddr { + /// The address. + pub addr: SocketAddr, + /// The origin of this direct address. + pub typ: DirectAddrType, +} + +/// The type of direct address. +/// +/// These are the various sources or origins from which an iroh-net node might have found a +/// possible [`DirectAddr`]. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub enum DirectAddrType { + /// Not yet determined.. + Unknown, + /// A locally bound socket address. + Local, + /// Public internet address discovered via STUN. + /// + /// When possible an iroh-net node will perform STUN to discover which is the address + /// from which it sends data on the public internet. This can be different from locally + /// bound addresses when the node is on a local network wich performs NAT or similar. + Stun, + /// An address assigned by the router using port mapping. + /// + /// When possible an iroh-net node will request a port mapping from the local router to + /// get a publicly routable direct address. + Portmapped, + /// Hard NAT: STUN'ed IPv4 address + local fixed port. + /// + /// It is possible to configure iroh-net to bound to a specific port and independently + /// configure the router to forward this port to the iroh-net node. This indicates a + /// situation like this, which still uses STUN to discover the public address. + Stun4LocalPort, +} + +impl Display for DirectAddrType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + DirectAddrType::Unknown => write!(f, "?"), + DirectAddrType::Local => write!(f, "local"), + DirectAddrType::Stun => write!(f, "stun"), + DirectAddrType::Portmapped => write!(f, "portmap"), + DirectAddrType::Stun4LocalPort => write!(f, "stun4localport"), + } + } +} + +/// Contains information about the host's network state. +#[derive(Debug, Clone, PartialEq)] +struct NetInfo { + /// Says whether the host's NAT mappings vary based on the destination IP. + mapping_varies_by_dest_ip: Option, + + /// If their router does hairpinning. It reports true even if there's no NAT involved. + hair_pinning: Option, + + /// Whether the host has IPv6 internet connectivity. + working_ipv6: Option, + + /// Whether the OS supports IPv6 at all, regardless of whether IPv6 internet connectivity is available. + os_has_ipv6: Option, + + /// Whether the host has UDP internet connectivity. + working_udp: Option, + + /// Whether ICMPv4 works, `None` means not checked. + working_icmp_v4: Option, + + /// Whether ICMPv6 works, `None` means not checked. + working_icmp_v6: Option, + + /// Whether we have an existing portmap open (UPnP, PMP, or PCP). + have_port_map: bool, + + /// Probe indicating the presence of port mapping protocols on the LAN. + portmap_probe: Option, + + /// This node's preferred relay server for incoming traffic. + /// + /// The node might be be temporarily connected to multiple relay servers (to send to + /// other nodes) but this is the relay on which you can always contact this node. Also + /// known as home relay. + preferred_relay: Option, + + /// The fastest recent time to reach various relay STUN servers, in seconds. + /// + /// This should only be updated rarely, or when there's a + /// material change, as any change here also gets uploaded to the control plane. + relay_latency: BTreeMap, +} + +impl NetInfo { + /// Checks if this is probably still the same network as *other*. + /// + /// This tries to compare the network situation, without taking into account things + /// expected to change a little like e.g. latency to the relay server. + fn basically_equal(&self, other: &Self) -> bool { + let eq_icmp_v4 = match (self.working_icmp_v4, other.working_icmp_v4) { + (Some(slf), Some(other)) => slf == other, + _ => true, // ignore for comparison if only one report had this info + }; + let eq_icmp_v6 = match (self.working_icmp_v6, other.working_icmp_v6) { + (Some(slf), Some(other)) => slf == other, + _ => true, // ignore for comparison if only one report had this info + }; + self.mapping_varies_by_dest_ip == other.mapping_varies_by_dest_ip + && self.hair_pinning == other.hair_pinning + && self.working_ipv6 == other.working_ipv6 + && self.os_has_ipv6 == other.os_has_ipv6 + && self.working_udp == other.working_udp + && eq_icmp_v4 + && eq_icmp_v6 + && self.have_port_map == other.have_port_map + && self.portmap_probe == other.portmap_probe + && self.preferred_relay == other.preferred_relay + } +} + #[cfg(test)] -pub(crate) mod tests { +mod tests { use anyhow::Context; use futures_lite::StreamExt; use iroh_test::CallOnDrop; @@ -2592,6 +2746,14 @@ pub(crate) mod tests { use super::*; + impl MagicSock { + #[track_caller] + pub fn add_test_addr(&self, node_addr: NodeAddr) { + self.add_node_addr(node_addr, Source::NamedApp { name: "test" }) + .unwrap() + } + } + /// Magicsock plus wrappers for sending packets #[derive(Clone)] struct MagicStack { @@ -2647,7 +2809,7 @@ pub(crate) mod tests { #[instrument(skip_all)] async fn mesh_stacks(stacks: Vec) -> Result { /// Registers endpoint addresses of a node to all other nodes. - fn update_eps(stacks: &[MagicStack], my_idx: usize, new_eps: Vec) { + fn update_direct_addrs(stacks: &[MagicStack], my_idx: usize, new_addrs: Vec) { let me = &stacks[my_idx]; for (i, m) in stacks.iter().enumerate() { if i == my_idx { @@ -2658,10 +2820,10 @@ pub(crate) mod tests { node_id: me.public(), info: crate::AddrInfo { relay_url: None, - direct_addresses: new_eps.iter().map(|ep| ep.addr).collect(), + direct_addresses: new_addrs.iter().map(|ep| ep.addr).collect(), }, }; - m.endpoint.magic_sock().add_node_addr(addr); + m.endpoint.magic_sock().add_test_addr(addr); } } @@ -2673,10 +2835,10 @@ pub(crate) mod tests { let stacks = stacks.clone(); tasks.spawn(async move { let me = m.endpoint.node_id().fmt_short(); - let mut stream = m.endpoint.local_endpoints(); + let mut stream = m.endpoint.direct_addresses(); while let Some(new_eps) = stream.next().await { info!(%me, "conn{} endpoints update: {:?}", my_idx + 1, new_eps); - update_eps(&stacks, my_idx, new_eps); + update_direct_addrs(&stacks, my_idx, new_eps); } }); } @@ -3342,13 +3504,13 @@ pub(crate) mod tests { let ms = Handle::new(Default::default()).await.unwrap(); // See if we can get endpoints. - let mut eps0 = ms.local_endpoints().next().await.unwrap(); + let mut eps0 = ms.direct_addresses().next().await.unwrap(); eps0.sort(); println!("{eps0:?}"); assert!(!eps0.is_empty()); // Getting the endpoints again immediately should give the same results. - let mut eps1 = ms.local_endpoints().next().await.unwrap(); + let mut eps1 = ms.direct_addresses().next().await.unwrap(); eps1.sort(); println!("{eps1:?}"); assert_eq!(eps0, eps1); diff --git a/iroh-net/src/magicsock/node_map.rs b/iroh-net/src/magicsock/node_map.rs index c17cfccaeb..89037450b0 100644 --- a/iroh-net/src/magicsock/node_map.rs +++ b/iroh-net/src/magicsock/node_map.rs @@ -1,5 +1,5 @@ use std::{ - collections::HashMap, + collections::{hash_map::Entry, HashMap}, hash::Hash, net::{IpAddr, SocketAddr}, path::Path, @@ -17,7 +17,10 @@ use stun_rs::TransactionId; use tokio::io::AsyncWriteExt; use tracing::{debug, info, instrument, trace, warn}; -use self::node_state::{NodeState, Options, PingHandled}; +use self::{ + best_addr::ClearReason, + node_state::{NodeState, Options, PingHandled}, +}; use super::{ metrics::Metrics as MagicsockMetrics, ActorMessage, DiscoMessageSource, QuicMappedAddr, }; @@ -74,11 +77,29 @@ pub(super) struct NodeMapInner { /// You can look up entries in [`NodeMap`] with various keys, depending on the context you /// have for the node. These are all the keys the [`NodeMap`] can use. #[derive(Clone)] -enum NodeStateKey<'a> { - Idx(&'a usize), - NodeId(&'a NodeId), - QuicMappedAddr(&'a QuicMappedAddr), - IpPort(&'a IpPort), +enum NodeStateKey { + Idx(usize), + NodeId(NodeId), + QuicMappedAddr(QuicMappedAddr), + IpPort(IpPort), +} + +/// Source for a new node. +/// +/// This is used for debugging purposes. +#[derive(strum::Display, Debug)] +#[strum(serialize_all = "kebab-case")] +pub(crate) enum Source { + /// Node was loaded from the fs. + Saved, + /// Node communicated with us first via UDP. + Udp, + /// Node communicated with us first via relay. + Relay, + /// Application layer added the node directly. + App, + #[strum(serialize = "{name}")] + NamedApp { name: &'static str }, } impl NodeMap { @@ -99,8 +120,8 @@ impl NodeMap { } /// Add the contact information for a node. - pub(super) fn add_node_addr(&self, node_addr: NodeAddr) { - self.inner.lock().add_node_addr(node_addr) + pub(super) fn add_node_addr(&self, node_addr: NodeAddr, source: Source) { + self.inner.lock().add_node_addr(node_addr, source) } /// Number of nodes currently listed. @@ -112,8 +133,8 @@ impl NodeMap { self.inner.lock().receive_udp(udp_addr) } - pub(super) fn receive_relay(&self, relay_url: &RelayUrl, src: PublicKey) -> QuicMappedAddr { - self.inner.lock().receive_relay(relay_url, &src) + pub(super) fn receive_relay(&self, relay_url: &RelayUrl, src: NodeId) -> QuicMappedAddr { + self.inner.lock().receive_relay(relay_url, src) } pub(super) fn notify_ping_sent( @@ -124,20 +145,20 @@ impl NodeMap { purpose: DiscoPingPurpose, msg_sender: tokio::sync::mpsc::Sender, ) { - if let Some(ep) = self.inner.lock().get_mut(NodeStateKey::Idx(&id)) { + if let Some(ep) = self.inner.lock().get_mut(NodeStateKey::Idx(id)) { ep.ping_sent(dst, tx_id, purpose, msg_sender); } } pub(super) fn notify_ping_timeout(&self, id: usize, tx_id: stun::TransactionId) { - if let Some(ep) = self.inner.lock().get_mut(NodeStateKey::Idx(&id)) { + if let Some(ep) = self.inner.lock().get_mut(NodeStateKey::Idx(id)) { ep.ping_timeout(tx_id); } } pub(super) fn get_quic_mapped_addr_for_node_key( &self, - node_key: &PublicKey, + node_key: NodeId, ) -> Option { self.inner .lock() @@ -172,7 +193,7 @@ impl NodeMap { #[allow(clippy::type_complexity)] pub(super) fn get_send_addrs( &self, - addr: &QuicMappedAddr, + addr: QuicMappedAddr, have_ipv6: bool, ) -> Option<( PublicKey, @@ -223,16 +244,13 @@ impl NodeMap { /// /// Will return an error if there is not an entry in the [`NodeMap`] for /// the `public_key` - pub(super) fn conn_type_stream( - &self, - public_key: &PublicKey, - ) -> anyhow::Result { - self.inner.lock().conn_type_stream(public_key) + pub(super) fn conn_type_stream(&self, node_id: NodeId) -> anyhow::Result { + self.inner.lock().conn_type_stream(node_id) } /// Get the [`NodeInfo`]s for each endpoint - pub(super) fn node_info(&self, public_key: &PublicKey) -> Option { - self.inner.lock().node_info(public_key) + pub(super) fn node_info(&self, node_id: NodeId) -> Option { + self.inner.lock().node_info(node_id) } /// Saves the known node info to the given path, returning the number of nodes persisted. @@ -289,6 +307,13 @@ impl NodeMap { pub(super) fn prune_inactive(&self) { self.inner.lock().prune_inactive(); } + + pub(crate) fn on_direct_addr_discovered( + &self, + discovered: impl Iterator>, + ) { + self.inner.lock().on_direct_addr_discovered(discovered); + } } impl NodeMapInner { @@ -312,7 +337,7 @@ impl NodeMapInner { while !slice.is_empty() { let (node_addr, next_contents) = postcard::take_from_bytes(slice).context("failed to load node data")?; - me.add_node_addr(node_addr); + me.add_node_addr(node_addr, Source::Saved); slice = next_contents; } Ok(me) @@ -320,13 +345,14 @@ impl NodeMapInner { /// Add the contact information for a node. #[instrument(skip_all, fields(node = %node_addr.node_id.fmt_short()))] - fn add_node_addr(&mut self, node_addr: NodeAddr) { + fn add_node_addr(&mut self, node_addr: NodeAddr, source: Source) { let NodeAddr { node_id, info } = node_addr; - let node_state = self.get_or_insert_with(NodeStateKey::NodeId(&node_id), || Options { + let node_state = self.get_or_insert_with(NodeStateKey::NodeId(node_id), || Options { node_id, relay_url: info.relay_url.clone(), active: false, + source, }); node_state.update_from_node_addr(&info); @@ -336,12 +362,40 @@ impl NodeMapInner { } } + /// Prunes direct addresses from nodes that claim to share an address we know points to us. + pub(super) fn on_direct_addr_discovered( + &mut self, + discovered: impl Iterator>, + ) { + for addr in discovered { + self.remove_by_ipp(addr.into(), ClearReason::MatchesOurLocalAddr) + } + } + + /// Removes a direct address from a node. + fn remove_by_ipp(&mut self, ipp: IpPort, reason: ClearReason) { + if let Some(id) = self.by_ip_port.remove(&ipp) { + if let Entry::Occupied(mut entry) = self.by_id.entry(id) { + let node = entry.get_mut(); + node.remove_direct_addr(&ipp, reason); + if node.direct_addresses().count() == 0 { + let node_id = node.public_key(); + let mapped_addr = node.quic_mapped_addr(); + self.by_node_key.remove(node_id); + self.by_quic_mapped_addr.remove(mapped_addr); + debug!(node_id=%node_id.fmt_short(), ?reason, "removing node"); + entry.remove(); + } + } + } + } + fn get_id(&self, id: NodeStateKey) -> Option { match id { - NodeStateKey::Idx(id) => Some(*id), - NodeStateKey::NodeId(node_key) => self.by_node_key.get(node_key).copied(), - NodeStateKey::QuicMappedAddr(addr) => self.by_quic_mapped_addr.get(addr).copied(), - NodeStateKey::IpPort(ipp) => self.by_ip_port.get(ipp).copied(), + NodeStateKey::Idx(id) => Some(id), + NodeStateKey::NodeId(node_key) => self.by_node_key.get(&node_key).copied(), + NodeStateKey::QuicMappedAddr(addr) => self.by_quic_mapped_addr.get(&addr).copied(), + NodeStateKey::IpPort(ipp) => self.by_ip_port.get(&ipp).copied(), } } @@ -373,7 +427,7 @@ impl NodeMapInner { /// Marks the node we believe to be at `ipp` as recently used. fn receive_udp(&mut self, udp_addr: SocketAddr) -> Option<(NodeId, QuicMappedAddr)> { let ip_port: IpPort = udp_addr.into(); - let Some(node_state) = self.get_mut(NodeStateKey::IpPort(&ip_port)) else { + let Some(node_state) = self.get_mut(NodeStateKey::IpPort(ip_port)) else { info!(src=%udp_addr, "receive_udp: no node_state found for addr, ignore"); return None; }; @@ -382,13 +436,14 @@ impl NodeMapInner { } #[instrument(skip_all, fields(src = %src.fmt_short()))] - fn receive_relay(&mut self, relay_url: &RelayUrl, src: &PublicKey) -> QuicMappedAddr { + fn receive_relay(&mut self, relay_url: &RelayUrl, src: NodeId) -> QuicMappedAddr { let node_state = self.get_or_insert_with(NodeStateKey::NodeId(src), || { trace!("packets from unknown node, insert into node map"); Options { - node_id: *src, + node_id: src, relay_url: Some(relay_url.clone()), active: true, + source: Source::Relay, } }); node_state.receive_relay(relay_url, src, Instant::now()); @@ -409,8 +464,8 @@ impl NodeMapInner { } /// Get the [`NodeInfo`]s for each endpoint - fn node_info(&self, public_key: &PublicKey) -> Option { - self.get(NodeStateKey::NodeId(public_key)) + fn node_info(&self, node_id: NodeId) -> Option { + self.get(NodeStateKey::NodeId(node_id)) .map(|ep| ep.info(Instant::now())) } @@ -423,18 +478,18 @@ impl NodeMapInner { /// /// Will return an error if there is not an entry in the [`NodeMap`] for /// the `public_key` - fn conn_type_stream(&self, public_key: &PublicKey) -> anyhow::Result { - match self.get(NodeStateKey::NodeId(public_key)) { + fn conn_type_stream(&self, node_id: NodeId) -> anyhow::Result { + match self.get(NodeStateKey::NodeId(node_id)) { Some(ep) => Ok(ConnectionTypeStream { initial: Some(ep.conn_type()), inner: ep.conn_type_stream(), }), - None => anyhow::bail!("No endpoint for {public_key:?} found"), + None => anyhow::bail!("No endpoint for {node_id:?} found"), } } - fn handle_pong(&mut self, sender: PublicKey, src: &DiscoMessageSource, pong: Pong) { - if let Some(ns) = self.get_mut(NodeStateKey::NodeId(&sender)).as_mut() { + fn handle_pong(&mut self, sender: NodeId, src: &DiscoMessageSource, pong: Pong) { + if let Some(ns) = self.get_mut(NodeStateKey::NodeId(sender)).as_mut() { let insert = ns.handle_pong(&pong, src.into()); if let Some((src, key)) = insert { self.set_node_key_for_ip_port(src, &key); @@ -446,8 +501,8 @@ impl NodeMapInner { } #[must_use = "actions must be handled"] - fn handle_call_me_maybe(&mut self, sender: PublicKey, cm: CallMeMaybe) -> Vec { - let ns_id = NodeStateKey::NodeId(&sender); + fn handle_call_me_maybe(&mut self, sender: NodeId, cm: CallMeMaybe) -> Vec { + let ns_id = NodeStateKey::NodeId(sender); if let Some(id) = self.get_id(ns_id.clone()) { for number in &cm.my_numbers { // ensure the new addrs are known @@ -468,18 +523,19 @@ impl NodeMapInner { } } - fn handle_ping( - &mut self, - sender: PublicKey, - src: SendAddr, - tx_id: TransactionId, - ) -> PingHandled { - let node_state = self.get_or_insert_with(NodeStateKey::NodeId(&sender), || { + fn handle_ping(&mut self, sender: NodeId, src: SendAddr, tx_id: TransactionId) -> PingHandled { + let node_state = self.get_or_insert_with(NodeStateKey::NodeId(sender), || { debug!("received ping: node unknown, add to node map"); + let source = if src.is_relay() { + Source::Relay + } else { + Source::Udp + }; Options { node_id: sender, relay_url: src.relay_url(), active: true, + source, } }); @@ -497,6 +553,7 @@ impl NodeMapInner { info!( node = %options.node_id.fmt_short(), relay_url = ?options.relay_url, + source = %options.source, "inserting new node in NodeMap", ); let id = self.next_id; @@ -644,6 +701,13 @@ mod tests { use crate::{endpoint::AddrInfo, key::SecretKey}; use std::net::Ipv4Addr; + impl NodeMap { + #[track_caller] + fn add_test_addr(&self, node_addr: NodeAddr) { + self.add_node_addr(node_addr, Source::NamedApp { name: "test" }) + } + } + /// Test persisting and loading of known nodes. #[tokio::test] async fn load_save_node_data() { @@ -669,10 +733,10 @@ mod tests { let node_addr_c = NodeAddr::new(node_c).with_direct_addresses(direct_addresses_c); let node_addr_d = NodeAddr::new(node_d); - node_map.add_node_addr(node_addr_a); - node_map.add_node_addr(node_addr_b); - node_map.add_node_addr(node_addr_c); - node_map.add_node_addr(node_addr_d); + node_map.add_test_addr(node_addr_a); + node_map.add_test_addr(node_addr_b); + node_map.add_test_addr(node_addr_c); + node_map.add_test_addr(node_addr_d); let root = testdir::testdir!(); let path = root.join("nodes.postcard"); @@ -705,7 +769,7 @@ mod tests { let node_addr_a = NodeAddr::new(node_a).with_direct_addresses(direct_addrs_a); let node_map = NodeMap::default(); - node_map.add_node_addr(node_addr_a.clone()); + node_map.add_test_addr(node_addr_a.clone()); // unused endpoints are included let list = node_map.node_addresses_for_storage(); @@ -738,6 +802,7 @@ mod tests { node_id: public_key, relay_url: None, active: false, + source: Source::NamedApp { name: "test" }, }) .id(); @@ -751,7 +816,7 @@ mod tests { let addr = SocketAddr::new(LOCALHOST, 5000 + i as u16); let node_addr = NodeAddr::new(public_key).with_direct_addresses([addr]); // add address - node_map.add_node_addr(node_addr); + node_map.add_test_addr(node_addr); // make it active node_map.inner.lock().receive_udp(addr); } @@ -760,7 +825,7 @@ mod tests { for i in 0..MAX_INACTIVE_DIRECT_ADDRESSES * 2 { let addr = SocketAddr::new(LOCALHOST, 6000 + i as u16); let node_addr = NodeAddr::new(public_key).with_direct_addresses([addr]); - node_map.add_node_addr(node_addr); + node_map.add_test_addr(node_addr); } let mut node_map_inner = node_map.inner.lock(); @@ -801,12 +866,12 @@ mod tests { // add one active node and more than MAX_INACTIVE_NODES inactive nodes let active_node = SecretKey::generate().public(); let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 167); - node_map.add_node_addr(NodeAddr::new(active_node).with_direct_addresses([addr])); + node_map.add_test_addr(NodeAddr::new(active_node).with_direct_addresses([addr])); node_map.inner.lock().receive_udp(addr).expect("registered"); for _ in 0..MAX_INACTIVE_NODES + 1 { let node = SecretKey::generate().public(); - node_map.add_node_addr(NodeAddr::new(node)); + node_map.add_test_addr(NodeAddr::new(node)); } assert_eq!(node_map.node_count(), MAX_INACTIVE_NODES + 2); @@ -815,7 +880,7 @@ mod tests { node_map .inner .lock() - .get(NodeStateKey::NodeId(&active_node)) + .get(NodeStateKey::NodeId(active_node)) .expect("should not be pruned"); } } diff --git a/iroh-net/src/magicsock/node_map/best_addr.rs b/iroh-net/src/magicsock/node_map/best_addr.rs index 6378108708..95b47b361f 100644 --- a/iroh-net/src/magicsock/node_map/best_addr.rs +++ b/iroh-net/src/magicsock/node_map/best_addr.rs @@ -60,11 +60,12 @@ pub(super) enum State<'a> { Empty, } -#[derive(Debug)] +#[derive(Debug, Clone, Copy)] pub enum ClearReason { Reset, Inactive, PongTimeout, + MatchesOurLocalAddr, } impl BestAddr { diff --git a/iroh-net/src/magicsock/node_map/node_state.rs b/iroh-net/src/magicsock/node_map/node_state.rs index c72b1118bb..6883c727b2 100644 --- a/iroh-net/src/magicsock/node_map/node_state.rs +++ b/iroh-net/src/magicsock/node_map/node_state.rs @@ -140,12 +140,14 @@ pub(super) struct NodeState { conn_type: Watchable, } +/// Options for creating a new [`NodeState`]. #[derive(Debug)] pub(super) struct Options { pub(super) node_id: NodeId, pub(super) relay_url: Option, /// Is this endpoint currently active (sending data)? pub(super) active: bool, + pub(super) source: super::Source, } impl NodeState { @@ -309,7 +311,24 @@ impl NodeState { (best_addr, relay_url) } - /// Fixup best_addr from candidates. + /// Removes a direct address for this node. + /// + /// If this is also the best address, it will be cleared as well. + pub(super) fn remove_direct_addr(&mut self, ip_port: &IpPort, reason: ClearReason) { + let Some(state) = self.direct_addr_state.remove(ip_port) else { + return; + }; + + match state.last_alive().map(|instant| instant.elapsed()) { + Some(last_alive) => debug!(%ip_port, ?last_alive, ?reason, "pruning address"), + None => debug!(%ip_port, last_seen=%"never", ?reason, "pruning address"), + } + + self.best_addr + .clear_if_equals((*ip_port).into(), reason, self.relay_url.is_some()); + } + + /// Fixup best_adrr from candidates. /// /// If somehow we end up in a state where we failed to set a best_addr, while we do have /// valid candidates, this will chose a candidate and set best_addr again. Most likely @@ -760,19 +779,8 @@ impl NodeState { // used ones) last prune_candidates.sort_unstable_by_key(|(_ip_port, last_alive)| *last_alive); prune_candidates.truncate(prune_count); - for (ip_port, last_alive) in prune_candidates.into_iter() { - self.direct_addr_state.remove(&ip_port); - - match last_alive.map(|instant| instant.elapsed()) { - Some(last_alive) => debug!(%ip_port, ?last_alive, "pruning address"), - None => debug!(%ip_port, last_seen=%"never", "pruning address"), - } - - self.best_addr.clear_if_equals( - ip_port.into(), - ClearReason::Inactive, - self.relay_url.is_some(), - ); + for (ip_port, _last_alive) in prune_candidates.into_iter() { + self.remove_direct_addr(&ip_port, ClearReason::Inactive) } debug!( paths = %summarize_node_paths(&self.direct_addr_state), @@ -962,7 +970,7 @@ impl NodeState { .reconfirm_if_used(addr.into(), Source::Udp, now); } - pub(super) fn receive_relay(&mut self, url: &RelayUrl, _src: &PublicKey, now: Instant) { + pub(super) fn receive_relay(&mut self, url: &RelayUrl, _src: NodeId, now: Instant) { match self.relay_url.as_mut() { Some((current_home, state)) if current_home == url => { // We received on the expected url. update state. @@ -1719,6 +1727,7 @@ mod tests { node_id: key.public(), relay_url: None, active: true, + source: crate::magicsock::Source::NamedApp { name: "test" }, }; let mut ep = NodeState::new(0, opts); diff --git a/iroh-net/src/netcheck.rs b/iroh-net/src/netcheck.rs index 062368c1c8..391e174202 100644 --- a/iroh-net/src/netcheck.rs +++ b/iroh-net/src/netcheck.rs @@ -785,7 +785,7 @@ mod tests { use tokio::time; use tracing::info; - use crate::defaults::{DEFAULT_RELAY_STUN_PORT, EU_RELAY_HOSTNAME}; + use crate::defaults::{DEFAULT_STUN_PORT, EU_RELAY_HOSTNAME}; use crate::ping::Pinger; use crate::relay::RelayNode; @@ -795,11 +795,11 @@ mod tests { async fn test_basic() -> Result<()> { let _guard = iroh_test::logging::setup(); let (stun_addr, stun_stats, _cleanup_guard) = - stun::test::serve("0.0.0.0".parse().unwrap()).await?; + stun::tests::serve("127.0.0.1".parse().unwrap()).await?; let resolver = crate::dns::default_resolver(); let mut client = Client::new(None, resolver.clone())?; - let dm = stun::test::relay_map_of([stun_addr].into_iter()); + let dm = stun::tests::relay_map_of([stun_addr].into_iter()); // Note that the ProbePlan will change with each iteration. for i in 0..5 { @@ -842,7 +842,7 @@ mod tests { let dm = RelayMap::from_nodes([RelayNode { url: url.clone(), stun_only: true, - stun_port: DEFAULT_RELAY_STUN_PORT, + stun_port: DEFAULT_STUN_PORT, }]) .expect("hardcoded"); @@ -890,7 +890,7 @@ mod tests { // the STUN server being blocked will look like from the client's perspective. let blackhole = tokio::net::UdpSocket::bind("127.0.0.1:0").await?; let stun_addr = blackhole.local_addr()?; - let dm = stun::test::relay_map_of_opts([(stun_addr, false)].into_iter()); + let dm = stun::tests::relay_map_of_opts([(stun_addr, false)].into_iter()); // Now create a client and generate a report. let resolver = crate::dns::default_resolver().clone(); @@ -1127,8 +1127,8 @@ mod tests { // can easily use to identify the packet. // Setup STUN server and create relay_map. - let (stun_addr, _stun_stats, _done) = stun::test::serve_v4().await?; - let dm = stun::test::relay_map_of([stun_addr].into_iter()); + let (stun_addr, _stun_stats, _done) = stun::tests::serve_v4().await?; + let dm = stun::tests::relay_map_of([stun_addr].into_iter()); dbg!(&dm); let resolver = crate::dns::default_resolver().clone(); diff --git a/iroh-net/src/netcheck/reportgen.rs b/iroh-net/src/netcheck/reportgen.rs index 665ea66e54..b683d500a1 100644 --- a/iroh-net/src/netcheck/reportgen.rs +++ b/iroh-net/src/netcheck/reportgen.rs @@ -31,7 +31,7 @@ use tokio::time::{self, Instant}; use tracing::{debug, debug_span, error, info_span, trace, warn, Instrument, Span}; use super::NetcheckMetrics; -use crate::defaults::DEFAULT_RELAY_STUN_PORT; +use crate::defaults::DEFAULT_STUN_PORT; use crate::dns::{DnsResolver, ResolverExt}; use crate::net::interfaces; use crate::net::ip; @@ -935,7 +935,7 @@ async fn get_relay_addr( proto: ProbeProto, ) -> Result { let port = if relay_node.stun_port == 0 { - DEFAULT_RELAY_STUN_PORT + DEFAULT_STUN_PORT } else { relay_node.stun_port }; diff --git a/iroh-net/src/relay.rs b/iroh-net/src/relay.rs index 13dc332f75..88213f0635 100644 --- a/iroh-net/src/relay.rs +++ b/iroh-net/src/relay.rs @@ -15,6 +15,7 @@ pub(crate) mod client_conn; pub(crate) mod clients; mod codec; pub mod http; +pub mod iroh_relay; mod map; mod metrics; pub(crate) mod server; diff --git a/iroh-net/src/relay/http.rs b/iroh-net/src/relay/http.rs index e73da2de73..cd3d7519bf 100644 --- a/iroh-net/src/relay/http.rs +++ b/iroh-net/src/relay/http.rs @@ -6,32 +6,10 @@ mod server; pub(crate) mod streams; pub use self::client::{Client, ClientBuilder, ClientError, ClientReceiver}; -pub use self::server::{Server, ServerBuilder, TlsAcceptor, TlsConfig}; +pub use self::server::{Server, ServerBuilder, ServerHandle, TlsAcceptor, TlsConfig}; pub(crate) const HTTP_UPGRADE_PROTOCOL: &str = "iroh derp http"; -#[cfg(any(test, feature = "test-utils"))] -pub(crate) fn make_tls_config() -> TlsConfig { - let subject_alt_names = vec!["localhost".to_string()]; - - let cert = rcgen::generate_simple_self_signed(subject_alt_names).unwrap(); - let rustls_certificate = rustls::Certificate(cert.serialize_der().unwrap()); - let rustls_key = rustls::PrivateKey(cert.get_key_pair().serialize_der()); - let config = rustls::ServerConfig::builder() - .with_safe_defaults() - .with_no_client_auth() - .with_single_cert(vec![(rustls_certificate)], rustls_key) - .unwrap(); - - let config = std::sync::Arc::new(config); - let acceptor = tokio_rustls::TlsAcceptor::from(config.clone()); - - TlsConfig { - config, - acceptor: TlsAcceptor::Manual(acceptor), - } -} - #[cfg(test)] mod tests { use super::*; @@ -47,6 +25,27 @@ mod tests { use crate::key::{PublicKey, SecretKey}; use crate::relay::ReceivedMessage; + pub(crate) fn make_tls_config() -> TlsConfig { + let subject_alt_names = vec!["localhost".to_string()]; + + let cert = rcgen::generate_simple_self_signed(subject_alt_names).unwrap(); + let rustls_certificate = rustls::Certificate(cert.serialize_der().unwrap()); + let rustls_key = rustls::PrivateKey(cert.get_key_pair().serialize_der()); + let config = rustls::ServerConfig::builder() + .with_safe_defaults() + .with_no_client_auth() + .with_single_cert(vec![(rustls_certificate)], rustls_key) + .unwrap(); + + let config = std::sync::Arc::new(config); + let acceptor = tokio_rustls::TlsAcceptor::from(config.clone()); + + TlsConfig { + config, + acceptor: TlsAcceptor::Manual(acceptor), + } + } + #[tokio::test] async fn test_http_clients_and_server() -> Result<()> { let _guard = iroh_test::logging::setup(); @@ -115,7 +114,7 @@ mod tests { client_a_task.abort(); client_b.close().await?; client_b_task.abort(); - server.shutdown().await; + server.shutdown(); Ok(()) } @@ -186,7 +185,7 @@ mod tests { let tls_config = make_tls_config(); // start server - let server = ServerBuilder::new("127.0.0.1:0".parse().unwrap()) + let mut server = ServerBuilder::new("127.0.0.1:0".parse().unwrap()) .secret_key(Some(server_key)) .tls_config(Some(tls_config)) .spawn() @@ -232,7 +231,8 @@ mod tests { assert_eq!(b_key, got_key); assert_eq!(msg, got_msg); - server.shutdown().await; + server.shutdown(); + server.task_handle().await?; client_a.close().await?; client_a_task.abort(); client_b.close().await?; diff --git a/iroh-net/src/relay/http/server.rs b/iroh-net/src/relay/http/server.rs index a102458a3e..eaf6ffd70a 100644 --- a/iroh-net/src/relay/http/server.rs +++ b/iroh-net/src/relay/http/server.rs @@ -18,7 +18,7 @@ use tokio::net::{TcpListener, TcpStream}; use tokio::task::JoinHandle; use tokio_rustls_acme::AcmeAcceptor; use tokio_util::sync::CancellationToken; -use tracing::{debug, error, info, info_span, warn, Instrument}; +use tracing::{debug, error, info, info_span, Instrument}; use crate::key::SecretKey; use crate::relay::http::HTTP_UPGRADE_PROTOCOL; @@ -70,30 +70,68 @@ async fn relay_connection_handler( conn_handler.accept(io).await } -/// A Relay Server handler. Created using [`ServerBuilder::spawn`], it starts a relay server -/// listening over HTTP or HTTPS. +/// The Relay HTTP server. +/// +/// A running HTTP server serving the relay endpoint and optionally a number of additional +/// HTTP services added with [`ServerBuilder::request_handler`]. If configured using +/// [`ServerBuilder::tls_config`] the server will handle TLS as well. +/// +/// Created using [`ServerBuilder::spawn`]. #[derive(Debug)] pub struct Server { addr: SocketAddr, - server: Option, http_server_task: JoinHandle<()>, cancel_server_loop: CancellationToken, } impl Server { - /// Close the underlying relay server and the HTTP(S) server task - pub async fn shutdown(self) { - if let Some(server) = self.server { - server.close().await; + /// Returns a handle for this server. + /// + /// The server runs in the background as several async tasks. This allows controlling + /// the server, in particular it allows gracefully shutting down the server. + pub fn handle(&self) -> ServerHandle { + ServerHandle { + addr: self.addr, + cancel_token: self.cancel_server_loop.clone(), } + } + /// Closes the underlying relay server and the HTTP(S) server tasks. + pub fn shutdown(&self) { self.cancel_server_loop.cancel(); - if let Err(e) = self.http_server_task.await { - warn!("Error shutting down server: {e:?}"); - } } - /// Get the local address of this server. + /// Returns the [`JoinHandle`] for the supervisor task managing the server. + /// + /// This is the root of all the tasks for the server. Aborting it will abort all the + /// other tasks for the server. Awaiting it will complete when all the server tasks are + /// completed. + pub fn task_handle(&mut self) -> &mut JoinHandle<()> { + &mut self.http_server_task + } + + /// Returns the local address of this server. + pub fn addr(&self) -> SocketAddr { + self.addr + } +} + +/// A handle for the [`Server`]. +/// +/// This does not allow access to the task but can communicate with it. +#[derive(Debug, Clone)] +pub struct ServerHandle { + addr: SocketAddr, + cancel_token: CancellationToken, +} + +impl ServerHandle { + /// Gracefully shut down the server. + pub fn shutdown(&self) { + self.cancel_token.cancel() + } + + /// Returns the address the server is bound on. pub fn addr(&self) -> SocketAddr { self.addr } @@ -108,13 +146,15 @@ pub struct TlsConfig { pub acceptor: TlsAcceptor, } -/// Build a Relay Server that communicates over HTTP or HTTPS, on a given address. +/// Builder for the Relay HTTP Server. /// -/// Defaults to handling relay requests on the "/derp" endpoint. +/// Defaults to handling relay requests on the "/derp" endpoint. Other HTTP endpoints can +/// be added using [`ServerBuilder::request_handler`]. /// -/// If no [`SecretKey`] is provided, it is assumed that you will provide a `relay_override` function -/// that handles requests to the relay endpoint. Not providing a `relay_override` in this case will -/// result in an error on `spawn`. +/// If no [`SecretKey`] is provided, it is assumed that you will provide a +/// [`ServerBuilder::relay_override`] function that handles requests to the relay +/// endpoint. Not providing a [`ServerBuilder::relay_override`] in this case will result in +/// an error on `spawn`. #[derive(derive_more::Debug)] pub struct ServerBuilder { /// The secret key for this Server. @@ -128,18 +168,21 @@ pub struct ServerBuilder { /// /// When `None`, the server will serve HTTP, otherwise it will serve HTTPS. tls_config: Option, - /// A map of request handlers to routes. Used when certain routes in your server should be made - /// available at the same port as the relay server, and so must be handled along side requests - /// to the relay endpoint. + /// A map of request handlers to routes. + /// + /// Used when certain routes in your server should be made available at the same port as + /// the relay server, and so must be handled along side requests to the relay endpoint. handlers: Handlers, /// Defaults to `GET` request at "/derp". relay_endpoint: &'static str, - /// Use a custom relay response handler. Typically used when you want to disable any relay connections. + /// Use a custom relay response handler. + /// + /// Typically used when you want to disable any relay connections. #[debug("{}", relay_override.as_ref().map_or("None", |_| "Some(Box, ResponseBuilder) -> Result + Send + Sync + 'static>)"))] relay_override: Option, - /// Headers to use for HTTP or HTTPS messages. + /// Headers to use for HTTP responses. headers: HeaderMap, - /// 404 not found response + /// 404 not found response. /// /// When `None`, a default is provided. #[debug("{}", not_found_fn.as_ref().map_or("None", |_| "Some(Box Result> + Send + Sync + 'static>)"))] @@ -147,7 +190,7 @@ pub struct ServerBuilder { } impl ServerBuilder { - /// Create a new [ServerBuilder] + /// Creates a new [ServerBuilder]. pub fn new(addr: SocketAddr) -> Self { Self { secret_key: None, @@ -161,20 +204,21 @@ impl ServerBuilder { } } - /// The [`SecretKey`] identity for this relay server. When set to `None`, the builder assumes - /// you do not want to run a relay service. + /// The [`SecretKey`] identity for this relay server. + /// + /// When set to `None`, the builder assumes you do not want to run a relay service. pub fn secret_key(mut self, secret_key: Option) -> Self { self.secret_key = secret_key; self } - /// Serve relay content using TLS. + /// Serves all requests content using TLS. pub fn tls_config(mut self, config: Option) -> Self { self.tls_config = config; self } - /// Add a custom handler for a specific Method & URI. + /// Adds a custom handler for a specific Method & URI. pub fn request_handler( mut self, method: Method, @@ -185,26 +229,29 @@ impl ServerBuilder { self } - /// Pass in a custom "404" handler. + /// Sets a custom "404" handler. pub fn not_found_handler(mut self, handler: HyperHandler) -> Self { self.not_found_fn = Some(handler); self } - /// Handle the relay endpoint in a custom way. This is required if no [`SecretKey`] was provided - /// to the builder. + /// Handles the relay endpoint in a custom way. + /// + /// This is required if no [`SecretKey`] was provided to the builder. pub fn relay_override(mut self, handler: HyperHandler) -> Self { self.relay_override = Some(handler); self } - /// Change the relay endpoint from "/derp" to `endpoint`. + /// Sets a custom endpoint for the relay handler. + /// + /// The default is `/derp`. pub fn relay_endpoint(mut self, endpoint: &'static str) -> Self { self.relay_endpoint = endpoint; self } - /// Add http headers. + /// Adds HTTP headers to responses. pub fn headers(mut self, headers: HeaderMap) -> Self { for (k, v) in headers.iter() { self.headers.insert(k.clone(), v.clone()); @@ -212,10 +259,14 @@ impl ServerBuilder { self } - /// Build and spawn an HTTP(S) relay Server + /// Builds and spawns an HTTP(S) Relay Server. pub async fn spawn(self) -> Result { - ensure!(self.secret_key.is_some() || self.relay_override.is_some(), "Must provide a `SecretKey` for the relay server OR pass in an override function for the 'relay' endpoint"); + ensure!( + self.secret_key.is_some() || self.relay_override.is_some(), + "Must provide a `SecretKey` for the relay server OR pass in an override function for the 'relay' endpoint" + ); let (relay_handler, relay_server) = if let Some(secret_key) = self.secret_key { + // spawns a server actor/task let server = crate::relay::server::Server::new(secret_key.clone()); ( RelayHandler::ConnHandler(server.client_conn_handler(self.headers.clone())), @@ -258,6 +309,7 @@ impl ServerBuilder { service, }; + // Spawns some server tasks, we only wait till all tasks are started. server_state.serve().await } } @@ -274,13 +326,19 @@ impl ServerState { // Binds a TCP listener on `addr` and handles content using HTTPS. // Returns the local [`SocketAddr`] on which the server is listening. async fn serve(self) -> Result { - let listener = TcpListener::bind(&self.addr) + let ServerState { + addr, + tls_config, + server, + service, + } = self; + let listener = TcpListener::bind(&addr) .await - .context("failed to bind https")?; + .context("failed to bind server socket")?; // we will use this cancel token to stop the infinite loop in the `listener.accept() task` let cancel_server_loop = CancellationToken::new(); let addr = listener.local_addr()?; - let http_str = self.tls_config.as_ref().map_or("HTTP", |_| "HTTPS"); + let http_str = tls_config.as_ref().map_or("HTTP", |_| "HTTPS"); info!("[{http_str}] relay: serving on {addr}"); let cancel = cancel_server_loop.clone(); let task = tokio::task::spawn(async move { @@ -295,8 +353,8 @@ impl ServerState { res = listener.accept() => match res { Ok((stream, peer_addr)) => { debug!("[{http_str}] relay: Connection opened from {peer_addr}"); - let tls_config = self.tls_config.clone(); - let service = self.service.clone(); + let tls_config = tls_config.clone(); + let service = service.clone(); // spawn a task to handle the connection set.spawn(async move { if let Err(error) = service @@ -320,13 +378,17 @@ impl ServerState { } } } + if let Some(server) = server { + // TODO: if the task this is running in is aborted this server is not shut + // down. + server.close().await; + } set.shutdown().await; debug!("[{http_str}] relay: server has been shutdown."); }.instrument(info_span!("relay-http-serve"))); Ok(Server { addr, - server: self.server, http_server_task: task, cancel_server_loop, }) diff --git a/iroh-net/src/relay/iroh_relay.rs b/iroh-net/src/relay/iroh_relay.rs new file mode 100644 index 0000000000..928cdbaa8c --- /dev/null +++ b/iroh-net/src/relay/iroh_relay.rs @@ -0,0 +1,909 @@ +//! A full-fledged iroh-relay server. +//! +//! This module provides an API to run a full fledged iroh-relay server. It is primarily +//! used by the `iroh-relay` binary in this crate. It can be used to run a relay server in +//! other locations however. +//! +//! This code is fully written in a form of structured-concurrency: every spawned task is +//! always attached to a handle and when the handle is dropped the tasks abort. So tasks +//! can not outlive their handle. It is also always possible to await for completion of a +//! task. Some tasks additionally have a method to do graceful shutdown. + +use std::fmt; +use std::future::Future; +use std::net::SocketAddr; +use std::pin::Pin; +use std::sync::Arc; + +use anyhow::{anyhow, bail, Context, Result}; +use futures_lite::StreamExt; +use http::response::Builder as ResponseBuilder; +use http::{HeaderMap, Method, Request, Response, StatusCode}; +use hyper::body::Incoming; +use iroh_metrics::inc; +use tokio::net::{TcpListener, UdpSocket}; +use tokio::task::JoinSet; +use tracing::{debug, error, info, info_span, instrument, trace, warn, Instrument}; + +use crate::key::SecretKey; +use crate::relay; +use crate::relay::http::{ServerBuilder as RelayServerBuilder, TlsAcceptor}; +use crate::stun; +use crate::util::AbortingJoinHandle; + +// Module defined in this file. +use metrics::StunMetrics; + +const NO_CONTENT_CHALLENGE_HEADER: &str = "X-Tailscale-Challenge"; +const NO_CONTENT_RESPONSE_HEADER: &str = "X-Tailscale-Response"; +const NOTFOUND: &[u8] = b"Not Found"; +const RELAY_DISABLED: &[u8] = b"relay server disabled"; +const ROBOTS_TXT: &[u8] = b"User-agent: *\nDisallow: /\n"; +const INDEX: &[u8] = br#" +

Iroh Relay

+

+ This is an Iroh Relay server. +

+"#; +const TLS_HEADERS: [(&str, &str); 2] = [ + ("Strict-Transport-Security", "max-age=63072000; includeSubDomains"), + ("Content-Security-Policy", "default-src 'none'; frame-ancestors 'none'; form-action 'none'; base-uri 'self'; block-all-mixed-content; plugin-types 'none'") +]; + +type BytesBody = http_body_util::Full; +type HyperError = Box; +type HyperResult = std::result::Result; + +/// Creates a new [`BytesBody`] with no content. +fn body_empty() -> BytesBody { + http_body_util::Full::new(hyper::body::Bytes::new()) +} + +/// Configuration for the full Relay & STUN server. +/// +/// Be aware the generic parameters are for when using the Let's Encrypt TLS configuration. +/// If not used dummy ones need to be provided, e.g. `ServerConfig::<(), ()>::default()`. +#[derive(Debug, Default)] +pub struct ServerConfig { + /// Configuration for the Relay server, disabled if `None`. + pub relay: Option>, + /// Configuration for the STUN server, disabled if `None`. + pub stun: Option, + /// Socket to serve metrics on. + #[cfg(feature = "metrics")] + pub metrics_addr: Option, +} + +/// Configuration for the Relay HTTP and HTTPS server. +/// +/// This includes the HTTP services hosted by the Relay server, the Relay `/derp` HTTP +/// endpoint is only one of the services served. +#[derive(Debug)] +pub struct RelayConfig { + /// The iroh secret key of the Relay server. + pub secret_key: SecretKey, + /// The socket address on which the Relay HTTP server should bind. + /// + /// Normally you'd choose port `80`. The bind address for the HTTPS server is + /// configured in [`RelayConfig::tls`]. + /// + /// If [`RelayConfig::tls`] is `None` then this serves all the HTTP services without + /// TLS. + pub http_bind_addr: SocketAddr, + /// TLS configuration for the HTTPS server. + /// + /// If *None* all the HTTP services that would be served here are served from + /// [`RelayConfig::http_bind_addr`]. + pub tls: Option>, + /// Rate limits. + pub limits: Limits, +} + +/// Configuration for the STUN server. +#[derive(Debug)] +pub struct StunConfig { + /// The socket address on which the STUN server should bind. + /// + /// Normally you'd chose port `3478`, see [`crate::defaults::DEFAULT_STUN_PORT`]. + pub bind_addr: SocketAddr, +} + +/// TLS configuration for Relay server. +/// +/// Normally the Relay server accepts connections on both HTTPS and HTTP. +#[derive(Debug)] +pub struct TlsConfig { + /// The socket address on which to serve the HTTPS server. + /// + /// Since the captive portal probe has to run over plain text HTTP and TLS is used for + /// the main relay server this has to be on a different port. When TLS is not enabled + /// this is served on the [`RelayConfig::http_bind_addr`] socket address. + /// + /// Normally you'd choose port `80`. + pub https_bind_addr: SocketAddr, + /// Mode for getting a cert. + pub cert: CertConfig, +} + +/// Rate limits. +#[derive(Debug, Default)] +pub struct Limits { + /// Rate limit for accepting new connection. Unlimited if not set. + pub accept_conn_limit: Option, + /// Burst limit for accepting new connection. Unlimited if not set. + pub accept_conn_burst: Option, +} + +/// TLS certificate configuration. +#[derive(derive_more::Debug)] +pub enum CertConfig { + /// Use Let's Encrypt. + LetsEncrypt { + /// Configuration for Let's Encrypt certificates. + #[debug("AcmeConfig")] + config: tokio_rustls_acme::AcmeConfig, + }, + /// Use a static TLS key and certificate chain. + Manual { + /// The TLS private key. + private_key: rustls::PrivateKey, + /// The TLS certificate chain. + certs: Vec, + }, +} + +/// A running Relay + STUN server. +/// +/// This is a full Relay server, including STUN, Relay and various associated HTTP services. +/// +/// Dropping this will stop the server. +#[derive(Debug)] +pub struct Server { + /// The address of the HTTP server, if configured. + http_addr: Option, + /// The address of the STUN server, if configured. + stun_addr: Option, + /// The address of the HTTPS server, if the relay server is using TLS. + /// + /// If the Relay server is not using TLS then it is served from the + /// [`Server::http_addr`]. + https_addr: Option, + /// Handle to the relay server. + relay_handle: Option, + /// The main task running the server. + supervisor: AbortingJoinHandle>, +} + +impl Server { + /// Starts the server. + pub async fn spawn(config: ServerConfig) -> Result + where + EC: fmt::Debug + 'static, + EA: fmt::Debug + 'static, + { + let mut tasks = JoinSet::new(); + + #[cfg(feature = "metrics")] + if let Some(addr) = config.metrics_addr { + debug!("Starting metrics server"); + use iroh_metrics::core::Metric; + + iroh_metrics::core::Core::init(|reg, metrics| { + metrics.insert(crate::metrics::RelayMetrics::new(reg)); + metrics.insert(StunMetrics::new(reg)); + }); + tasks.spawn( + iroh_metrics::metrics::start_metrics_server(addr) + .instrument(info_span!("metrics-server")), + ); + } + + // Start the STUN server. + let stun_addr = match config.stun { + Some(stun) => { + debug!("Starting STUN server"); + match UdpSocket::bind(stun.bind_addr).await { + Ok(sock) => { + let addr = sock.local_addr()?; + info!("STUN server bound on {addr}"); + tasks.spawn( + server_stun_listener(sock).instrument(info_span!("stun-server", %addr)), + ); + Some(addr) + } + Err(err) => bail!("failed to bind STUN listener: {err:#?}"), + } + } + None => None, + }; + + // Start the Relay server. + let (relay_server, http_addr) = match config.relay { + Some(relay_config) => { + debug!("Starting Relay server"); + let mut headers = HeaderMap::new(); + for (name, value) in TLS_HEADERS.iter() { + headers.insert(*name, value.parse()?); + } + let relay_bind_addr = match relay_config.tls { + Some(ref tls) => tls.https_bind_addr, + None => relay_config.http_bind_addr, + }; + let mut builder = RelayServerBuilder::new(relay_bind_addr) + .secret_key(Some(relay_config.secret_key)) + .headers(headers) + .relay_override(Box::new(relay_disabled_handler)) + .request_handler(Method::GET, "/", Box::new(root_handler)) + .request_handler(Method::GET, "/index.html", Box::new(root_handler)) + .request_handler(Method::GET, "/derp/probe", Box::new(probe_handler)) + .request_handler(Method::GET, "/robots.txt", Box::new(robots_handler)); + let http_addr = match relay_config.tls { + Some(tls_config) => { + let server_config = rustls::ServerConfig::builder() + .with_safe_defaults() + .with_no_client_auth(); + let server_tls_config = match tls_config.cert { + CertConfig::LetsEncrypt { config } => { + let mut state = config.state(); + let server_config = + server_config.with_cert_resolver(state.resolver()); + let acceptor = TlsAcceptor::LetsEncrypt(state.acceptor()); + tasks.spawn( + async move { + while let Some(event) = state.next().await { + match event { + Ok(ok) => debug!("acme event: {ok:?}"), + Err(err) => error!("error: {err:?}"), + } + } + Err(anyhow!("acme event stream finished")) + } + .instrument(info_span!("acme")), + ); + Some(relay::http::TlsConfig { + config: Arc::new(server_config), + acceptor, + }) + } + CertConfig::Manual { private_key, certs } => { + let server_config = server_config + .with_single_cert(certs.clone(), private_key.clone())?; + let server_config = Arc::new(server_config); + let acceptor = + tokio_rustls::TlsAcceptor::from(server_config.clone()); + let acceptor = TlsAcceptor::Manual(acceptor); + Some(relay::http::TlsConfig { + config: server_config, + acceptor, + }) + } + }; + builder = builder.tls_config(server_tls_config); + + // Some services always need to be served over HTTP without TLS. Run + // these standalone. + let http_listener = TcpListener::bind(&relay_config.http_bind_addr) + .await + .context("failed to bind http")?; + let http_addr = http_listener.local_addr()?; + tasks.spawn( + run_captive_portal_service(http_listener) + .instrument(info_span!("http-service", addr = %http_addr)), + ); + Some(http_addr) + } + None => { + // If running Relay without TLS add the plain HTTP server directly + // to the Relay server. + builder = builder.request_handler( + Method::GET, + "/generate_204", + Box::new(serve_no_content_handler), + ); + None + } + }; + let relay_server = builder.spawn().await?; + (Some(relay_server), http_addr) + } + None => (None, None), + }; + // If http_addr is Some then relay_server is serving HTTPS. If http_addr is None + // relay_server is serving HTTP, including the /generate_204 service. + let relay_addr = relay_server.as_ref().map(|srv| srv.addr()); + let relay_handle = relay_server.as_ref().map(|srv| srv.handle()); + let relay_server = relay_server.map(RelayHttpServerGuard); + let task = tokio::spawn(relay_supervisor(tasks, relay_server)); + Ok(Self { + http_addr: http_addr.or(relay_addr), + stun_addr, + https_addr: http_addr.and(relay_addr), + relay_handle, + supervisor: AbortingJoinHandle::from(task), + }) + } + + /// Requests graceful shutdown. + /// + /// Returns once all server tasks have stopped. + pub async fn shutdown(self) -> Result<()> { + // Only the Relay server needs shutting down, the supervisor will abort the tasks in + // the JoinSet when the server terminates. + if let Some(handle) = self.relay_handle { + handle.shutdown(); + } + self.supervisor.await? + } + + /// Returns the handle for the task. + /// + /// This allows waiting for the server's supervisor task to finish. Can be useful in + /// case there is an error in the server before it is shut down. + pub fn task_handle(&mut self) -> &mut AbortingJoinHandle> { + &mut self.supervisor + } + + /// The socket address the HTTPS server is listening on. + pub fn https_addr(&self) -> Option { + self.https_addr + } + + /// The socket address the HTTP server is listening on. + pub fn http_addr(&self) -> Option { + self.http_addr + } + + /// The socket address the STUN server is listening on. + pub fn stun_addr(&self) -> Option { + self.stun_addr + } +} + +/// Horrible hack to make [`relay::http::Server`] behave somewhat. +/// +/// We need this server to abort on drop to achieve structured concurrency. +// TODO: could consider building this directly into the relay::http::Server +#[derive(Debug)] +struct RelayHttpServerGuard(relay::http::Server); + +impl Drop for RelayHttpServerGuard { + fn drop(&mut self) { + self.0.task_handle().abort(); + } +} + +/// Supervisor for the relay server tasks. +/// +/// As soon as one of the tasks exits, all other tasks are stopped and the server stops. +/// The supervisor finishes once all tasks are finished. +#[instrument(skip_all)] +async fn relay_supervisor( + mut tasks: JoinSet>, + mut relay_http_server: Option, +) -> Result<()> { + let res = match (relay_http_server.as_mut(), tasks.len()) { + (None, _) => tasks + .join_next() + .await + .unwrap_or_else(|| Ok(Err(anyhow!("Nothing to supervise")))), + (Some(relay), 0) => relay.0.task_handle().await.map(anyhow::Ok), + (Some(relay), _) => { + tokio::select! { + biased; + Some(ret) = tasks.join_next() => ret, + ret = relay.0.task_handle() => ret.map(anyhow::Ok), + else => Ok(Err(anyhow!("Empty JoinSet (unreachable)"))), + } + } + }; + let ret = match res { + Ok(Ok(())) => { + debug!("Task exited"); + Ok(()) + } + Ok(Err(err)) => { + error!(%err, "Task failed"); + Err(err.context("task failed")) + } + Err(err) => { + if let Ok(panic) = err.try_into_panic() { + error!("Task panicked"); + std::panic::resume_unwind(panic); + } + debug!("Task cancelled"); + Err(anyhow!("task cancelled")) + } + }; + + // Ensure the HTTP server terminated, there is no harm in calling this after it is + // already shut down. The JoinSet is aborted on drop. + if let Some(server) = relay_http_server { + server.0.shutdown(); + } + + tasks.shutdown().await; + + ret +} + +/// Runs a STUN server. +/// +/// When the future is dropped, the server stops. +async fn server_stun_listener(sock: UdpSocket) -> Result<()> { + info!(addr = ?sock.local_addr().ok(), "running STUN server"); + let sock = Arc::new(sock); + let mut buffer = vec![0u8; 64 << 10]; + let mut tasks = JoinSet::new(); + loop { + tokio::select! { + biased; + _ = tasks.join_next(), if !tasks.is_empty() => (), + res = sock.recv_from(&mut buffer) => { + match res { + Ok((n, src_addr)) => { + inc!(StunMetrics, requests); + let pkt = &buffer[..n]; + if !stun::is(pkt) { + debug!(%src_addr, "STUN: ignoring non stun packet"); + inc!(StunMetrics, bad_requests); + continue; + } + let pkt = pkt.to_vec(); + tasks.spawn(handle_stun_request(src_addr, pkt, sock.clone())); + } + Err(err) => { + inc!(StunMetrics, failures); + warn!("failed to recv: {err:#}"); + } + } + } + } + } +} + +/// Handles a single STUN request, doing all logging required. +async fn handle_stun_request(src_addr: SocketAddr, pkt: Vec, sock: Arc) { + let handle = AbortingJoinHandle::from(tokio::task::spawn_blocking(move || { + match stun::parse_binding_request(&pkt) { + Ok(txid) => { + debug!(%src_addr, %txid, "STUN: received binding request"); + Some((txid, stun::response(txid, src_addr))) + } + Err(err) => { + inc!(StunMetrics, bad_requests); + warn!(%src_addr, "STUN: invalid binding request: {:?}", err); + None + } + } + })); + let (txid, response) = match handle.await { + Ok(Some(val)) => val, + Ok(None) => return, + Err(err) => { + error!("{err:#}"); + return; + } + }; + match sock.send_to(&response, src_addr).await { + Ok(len) => { + if len != response.len() { + warn!( + %src_addr, + %txid, + "failed to write response, {len}/{} bytes sent", + response.len() + ); + } else { + match src_addr { + SocketAddr::V4(_) => inc!(StunMetrics, ipv4_success), + SocketAddr::V6(_) => inc!(StunMetrics, ipv6_success), + } + } + trace!(%src_addr, %txid, "sent {len} bytes"); + } + Err(err) => { + inc!(StunMetrics, failures); + warn!(%src_addr, %txid, "failed to write response: {err:#}"); + } + } +} + +fn relay_disabled_handler( + _r: Request, + response: ResponseBuilder, +) -> HyperResult> { + response + .status(StatusCode::NOT_FOUND) + .body(RELAY_DISABLED.into()) + .map_err(|err| Box::new(err) as HyperError) +} + +fn root_handler( + _r: Request, + response: ResponseBuilder, +) -> HyperResult> { + response + .status(StatusCode::OK) + .header("Content-Type", "text/html; charset=utf-8") + .body(INDEX.into()) + .map_err(|err| Box::new(err) as HyperError) +} + +/// HTTP latency queries +fn probe_handler( + _r: Request, + response: ResponseBuilder, +) -> HyperResult> { + response + .status(StatusCode::OK) + .header("Access-Control-Allow-Origin", "*") + .body(body_empty()) + .map_err(|err| Box::new(err) as HyperError) +} + +fn robots_handler( + _r: Request, + response: ResponseBuilder, +) -> HyperResult> { + response + .status(StatusCode::OK) + .body(ROBOTS_TXT.into()) + .map_err(|err| Box::new(err) as HyperError) +} + +/// For captive portal detection. +fn serve_no_content_handler( + r: Request, + mut response: ResponseBuilder, +) -> HyperResult> { + if let Some(challenge) = r.headers().get(NO_CONTENT_CHALLENGE_HEADER) { + if !challenge.is_empty() + && challenge.len() < 64 + && challenge + .as_bytes() + .iter() + .all(|c| is_challenge_char(*c as char)) + { + response = response.header( + NO_CONTENT_RESPONSE_HEADER, + format!("response {}", challenge.to_str()?), + ); + } + } + + response + .status(StatusCode::NO_CONTENT) + .body(body_empty()) + .map_err(|err| Box::new(err) as HyperError) +} + +fn is_challenge_char(c: char) -> bool { + // Semi-randomly chosen as a limited set of valid characters + c.is_ascii_lowercase() + || c.is_ascii_uppercase() + || c.is_ascii_digit() + || c == '.' + || c == '-' + || c == '_' +} + +/// This is a future that never returns, drop it to cancel/abort. +async fn run_captive_portal_service(http_listener: TcpListener) -> Result<()> { + info!("serving"); + + // If this future is cancelled, this is dropped and all tasks are aborted. + let mut tasks = JoinSet::new(); + + loop { + match http_listener.accept().await { + Ok((stream, peer_addr)) => { + debug!(%peer_addr, "Connection opened",); + let handler = CaptivePortalService; + + tasks.spawn(async move { + let stream = relay::MaybeTlsStreamServer::Plain(stream); + let stream = hyper_util::rt::TokioIo::new(stream); + if let Err(err) = hyper::server::conn::http1::Builder::new() + .serve_connection(stream, handler) + .with_upgrades() + .await + { + error!("Failed to serve connection: {err:?}"); + } + }); + } + Err(err) => { + error!( + "[CaptivePortalService] failed to accept connection: {:#?}", + err + ); + } + } + } +} + +#[derive(Clone)] +struct CaptivePortalService; + +impl hyper::service::Service> for CaptivePortalService { + type Response = Response; + type Error = HyperError; + type Future = Pin> + Send>>; + + fn call(&self, req: Request) -> Self::Future { + match (req.method(), req.uri().path()) { + // Captive Portal checker + (&Method::GET, "/generate_204") => { + Box::pin(async move { serve_no_content_handler(req, Response::builder()) }) + } + _ => { + // Return 404 not found response. + let r = Response::builder() + .status(StatusCode::NOT_FOUND) + .body(NOTFOUND.into()) + .map_err(|err| Box::new(err) as HyperError); + Box::pin(async move { r }) + } + } + } +} + +mod metrics { + use iroh_metrics::{ + core::{Counter, Metric}, + struct_iterable::Iterable, + }; + + /// StunMetrics tracked for the DERPER + #[allow(missing_docs)] + #[derive(Debug, Clone, Iterable)] + pub struct StunMetrics { + /* + * Metrics about STUN requests over ipv6 + */ + /// Number of stun requests made + pub requests: Counter, + /// Number of successful requests over ipv4 + pub ipv4_success: Counter, + /// Number of successful requests over ipv6 + pub ipv6_success: Counter, + + /// Number of bad requests, either non-stun packets or incorrect binding request + pub bad_requests: Counter, + /// Number of failures + pub failures: Counter, + } + + impl Default for StunMetrics { + fn default() -> Self { + Self { + /* + * Metrics about STUN requests + */ + requests: Counter::new("Number of STUN requests made to the server."), + ipv4_success: Counter::new("Number of successful ipv4 STUN requests served."), + ipv6_success: Counter::new("Number of successful ipv6 STUN requests served."), + bad_requests: Counter::new("Number of bad requests made to the STUN endpoint."), + failures: Counter::new("Number of STUN requests that end in failure."), + } + } + } + + impl Metric for StunMetrics { + fn name() -> &'static str { + "stun" + } + } +} + +#[cfg(test)] +mod tests { + use std::net::Ipv4Addr; + use std::time::Duration; + + use bytes::Bytes; + use iroh_base::node_addr::RelayUrl; + + use crate::relay::http::ClientBuilder; + + use self::relay::ReceivedMessage; + + use super::*; + + #[tokio::test] + async fn test_no_services() { + let _guard = iroh_test::logging::setup(); + let mut server = Server::spawn(ServerConfig::<(), ()>::default()) + .await + .unwrap(); + let res = tokio::time::timeout(Duration::from_secs(5), server.task_handle()) + .await + .expect("timeout, server not finished") + .expect("server task JoinError"); + assert!(res.is_err()); + } + + #[tokio::test] + async fn test_conflicting_bind() { + let _guard = iroh_test::logging::setup(); + let mut server = Server::spawn(ServerConfig::<(), ()> { + relay: Some(RelayConfig { + secret_key: SecretKey::generate(), + http_bind_addr: (Ipv4Addr::LOCALHOST, 1234).into(), + tls: None, + limits: Default::default(), + }), + stun: None, + metrics_addr: Some((Ipv4Addr::LOCALHOST, 1234).into()), + }) + .await + .unwrap(); + let res = tokio::time::timeout(Duration::from_secs(5), server.task_handle()) + .await + .expect("timeout, server not finished") + .expect("server task JoinError"); + assert!(res.is_err()); // AddrInUse + } + + #[tokio::test] + async fn test_root_handler() { + let _guard = iroh_test::logging::setup(); + let server = Server::spawn(ServerConfig::<(), ()> { + relay: Some(RelayConfig { + secret_key: SecretKey::generate(), + http_bind_addr: (Ipv4Addr::LOCALHOST, 0).into(), + tls: None, + limits: Default::default(), + }), + stun: None, + metrics_addr: None, + }) + .await + .unwrap(); + let url = format!("http://{}", server.http_addr().unwrap()); + + let response = reqwest::get(&url).await.unwrap(); + assert_eq!(response.status(), 200); + let body = response.text().await.unwrap(); + assert!(body.contains("iroh.computer")); + } + + #[tokio::test] + async fn test_captive_portal_service() { + let _guard = iroh_test::logging::setup(); + let server = Server::spawn(ServerConfig::<(), ()> { + relay: Some(RelayConfig { + secret_key: SecretKey::generate(), + http_bind_addr: (Ipv4Addr::LOCALHOST, 0).into(), + tls: None, + limits: Default::default(), + }), + stun: None, + metrics_addr: None, + }) + .await + .unwrap(); + let url = format!("http://{}/generate_204", server.http_addr().unwrap()); + let challenge = "123az__."; + + let client = reqwest::Client::new(); + let response = client + .get(&url) + .header(NO_CONTENT_CHALLENGE_HEADER, challenge) + .send() + .await + .unwrap(); + assert_eq!(response.status(), StatusCode::NO_CONTENT); + let header = response.headers().get(NO_CONTENT_RESPONSE_HEADER).unwrap(); + assert_eq!(header.to_str().unwrap(), format!("response {challenge}")); + let body = response.text().await.unwrap(); + assert!(body.is_empty()); + } + + #[tokio::test] + async fn test_relay_clients() { + let _guard = iroh_test::logging::setup(); + let server = Server::spawn(ServerConfig::<(), ()> { + relay: Some(RelayConfig { + secret_key: SecretKey::generate(), + http_bind_addr: (Ipv4Addr::LOCALHOST, 0).into(), + tls: None, + limits: Default::default(), + }), + stun: None, + metrics_addr: None, + }) + .await + .unwrap(); + let relay_url = format!("http://{}", server.http_addr().unwrap()); + let relay_url: RelayUrl = relay_url.parse().unwrap(); + + // set up client a + let a_secret_key = SecretKey::generate(); + let a_key = a_secret_key.public(); + let resolver = crate::dns::default_resolver().clone(); + let (client_a, mut client_a_receiver) = + ClientBuilder::new(relay_url.clone()).build(a_secret_key, resolver); + let connect_client = client_a.clone(); + + // give the relay server some time to accept connections + if let Err(err) = tokio::time::timeout(Duration::from_secs(10), async move { + loop { + match connect_client.connect().await { + Ok(_) => break, + Err(err) => { + warn!("client unable to connect to relay server: {err:#}"); + tokio::time::sleep(Duration::from_millis(100)).await; + } + } + } + }) + .await + { + panic!("error connecting to relay server: {err:#}"); + } + + // set up client b + let b_secret_key = SecretKey::generate(); + let b_key = b_secret_key.public(); + let resolver = crate::dns::default_resolver().clone(); + let (client_b, mut client_b_receiver) = + ClientBuilder::new(relay_url.clone()).build(b_secret_key, resolver); + client_b.connect().await.unwrap(); + + // send message from a to b + let msg = Bytes::from("hello, b"); + client_a.send(b_key, msg.clone()).await.unwrap(); + + let (res, _) = client_b_receiver.recv().await.unwrap().unwrap(); + if let ReceivedMessage::ReceivedPacket { source, data } = res { + assert_eq!(a_key, source); + assert_eq!(msg, data); + } else { + panic!("client_b received unexpected message {res:?}"); + } + + // send message from b to a + let msg = Bytes::from("howdy, a"); + client_b.send(a_key, msg.clone()).await.unwrap(); + + let (res, _) = client_a_receiver.recv().await.unwrap().unwrap(); + if let ReceivedMessage::ReceivedPacket { source, data } = res { + assert_eq!(b_key, source); + assert_eq!(msg, data); + } else { + panic!("client_a received unexpected message {res:?}"); + } + } + + #[tokio::test] + async fn test_stun() { + let _guard = iroh_test::logging::setup(); + let server = Server::spawn(ServerConfig::<(), ()> { + relay: None, + stun: Some(StunConfig { + bind_addr: (Ipv4Addr::LOCALHOST, 0).into(), + }), + metrics_addr: None, + }) + .await + .unwrap(); + + let txid = stun::TransactionId::default(); + let req = stun::request(txid); + let socket = UdpSocket::bind("127.0.0.1:0").await.unwrap(); + socket + .send_to(&req, server.stun_addr().unwrap()) + .await + .unwrap(); + + // get response + let mut buf = vec![0u8; 64000]; + let (len, addr) = socket.recv_from(&mut buf).await.unwrap(); + assert_eq!(addr, server.stun_addr().unwrap()); + buf.truncate(len); + let (txid_back, response_addr) = stun::parse_response(&buf).unwrap(); + assert_eq!(txid, txid_back); + assert_eq!(response_addr, socket.local_addr().unwrap()); + } +} diff --git a/iroh-net/src/relay/map.rs b/iroh-net/src/relay/map.rs index ede590a7ae..721fd778a1 100644 --- a/iroh-net/src/relay/map.rs +++ b/iroh-net/src/relay/map.rs @@ -5,7 +5,7 @@ use std::{collections::BTreeMap, fmt, sync::Arc}; use anyhow::{ensure, Result}; use serde::{Deserialize, Serialize}; -use crate::defaults::DEFAULT_RELAY_STUN_PORT; +use crate::defaults::DEFAULT_STUN_PORT; use super::RelayUrl; @@ -91,7 +91,7 @@ impl RelayMap { /// This will use the default STUN port and IP addresses resolved from the URL's host name via DNS. /// relay nodes are specified at <../../../docs/relay_nodes.md> pub fn from_url(url: RelayUrl) -> Self { - Self::default_from_node(url, DEFAULT_RELAY_STUN_PORT) + Self::default_from_node(url, DEFAULT_STUN_PORT) } /// Constructs the [`RelayMap] from an iterator of [`RelayNode`]s. diff --git a/iroh-net/src/relay/server.rs b/iroh-net/src/relay/server.rs index 38493c4601..05dbc60ad7 100644 --- a/iroh-net/src/relay/server.rs +++ b/iroh-net/src/relay/server.rs @@ -113,6 +113,13 @@ impl Server { } } + /// Aborts the server. + /// + /// You should prefer to use [`Server::close`] for a graceful shutdown. + pub fn abort(&self) { + self.cancel.cancel(); + } + /// Whether or not the relay [Server] is closed. pub fn is_closed(&self) -> bool { self.closed diff --git a/iroh-net/src/stun.rs b/iroh-net/src/stun.rs index b9ff7e6dde..e0ed936782 100644 --- a/iroh-net/src/stun.rs +++ b/iroh-net/src/stun.rs @@ -72,8 +72,8 @@ const COOKIE: [u8; 4] = 0x2112_A442u32.to_be_bytes(); /// Reports whether b is a STUN message. pub fn is(b: &[u8]) -> bool { b.len() >= stun_rs::MESSAGE_HEADER_SIZE && - b[0]&0b11000000 == 0 && // top two bits must be zero - b[4..8] == COOKIE + b[0]&0b11000000 == 0 && // top two bits must be zero + b[4..8] == COOKIE } /// Parses a STUN binding request. @@ -149,9 +149,10 @@ pub fn parse_response(b: &[u8]) -> Result<(TransactionId, SocketAddr), Error> { Err(Error::MalformedAttrs) } -#[cfg(any(test, feature = "test-utils"))] -pub(crate) mod test { - use std::{net::IpAddr, sync::Arc}; +#[cfg(test)] +pub(crate) mod tests { + use std::net::{IpAddr, Ipv4Addr}; + use std::sync::Arc; use anyhow::Result; use tokio::{ @@ -160,30 +161,28 @@ pub(crate) mod test { }; use tracing::{debug, trace}; - #[cfg(test)] use crate::relay::{RelayMap, RelayNode, RelayUrl}; use crate::test_utils::CleanupDropGuard; use super::*; + // TODO: make all this private + // (read_ipv4, read_ipv5) #[derive(Debug, Default, Clone)] pub struct StunStats(Arc>); impl StunStats { - #[cfg(test)] pub async fn total(&self) -> usize { let s = self.0.lock().await; s.0 + s.1 } } - #[cfg(test)] pub fn relay_map_of(stun: impl Iterator) -> RelayMap { relay_map_of_opts(stun.map(|addr| (addr, true))) } - #[cfg(test)] pub fn relay_map_of_opts(stun: impl Iterator) -> RelayMap { let nodes = stun.map(|(addr, stun_only)| { let host = addr.ip(); @@ -202,7 +201,6 @@ pub(crate) mod test { /// Sets up a simple STUN server binding to `0.0.0.0:0`. /// /// See [`serve`] for more details. - #[cfg(test)] pub(crate) async fn serve_v4() -> Result<(SocketAddr, StunStats, CleanupDropGuard)> { serve(std::net::Ipv4Addr::UNSPECIFIED.into()).await } @@ -272,13 +270,6 @@ pub(crate) mod test { } } } -} - -#[cfg(test)] -mod tests { - use std::net::{IpAddr, Ipv4Addr}; - - use super::*; // Test to check if an existing stun server works // #[tokio::test] diff --git a/iroh-net/src/test_utils.rs b/iroh-net/src/test_utils.rs index 0cbf8bd857..3188a7e128 100644 --- a/iroh-net/src/test_utils.rs +++ b/iroh-net/src/test_utils.rs @@ -2,7 +2,6 @@ use anyhow::Result; use tokio::sync::oneshot; -use tracing::{error_span, info_span, Instrument}; use crate::{ key::SecretKey, @@ -24,48 +23,51 @@ pub struct CleanupDropGuard(pub(crate) oneshot::Sender<()>); /// Runs a relay server with STUN enabled suitable for tests. /// -/// The returned `Url` is the url of the relay server in the returned [`RelayMap`], it -/// is always `Some` as that is how the [`Endpoint::connect`] API expects it. +/// The returned `Url` is the url of the relay server in the returned [`RelayMap`]. +/// When dropped, the returned [`Server`] does will stop running. /// -/// [`Endpoint::connect`]: crate::endpoint::Endpoint -pub async fn run_relay_server() -> Result<(RelayMap, RelayUrl, CleanupDropGuard)> { - let server_key = SecretKey::generate(); - let me = server_key.public().fmt_short(); - let tls_config = crate::relay::http::make_tls_config(); - let server = crate::relay::http::ServerBuilder::new("127.0.0.1:0".parse().unwrap()) - .secret_key(Some(server_key)) - .tls_config(Some(tls_config)) - .spawn() - .instrument(error_span!("relay server", %me)) - .await?; - - let https_addr = server.addr(); - println!("relay listening on {:?}", https_addr); - - let (stun_addr, _, stun_drop_guard) = crate::stun::test::serve(server.addr().ip()).await?; - let url: RelayUrl = format!("https://localhost:{}", https_addr.port()) +/// [`Server`]: crate::relay::iroh_relay::Server +pub async fn run_relay_server() -> Result<(RelayMap, RelayUrl, crate::relay::iroh_relay::Server)> { + use crate::relay::iroh_relay::{CertConfig, RelayConfig, ServerConfig, StunConfig, TlsConfig}; + use std::net::Ipv4Addr; + + let secret_key = SecretKey::generate(); + let cert = rcgen::generate_simple_self_signed(vec!["localhost".to_string()]).unwrap(); + let rustls_cert = rustls::Certificate(cert.serialize_der().unwrap()); + let private_key = rustls::PrivateKey(cert.get_key_pair().serialize_der()); + + let config = ServerConfig { + relay: Some(RelayConfig { + http_bind_addr: (Ipv4Addr::LOCALHOST, 0).into(), + secret_key, + tls: Some(TlsConfig { + cert: CertConfig::<(), ()>::Manual { + private_key, + certs: vec![rustls_cert], + }, + https_bind_addr: (Ipv4Addr::LOCALHOST, 0).into(), + }), + limits: Default::default(), + }), + stun: Some(StunConfig { + bind_addr: (Ipv4Addr::LOCALHOST, 0).into(), + }), + #[cfg(feature = "metrics")] + metrics_addr: None, + }; + let server = crate::relay::iroh_relay::Server::spawn(config) + .await + .unwrap(); + let url: RelayUrl = format!("https://localhost:{}", server.https_addr().unwrap().port()) .parse() .unwrap(); let m = RelayMap::from_nodes([RelayNode { url: url.clone(), stun_only: false, - stun_port: stun_addr.port(), + stun_port: server.stun_addr().unwrap().port(), }]) - .expect("hardcoded"); - - let (tx, rx) = oneshot::channel(); - tokio::spawn( - async move { - let _stun_cleanup = stun_drop_guard; // move into this closure - - // Wait until we're dropped or receive a message. - rx.await.ok(); - server.shutdown().await; - } - .instrument(info_span!("relay-stun-cleanup")), - ); - - Ok((m, url, CleanupDropGuard(tx))) + .unwrap(); + Ok((m, url, server)) } pub(crate) mod dns_and_pkarr_servers { diff --git a/iroh-test/Cargo.toml b/iroh-test/Cargo.toml index 319f26aaef..f8c7959de8 100644 --- a/iroh-test/Cargo.toml +++ b/iroh-test/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "iroh-test" -version = "0.17.0" +version = "0.18.0" edition = "2021" readme = "README.md" description = "Internal utilities to support testing of iroh." diff --git a/iroh/Cargo.toml b/iroh/Cargo.toml index 07bebf54d0..15404847da 100644 --- a/iroh/Cargo.toml +++ b/iroh/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "iroh" -version = "0.17.0" +version = "0.18.0" edition = "2021" readme = "README.md" description = "Bytes. Distributed." @@ -26,15 +26,15 @@ futures-lite = "2.3" futures-util = "0.3" genawaiter = { version = "0.99", default-features = false, features = ["futures03"] } hex = { version = "0.4.3" } -iroh-blobs = { version = "0.17.0", path = "../iroh-blobs", features = ["downloader"] } -iroh-base = { version = "0.17.0", path = "../iroh-base", features = ["key"] } +iroh-blobs = { version = "0.18.0", path = "../iroh-blobs", features = ["downloader"] } +iroh-base = { version = "0.18.0", path = "../iroh-base", features = ["key"] } iroh-io = { version = "0.6.0", features = ["stats"] } -iroh-metrics = { version = "0.17.0", path = "../iroh-metrics", optional = true } -iroh-net = { version = "0.17.0", path = "../iroh-net" } +iroh-metrics = { version = "0.18.0", path = "../iroh-metrics", optional = true } +iroh-net = { version = "0.18.0", path = "../iroh-net" } num_cpus = { version = "1.15.0" } portable-atomic = "1" -iroh-docs = { version = "0.17.0", path = "../iroh-docs" } -iroh-gossip = { version = "0.17.0", path = "../iroh-gossip" } +iroh-docs = { version = "0.18.0", path = "../iroh-docs" } +iroh-gossip = { version = "0.18.0", path = "../iroh-gossip" } parking_lot = "0.12.1" postcard = { version = "1", default-features = false, features = ["alloc", "use-std", "experimental-derive"] } quic-rpc = { version = "0.10.0", default-features = false, features = ["flume-transport", "quinn-transport"] } @@ -53,6 +53,7 @@ walkdir = "2" # Examples clap = { version = "4", features = ["derive"], optional = true } indicatif = { version = "0.17", features = ["tokio"], optional = true } +ref-cast = "1.0.23" [features] default = ["metrics", "fs-store"] @@ -101,3 +102,7 @@ required-features = ["examples"] [[example]] name = "client" required-features = ["examples"] + +[[example]] +name = "custom-protocol" +required-features = ["examples"] diff --git a/iroh/examples/client.rs b/iroh/examples/client.rs index 3e4d018aed..0e04a91c30 100644 --- a/iroh/examples/client.rs +++ b/iroh/examples/client.rs @@ -16,8 +16,8 @@ async fn main() -> anyhow::Result<()> { // Could also use `node` directly, as it derefs to the client. let client = node.client(); - let doc = client.docs.create().await?; - let author = client.authors.default().await?; + let doc = client.docs().create().await?; + let author = client.authors().default().await?; doc.set_bytes(author, "hello", "world").await?; diff --git a/iroh/examples/collection-fetch.rs b/iroh/examples/collection-fetch.rs index e35f61ba95..c827f13cdc 100644 --- a/iroh/examples/collection-fetch.rs +++ b/iroh/examples/collection-fetch.rs @@ -59,7 +59,7 @@ async fn main() -> Result<()> { // `download` returns a stream of `DownloadProgress` events. You can iterate through these updates to get progress // on the state of your download. let download_stream = node - .blobs + .blobs() .download_hash_seq(ticket.hash(), ticket.node_addr().clone()) .await?; @@ -76,7 +76,7 @@ async fn main() -> Result<()> { // A `Collection` is a special `HashSeq`, where we preserve the names of any blobs added to the collection. (We do this by designating the first entry in the `Collection` as meta data.) // To get the content of the collection, we first get the collection from the database using the `blobs` API let collection = node - .blobs + .blobs() .get_collection(ticket.hash()) .await .context("expect hash with `BlobFormat::HashSeq` to be a collection")?; @@ -85,7 +85,7 @@ async fn main() -> Result<()> { for (name, hash) in collection.iter() { println!("\nname: {name}, hash: {hash}"); // Use the hash of the blob to get the content. - let content = node.blobs.read_to_bytes(*hash).await?; + let content = node.blobs().read_to_bytes(*hash).await?; let s = std::str::from_utf8(&content).context("unable to parse blob as as utf-8 string")?; println!("{s}"); } diff --git a/iroh/examples/collection-provide.rs b/iroh/examples/collection-provide.rs index 37f05da545..867b2ac5e3 100644 --- a/iroh/examples/collection-provide.rs +++ b/iroh/examples/collection-provide.rs @@ -27,8 +27,8 @@ async fn main() -> anyhow::Result<()> { let node = iroh::node::Node::memory().spawn().await?; // Add two blobs - let blob1 = node.blobs.add_bytes("the first blob of bytes").await?; - let blob2 = node.blobs.add_bytes("the second blob of bytes").await?; + let blob1 = node.blobs().add_bytes("the first blob of bytes").await?; + let blob2 = node.blobs().add_bytes("the second blob of bytes").await?; // Create blobs from the data let collection: Collection = [("blob1", blob1.hash), ("blob2", blob2.hash)] @@ -37,14 +37,14 @@ async fn main() -> anyhow::Result<()> { // Create a collection let (hash, _) = node - .blobs + .blobs() .create_collection(collection, SetTagOption::Auto, Default::default()) .await?; // create a ticket // tickets wrap all details needed to get a collection let ticket = node - .blobs + .blobs() .share(hash, BlobFormat::HashSeq, Default::default()) .await?; diff --git a/iroh/examples/custom-protocol.rs b/iroh/examples/custom-protocol.rs new file mode 100644 index 0000000000..4a12687725 --- /dev/null +++ b/iroh/examples/custom-protocol.rs @@ -0,0 +1,127 @@ +use std::sync::Arc; + +use anyhow::Result; +use clap::Parser; +use futures_lite::future::Boxed as BoxedFuture; +use iroh::{ + client::MemIroh, + net::{ + endpoint::{get_remote_node_id, Connecting}, + Endpoint, NodeId, + }, + node::ProtocolHandler, +}; +use tracing_subscriber::{prelude::*, EnvFilter}; + +#[derive(Debug, Parser)] +pub struct Cli { + #[clap(subcommand)] + command: Command, +} + +#[derive(Debug, Parser)] +pub enum Command { + Accept, + Connect { node: NodeId }, +} + +#[tokio::main] +async fn main() -> Result<()> { + setup_logging(); + let args = Cli::parse(); + // create a new node + let builder = iroh::node::Node::memory().build().await?; + let proto = ExampleProto::new(builder.client().clone(), builder.endpoint().clone()); + let node = builder + .accept(EXAMPLE_ALPN, Arc::new(proto.clone())) + .spawn() + .await?; + + // print the ticket if this is the accepting side + match args.command { + Command::Accept => { + let node_id = node.node_id(); + println!("node id: {node_id}"); + // wait until ctrl-c + tokio::signal::ctrl_c().await?; + } + Command::Connect { node: node_id } => { + proto.connect(node_id).await?; + } + } + + node.shutdown().await?; + + Ok(()) +} + +const EXAMPLE_ALPN: &[u8] = b"example-proto/0"; + +#[derive(Debug, Clone)] +struct ExampleProto { + client: MemIroh, + endpoint: Endpoint, +} + +impl ProtocolHandler for ExampleProto { + fn accept(self: Arc, connecting: Connecting) -> BoxedFuture> { + Box::pin(async move { + let connection = connecting.await?; + let peer = get_remote_node_id(&connection)?; + println!("accepted connection from {peer}"); + let mut send_stream = connection.open_uni().await?; + // Let's create a new blob for each incoming connection. + // This functions as an example of using existing iroh functionality within a protocol + // (you likely don't want to create a new blob for each connection for real) + let content = format!("this blob is created for my beloved peer {peer} โ™ฅ"); + let hash = self + .client + .blobs() + .add_bytes(content.as_bytes().to_vec()) + .await?; + // Send the hash over our custom protocol. + send_stream.write_all(hash.hash.as_bytes()).await?; + send_stream.finish().await?; + println!("closing connection from {peer}"); + Ok(()) + }) + } +} + +impl ExampleProto { + pub fn new(client: MemIroh, endpoint: Endpoint) -> Self { + Self { client, endpoint } + } + + pub async fn connect(&self, remote_node_id: NodeId) -> Result<()> { + println!("our node id: {}", self.endpoint.node_id()); + println!("connecting to {remote_node_id}"); + let conn = self + .endpoint + .connect_by_node_id(&remote_node_id, EXAMPLE_ALPN) + .await?; + let mut recv_stream = conn.accept_uni().await?; + let hash_bytes = recv_stream.read_to_end(32).await?; + let hash = iroh::blobs::Hash::from_bytes(hash_bytes.try_into().unwrap()); + println!("received hash: {hash}"); + self.client + .blobs() + .download(hash, remote_node_id.into()) + .await? + .await?; + println!("blob downloaded"); + let content = self.client.blobs().read_to_bytes(hash).await?; + let message = String::from_utf8(content.to_vec())?; + println!("blob content: {message}"); + Ok(()) + } +} + +/// Set the RUST_LOG env var to one of {debug,info,warn} to see logging. +fn setup_logging() { + tracing_subscriber::registry() + .with(tracing_subscriber::fmt::layer().with_writer(std::io::stderr)) + .with(EnvFilter::from_default_env()) + .try_init() + .ok(); +} diff --git a/iroh/examples/hello-world-fetch.rs b/iroh/examples/hello-world-fetch.rs index 71672845a8..06578b71eb 100644 --- a/iroh/examples/hello-world-fetch.rs +++ b/iroh/examples/hello-world-fetch.rs @@ -59,7 +59,7 @@ async fn main() -> Result<()> { // `download` returns a stream of `DownloadProgress` events. You can iterate through these updates to get progress // on the state of your download. let download_stream = node - .blobs + .blobs() .download(ticket.hash(), ticket.node_addr().clone()) .await?; @@ -74,7 +74,7 @@ async fn main() -> Result<()> { // Get the content we have just fetched from the iroh database. - let bytes = node.blobs.read_to_bytes(ticket.hash()).await?; + let bytes = node.blobs().read_to_bytes(ticket.hash()).await?; let s = std::str::from_utf8(&bytes).context("unable to parse blob as as utf-8 string")?; println!("{s}"); diff --git a/iroh/examples/hello-world-provide.rs b/iroh/examples/hello-world-provide.rs index 14be61aef5..8fe0e9c12a 100644 --- a/iroh/examples/hello-world-provide.rs +++ b/iroh/examples/hello-world-provide.rs @@ -23,11 +23,11 @@ async fn main() -> anyhow::Result<()> { let node = iroh::node::Node::memory().spawn().await?; // add some data and remember the hash - let res = node.blobs.add_bytes("Hello, world!").await?; + let res = node.blobs().add_bytes("Hello, world!").await?; // create a ticket let ticket = node - .blobs + .blobs() .share(res.hash, res.format, Default::default()) .await?; diff --git a/iroh/src/client.rs b/iroh/src/client.rs index 4c75adcf55..8e5bd9c411 100644 --- a/iroh/src/client.rs +++ b/iroh/src/client.rs @@ -2,6 +2,7 @@ use futures_lite::{Stream, StreamExt}; use quic_rpc::{RpcClient, ServiceConnection}; +use ref_cast::RefCast; #[doc(inline)] pub use crate::rpc_protocol::RpcService; @@ -26,17 +27,6 @@ mod node; /// Iroh client. #[derive(Debug, Clone)] pub struct Iroh { - /// Client for blobs operations. - pub blobs: blobs::Client, - /// Client for docs operations. - pub docs: docs::Client, - /// Client for author operations. - pub authors: authors::Client, - /// Client for tags operations. - pub tags: tags::Client, - /// Client for tags operations. - pub gossip: gossip::Client, - rpc: RpcClient, } @@ -46,14 +36,32 @@ where { /// Create a new high-level client to a Iroh node from the low-level RPC client. pub fn new(rpc: RpcClient) -> Self { - Self { - blobs: blobs::Client { rpc: rpc.clone() }, - docs: docs::Client { rpc: rpc.clone() }, - authors: authors::Client { rpc: rpc.clone() }, - gossip: gossip::Client { rpc: rpc.clone() }, - tags: tags::Client { rpc: rpc.clone() }, - rpc, - } + Self { rpc } + } + + /// Blobs client + pub fn blobs(&self) -> &blobs::Client { + blobs::Client::ref_cast(&self.rpc) + } + + /// Docs client + pub fn docs(&self) -> &docs::Client { + docs::Client::ref_cast(&self.rpc) + } + + /// Authors client + pub fn authors(&self) -> &authors::Client { + authors::Client::ref_cast(&self.rpc) + } + + /// Tags client + pub fn tags(&self) -> &tags::Client { + tags::Client::ref_cast(&self.rpc) + } + + /// Gossip client + pub fn gossip(&self) -> &gossip::Client { + gossip::Client::ref_cast(&self.rpc) } } diff --git a/iroh/src/client/authors.rs b/iroh/src/client/authors.rs index b695b3da7c..bf642fc3d9 100644 --- a/iroh/src/client/authors.rs +++ b/iroh/src/client/authors.rs @@ -4,6 +4,7 @@ use anyhow::Result; use futures_lite::{stream::StreamExt, Stream}; use iroh_docs::{Author, AuthorId}; use quic_rpc::{RpcClient, ServiceConnection}; +use ref_cast::RefCast; use crate::rpc_protocol::{ AuthorCreateRequest, AuthorDeleteRequest, AuthorExportRequest, AuthorGetDefaultRequest, @@ -13,7 +14,8 @@ use crate::rpc_protocol::{ use super::flatten; /// Iroh authors client. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, RefCast)] +#[repr(transparent)] pub struct Client { pub(super) rpc: RpcClient, } @@ -101,33 +103,33 @@ mod tests { let node = Node::memory().spawn().await?; // default author always exists - let authors: Vec<_> = node.authors.list().await?.try_collect().await?; + let authors: Vec<_> = node.authors().list().await?.try_collect().await?; assert_eq!(authors.len(), 1); - let default_author = node.authors.default().await?; + let default_author = node.authors().default().await?; assert_eq!(authors, vec![default_author]); - let author_id = node.authors.create().await?; + let author_id = node.authors().create().await?; - let authors: Vec<_> = node.authors.list().await?.try_collect().await?; + let authors: Vec<_> = node.authors().list().await?.try_collect().await?; assert_eq!(authors.len(), 2); let author = node - .authors + .authors() .export(author_id) .await? .expect("should have author"); - node.authors.delete(author_id).await?; - let authors: Vec<_> = node.authors.list().await?.try_collect().await?; + node.authors().delete(author_id).await?; + let authors: Vec<_> = node.authors().list().await?.try_collect().await?; assert_eq!(authors.len(), 1); - node.authors.import(author).await?; + node.authors().import(author).await?; - let authors: Vec<_> = node.authors.list().await?.try_collect().await?; + let authors: Vec<_> = node.authors().list().await?.try_collect().await?; assert_eq!(authors.len(), 2); - assert!(node.authors.default().await? != author_id); - node.authors.set_default(author_id).await?; - assert_eq!(node.authors.default().await?, author_id); + assert!(node.authors().default().await? != author_id); + node.authors().set_default(author_id).await?; + assert_eq!(node.authors().default().await?, author_id); Ok(()) } diff --git a/iroh/src/client/blobs.rs b/iroh/src/client/blobs.rs index 61d075e7fc..53245acd3d 100644 --- a/iroh/src/client/blobs.rs +++ b/iroh/src/client/blobs.rs @@ -13,10 +13,11 @@ use anyhow::{anyhow, Result}; use bytes::Bytes; use futures_lite::{Stream, StreamExt}; use futures_util::SinkExt; +use genawaiter::sync::{Co, Gen}; use iroh_base::{node_addr::AddrInfoOptions, ticket::BlobTicket}; use iroh_blobs::{ export::ExportProgress as BytesExportProgress, - format::collection::Collection, + format::collection::{Collection, SimpleStore}, get::db::DownloadProgress as BytesDownloadProgress, store::{ConsistencyCheckProgress, ExportFormat, ExportMode, ValidateProgress}, BlobFormat, Hash, Tag, @@ -24,6 +25,7 @@ use iroh_blobs::{ use iroh_net::NodeAddr; use portable_atomic::{AtomicU64, Ordering}; use quic_rpc::{client::BoxStreamSync, RpcClient, ServiceConnection}; +use ref_cast::RefCast; use serde::{Deserialize, Serialize}; use tokio::io::{AsyncRead, AsyncReadExt, ReadBuf}; use tokio_util::io::{ReaderStream, StreamReader}; @@ -31,23 +33,23 @@ use tracing::warn; use crate::rpc_protocol::{ BlobAddPathRequest, BlobAddStreamRequest, BlobAddStreamUpdate, BlobConsistencyCheckRequest, - BlobDeleteBlobRequest, BlobDownloadRequest, BlobExportRequest, BlobGetCollectionRequest, - BlobGetCollectionResponse, BlobListCollectionsRequest, BlobListIncompleteRequest, + BlobDeleteBlobRequest, BlobDownloadRequest, BlobExportRequest, BlobListIncompleteRequest, BlobListRequest, BlobReadAtRequest, BlobReadAtResponse, BlobValidateRequest, CreateCollectionRequest, CreateCollectionResponse, NodeStatusRequest, RpcService, SetTagOption, }; -use super::{flatten, Iroh}; +use super::{flatten, tags, Iroh}; /// Iroh blobs client. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, RefCast)] +#[repr(transparent)] pub struct Client { pub(super) rpc: RpcClient, } impl<'a, C: ServiceConnection> From<&'a Iroh> for &'a RpcClient { fn from(client: &'a Iroh) -> &'a RpcClient { - &client.blobs.rpc + &client.blobs().rpc } } @@ -322,18 +324,35 @@ where /// Read the content of a collection. pub async fn get_collection(&self, hash: Hash) -> Result { - let BlobGetCollectionResponse { collection } = - self.rpc.rpc(BlobGetCollectionRequest { hash }).await??; - Ok(collection) + Collection::load(hash, self).await } /// List all collections. - pub async fn list_collections(&self) -> Result>> { - let stream = self - .rpc - .server_streaming(BlobListCollectionsRequest) - .await?; - Ok(flatten(stream)) + pub fn list_collections(&self) -> Result>> { + let this = self.clone(); + Ok(Gen::new(|co| async move { + if let Err(cause) = this.list_collections_impl(&co).await { + co.yield_(Err(cause)).await; + } + })) + } + + async fn list_collections_impl(&self, co: &Co>) -> Result<()> { + let tags = self.tags_client(); + let mut tags = tags.list_hash_seq().await?; + while let Some(tag) = tags.next().await { + let tag = tag?; + if let Ok(collection) = self.get_collection(tag.hash).await { + let info = CollectionInfo { + tag: tag.name, + hash: tag.hash, + total_blobs_count: Some(collection.len() as u64 + 1), + total_blobs_size: Some(0), + }; + co.yield_(Ok(info)).await; + } + } + Ok(()) } /// Delete a blob. @@ -366,6 +385,21 @@ where Ok(BlobStatus::Partial { size: reader.size }) } } + + fn tags_client(&self) -> tags::Client { + tags::Client { + rpc: self.rpc.clone(), + } + } +} + +impl SimpleStore for Client +where + C: ServiceConnection, +{ + async fn load(&self, hash: Hash) -> anyhow::Result { + self.read_to_bytes(hash).await + } } /// Whether to wrap the added data in a collection. @@ -904,7 +938,7 @@ mod tests { // import files for path in &paths { let import_outcome = client - .blobs + .blobs() .add_from_path( path.to_path_buf(), false, @@ -925,11 +959,11 @@ mod tests { } let (hash, tag) = client - .blobs + .blobs() .create_collection(collection, SetTagOption::Auto, tags) .await?; - let collections: Vec<_> = client.blobs.list_collections().await?.try_collect().await?; + let collections: Vec<_> = client.blobs().list_collections()?.try_collect().await?; assert_eq!(collections.len(), 1); { @@ -946,7 +980,7 @@ mod tests { } // check that "temp" tags have been deleted - let tags: Vec<_> = client.tags.list().await?.try_collect().await?; + let tags: Vec<_> = client.tags().list().await?.try_collect().await?; assert_eq!(tags.len(), 1); assert_eq!(tags[0].hash, hash); assert_eq!(tags[0].name, tag); @@ -981,7 +1015,7 @@ mod tests { let client = node.client(); let import_outcome = client - .blobs + .blobs() .add_from_path( path.to_path_buf(), false, @@ -997,28 +1031,28 @@ mod tests { let hash = import_outcome.hash; // Read everything - let res = client.blobs.read_to_bytes(hash).await?; + let res = client.blobs().read_to_bytes(hash).await?; assert_eq!(&res, &buf[..]); // Read at smaller than blob_get_chunk_size - let res = client.blobs.read_at_to_bytes(hash, 0, Some(100)).await?; + let res = client.blobs().read_at_to_bytes(hash, 0, Some(100)).await?; assert_eq!(res.len(), 100); assert_eq!(&res[..], &buf[0..100]); - let res = client.blobs.read_at_to_bytes(hash, 20, Some(120)).await?; + let res = client.blobs().read_at_to_bytes(hash, 20, Some(120)).await?; assert_eq!(res.len(), 120); assert_eq!(&res[..], &buf[20..140]); // Read at equal to blob_get_chunk_size let res = client - .blobs + .blobs() .read_at_to_bytes(hash, 0, Some(1024 * 64)) .await?; assert_eq!(res.len(), 1024 * 64); assert_eq!(&res[..], &buf[0..1024 * 64]); let res = client - .blobs + .blobs() .read_at_to_bytes(hash, 20, Some(1024 * 64)) .await?; assert_eq!(res.len(), 1024 * 64); @@ -1026,26 +1060,26 @@ mod tests { // Read at larger than blob_get_chunk_size let res = client - .blobs + .blobs() .read_at_to_bytes(hash, 0, Some(10 + 1024 * 64)) .await?; assert_eq!(res.len(), 10 + 1024 * 64); assert_eq!(&res[..], &buf[0..(10 + 1024 * 64)]); let res = client - .blobs + .blobs() .read_at_to_bytes(hash, 20, Some(10 + 1024 * 64)) .await?; assert_eq!(res.len(), 10 + 1024 * 64); assert_eq!(&res[..], &buf[20..(20 + 10 + 1024 * 64)]); // full length - let res = client.blobs.read_at_to_bytes(hash, 20, None).await?; + let res = client.blobs().read_at_to_bytes(hash, 20, None).await?; assert_eq!(res.len(), 1024 * 128 - 20); assert_eq!(&res[..], &buf[20..]); // size should be total - let reader = client.blobs.read_at(hash, 0, Some(20)).await?; + let reader = client.blobs().read_at(hash, 0, Some(20)).await?; assert_eq!(reader.size(), 1024 * 128); assert_eq!(reader.response_size, 20); @@ -1087,7 +1121,7 @@ mod tests { // import files for path in &paths { let import_outcome = client - .blobs + .blobs() .add_from_path( path.to_path_buf(), false, @@ -1108,11 +1142,11 @@ mod tests { } let (hash, _tag) = client - .blobs + .blobs() .create_collection(collection, SetTagOption::Auto, tags) .await?; - let collection = client.blobs.get_collection(hash).await?; + let collection = client.blobs().get_collection(hash).await?; // 5 blobs assert_eq!(collection.len(), 5); @@ -1146,7 +1180,7 @@ mod tests { let client = node.client(); let import_outcome = client - .blobs + .blobs() .add_from_path( path.to_path_buf(), false, @@ -1160,12 +1194,12 @@ mod tests { .context("import finish")?; let ticket = client - .blobs + .blobs() .share(import_outcome.hash, BlobFormat::Raw, Default::default()) .await?; assert_eq!(ticket.hash(), import_outcome.hash); - let status = client.blobs.status(import_outcome.hash).await?; + let status = client.blobs().status(import_outcome.hash).await?; assert_eq!(status, BlobStatus::Complete { size }); Ok(()) diff --git a/iroh/src/client/docs.rs b/iroh/src/client/docs.rs index 2a35233eba..1b900a9463 100644 --- a/iroh/src/client/docs.rs +++ b/iroh/src/client/docs.rs @@ -22,6 +22,7 @@ use iroh_docs::{ use iroh_net::NodeAddr; use portable_atomic::{AtomicBool, Ordering}; use quic_rpc::{message::RpcMsg, RpcClient, ServiceConnection}; +use ref_cast::RefCast; use serde::{Deserialize, Serialize}; use crate::rpc_protocol::{ @@ -38,7 +39,8 @@ pub use iroh_docs::engine::{Origin, SyncEvent, SyncReason}; use super::{blobs, flatten}; /// Iroh docs client. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, RefCast)] +#[repr(transparent)] pub struct Client { pub(super) rpc: RpcClient, } @@ -768,7 +770,7 @@ mod tests { let node = crate::node::Node::memory().spawn().await?; let client = node.client(); - let doc = client.docs.create().await?; + let doc = client.docs().create().await?; let res = std::thread::spawn(move || { drop(doc); @@ -809,8 +811,8 @@ mod tests { // create doc & author let client = node.client(); - let doc = client.docs.create().await.context("doc create")?; - let author = client.authors.create().await.context("author create")?; + let doc = client.docs().create().await.context("doc create")?; + let author = client.authors().create().await.context("author create")?; // import file let import_outcome = doc diff --git a/iroh/src/client/gossip.rs b/iroh/src/client/gossip.rs index fd1c7614b1..0afdfd97e0 100644 --- a/iroh/src/client/gossip.rs +++ b/iroh/src/client/gossip.rs @@ -7,13 +7,15 @@ use futures_util::{Sink, SinkExt}; use iroh_gossip::proto::TopicId; use iroh_net::NodeId; use quic_rpc::{RpcClient, ServiceConnection}; +use ref_cast::RefCast; use crate::rpc_protocol::{GossipSubscribeRequest, GossipSubscribeResponse, GossipSubscribeUpdate}; use super::RpcService; /// Iroh gossip client. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, RefCast)] +#[repr(transparent)] pub struct Client { pub(super) rpc: RpcClient, } diff --git a/iroh/src/client/tags.rs b/iroh/src/client/tags.rs index c2d4309977..9c3ef34f12 100644 --- a/iroh/src/client/tags.rs +++ b/iroh/src/client/tags.rs @@ -4,12 +4,14 @@ use anyhow::Result; use futures_lite::{Stream, StreamExt}; use iroh_blobs::{BlobFormat, Hash, Tag}; use quic_rpc::{RpcClient, ServiceConnection}; +use ref_cast::RefCast; use serde::{Deserialize, Serialize}; use crate::rpc_protocol::{DeleteTagRequest, ListTagsRequest, RpcService}; /// Iroh tags client. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, RefCast)] +#[repr(transparent)] pub struct Client { pub(super) rpc: RpcClient, } @@ -20,7 +22,16 @@ where { /// List all tags. pub async fn list(&self) -> Result>> { - let stream = self.rpc.server_streaming(ListTagsRequest).await?; + let stream = self.rpc.server_streaming(ListTagsRequest::all()).await?; + Ok(stream.map(|res| res.map_err(anyhow::Error::from))) + } + + /// List all tags with a hash_seq format. + pub async fn list_hash_seq(&self) -> Result>> { + let stream = self + .rpc + .server_streaming(ListTagsRequest::hash_seq()) + .await?; Ok(stream.map(|res| res.map_err(anyhow::Error::from))) } diff --git a/iroh/src/node.rs b/iroh/src/node.rs index 9bd2c61f53..692b4b034d 100644 --- a/iroh/src/node.rs +++ b/iroh/src/node.rs @@ -15,23 +15,26 @@ use iroh_blobs::downloader::Downloader; use iroh_blobs::store::Store as BaoStore; use iroh_docs::engine::Engine; use iroh_gossip::dispatcher::GossipDispatcher; -use iroh_net::util::AbortingJoinHandle; -use iroh_net::{endpoint::LocalEndpointsStream, key::SecretKey, Endpoint}; +use iroh_gossip::net::Gossip; +use iroh_net::key::SecretKey; +use iroh_net::Endpoint; +use iroh_net::{endpoint::DirectAddrsStream, util::SharedAbortingJoinHandle}; use quic_rpc::transport::flume::FlumeConnection; use quic_rpc::RpcClient; -use tokio::task::JoinHandle; use tokio_util::sync::CancellationToken; use tokio_util::task::LocalPoolHandle; use tracing::debug; -use crate::client::RpcService; +use crate::{client::RpcService, node::protocol::ProtocolMap}; mod builder; +mod protocol; mod rpc; mod rpc_status; pub use self::builder::{Builder, DiscoveryConfig, GcPolicy, StorageConfig}; pub use self::rpc_status::RpcStatus; +pub use protocol::ProtocolHandler; /// A server which implements the iroh node. /// @@ -46,24 +49,24 @@ pub use self::rpc_status::RpcStatus; #[derive(Debug, Clone)] pub struct Node { inner: Arc>, - task: Arc>, client: crate::client::MemIroh, + task: SharedAbortingJoinHandle<()>, + protocols: Arc, } #[derive(derive_more::Debug)] struct NodeInner { db: D, + sync: DocsEngine, endpoint: Endpoint, + gossip: Gossip, secret_key: SecretKey, cancel_token: CancellationToken, controller: FlumeConnection, - #[allow(dead_code)] - gc_task: Option>, #[debug("rt")] rt: LocalPoolHandle, - pub(crate) sync: DocsEngine, - gossip: GossipDispatcher, downloader: Downloader, + gossip_dispatcher: GossipDispatcher, } /// In memory node. @@ -109,7 +112,7 @@ impl Node { /// can contact the node consider using [`Node::local_endpoint_addresses`]. However the /// port will always be the concrete port. pub fn local_address(&self) -> Vec { - let (v4, v6) = self.inner.endpoint.local_addr(); + let (v4, v6) = self.inner.endpoint.bound_sockets(); let mut addrs = vec![v4]; if let Some(v6) = v6 { addrs.push(v6); @@ -118,8 +121,8 @@ impl Node { } /// Lists the local endpoint of this node. - pub fn local_endpoints(&self) -> LocalEndpointsStream { - self.inner.endpoint.local_endpoints() + pub fn local_endpoints(&self) -> DirectAddrsStream { + self.inner.endpoint.direct_addresses() } /// Convenience method to get just the addr part of [`Node::local_endpoints`]. @@ -149,23 +152,24 @@ impl Node { /// Get the relay server we are connected to. pub fn my_relay(&self) -> Option { - self.inner.endpoint.my_relay() + self.inner.endpoint.home_relay() } - /// Aborts the node. + /// Shutdown the node. /// /// This does not gracefully terminate currently: all connections are closed and - /// anything in-transit is lost. The task will stop running. - /// If this is the last copy of the `Node`, this will finish once the task is - /// fully shutdown. + /// anything in-transit is lost. The shutdown behaviour will become more graceful + /// in the future. /// - /// The shutdown behaviour will become more graceful in the future. + /// Returns a future that completes once all tasks terminated and all resources are closed. + /// The future resolves to an error if the main task panicked. pub async fn shutdown(self) -> Result<()> { + // Trigger shutdown of the main run task by activating the cancel token. self.inner.cancel_token.cancel(); - if let Ok(task) = Arc::try_unwrap(self.task) { - task.await?; - } + // Wait for the main task to terminate. + self.task.await.map_err(|err| anyhow!(err))?; + Ok(()) } @@ -173,6 +177,14 @@ impl Node { pub fn cancel_token(&self) -> CancellationToken { self.inner.cancel_token.clone() } + + /// Returns a protocol handler for an ALPN. + /// + /// This downcasts to the concrete type and returns `None` if the handler registered for `alpn` + /// does not match the passed type. + pub fn get_protocol(&self, alpn: &[u8]) -> Option> { + self.protocols.get_typed(alpn) + } } impl std::ops::Deref for Node { @@ -187,7 +199,7 @@ impl NodeInner { async fn local_endpoint_addresses(&self) -> Result> { let endpoints = self .endpoint - .local_endpoints() + .direct_addresses() .next() .await .ok_or(anyhow!("no endpoints found"))?; @@ -230,7 +242,7 @@ mod tests { let node = Node::memory().spawn().await.unwrap(); let hash = node .client() - .blobs + .blobs() .add_bytes(Bytes::from_static(b"hello")) .await .unwrap() @@ -238,7 +250,7 @@ mod tests { let _drop_guard = node.cancel_token().drop_guard(); let ticket = node - .blobs + .blobs() .share(hash, BlobFormat::Raw, AddrInfoOptions::RelayAndAddresses) .await .unwrap(); @@ -257,10 +269,13 @@ mod tests { let client = node.client(); let input = vec![2u8; 1024 * 256]; // 265kb so actually streaming, chunk size is 64kb let reader = Cursor::new(input.clone()); - let progress = client.blobs.add_reader(reader, SetTagOption::Auto).await?; + let progress = client + .blobs() + .add_reader(reader, SetTagOption::Auto) + .await?; let outcome = progress.finish().await?; let hash = outcome.hash; - let output = client.blobs.read_to_bytes(hash).await?; + let output = client.blobs().read_to_bytes(hash).await?; assert_eq!(input, output.to_vec()); Ok(()) } @@ -314,13 +329,13 @@ mod tests { let iroh_root = tempfile::TempDir::new()?; { let iroh = Node::persistent(iroh_root.path()).await?.spawn().await?; - let doc = iroh.docs.create().await?; + let doc = iroh.docs().create().await?; drop(doc); iroh.shutdown().await?; } let iroh = Node::persistent(iroh_root.path()).await?.spawn().await?; - let _doc = iroh.docs.create().await?; + let _doc = iroh.docs().create().await?; Ok(()) } @@ -342,14 +357,14 @@ mod tests { .insecure_skip_relay_cert_verify(true) .spawn() .await?; - let AddOutcome { hash, .. } = node1.blobs.add_bytes(b"foo".to_vec()).await?; + let AddOutcome { hash, .. } = node1.blobs().add_bytes(b"foo".to_vec()).await?; // create a node addr with only a relay URL, no direct addresses let addr = NodeAddr::new(node1.node_id()).with_relay_url(relay_url); - node2.blobs.download(hash, addr).await?.await?; + node2.blobs().download(hash, addr).await?.await?; assert_eq!( node2 - .blobs + .blobs() .read_to_bytes(hash) .await .context("get")? @@ -385,14 +400,14 @@ mod tests { .node_discovery(dns_pkarr_server.discovery(secret2).into()) .spawn() .await?; - let hash = node1.blobs.add_bytes(b"foo".to_vec()).await?.hash; + let hash = node1.blobs().add_bytes(b"foo".to_vec()).await?.hash; // create a node addr with node id only let addr = NodeAddr::new(node1.node_id()); - node2.blobs.download(hash, addr).await?.await?; + node2.blobs().download(hash, addr).await?.await?; assert_eq!( node2 - .blobs + .blobs() .read_to_bytes(hash) .await .context("get")? @@ -405,13 +420,14 @@ mod tests { #[tokio::test] async fn test_default_author_memory() -> Result<()> { let iroh = Node::memory().spawn().await?; - let author = iroh.authors.default().await?; - assert!(iroh.authors.export(author).await?.is_some()); - assert!(iroh.authors.delete(author).await.is_err()); + let author = iroh.authors().default().await?; + assert!(iroh.authors().export(author).await?.is_some()); + assert!(iroh.authors().delete(author).await.is_err()); Ok(()) } #[cfg(feature = "fs-store")] + #[ignore = "flaky"] #[tokio::test] async fn test_default_author_persist() -> Result<()> { use crate::util::path::IrohPaths; @@ -429,9 +445,9 @@ mod tests { .spawn() .await .unwrap(); - let author = iroh.authors.default().await.unwrap(); - assert!(iroh.authors.export(author).await.unwrap().is_some()); - assert!(iroh.authors.delete(author).await.is_err()); + let author = iroh.authors().default().await.unwrap(); + assert!(iroh.authors().export(author).await.unwrap().is_some()); + assert!(iroh.authors().delete(author).await.is_err()); iroh.shutdown().await.unwrap(); author }; @@ -444,10 +460,10 @@ mod tests { .spawn() .await .unwrap(); - let author = iroh.authors.default().await.unwrap(); + let author = iroh.authors().default().await.unwrap(); assert_eq!(author, default_author); - assert!(iroh.authors.export(author).await.unwrap().is_some()); - assert!(iroh.authors.delete(author).await.is_err()); + assert!(iroh.authors().export(author).await.unwrap().is_some()); + assert!(iroh.authors().delete(author).await.is_err()); iroh.shutdown().await.unwrap(); }; @@ -463,10 +479,10 @@ mod tests { .spawn() .await .unwrap(); - let author = iroh.authors.default().await.unwrap(); + let author = iroh.authors().default().await.unwrap(); assert!(author != default_author); - assert!(iroh.authors.export(author).await.unwrap().is_some()); - assert!(iroh.authors.delete(author).await.is_err()); + assert!(iroh.authors().export(author).await.unwrap().is_some()); + assert!(iroh.authors().delete(author).await.is_err()); iroh.shutdown().await.unwrap(); author }; @@ -506,9 +522,9 @@ mod tests { .spawn() .await .unwrap(); - let author = iroh.authors.create().await.unwrap(); - iroh.authors.set_default(author).await.unwrap(); - assert_eq!(iroh.authors.default().await.unwrap(), author); + let author = iroh.authors().create().await.unwrap(); + iroh.authors().set_default(author).await.unwrap(); + assert_eq!(iroh.authors().default().await.unwrap(), author); iroh.shutdown().await.unwrap(); author }; @@ -519,7 +535,7 @@ mod tests { .spawn() .await .unwrap(); - assert_eq!(iroh.authors.default().await.unwrap(), default_author); + assert_eq!(iroh.authors().default().await.unwrap(), default_author); iroh.shutdown().await.unwrap(); } diff --git a/iroh/src/node/builder.rs b/iroh/src/node/builder.rs index 15be1313f4..b884ec0ef6 100644 --- a/iroh/src/node/builder.rs +++ b/iroh/src/node/builder.rs @@ -6,7 +6,7 @@ use std::{ time::Duration, }; -use anyhow::{bail, Context, Result}; +use anyhow::{Context, Result}; use futures_lite::StreamExt; use iroh_base::key::SecretKey; use iroh_blobs::{ @@ -27,23 +27,28 @@ use iroh_net::{ Endpoint, }; use quic_rpc::{ - transport::{misc::DummyServerEndpoint, quinn::QuinnServerEndpoint}, + transport::{ + flume::FlumeServerEndpoint, misc::DummyServerEndpoint, quinn::QuinnServerEndpoint, + }, RpcServer, ServiceEndpoint, }; use serde::{Deserialize, Serialize}; +use tokio::task::JoinSet; use tokio_util::{sync::CancellationToken, task::LocalPoolHandle}; use tracing::{debug, error, error_span, info, trace, warn, Instrument}; use crate::{ client::RPC_ALPN, + node::{ + protocol::{BlobsProtocol, ProtocolMap}, + ProtocolHandler, + }, rpc_protocol::RpcService, util::{fs::load_secret_key, path::IrohPaths}, }; use super::{rpc, rpc_status::RpcStatus, DocsEngine, Node, NodeInner}; -pub const PROTOCOLS: [&[u8]; 3] = [iroh_blobs::protocol::ALPN, GOSSIP_ALPN, DOCS_ALPN]; - /// Default bind address for the node. /// 11204 is "iroh" in leetspeak pub const DEFAULT_BIND_PORT: u16 = 11204; @@ -86,7 +91,7 @@ where gc_policy: GcPolicy, dns_resolver: Option, node_discovery: DiscoveryConfig, - docs_store: iroh_docs::store::fs::Store, + docs_store: iroh_docs::store::Store, #[cfg(any(test, feature = "test-utils"))] insecure_skip_relay_cert_verify: bool, /// Callback to register when a gc loop is done @@ -186,7 +191,9 @@ where tokio::fs::create_dir_all(&blob_dir).await?; let blobs_store = iroh_blobs::store::fs::Store::load(&blob_dir) .await - .with_context(|| format!("Failed to load iroh database from {}", blob_dir.display()))?; + .with_context(|| { + format!("Failed to load blobs database from {}", blob_dir.display()) + })?; let docs_store = iroh_docs::store::fs::Store::persistent(IrohPaths::DocsDatabase.with_root(root))?; @@ -369,20 +376,28 @@ where /// connections. The returned [`Node`] can be used to control the task as well as /// get information about it. pub async fn spawn(self) -> Result> { - // We clone the blob store to shut it down in case the node fails to spawn. + let unspawned_node = self.build().await?; + unspawned_node.spawn().await + } + + /// Build a node without spawning it. + /// + /// Returns an `ProtocolBuilder`, on which custom protocols can be registered with + /// [`ProtocolBuilder::accept`]. To spawn the node, call [`ProtocolBuilder::spawn`]. + pub async fn build(self) -> Result> { + // Clone the blob store to shutdown in case of error. let blobs_store = self.blobs_store.clone(); - match self.spawn_inner().await { + match self.build_inner().await { Ok(node) => Ok(node), Err(err) => { - debug!("failed to spawn node, shutting down"); blobs_store.shutdown().await; Err(err) } } } - async fn spawn_inner(mut self) -> Result> { - trace!("spawning node"); + async fn build_inner(self) -> Result> { + trace!("building node"); let lp = LocalPoolHandle::new(num_cpus::get()); let mut transport_config = quinn::TransportConfig::default(); @@ -407,7 +422,6 @@ where let endpoint = Endpoint::builder() .secret_key(self.secret_key.clone()) .proxy_from_env() - .alpns(PROTOCOLS.iter().map(|p| p.to_vec()).collect()) .keylog(self.keylog) .transport_config(transport_config) .concurrent_connections(MAX_CONNECTIONS) @@ -438,9 +452,7 @@ where let cancel_token = CancellationToken::new(); - debug!("rpc listening on: {:?}", self.rpc_endpoint.local_addr()); - - let addr = endpoint.my_addr().await?; + let addr = endpoint.node_addr().await?; // initialize the gossip protocol let gossip = Gossip::from_endpoint(endpoint.clone(), Default::default(), &addr.info); @@ -468,96 +480,56 @@ where ) .await?; let gossip_dispatcher = GossipDispatcher::new(gossip.clone()); - let sync_db = sync.sync.clone(); let sync = DocsEngine(sync); - let gc_task = if let GcPolicy::Interval(gc_period) = self.gc_policy { - tracing::info!("Starting GC task with interval {:?}", gc_period); - let db = self.blobs_store.clone(); - let gc_done_callback = self.gc_done_callback.take(); - - let task = - lp.spawn_pinned(move || Self::gc_loop(db, sync_db, gc_period, gc_done_callback)); - Some(task.into()) - } else { - None - }; + // Initialize the internal RPC connection. let (internal_rpc, controller) = quic_rpc::transport::flume::connection(1); let client = crate::client::Iroh::new(quic_rpc::RpcClient::new(controller.clone())); + debug!("rpc listening on: {:?}", self.rpc_endpoint.local_addr()); let inner = Arc::new(NodeInner { - db: self.blobs_store, + db: self.blobs_store.clone(), + sync, endpoint: endpoint.clone(), secret_key: self.secret_key, controller, cancel_token, - gc_task, - rt: lp.clone(), - sync, - gossip: gossip_dispatcher, + rt: lp, downloader, + gossip, + gossip_dispatcher, }); - let task = { - let gossip = gossip.clone(); - let handler = rpc::Handler { - inner: inner.clone(), - }; - let me = endpoint.node_id().fmt_short(); - let ep = endpoint.clone(); - tokio::task::spawn( - async move { - Self::run( - ep, - handler, - self.rpc_endpoint, - internal_rpc, - gossip, - ) - .await - } - .instrument(error_span!("node", %me)), - ) - }; - let node = Node { + let node = ProtocolBuilder { inner, - task: Arc::new(task), client, + protocols: Default::default(), + internal_rpc, + gc_policy: self.gc_policy, + gc_done_callback: self.gc_done_callback, + rpc_endpoint: self.rpc_endpoint, }; - // spawn a task that updates the gossip endpoints. - // TODO: track task - let mut stream = endpoint.local_endpoints(); - tokio::task::spawn(async move { - while let Some(eps) = stream.next().await { - if let Err(err) = gossip.update_endpoints(&eps) { - warn!("Failed to update gossip endpoints: {err:?}"); - } - } - warn!("failed to retrieve local endpoints"); - }); - - // Wait for a single endpoint update, to make sure - // we found some endpoints - tokio::time::timeout(ENDPOINT_WAIT, endpoint.local_endpoints().next()) - .await - .context("waiting for endpoint")? - .context("no endpoints")?; + let node = node.register_iroh_protocols(); Ok(node) } - #[allow(clippy::too_many_arguments)] async fn run( - server: Endpoint, - handler: rpc::Handler, + inner: Arc>, rpc: E, internal_rpc: impl ServiceEndpoint, - gossip: Gossip, + protocols: Arc, + mut join_set: JoinSet>, ) { + let endpoint = inner.endpoint.clone(); + + let handler = rpc::Handler { + inner: inner.clone(), + }; let rpc = RpcServer::new(rpc); let internal_rpc = RpcServer::new(internal_rpc); - let (ipv4, ipv6) = server.local_addr(); + let (ipv4, ipv6) = endpoint.bound_sockets(); debug!( "listening at: {}{}", ipv4, @@ -566,24 +538,19 @@ where let cancel_token = handler.inner.cancel_token.clone(); - // forward our initial endpoints to the gossip protocol + // forward the initial endpoints to the gossip protocol. // it may happen the the first endpoint update callback is missed because the gossip cell // is only initialized once the endpoint is fully bound - if let Some(local_endpoints) = server.local_endpoints().next().await { - debug!(me = ?server.node_id(), "gossip initial update: {local_endpoints:?}"); - gossip.update_endpoints(&local_endpoints).ok(); + if let Some(direct_addresses) = endpoint.direct_addresses().next().await { + debug!(me = ?endpoint.node_id(), "gossip initial update: {direct_addresses:?}"); + inner.gossip.update_direct_addresses(&direct_addresses).ok(); } + loop { tokio::select! { biased; _ = cancel_token.cancelled() => { - // clean shutdown of the blobs db to close the write transaction - handler.inner.db.shutdown().await; - - if let Err(err) = handler.inner.sync.shutdown().await { - warn!("sync shutdown error: {:?}", err); - } - break + break; }, // handle rpc requests. This will do nothing if rpc is not configured, since // accept is just a pending future. @@ -608,42 +575,49 @@ where } } }, - // handle incoming p2p connections - Some(mut connecting) = server.accept() => { - let alpn = match connecting.alpn().await { - Ok(alpn) => alpn, - Err(err) => { - error!("invalid handshake: {:?}", err); - continue; - } - }; - let gossip = gossip.clone(); - let inner = handler.inner.clone(); - let sync = handler.inner.sync.clone(); - tokio::task::spawn(async move { - if let Err(err) = handle_connection(connecting, alpn, inner, gossip, sync).await { - warn!("Handling incoming connection ended with error: {err}"); - } + // handle incoming p2p connections. + Some(connecting) = endpoint.accept() => { + let protocols = protocols.clone(); + join_set.spawn(async move { + handle_connection(connecting, protocols).await; + Ok(()) }); }, + // handle task terminations and quit on panics. + res = join_set.join_next(), if !join_set.is_empty() => { + if let Some(Err(err)) = res { + error!("Task failed: {err:?}"); + break; + } + }, else => break, } } - // Closing the Endpoint is the equivalent of calling Connection::close on all - // connections: Operations will immediately fail with - // ConnectionError::LocallyClosed. All streams are interrupted, this is not - // graceful. + // Shutdown the different parts of the node concurrently. let error_code = Closed::ProviderTerminating; - server - .close(error_code.into(), error_code.reason()) - .await - .ok(); + // We ignore all errors during shutdown. + let _ = tokio::join!( + // Close the endpoint. + // Closing the Endpoint is the equivalent of calling Connection::close on all + // connections: Operations will immediately fail with ConnectionError::LocallyClosed. + // All streams are interrupted, this is not graceful. + endpoint.close(error_code.into(), error_code.reason()), + // Shutdown sync engine. + inner.sync.shutdown(), + // Shutdown blobs store engine. + inner.db.shutdown(), + // Shutdown protocol handlers. + protocols.shutdown(), + ); + + // Abort remaining tasks. + join_set.shutdown().await; } async fn gc_loop( db: D, - ds: iroh_docs::actor::SyncHandle, + ds: DocsEngine, gc_period: Duration, done_cb: Option>, ) { @@ -660,7 +634,8 @@ where tokio::time::sleep(gc_period).await; tracing::debug!("Starting GC"); live.clear(); - let doc_hashes = match ds.content_hashes().await { + + let doc_hashes = match ds.sync.content_hashes().await { Ok(hashes) => hashes, Err(err) => { tracing::warn!("Error getting doc hashes: {}", err); @@ -720,6 +695,237 @@ where } } +/// A node that is initialized but not yet spawned. +/// +/// This is returned from [`Builder::build`] and may be used to register custom protocols with +/// [`Self::accept`]. It provides access to the services which are already started, the node's +/// endpoint and a client to the node. +/// +/// Note that RPC calls performed with client returned from [`Self::client`] will not complete +/// until the node is spawned. +#[derive(derive_more::Debug)] +pub struct ProtocolBuilder { + inner: Arc>, + client: crate::client::MemIroh, + internal_rpc: FlumeServerEndpoint, + rpc_endpoint: E, + protocols: ProtocolMap, + #[debug("callback")] + gc_done_callback: Option>, + gc_policy: GcPolicy, +} + +impl> ProtocolBuilder { + /// Register a protocol handler for incoming connections. + /// + /// Use this to register custom protocols onto the iroh node. Whenever a new connection for + /// `alpn` comes in, it is passed to this protocol handler. + /// + /// See the [`ProtocolHandler`] trait for details. + /// + /// Example usage: + /// + /// ```rust + /// # use std::sync::Arc; + /// # use anyhow::Result; + /// # use futures_lite::future::Boxed as BoxedFuture; + /// # use iroh::{node::{Node, ProtocolHandler}, net::endpoint::Connecting, client::MemIroh}; + /// # + /// # #[tokio::main] + /// # async fn main() -> Result<()> { + /// + /// const MY_ALPN: &[u8] = b"my-protocol/1"; + /// + /// #[derive(Debug)] + /// struct MyProtocol { + /// client: MemIroh + /// } + /// + /// impl ProtocolHandler for MyProtocol { + /// fn accept(self: Arc, conn: Connecting) -> BoxedFuture> { + /// todo!(); + /// } + /// } + /// + /// let unspawned_node = Node::memory() + /// .build() + /// .await?; + /// + /// let client = unspawned_node.client().clone(); + /// let handler = MyProtocol { client }; + /// + /// let node = unspawned_node + /// .accept(MY_ALPN, Arc::new(handler)) + /// .spawn() + /// .await?; + /// # node.shutdown().await?; + /// # Ok(()) + /// # } + /// ``` + /// + /// + pub fn accept(mut self, alpn: &'static [u8], handler: Arc) -> Self { + self.protocols.insert(alpn, handler); + self + } + + /// Return a client to control this node over an in-memory channel. + /// + /// Note that RPC calls performed with the client will not complete until the node is + /// spawned. + pub fn client(&self) -> &crate::client::MemIroh { + &self.client + } + + /// Returns the [`Endpoint`] of the node. + pub fn endpoint(&self) -> &Endpoint { + &self.inner.endpoint + } + + /// Returns the [`crate::blobs::store::Store`] used by the node. + pub fn blobs_db(&self) -> &D { + &self.inner.db + } + + /// Returns a reference to the used [`LocalPoolHandle`]. + pub fn local_pool_handle(&self) -> &LocalPoolHandle { + &self.inner.rt + } + + /// Returns a reference to the [`Downloader`] used by the node. + pub fn downloader(&self) -> &Downloader { + &self.inner.downloader + } + + /// Returns a reference to the [`Gossip`] handle used by the node. + pub fn gossip(&self) -> &Gossip { + &self.inner.gossip + } + + /// Returns a protocol handler for an ALPN. + /// + /// This downcasts to the concrete type and returns `None` if the handler registered for `alpn` + /// does not match the passed type. + pub fn get_protocol(&self, alpn: &[u8]) -> Option> { + self.protocols.get_typed(alpn) + } + + /// Register the core iroh protocols (blobs, gossip, docs). + fn register_iroh_protocols(mut self) -> Self { + // Register blobs. + let blobs_proto = + BlobsProtocol::new(self.blobs_db().clone(), self.local_pool_handle().clone()); + self = self.accept(iroh_blobs::protocol::ALPN, Arc::new(blobs_proto)); + + // Register gossip. + let gossip = self.gossip().clone(); + self = self.accept(GOSSIP_ALPN, Arc::new(gossip)); + + // Register docs. + let docs = self.inner.sync.clone(); + self = self.accept(DOCS_ALPN, Arc::new(docs)); + + self + } + + /// Spawn the node and start accepting connections. + pub async fn spawn(self) -> Result> { + let Self { + inner, + client, + internal_rpc, + rpc_endpoint, + protocols, + gc_done_callback, + gc_policy, + } = self; + let protocols = Arc::new(protocols); + let protocols_clone = protocols.clone(); + + // Create the actual spawn future in an async block so that we can shutdown the protocols in case of + // error. + let node_fut = async move { + let mut join_set = JoinSet::new(); + + // Spawn a task for the garbage collection. + if let GcPolicy::Interval(gc_period) = gc_policy { + tracing::info!("Starting GC task with interval {:?}", gc_period); + let lp = inner.rt.clone(); + let docs = inner.sync.clone(); + let blobs_store = inner.db.clone(); + let handle = lp.spawn_pinned(move || { + Builder::::gc_loop(blobs_store, docs, gc_period, gc_done_callback) + }); + // We cannot spawn tasks that run on the local pool directly into the join set, + // so instead we create a new task that supervises the local task. + join_set.spawn(async move { + if let Err(err) = handle.await { + return Err(anyhow::Error::from(err)); + } + Ok(()) + }); + } + + // Spawn a task that updates the gossip endpoints. + let mut stream = inner.endpoint.direct_addresses(); + let gossip = inner.gossip.clone(); + join_set.spawn(async move { + while let Some(eps) = stream.next().await { + if let Err(err) = gossip.update_direct_addresses(&eps) { + warn!("Failed to update direct addresses for gossip: {err:?}"); + } + } + warn!("failed to retrieve local endpoints"); + Ok(()) + }); + + // Update the endpoint with our alpns. + let alpns = protocols + .alpns() + .map(|alpn| alpn.to_vec()) + .collect::>(); + inner.endpoint.set_alpns(alpns)?; + + // Spawn the main task and store it in the node for structured termination in shutdown. + let task = tokio::task::spawn( + Builder::run( + inner.clone(), + rpc_endpoint, + internal_rpc, + protocols.clone(), + join_set, + ) + .instrument(error_span!("node", me=%inner.endpoint.node_id().fmt_short())), + ); + + let node = Node { + inner, + client, + protocols, + task: task.into(), + }; + + // Wait for a single endpoint update, to make sure + // we found some endpoints + tokio::time::timeout(ENDPOINT_WAIT, node.endpoint().direct_addresses().next()) + .await + .context("waiting for endpoint")? + .context("no endpoints")?; + + Ok(node) + }; + + match node_fut.await { + Ok(node) => Ok(node), + Err(err) => { + // Shutdown the protocols in case of error. + protocols_clone.shutdown().await; + Err(err) + } + } + } +} + /// Policy for garbage collection. #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] pub enum GcPolicy { @@ -735,31 +941,24 @@ impl Default for GcPolicy { } } -// TODO: Restructure this code to not take all these arguments. -#[allow(clippy::too_many_arguments)] -async fn handle_connection( - connecting: iroh_net::endpoint::Connecting, - alpn: String, - node: Arc>, - gossip: Gossip, - sync: DocsEngine, -) -> Result<()> { - match alpn.as_bytes() { - GOSSIP_ALPN => gossip.handle_connection(connecting.await?).await?, - DOCS_ALPN => sync.handle_connection(connecting).await?, - alpn if alpn == iroh_blobs::protocol::ALPN => { - let connection = connecting.await?; - iroh_blobs::provider::handle_connection( - connection, - node.db.clone(), - MockEventSender, - node.rt.clone(), - ) - .await +async fn handle_connection( + mut connecting: iroh_net::endpoint::Connecting, + protocols: Arc, +) { + let alpn = match connecting.alpn().await { + Ok(alpn) => alpn, + Err(err) => { + warn!("Ignoring connection: invalid handshake: {:?}", err); + return; } - _ => bail!("ignoring connection: unsupported ALPN protocol"), + }; + let Some(handler) = protocols.get(&alpn) else { + warn!("Ignoring connection: unsupported ALPN protocol"); + return; + }; + if let Err(err) = handler.accept(connecting).await { + warn!("Handling incoming connection ended with error: {err}"); } - Ok(()) } const DEFAULT_RPC_PORT: u16 = 0x1337; @@ -779,7 +978,7 @@ fn make_rpc_endpoint( let mut server_config = iroh_net::endpoint::make_server_config( secret_key, vec![RPC_ALPN.to_vec()], - Some(transport_config), + Arc::new(transport_config), false, )?; server_config.concurrent_connections(MAX_RPC_CONNECTIONS); @@ -809,12 +1008,3 @@ fn make_rpc_endpoint( Ok((rpc_endpoint, actual_rpc_port)) } - -#[derive(Debug, Clone)] -struct MockEventSender; - -impl iroh_blobs::provider::EventSender for MockEventSender { - fn send(&self, _event: iroh_blobs::provider::Event) -> futures_lite::future::Boxed<()> { - Box::pin(std::future::ready(())) - } -} diff --git a/iroh/src/node/protocol.rs b/iroh/src/node/protocol.rs new file mode 100644 index 0000000000..25106e7c38 --- /dev/null +++ b/iroh/src/node/protocol.rs @@ -0,0 +1,127 @@ +use std::{any::Any, collections::BTreeMap, fmt, sync::Arc}; + +use anyhow::Result; +use futures_lite::future::Boxed as BoxedFuture; +use futures_util::future::join_all; +use iroh_net::endpoint::Connecting; + +use crate::node::DocsEngine; + +/// Handler for incoming connections. +/// +/// An iroh node can accept connections for arbitrary ALPN protocols. By default, the iroh node +/// only accepts connections for the ALPNs of the core iroh protocols (blobs, gossip, docs). +/// +/// With this trait, you can handle incoming connections for custom protocols. +/// +/// Implement this trait on a struct that should handle incoming connections. +/// The protocol handler must then be registered on the node for an ALPN protocol with +/// [`crate::node::builder::ProtocolBuilder::accept`]. +pub trait ProtocolHandler: Send + Sync + IntoArcAny + fmt::Debug + 'static { + /// Handle an incoming connection. + /// + /// This runs on a freshly spawned tokio task so this can be long-running. + fn accept(self: Arc, conn: Connecting) -> BoxedFuture>; + + /// Called when the node shuts down. + fn shutdown(self: Arc) -> BoxedFuture<()> { + Box::pin(async move {}) + } +} + +/// Helper trait to facilite casting from `Arc` to `Arc`. +/// +/// This trait has a blanket implementation so there is no need to implement this yourself. +pub trait IntoArcAny { + fn into_arc_any(self: Arc) -> Arc; +} + +impl IntoArcAny for T { + fn into_arc_any(self: Arc) -> Arc { + self + } +} + +#[derive(Debug, Clone, Default)] +pub(super) struct ProtocolMap(BTreeMap<&'static [u8], Arc>); + +impl ProtocolMap { + /// Returns the registered protocol handler for an ALPN as a concrete type. + pub fn get_typed(&self, alpn: &[u8]) -> Option> { + let protocol: Arc = self.0.get(alpn)?.clone(); + let protocol_any: Arc = protocol.into_arc_any(); + let protocol_ref = Arc::downcast(protocol_any).ok()?; + Some(protocol_ref) + } + + /// Returns the registered protocol handler for an ALPN as a [`Arc`]. + pub fn get(&self, alpn: &[u8]) -> Option> { + self.0.get(alpn).cloned() + } + + /// Insert a protocol handler. + pub fn insert(&mut self, alpn: &'static [u8], handler: Arc) { + self.0.insert(alpn, handler); + } + + /// Returns an iterator of all registered ALPN protocol identifiers. + pub fn alpns(&self) -> impl Iterator { + self.0.keys() + } + + /// Shutdown all protocol handlers. + /// + /// Calls and awaits [`ProtocolHandler::shutdown`] for all registered handlers concurrently. + pub async fn shutdown(&self) { + let handlers = self.0.values().cloned().map(ProtocolHandler::shutdown); + join_all(handlers).await; + } +} + +#[derive(Debug)] +pub(crate) struct BlobsProtocol { + rt: tokio_util::task::LocalPoolHandle, + store: S, +} + +impl BlobsProtocol { + pub fn new(store: S, rt: tokio_util::task::LocalPoolHandle) -> Self { + Self { rt, store } + } +} + +impl ProtocolHandler for BlobsProtocol { + fn accept(self: Arc, conn: Connecting) -> BoxedFuture> { + Box::pin(async move { + iroh_blobs::provider::handle_connection( + conn.await?, + self.store.clone(), + MockEventSender, + self.rt.clone(), + ) + .await; + Ok(()) + }) + } +} + +#[derive(Debug, Clone)] +struct MockEventSender; + +impl iroh_blobs::provider::EventSender for MockEventSender { + fn send(&self, _event: iroh_blobs::provider::Event) -> futures_lite::future::Boxed<()> { + Box::pin(std::future::ready(())) + } +} + +impl ProtocolHandler for iroh_gossip::net::Gossip { + fn accept(self: Arc, conn: Connecting) -> BoxedFuture> { + Box::pin(async move { self.handle_connection(conn.await?).await }) + } +} + +impl ProtocolHandler for DocsEngine { + fn accept(self: Arc, conn: Connecting) -> BoxedFuture> { + Box::pin(async move { self.handle_connection(conn).await }) + } +} diff --git a/iroh/src/node/rpc.rs b/iroh/src/node/rpc.rs index 2dc56078ce..51b096c37f 100644 --- a/iroh/src/node/rpc.rs +++ b/iroh/src/node/rpc.rs @@ -17,7 +17,6 @@ use iroh_blobs::store::{ConsistencyCheckProgress, ExportFormat, ImportProgress, use iroh_blobs::util::progress::ProgressSender; use iroh_blobs::BlobFormat; use iroh_blobs::{ - hashseq::parse_hash_seq, provider::AddProgress, store::{Store as BaoStore, ValidateProgress}, util::progress::FlumeProgressSender, @@ -33,16 +32,13 @@ use quic_rpc::{ use tokio_util::task::LocalPoolHandle; use tracing::{debug, info}; -use crate::client::blobs::{ - BlobInfo, CollectionInfo, DownloadMode, IncompleteBlobInfo, WrapOption, -}; +use crate::client::blobs::{BlobInfo, DownloadMode, IncompleteBlobInfo, WrapOption}; use crate::client::tags::TagInfo; use crate::client::NodeStatus; use crate::rpc_protocol::{ BlobAddPathRequest, BlobAddPathResponse, BlobAddStreamRequest, BlobAddStreamResponse, BlobAddStreamUpdate, BlobConsistencyCheckRequest, BlobDeleteBlobRequest, BlobDownloadRequest, - BlobDownloadResponse, BlobExportRequest, BlobExportResponse, BlobGetCollectionRequest, - BlobGetCollectionResponse, BlobListCollectionsRequest, BlobListIncompleteRequest, + BlobDownloadResponse, BlobExportRequest, BlobExportResponse, BlobListIncompleteRequest, BlobListRequest, BlobReadAtRequest, BlobReadAtResponse, BlobValidateRequest, CreateCollectionRequest, CreateCollectionResponse, DeleteTagRequest, DocExportFileRequest, DocExportFileResponse, DocImportFileRequest, DocImportFileResponse, DocSetHashRequest, @@ -61,6 +57,8 @@ const HEALTH_POLL_WAIT: Duration = Duration::from_secs(1); const RPC_BLOB_GET_CHUNK_SIZE: usize = 1024 * 64; /// Channel cap for getting blobs over RPC const RPC_BLOB_GET_CHANNEL_CAP: usize = 2; +/// Name used for logging when new node addresses are added from gossip. +const BLOB_DOWNLOAD_SOURCE_NAME: &str = "blob_download"; #[derive(Debug, Clone)] pub(crate) struct Handler { @@ -95,12 +93,7 @@ impl Handler { chan.server_streaming(msg, handler, Self::blob_list_incomplete) .await } - BlobListCollections(msg) => { - chan.server_streaming(msg, handler, Self::blob_list_collections) - .await - } CreateCollection(msg) => chan.rpc(msg, handler, Self::create_collection).await, - BlobGetCollection(msg) => chan.rpc(msg, handler, Self::blob_get_collection).await, ListTags(msg) => { chan.server_streaming(msg, handler, Self::blob_list_tags) .await @@ -300,7 +293,7 @@ impl Handler { } GossipSubscribe(msg) => { chan.bidi_streaming(msg, handler, |handler, req, updates| { - handler.inner.gossip.subscribe_with_opts( + handler.inner.gossip_dispatcher.subscribe_with_opts( req.topic, iroh_gossip::dispatcher::SubscribeOptions { bootstrap: req.bootstrap, @@ -362,39 +355,6 @@ impl Handler { Ok(()) } - async fn blob_list_collections_impl( - self, - co: &Co>, - ) -> anyhow::Result<()> { - let db = self.inner.db.clone(); - let local = self.inner.rt.clone(); - let tags = db.tags().await.unwrap(); - for item in tags { - let (name, HashAndFormat { hash, format }) = item?; - if !format.is_hash_seq() { - continue; - } - let Some(entry) = db.get(&hash).await? else { - continue; - }; - let count = local - .spawn_pinned(|| async move { - let reader = entry.data_reader().await?; - let (_collection, count) = parse_hash_seq(reader).await?; - anyhow::Ok(count) - }) - .await??; - co.yield_(Ok(CollectionInfo { - tag: name, - hash, - total_blobs_count: Some(count), - total_blobs_size: None, - })) - .await; - } - Ok(()) - } - fn blob_list( self, _msg: BlobListRequest, @@ -417,17 +377,6 @@ impl Handler { }) } - fn blob_list_collections( - self, - _msg: BlobListCollectionsRequest, - ) -> impl Stream> + Send + 'static { - Gen::new(move |co| async move { - if let Err(e) = self.blob_list_collections_impl(&co).await { - co.yield_(Err(e.into())).await; - } - }) - } - async fn blob_delete_tag(self, msg: DeleteTagRequest) -> RpcResult<()> { self.inner.db.set_tag(msg.name, None).await?; Ok(()) @@ -438,15 +387,16 @@ impl Handler { Ok(()) } - fn blob_list_tags(self, _msg: ListTagsRequest) -> impl Stream + Send + 'static { + fn blob_list_tags(self, msg: ListTagsRequest) -> impl Stream + Send + 'static { tracing::info!("blob_list_tags"); Gen::new(|co| async move { let tags = self.inner.db.tags().await.unwrap(); #[allow(clippy::manual_flatten)] for item in tags { if let Ok((name, HashAndFormat { hash, format })) = item { - tracing::info!("{:?} {} {:?}", name, hash, format); - co.yield_(TagInfo { name, hash, format }).await; + if (format.is_raw() && msg.raw) || (format.is_hash_seq() && msg.hash_seq) { + co.yield_(TagInfo { name, hash, format }).await; + } } } }) @@ -807,7 +757,7 @@ impl Handler { async fn node_status(self, _: NodeStatusRequest) -> RpcResult { Ok(NodeStatus { - addr: self.inner.endpoint.my_addr().await?, + addr: self.inner.endpoint.node_addr().await?, listen_addrs: self .inner .local_endpoint_addresses() @@ -823,13 +773,13 @@ impl Handler { } async fn node_addr(self, _: NodeAddrRequest) -> RpcResult { - let addr = self.inner.endpoint.my_addr().await?; + let addr = self.inner.endpoint.node_addr().await?; Ok(addr) } #[allow(clippy::unused_async)] async fn node_relay(self, _: NodeRelayRequest) -> RpcResult> { - Ok(self.inner.endpoint.my_relay()) + Ok(self.inner.endpoint.home_relay()) } #[allow(clippy::unused_async)] @@ -1058,21 +1008,6 @@ impl Handler { Ok(CreateCollectionResponse { hash, tag }) } - - async fn blob_get_collection( - self, - req: BlobGetCollectionRequest, - ) -> RpcResult { - let hash = req.hash; - let db = self.inner.db.clone(); - let collection = self - .rt() - .spawn_pinned(move || async move { Collection::load(&db, &hash).await }) - .await - .map_err(|_| anyhow!("join failed"))??; - - Ok(BlobGetCollectionResponse { collection }) - } } async fn download( @@ -1093,6 +1028,7 @@ where mode, } = req; let hash_and_format = HashAndFormat { hash, format }; + let temp_tag = db.temp_tag(hash_and_format); let stats = match mode { DownloadMode::Queued => { download_queued( @@ -1100,18 +1036,26 @@ where downloader, hash_and_format, nodes, - tag, progress.clone(), ) .await? } DownloadMode::Direct => { - download_direct_from_nodes(db, endpoint, hash_and_format, nodes, tag, progress.clone()) + download_direct_from_nodes(db, endpoint, hash_and_format, nodes, progress.clone()) .await? } }; progress.send(DownloadProgress::AllDone(stats)).await.ok(); + match tag { + SetTagOption::Named(tag) => { + db.set_tag(tag, Some(hash_and_format)).await?; + } + SetTagOption::Auto => { + db.create_tag(hash_and_format).await?; + } + } + drop(temp_tag); Ok(()) } @@ -1121,17 +1065,20 @@ async fn download_queued( downloader: &Downloader, hash_and_format: HashAndFormat, nodes: Vec, - tag: SetTagOption, progress: FlumeProgressSender, ) -> Result { let mut node_ids = Vec::with_capacity(nodes.len()); + let mut any_added = false; for node in nodes { node_ids.push(node.node_id); - endpoint.add_node_addr(node)?; + if !node.info.is_empty() { + endpoint.add_node_addr_with_source(node, BLOB_DOWNLOAD_SOURCE_NAME)?; + any_added = true; + } } - let req = DownloadRequest::new(hash_and_format, node_ids) - .progress_sender(progress) - .tag(tag); + let can_download = !node_ids.is_empty() && (any_added || endpoint.discovery().is_some()); + anyhow::ensure!(can_download, "no way to reach a node for download"); + let req = DownloadRequest::new(hash_and_format, node_ids).progress_sender(progress); let handle = downloader.queue(req).await; let stats = handle.await?; Ok(stats) @@ -1142,7 +1089,6 @@ async fn download_direct_from_nodes( endpoint: Endpoint, hash_and_format: HashAndFormat, nodes: Vec, - tag: SetTagOption, progress: FlumeProgressSender, ) -> Result where @@ -1157,7 +1103,6 @@ where endpoint.clone(), hash_and_format, node, - tag.clone(), progress.clone(), ) .await @@ -1177,13 +1122,11 @@ async fn download_direct( endpoint: Endpoint, hash_and_format: HashAndFormat, node: NodeAddr, - tag: SetTagOption, progress: FlumeProgressSender, ) -> Result where D: BaoStore, { - let temp_pin = db.temp_tag(hash_and_format); let get_conn = { let progress = progress.clone(); move || async move { @@ -1195,18 +1138,5 @@ where let res = iroh_blobs::get::db::get_to_db(db, get_conn, &hash_and_format, progress).await; - if res.is_ok() { - match tag { - SetTagOption::Named(tag) => { - db.set_tag(tag, Some(hash_and_format)).await?; - } - SetTagOption::Auto => { - db.create_tag(hash_and_format).await?; - } - } - } - - drop(temp_pin); - res.map_err(Into::into) } diff --git a/iroh/src/node/rpc/docs.rs b/iroh/src/node/rpc/docs.rs index a0433a803e..4fbabf64ff 100644 --- a/iroh/src/node/rpc/docs.rs +++ b/iroh/src/node/rpc/docs.rs @@ -146,7 +146,7 @@ impl DocsEngine { mode, addr_options, } = req; - let mut me = self.endpoint.my_addr().await?; + let mut me = self.endpoint.node_addr().await?; me.apply_options(addr_options); let capability = match mode { diff --git a/iroh/src/rpc_protocol.rs b/iroh/src/rpc_protocol.rs index b6604c68ea..a11aff72be 100644 --- a/iroh/src/rpc_protocol.rs +++ b/iroh/src/rpc_protocol.rs @@ -48,7 +48,7 @@ pub use iroh_blobs::{provider::AddProgress, store::ValidateProgress}; use iroh_docs::engine::LiveEvent; use crate::client::{ - blobs::{BlobInfo, CollectionInfo, DownloadMode, IncompleteBlobInfo, WrapOption}, + blobs::{BlobInfo, DownloadMode, IncompleteBlobInfo, WrapOption}, docs::{ImportProgress, ShareMode}, tags::TagInfo, NodeStatus, @@ -211,22 +211,39 @@ impl ServerStreamingMsg for BlobListIncompleteRequest { /// /// Lists all collections that have been explicitly added to the database. #[derive(Debug, Serialize, Deserialize)] -pub struct BlobListCollectionsRequest; - -impl Msg for BlobListCollectionsRequest { - type Pattern = ServerStreaming; +pub struct ListTagsRequest { + /// List raw tags + pub raw: bool, + /// List hash seq tags + pub hash_seq: bool, +} + +impl ListTagsRequest { + /// List all tags + pub fn all() -> Self { + Self { + raw: true, + hash_seq: true, + } + } + + /// List raw tags + pub fn raw() -> Self { + Self { + raw: true, + hash_seq: false, + } + } + + /// List hash seq tags + pub fn hash_seq() -> Self { + Self { + raw: false, + hash_seq: true, + } + } } -impl ServerStreamingMsg for BlobListCollectionsRequest { - type Response = RpcResult; -} - -/// List all collections -/// -/// Lists all collections that have been explicitly added to the database. -#[derive(Debug, Serialize, Deserialize)] -pub struct ListTagsRequest; - impl Msg for ListTagsRequest { type Pattern = ServerStreaming; } @@ -256,25 +273,6 @@ pub struct DeleteTagRequest { impl RpcMsg for DeleteTagRequest { type Response = RpcResult<()>; } - -/// Get a collection -#[derive(Debug, Serialize, Deserialize)] -pub struct BlobGetCollectionRequest { - /// Hash of the collection - pub hash: Hash, -} - -impl RpcMsg for BlobGetCollectionRequest { - type Response = RpcResult; -} - -/// The response for a `BlobGetCollectionRequest`. -#[derive(Debug, Serialize, Deserialize)] -pub struct BlobGetCollectionResponse { - /// The collection. - pub collection: Collection, -} - /// Create a collection. #[derive(Debug, Serialize, Deserialize)] pub struct CreateCollectionRequest { @@ -1091,12 +1089,10 @@ pub enum Request { BlobExport(BlobExportRequest), BlobList(BlobListRequest), BlobListIncomplete(BlobListIncompleteRequest), - BlobListCollections(BlobListCollectionsRequest), BlobDeleteBlob(BlobDeleteBlobRequest), BlobValidate(BlobValidateRequest), BlobFsck(BlobConsistencyCheckRequest), CreateCollection(CreateCollectionRequest), - BlobGetCollection(BlobGetCollectionRequest), DeleteTag(DeleteTagRequest), ListTags(ListTagsRequest), @@ -1154,13 +1150,11 @@ pub enum Response { BlobAddPath(BlobAddPathResponse), BlobList(RpcResult), BlobListIncomplete(RpcResult), - BlobListCollections(RpcResult), BlobDownload(BlobDownloadResponse), BlobFsck(ConsistencyCheckProgress), BlobExport(BlobExportResponse), BlobValidate(ValidateProgress), CreateCollection(RpcResult), - BlobGetCollection(RpcResult), ListTags(TagInfo), DeleteTag(RpcResult<()>), diff --git a/iroh/tests/gc.rs b/iroh/tests/gc.rs index 4c3c3fc26f..dcca0893b5 100644 --- a/iroh/tests/gc.rs +++ b/iroh/tests/gc.rs @@ -232,8 +232,8 @@ mod file { let bao_store = iroh_blobs::store::fs::Store::load(dir.join("store")).await?; let (node, _) = wrap_in_node(bao_store.clone(), Duration::from_secs(10)).await; let client = node.client(); - let doc = client.docs.create().await?; - let author = client.authors.create().await?; + let doc = client.docs().create().await?; + let author = client.authors().create().await?; let temp_path = dir.join("temp"); tokio::fs::create_dir_all(&temp_path).await?; let mut to_import = Vec::new(); diff --git a/iroh/tests/provide.rs b/iroh/tests/provide.rs index a4f005fe58..13376273dd 100644 --- a/iroh/tests/provide.rs +++ b/iroh/tests/provide.rs @@ -391,7 +391,7 @@ async fn test_run_ticket() { let _drop_guard = node.cancel_token().drop_guard(); let ticket = node - .blobs + .blobs() .share( hash, BlobFormat::HashSeq, diff --git a/iroh/tests/sync.rs b/iroh/tests/sync.rs index 556f5829a7..a5e9b8a463 100644 --- a/iroh/tests/sync.rs +++ b/iroh/tests/sync.rs @@ -85,8 +85,8 @@ async fn sync_simple() -> Result<()> { // create doc on node0 let peer0 = nodes[0].node_id(); - let author0 = clients[0].authors.create().await?; - let doc0 = clients[0].docs.create().await?; + let author0 = clients[0].authors().create().await?; + let doc0 = clients[0].docs().create().await?; let hash0 = doc0 .set_bytes(author0, b"k1".to_vec(), b"v1".to_vec()) .await?; @@ -99,7 +99,7 @@ async fn sync_simple() -> Result<()> { info!("node1: join"); let peer1 = nodes[1].node_id(); - let doc1 = clients[1].docs.import(ticket.clone()).await?; + let doc1 = clients[1].docs().import(ticket.clone()).await?; let mut events1 = doc1.subscribe().await?; info!("node1: assert 4 events"); assert_next_unordered( @@ -140,9 +140,9 @@ async fn sync_subscribe_no_sync() -> Result<()> { setup_logging(); let node = spawn_node(0, &mut rng).await?; let client = node.client(); - let doc = client.docs.create().await?; + let doc = client.docs().create().await?; let mut sub = doc.subscribe().await?; - let author = client.authors.create().await?; + let author = client.authors().create().await?; doc.set_bytes(author, b"k".to_vec(), b"v".to_vec()).await?; let event = tokio::time::timeout(Duration::from_millis(100), sub.next()).await?; assert!( @@ -165,15 +165,15 @@ async fn sync_gossip_bulk() -> Result<()> { let clients = nodes.iter().map(|node| node.client()).collect::>(); let _peer0 = nodes[0].node_id(); - let author0 = clients[0].authors.create().await?; - let doc0 = clients[0].docs.create().await?; + let author0 = clients[0].authors().create().await?; + let doc0 = clients[0].docs().create().await?; let mut ticket = doc0 .share(ShareMode::Write, AddrInfoOptions::RelayAndAddresses) .await?; // unset peers to not yet start sync let peers = ticket.nodes.clone(); ticket.nodes = vec![]; - let doc1 = clients[1].docs.import(ticket).await?; + let doc1 = clients[1].docs().import(ticket).await?; let mut events = doc1.subscribe().await?; // create entries for initial sync. @@ -255,8 +255,8 @@ async fn sync_full_basic() -> Result<()> { // peer0: create doc and ticket let peer0 = nodes[0].node_id(); - let author0 = clients[0].authors.create().await?; - let doc0 = clients[0].docs.create().await?; + let author0 = clients[0].authors().create().await?; + let doc0 = clients[0].docs().create().await?; let mut events0 = doc0.subscribe().await?; let key0 = b"k1"; let value0 = b"v1"; @@ -277,9 +277,9 @@ async fn sync_full_basic() -> Result<()> { info!("peer1: spawn"); let peer1 = nodes[1].node_id(); - let author1 = clients[1].authors.create().await?; + let author1 = clients[1].authors().create().await?; info!("peer1: join doc"); - let doc1 = clients[1].docs.import(ticket.clone()).await?; + let doc1 = clients[1].docs().import(ticket.clone()).await?; info!("peer1: wait for 4 events (for sync and join with peer0)"); let mut events1 = doc1.subscribe().await?; @@ -345,7 +345,7 @@ async fn sync_full_basic() -> Result<()> { info!("peer2: spawn"); nodes.push(spawn_node(nodes.len(), &mut rng).await?); clients.push(nodes.last().unwrap().client().clone()); - let doc2 = clients[2].docs.import(ticket).await?; + let doc2 = clients[2].docs().import(ticket).await?; let peer2 = nodes[2].node_id(); let mut events2 = doc2.subscribe().await?; @@ -428,11 +428,11 @@ async fn sync_open_close() -> Result<()> { let node = spawn_node(0, &mut rng).await?; let client = node.client(); - let doc = client.docs.create().await?; + let doc = client.docs().create().await?; let status = doc.status().await?; assert_eq!(status.handles, 1); - let doc2 = client.docs.open(doc.id()).await?.unwrap(); + let doc2 = client.docs().open(doc.id()).await?.unwrap(); let status = doc2.status().await?; assert_eq!(status.handles, 2); @@ -452,8 +452,8 @@ async fn sync_subscribe_stop_close() -> Result<()> { let node = spawn_node(0, &mut rng).await?; let client = node.client(); - let doc = client.docs.create().await?; - let author = client.authors.create().await?; + let doc = client.docs().create().await?; + let author = client.authors().create().await?; let status = doc.status().await?; assert_eq!(status.subscribers, 0); @@ -504,8 +504,8 @@ async fn test_sync_via_relay() -> Result<()> { .spawn() .await?; - let doc1 = node1.docs.create().await?; - let author1 = node1.authors.create().await?; + let doc1 = node1.docs().create().await?; + let author1 = node1.authors().create().await?; let inserted_hash = doc1 .set_bytes(author1, b"foo".to_vec(), b"bar".to_vec()) .await?; @@ -517,7 +517,7 @@ async fn test_sync_via_relay() -> Result<()> { ticket.nodes[0].info.direct_addresses = Default::default(); // join - let doc2 = node2.docs.import(ticket).await?; + let doc2 = node2.docs().import(ticket).await?; let mut events = doc2.subscribe().await?; assert_next_unordered_with_optionals( @@ -598,7 +598,7 @@ async fn sync_restart_node() -> Result<()> { let id1 = node1.node_id(); // create doc & ticket on node1 - let doc1 = node1.docs.create().await?; + let doc1 = node1.docs().create().await?; let mut events1 = doc1.subscribe().await?; let ticket = doc1 .share(ShareMode::Write, AddrInfoOptions::RelayAndAddresses) @@ -615,8 +615,8 @@ async fn sync_restart_node() -> Result<()> { .spawn() .await?; let id2 = node2.node_id(); - let author2 = node2.authors.create().await?; - let doc2 = node2.docs.import(ticket.clone()).await?; + let author2 = node2.authors().create().await?; + let doc2 = node2.docs().import(ticket.clone()).await?; info!("node2 set a"); let hash_a = doc2.set_bytes(author2, "n2/a", "a").await?; @@ -662,7 +662,7 @@ async fn sync_restart_node() -> Result<()> { .await?; assert_eq!(id1, node1.node_id()); - let doc1 = node1.docs.open(doc1.id()).await?.expect("doc to exist"); + let doc1 = node1.docs().open(doc1.id()).await?.expect("doc to exist"); let mut events1 = doc1.subscribe().await?; assert_latest(&doc1, b"n2/a", b"a").await; @@ -748,14 +748,14 @@ async fn test_download_policies() -> Result<()> { let nodes = spawn_nodes(2, &mut rng).await?; let clients = nodes.iter().map(|node| node.client()).collect::>(); - let doc_a = clients[0].docs.create().await?; - let author_a = clients[0].authors.create().await?; + let doc_a = clients[0].docs().create().await?; + let author_a = clients[0].authors().create().await?; let ticket = doc_a .share(ShareMode::Write, AddrInfoOptions::RelayAndAddresses) .await?; - let doc_b = clients[1].docs.import(ticket).await?; - let author_b = clients[1].authors.create().await?; + let doc_b = clients[1].docs().import(ticket).await?; + let author_b = clients[1].authors().create().await?; doc_a.set_download_policy(policy_a).await?; doc_b.set_download_policy(policy_b).await?; @@ -871,9 +871,9 @@ async fn sync_big() -> Result<()> { let nodes = spawn_nodes(n_nodes, &mut rng).await?; let node_ids = nodes.iter().map(|node| node.node_id()).collect::>(); let clients = nodes.iter().map(|node| node.client()).collect::>(); - let authors = collect_futures(clients.iter().map(|c| c.authors.create())).await?; + let authors = collect_futures(clients.iter().map(|c| c.authors().create())).await?; - let doc0 = clients[0].docs.create().await?; + let doc0 = clients[0].docs().create().await?; let mut ticket = doc0 .share(ShareMode::Write, AddrInfoOptions::RelayAndAddresses) .await?; @@ -888,7 +888,7 @@ async fn sync_big() -> Result<()> { clients .iter() .skip(1) - .map(|c| c.docs.import(ticket.clone())), + .map(|c| c.docs().import(ticket.clone())), ) .await?, ); @@ -973,6 +973,44 @@ async fn sync_big() -> Result<()> { Ok(()) } +#[tokio::test] +#[cfg(feature = "test-utils")] +async fn test_list_docs_stream() -> Result<()> { + let node = Node::memory() + .node_discovery(iroh::node::DiscoveryConfig::None) + .relay_mode(iroh::net::relay::RelayMode::Disabled) + .spawn() + .await?; + let count = 200; + + // create docs + for _i in 0..count { + let doc = node.docs().create().await?; + doc.close().await?; + } + + // create doc stream + let mut stream = node.docs().list().await?; + + // process each doc and call into the docs actor. + // this makes sure that we don't deadlock the docs actor. + let mut i = 0; + let fut = async { + while let Some((id, _)) = stream.try_next().await.unwrap() { + let _doc = node.docs().open(id).await.unwrap().unwrap(); + i += 1; + } + }; + + tokio::time::timeout(Duration::from_secs(2), fut) + .await + .expect("not to timeout"); + + assert_eq!(i, count); + + Ok(()) +} + /// Get all entries of a document. async fn get_all(doc: &MemDoc) -> anyhow::Result> { let entries = doc.get_many(Query::all()).await?; @@ -1113,8 +1151,8 @@ async fn doc_delete() -> Result<()> { .spawn() .await?; let client = node.client(); - let doc = client.docs.create().await?; - let author = client.authors.create().await?; + let doc = client.docs().create().await?; + let author = client.authors().create().await?; let hash = doc .set_bytes(author, b"foo".to_vec(), b"hi".to_vec()) .await?; @@ -1128,7 +1166,7 @@ async fn doc_delete() -> Result<()> { // wait for gc // TODO: allow to manually trigger gc tokio::time::sleep(Duration::from_millis(200)).await; - let bytes = client.blobs.read_to_bytes(hash).await; + let bytes = client.blobs().read_to_bytes(hash).await; assert!(bytes.is_err()); node.shutdown().await?; Ok(()) @@ -1141,8 +1179,8 @@ async fn sync_drop_doc() -> Result<()> { let node = spawn_node(0, &mut rng).await?; let client = node.client(); - let doc = client.docs.create().await?; - let author = client.authors.create().await?; + let doc = client.docs().create().await?; + let author = client.authors().create().await?; let mut sub = doc.subscribe().await?; doc.set_bytes(author, b"foo".to_vec(), b"bar".to_vec()) @@ -1150,14 +1188,14 @@ async fn sync_drop_doc() -> Result<()> { let ev = sub.next().await; assert!(matches!(ev, Some(Ok(LiveEvent::InsertLocal { .. })))); - client.docs.drop_doc(doc.id()).await?; + client.docs().drop_doc(doc.id()).await?; let res = doc.get_exact(author, b"foo".to_vec(), true).await; assert!(res.is_err()); let res = doc .set_bytes(author, b"foo".to_vec(), b"bar".to_vec()) .await; assert!(res.is_err()); - let res = client.docs.open(doc.id()).await; + let res = client.docs().open(doc.id()).await; assert!(res.is_err()); let ev = sub.next().await; assert!(ev.is_none());