Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

wip: attempt to call discovery inside the magicsock #2256

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 137 additions & 37 deletions iroh-net/src/discovery.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
//! Trait and utils for the node discovery mechanism.

use std::time::Duration;
use std::{collections::BTreeMap, time::Duration};

use anyhow::{anyhow, ensure, Result};
use futures_lite::stream::{Boxed as BoxStream, StreamExt};
use iroh_base::node_addr::NodeAddr;
use tokio::{sync::oneshot, task::JoinHandle};
use tracing::{debug, error_span, warn, Instrument};
use tracing::{debug, error_span, trace, warn, Instrument};

use crate::{AddrInfo, MagicEndpoint, NodeId};

Expand Down Expand Up @@ -123,44 +123,26 @@ impl Discovery for ConcurrentDiscovery {
const MAX_AGE: Duration = Duration::from_secs(10);

/// A wrapper around a tokio task which runs a node discovery.
#[derive(derive_more::Debug)]
pub(super) struct DiscoveryTask {
on_first_rx: oneshot::Receiver<Result<()>>,
task: JoinHandle<()>,
}

impl DiscoveryTask {
/// Start a discovery task.
pub fn start(ep: MagicEndpoint, node_id: NodeId) -> Result<Self> {
ensure!(ep.discovery().is_some(), "No discovery services configured");
let (on_first_tx, on_first_rx) = oneshot::channel();
let me = ep.node_id();
let task = tokio::task::spawn(
async move { Self::run(ep, node_id, on_first_tx).await }.instrument(
error_span!("discovery", me = %me.fmt_short(), node = %node_id.fmt_short()),
),
);
Ok(Self { task, on_first_rx })
}

/// Start a discovery task after a delay and only if no path to the node was recently active.
/// Start a discovery task after a delay
///
/// This returns `None` if we received data or control messages from the remote endpoint
/// recently enough. If not it returns a [`DiscoveryTask`].
///
/// If `delay` is set, the [`DiscoveryTask`] will first wait for `delay` and then check again
/// if we recently received messages from remote endpoint. If true, the task will abort.
/// Otherwise, or if no `delay` is set, the discovery will be started.
pub fn maybe_start_after_delay(
pub fn start_after_delay(
ep: &MagicEndpoint,
node_id: NodeId,
delay: Option<Duration>,
) -> Result<Option<Self>> {
// If discovery is not needed, don't even spawn a task.
if !Self::needs_discovery(ep, node_id) {
return Ok(None);
}
ensure!(ep.discovery().is_some(), "No discovery services configured");
let (on_first_tx, on_first_rx) = oneshot::channel();
on_first_tx: Option<oneshot::Sender<Result<()>>>,
) -> Option<Self> {
let ep = ep.clone();
let me = ep.node_id();
let task = tokio::task::spawn(
Expand All @@ -170,7 +152,7 @@ impl DiscoveryTask {
tokio::time::sleep(delay).await;
if !Self::needs_discovery(&ep, node_id) {
debug!("no discovery needed, abort");
on_first_tx.send(Ok(())).ok();
on_first_tx.map(|tx| tx.send(Ok(())).ok());
return;
}
}
Expand All @@ -180,14 +162,7 @@ impl DiscoveryTask {
error_span!("discovery", me = %me.fmt_short(), node = %node_id.fmt_short()),
),
);
Ok(Some(Self { task, on_first_rx }))
}

/// Wait until the discovery task produced at least one result.
pub async fn first_arrived(&mut self) -> Result<()> {
let fut = &mut self.on_first_rx;
fut.await??;
Ok(())
Some(Self { task })
}

/// Cancel the discovery task.
Expand Down Expand Up @@ -229,15 +204,18 @@ impl DiscoveryTask {
}
}

async fn run(ep: MagicEndpoint, node_id: NodeId, on_first_tx: oneshot::Sender<Result<()>>) {
async fn run(
ep: MagicEndpoint,
node_id: NodeId,
mut on_first_tx: Option<oneshot::Sender<Result<()>>>,
) {
let mut stream = match Self::create_stream(&ep, node_id) {
Ok(stream) => stream,
Err(err) => {
on_first_tx.send(Err(err)).ok();
on_first_tx.map(|s| s.send(Err(err)).ok());
return;
}
};
let mut on_first_tx = Some(on_first_tx);
debug!("discovery: start");
loop {
let next = tokio::select! {
Expand Down Expand Up @@ -280,6 +258,128 @@ impl Drop for DiscoveryTask {
}
}

use flume::{Receiver, Sender};
use std::sync::Arc;

/// Responsible for starting and cancelling Discovery requests from
/// the magicsock.
#[derive(derive_more::Debug, Clone)]
pub(super) struct DiscoveryTasks {
handle: Arc<JoinHandle<()>>,
sender: Sender<DiscoveryTaskMessage>,
}

impl Drop for DiscoveryTasks {
fn drop(&mut self) {
self.handle.abort();
}
}

pub(super) type DiscoveryTasksChans =
(Sender<DiscoveryTaskMessage>, Receiver<DiscoveryTaskMessage>);

impl DiscoveryTasks {
/// Create a new `DiscoveryTasks` worker.
///
/// There should only ever be one `DiscoveryTask` running for each attempted connection,
pub(crate) fn new(ep: MagicEndpoint, chans: DiscoveryTasksChans) -> Result<Self> {
ensure!(
ep.discovery().is_some(),
"No discovery enabled, cannot start discovery tasks"
);
let (sender, recv) = chans;
let handle = tokio::spawn(async move {
let mut tasks = BTreeMap::default();
loop {
let msg = tokio::select! {
_ = ep.cancelled() => break,
msg = recv.recv_async() => {
match msg {
Err(e) => {
debug!("{e:?}");
break;
},
Ok(msg) => msg,
}
}
};
match msg {
DiscoveryTaskMessage::Start{node_id, delay, on_first_tx} => {
if !DiscoveryTask::needs_discovery(&ep, node_id) {
trace!("Discovery for {node_id} requested, but the node does not need discovery.");
continue;
}
if let Some(new_task) = DiscoveryTask::start_after_delay(&ep, node_id, delay, on_first_tx) {
if let Some(old_task) = tasks.insert(node_id, new_task) {
old_task.cancel();
}
}
}
DiscoveryTaskMessage::Cancel(node_id) => {
match tasks.remove(&node_id) {
None => trace!("Cancelled Discovery for {node_id}, but no Discovery for that id is currently running."),
Some(task) => task.cancel()
}
}
}
}
});
Ok(DiscoveryTasks {
handle: Arc::new(handle),
sender,
})
}

/// Cancel a [`DiscoveryTask`]
///
/// If the receiver is full, it drops the request. There will only ever be
/// one [`DiscoveryTask`] per node dialed, so if this happens, there is
/// something very wrong.
pub fn cancel(&self, node_id: NodeId) {
self.sender.send(DiscoveryTaskMessage::Cancel(node_id)).ok();
}

/// Start a [`DiscoveryTask`], if necessary.
///
/// You can start the task on a delay by providing an optional [`Duration`].
///
/// If the receiver is full, it drops the request. There will only ever be
/// one [`DiscoveryTask`] per node dialed, so if this happens, there is
/// something very wrong.
pub fn start(
&self,
node_id: NodeId,
delay: Option<Duration>,
on_first_tx: Option<oneshot::Sender<Result<()>>>,
) {
self.sender
.send(DiscoveryTaskMessage::Start {
node_id,
delay,
on_first_tx,
})
.ok();
}
}

/// Messages used by the [`DiscoveryTasks`] struct to manage [`DiscoveryTask`]s.
#[derive(Debug)]
pub(super) enum DiscoveryTaskMessage {
/// Launch discovery for the given [`NodeId`]
Start {
/// The node ID for the node we are trying to discover
node_id: NodeId,
/// When `None`, start discovery immediately
/// When `Some`, start discovery after a delay.
delay: Option<Duration>,
/// If it exists, send the first time a discovery message returns,
/// or send an error if the discovery was unable to occur.
on_first_tx: Option<oneshot::Sender<Result<()>>>,
},
/// Cancel any discovery for the given [`NodeId`]
Cancel(NodeId),
}

#[cfg(test)]
mod tests {
use std::{
Expand Down
62 changes: 43 additions & 19 deletions iroh-net/src/magic_endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@ use anyhow::{anyhow, bail, ensure, Context, Result};
use derive_more::Debug;
use futures_lite::StreamExt;
use quinn_proto::VarInt;
use tokio::sync::oneshot;
use tokio_util::sync::{CancellationToken, WaitForCancellationFuture};
use tracing::{debug, trace};

use crate::{
config,
defaults::default_relay_map,
discovery::{Discovery, DiscoveryTask},
discovery::{Discovery, DiscoveryTasks, DiscoveryTasksChans},
dns::{default_resolver, DnsResolver},
key::{PublicKey, SecretKey},
magicsock::{self, Handle},
Expand Down Expand Up @@ -204,17 +205,31 @@ impl MagicEndpointBuilder {
.dns_resolver
.unwrap_or_else(|| default_resolver().clone());

// Discovery should not happen that often, and only happens for
// nodes we are trying to connect to or already connected to.
// TODO: possibly make this configurable? It should only be an issue
// if you are connected to many connections and all of them suddenly need
// discovery at the same time.
let discovery_tasks_chans = self.discovery.as_ref().map(|_| flume::bounded(64));

let msock_opts = magicsock::Options {
port: bind_port,
secret_key,
relay_map,
nodes_path: self.peers_path,
discovery: self.discovery,
discovery_tasks_sender: discovery_tasks_chans.as_ref().map(|(s, _)| s.clone()),
dns_resolver,
#[cfg(any(test, feature = "test-utils"))]
insecure_skip_relay_cert_verify: self.insecure_skip_relay_cert_verify,
};
MagicEndpoint::bind(Some(server_config), msock_opts, self.keylog).await
MagicEndpoint::bind(
Some(server_config),
msock_opts,
self.keylog,
discovery_tasks_chans,
)
.await
}
}

Expand Down Expand Up @@ -248,6 +263,7 @@ pub struct MagicEndpoint {
endpoint: quinn::Endpoint,
keylog: bool,
cancel_token: CancellationToken,
discovery_tasks: Option<DiscoveryTasks>,
}

impl MagicEndpoint {
Expand All @@ -264,6 +280,7 @@ impl MagicEndpoint {
server_config: Option<quinn::ServerConfig>,
msock_opts: magicsock::Options,
keylog: bool,
discovery_tasks: Option<DiscoveryTasksChans>,
) -> Result<Self> {
let secret_key = msock_opts.secret_key.clone();
let msock = magicsock::MagicSock::spawn(msock_opts).await?;
Expand All @@ -285,13 +302,18 @@ impl MagicEndpoint {
)?;
trace!("created quinn endpoint");

Ok(Self {
let mut ep = Self {
secret_key: Arc::new(secret_key),
msock,
endpoint,
keylog,
cancel_token: CancellationToken::new(),
})
discovery_tasks: None,
};
let discovery_tasks =
discovery_tasks.map(|chans| DiscoveryTasks::new(ep.clone(), chans).expect("checked"));
ep.discovery_tasks = discovery_tasks;
Ok(ep)
}

/// Accept an incoming connection on the socket.
Expand Down Expand Up @@ -459,7 +481,7 @@ impl MagicEndpoint {
// Get the mapped IPv6 address from the magic socket. Quinn will connect to this address.
// Start discovery for this node if it's enabled and we have no valid or verified
// address information for this node.
let (addr, discovery) = self
let addr = self
.get_mapping_addr_and_maybe_start_discovery(node_addr)
.await?;

Expand All @@ -473,8 +495,8 @@ impl MagicEndpoint {
let conn = self.connect_quinn(&node_id, alpn, addr).await;

// Cancel the node discovery task (if still running).
if let Some(discovery) = discovery {
discovery.cancel();
if let Some(ref discovery_tasks) = self.discovery_tasks {
discovery_tasks.cancel(node_id);
}

conn
Expand Down Expand Up @@ -525,7 +547,7 @@ impl MagicEndpoint {
async fn get_mapping_addr_and_maybe_start_discovery(
&self,
node_addr: NodeAddr,
) -> Result<(SocketAddr, Option<DiscoveryTask>)> {
) -> Result<SocketAddr> {
let node_id = node_addr.node_id;

// Only return a mapped addr if we have some way of dialing this node, in other
Expand All @@ -545,25 +567,27 @@ impl MagicEndpoint {
// followed by a recheck before starting the discovery, to give the magicsocket a
// chance to test the newly provided addresses.
let delay = (!node_addr.info.is_empty()).then_some(DISCOVERY_WAIT_PERIOD);
let discovery = DiscoveryTask::maybe_start_after_delay(self, node_id, delay)
.ok()
.flatten();
Ok((addr, discovery))
if let Some(ref discovery_tasks) = self.discovery_tasks {
discovery_tasks.start(node_id, delay, None);
}
Ok(addr)
}

None => {
// We have no known addresses or relay URLs for this node.
// So, we start a discovery task and wait for the first result to arrive, and
// only then continue, because otherwise we wouldn't have any
// path to the remote endpoint.
let mut discovery = DiscoveryTask::start(self.clone(), node_id)?;
discovery.first_arrived().await?;
if self.msock.has_send_address(node_id) {
let addr = self.msock.get_mapping_addr(&node_id).expect("checked");
Ok((addr, Some(discovery)))
} else {
bail!("Failed to retrieve the mapped address from the magic socket. Unable to dial node {node_id:?}");
if let Some(ref discovery_tasks) = self.discovery_tasks {
let (first_arrived_tx, first_arrived_rx) = oneshot::channel();
discovery_tasks.start(node_id, None, Some(first_arrived_tx));
let _ = first_arrived_rx.await;
if self.msock.has_send_address(node_id) {
let addr = self.msock.get_mapping_addr(&node_id).expect("checked");
return Ok(addr);
}
}
bail!("Failed to retrieve the mapped address from the magic socket. Unable to dial node {node_id:?}");
}
}
}
Expand Down
Loading
Loading