diff --git a/iroh-gossip/src/net.rs b/iroh-gossip/src/net.rs index 29b43ae32f4..fc81efd9288 100644 --- a/iroh-gossip/src/net.rs +++ b/iroh-gossip/src/net.rs @@ -431,7 +431,7 @@ impl Actor { trace!("handle to_actor {msg:?}"); match msg { ToActor::ConnIncoming(conn) => { - if let Err(err) = self.conn_manager.accept(conn) { + if let Err(err) = self.conn_manager.handle_connection(conn) { warn!(?err, "failed to accept connection"); } } diff --git a/iroh-net/src/conn_manager.rs b/iroh-net/src/conn_manager.rs index 4b9828558ba..02fd93647ee 100644 --- a/iroh-net/src/conn_manager.rs +++ b/iroh-net/src/conn_manager.rs @@ -12,7 +12,7 @@ use tokio::{ sync::mpsc, task::{AbortHandle, JoinSet}, }; -use tracing::{debug, error}; +use tracing::{error, warn}; use crate::{ endpoint::{get_remote_node_id, Connection}, @@ -22,17 +22,75 @@ use crate::{ const DUPLICATE_REASON: &[u8] = b"abort_duplicate"; const DUPLICATE_CODE: u32 = 123; -/// A connection manager. +/// Whether we accepted the connection or initiated it. +#[derive(Debug, Clone, Copy, Eq, PartialEq)] +pub enum ConnDirection { + /// We accepted this connection from the other peer. + Accept, + /// We initiated this connection by connecting to the other peer. + Dial, +} + +/// A new connection as emitted from [`ConnManager`]. +#[derive(Debug, Clone, derive_more::Deref)] +pub struct ConnInfo { + /// The QUIC connection. + #[deref] + pub conn: Connection, + /// The node id of the other peer. + pub node_id: NodeId, + /// Whether we accepted or initiated this connection. + pub direction: ConnDirection, +} + +/// A sender to push new connections into a [`ConnManager`]. +/// +/// See [`ConnManager::accept_sender`] for details. +#[derive(Debug, Clone)] +pub struct HandleConnectionSender { + tx: mpsc::Sender, +} + +impl HandleConnectionSender { + /// Send a new connection to the [`ConnManager`]. + pub async fn send(&self, conn: Connection) -> anyhow::Result<()> { + self.tx.send(conn).await?; + Ok(()) + } +} + +/// The error returned from [`ConnManager::poll_next`]. +#[derive(thiserror::Error, Debug)] +#[error("Connection to node {} direction {:?} failed: {:?}", self.node_id, self.direction, self.reason)] +pub struct ConnectError { + /// The node id of the peer to which the connection failed. + pub node_id: NodeId, + /// The direction of the connection. + pub direction: ConnDirection, + /// The actual error that ocurred. + #[source] + pub reason: anyhow::Error, +} + +/// A connection manager that ensures that only a single connection between two peers prevails. +/// +/// You can start to dial peers by calling [`ConnManager::dial`]. Note that the method only takes a +/// node id; if you have more addressing info, add it to the endpoint directly with +/// [`Endpoint::add_node_addr`] before calling `dial`; /// /// The [`ConnManager`] does not accept connections from the endpoint by itself. Instead, you /// should run an accept loop yourself, and push connections with a matching ALPN into the manager -/// with [`ConnManager::accept`]. The connection will be dropped if we already have a connection to -/// that node. If we are currently dialing the node, the connection will only be accepted if the -/// peer's node id sorts lower than our node id. Through this, it is ensured that we will not get -/// double connections with a node if both we and them dial each other at the same time. +/// with [`ConnManager::handle_connection`] or [`ConnManager::handle_connection_sender`]. /// -/// The [`ConnManager`] implements [`Stream`]. It will yield new connections, both from dialing and -/// accepting. +/// The [`ConnManager`] is a [`Stream`] that yields all connections from both accepting and dialing. +/// +/// Before accepting incoming connections, the [`ConnManager`] makes sure that, if we are dialing +/// the same node, only one of the connections will prevail. In this case, the accepting side +/// rejects the connection if the peer's node id sorts higher than their own node id. +/// +/// To make this reliable even if the dials happen exactly at the same time, a single unidirectional +/// stream is opened, on which a single byte is sent. This additional rountrip ensures that no +/// double connections can prevail. #[derive(Debug)] pub struct ConnManager { endpoint: Endpoint, @@ -80,36 +138,6 @@ impl ConnManager { } } - fn spawn( - &mut self, - node_id: NodeId, - direction: ConnDirection, - fut: impl Future> + Send + 'static, - ) { - let abort_handle = self.tasks.spawn(fut.map(move |res| (node_id, res))); - let pending_state = PendingState { - direction, - abort_handle, - }; - self.pending.insert(node_id, pending_state); - if let Some(waker) = self.waker.take() { - waker.wake(); - } - } - - /// Get a sender to push new connections towards the [`ConnManager`] - /// - /// This does not check the connection's ALPN, so you should make sure that the ALPN matches - /// the [`ConnManager`]'s execpected ALPN before passing the connection to the sender. - /// - /// If we are currently dialing the node, the connection will be dropped if the peer's node id - /// sorty higher than our node id. Otherwise, the connection will be yielded from the manager - /// stream. - pub fn accept_sender(&self) -> AcceptSender { - let tx = self.accept_tx.clone(); - AcceptSender { tx } - } - /// Accept a connection. /// /// This does not check the connection's ALPN, so you should make sure that the ALPN matches @@ -117,7 +145,7 @@ impl ConnManager { /// /// If we are currently dialing the node, the connection will be dropped if the peer's node id /// sorty higher than our node id. Otherwise, the connection will be returned. - pub fn accept(&mut self, conn: quinn::Connection) -> anyhow::Result<()> { + pub fn handle_connection(&mut self, conn: quinn::Connection) -> anyhow::Result<()> { let node_id = get_remote_node_id(&conn)?; // We are already connected: drop the connection, keep using the existing conn. if self.is_connected(&node_id) { @@ -128,7 +156,7 @@ impl ConnManager { // We are currently dialing the node, but the incoming conn "wins": accept and abort // our dial. Some(state) - if state.direction == ConnDirection::Dial && node_id > self.our_node_id() => + if state.direction == ConnDirection::Dial && node_id > self.endpoint.node_id() => { state.abort_handle.abort(); true @@ -147,6 +175,19 @@ impl ConnManager { Ok(()) } + /// Get a sender to push new connections towards the [`ConnManager`] + /// + /// This does not check the connection's ALPN, so you should make sure that the ALPN matches + /// the [`ConnManager`]'s execpected ALPN before passing the connection to the sender. + /// + /// If we are currently dialing the node, the connection will be dropped if the peer's node id + /// sorty higher than our node id. Otherwise, the connection will be yielded from the manager + /// stream. + pub fn handle_connection_sender(&self) -> HandleConnectionSender { + let tx = self.accept_tx.clone(); + HandleConnectionSender { tx } + } + /// Remove the connection to a node. /// /// Also aborts pending dials to the node, if existing. @@ -174,8 +215,21 @@ impl ConnManager { self.active.contains_key(node_id) } - fn our_node_id(&self) -> NodeId { - self.endpoint.node_id() + fn spawn( + &mut self, + node_id: NodeId, + direction: ConnDirection, + fut: impl Future> + Send + 'static, + ) { + let abort_handle = self.tasks.spawn(fut.map(move |res| (node_id, res))); + let pending_state = PendingState { + direction, + abort_handle, + }; + self.pending.insert(node_id, pending_state); + if let Some(waker) = self.waker.take() { + waker.wake(); + } } } @@ -183,28 +237,27 @@ impl Stream for ConnManager { type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - tracing::debug!("poll_next in"); // Create new tasks for incoming connections. while let Poll::Ready(Some(conn)) = Pin::new(&mut self.accept_rx).poll_recv(cx) { - // self.accept(conn) - debug!("accept - polled"); - if let Err(error) = self.accept(conn) { - tracing::warn!(?error, "skipping invalid connection attempt"); + if let Err(error) = self.handle_connection(conn) { + warn!(?error, "skipping invalid connection attempt"); } } // Poll for finished tasks, loop { let join_res = ready!(self.tasks.poll_join_next(cx)); - debug!(?join_res, "join res"); let (node_id, res) = match join_res { None => { self.waker = Some(cx.waker().to_owned()); return Poll::Pending; } Some(Err(err)) if err.is_cancelled() => continue, - // we are merely forwarding a panic here, which should never occur. - Some(Err(err)) => panic!("connection manager task paniced with {err:?}"), + Some(Err(err)) => { + // TODO: unreachable? + warn!("connection manager task paniced with {err:?}"); + continue; + } Some(Ok(res)) => res, }; match res { @@ -212,7 +265,7 @@ impl Stream for ConnManager { Err(InitError::Other(reason)) => { let Some(PendingState { direction, .. }) = self.pending.remove(&node_id) else { // TODO: unreachable? - tracing::warn!(node_id=%node_id.fmt_short(), "missing pending state, dropping connection"); + warn!(node_id=%node_id.fmt_short(), "missing pending state, dropping connection"); continue; }; let err = ConnectError { @@ -225,7 +278,7 @@ impl Stream for ConnManager { Ok(conn) => { let Some(PendingState { direction, .. }) = self.pending.remove(&node_id) else { // TODO: unreachable? - tracing::warn!(node_id=%node_id.fmt_short(), "missing pending state, dropping connection"); + warn!(node_id=%node_id.fmt_short(), "missing pending state, dropping connection"); continue; }; let info = ConnInfo { @@ -265,35 +318,6 @@ struct PendingState { abort_handle: AbortHandle, } -/// A sender to push new connections into a [`ConnManager`]. -/// -/// See [`ConnManager::accept_sender`] for details. -#[derive(Debug, Clone)] -pub struct AcceptSender { - tx: mpsc::Sender, -} - -impl AcceptSender { - /// Send a new connection to the [`ConnManager`]. - pub async fn send(&self, conn: Connection) -> anyhow::Result<()> { - self.tx.send(conn).await?; - Ok(()) - } -} - -/// The error returned from [`ConnManager::poll_next`]. -#[derive(thiserror::Error, Debug)] -#[error("Connection to node {} direction {:?} failed: {:?}", self.node_id, self.direction, self.reason)] -pub struct ConnectError { - /// The node id of the peer to which the connection failed. - pub node_id: NodeId, - /// The direction of the connection. - pub direction: ConnDirection, - /// The actual error that ocurred. - #[source] - pub reason: anyhow::Error, -} - #[derive(Debug)] enum InitError { IsDuplicate, @@ -338,27 +362,6 @@ impl From for InitError { } } -/// Whether we accepted the connection or initiated it. -#[derive(Debug, Clone, Copy, Eq, PartialEq)] -pub enum ConnDirection { - /// We accepted this connection from the other peer. - Accept, - /// We initiated this connection by connecting to the other peer. - Dial, -} - -/// A new connection as emitted from [`ConnManager`]. -#[derive(Debug, Clone, derive_more::Deref)] -pub struct ConnInfo { - /// The QUIC connection. - #[deref] - pub conn: Connection, - /// The node id of the other peer. - pub node_id: NodeId, - /// Whether we accepted or initiated this connection. - pub direction: ConnDirection, -} - #[cfg(test)] mod tests { use std::time::Duration; @@ -368,11 +371,14 @@ mod tests { use crate::test_utils::TestEndpointFactory; - use super::{AcceptSender, ConnManager}; + use super::{ConnManager, HandleConnectionSender}; const TEST_ALPN: &[u8] = b"test"; - async fn accept_loop(ep: crate::Endpoint, accept_sender: AcceptSender) -> anyhow::Result<()> { + async fn accept_loop( + ep: crate::Endpoint, + accept_sender: HandleConnectionSender, + ) -> anyhow::Result<()> { while let Some(conn) = ep.accept().await { let conn = conn.await?; tracing::debug!(me=%ep.node_id().fmt_short(), "conn incoming"); @@ -398,8 +404,8 @@ mod tests { let mut conn_manager1 = ConnManager::new(ep1.clone(), TEST_ALPN); let mut conn_manager2 = ConnManager::new(ep2.clone(), TEST_ALPN); - let accept1 = conn_manager1.accept_sender(); - let accept2 = conn_manager2.accept_sender(); + let accept1 = conn_manager1.handle_connection_sender(); + let accept2 = conn_manager2.handle_connection_sender(); let mut tasks = JoinSet::new(); tasks.spawn(accept_loop(ep1, accept1)); tasks.spawn(accept_loop(ep2, accept2));