Skip to content

Commit

Permalink
feat: add connection manager and use in iroh-gossip
Browse files Browse the repository at this point in the history
  • Loading branch information
Frando committed May 22, 2024
1 parent d635d93 commit c8265ee
Show file tree
Hide file tree
Showing 7 changed files with 614 additions and 82 deletions.
3 changes: 2 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion iroh-gossip/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ iroh-base = { version = "0.16.0", path = "../iroh-base" }

# net dependencies (optional)
futures-lite = { version = "2.3", optional = true }
futures-util = { version = "0.3.30", optional = true }
iroh-net = { path = "../iroh-net", version = "0.16.0", optional = true, default-features = false, features = ["test-utils"] }
tokio = { version = "1", optional = true, features = ["io-util", "sync", "rt", "macros", "net", "fs"] }
tokio-util = { version = "0.7.8", optional = true, features = ["codec"] }
Expand All @@ -46,7 +47,7 @@ url = "2.4.0"

[features]
default = ["net"]
net = ["dep:futures-lite", "dep:iroh-net", "dep:tokio", "dep:tokio-util"]
net = ["dep:futures-lite", "dep:futures-util", "dep:iroh-net", "dep:tokio", "dep:tokio-util"]

[[example]]
name = "chat"
Expand Down
170 changes: 93 additions & 77 deletions iroh-gossip/src/net.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
use anyhow::{anyhow, Context};
use bytes::{Bytes, BytesMut};
use futures_lite::stream::Stream;
use futures_lite::{stream::Stream, StreamExt};
use futures_util::future::FutureExt;
use genawaiter::sync::{Co, Gen};
use iroh_net::{
dialer::Dialer,
endpoint::{get_remote_node_id, Connection},
conn_manager::{ConnDirection, ConnInfo, ConnManager},
endpoint::Connection,
key::PublicKey,
AddrInfo, Endpoint, NodeAddr,
};
Expand All @@ -15,7 +16,7 @@ use rand_core::SeedableRng;
use std::{collections::HashMap, future::Future, pin::Pin, sync::Arc, task::Poll, time::Instant};
use tokio::{
sync::{broadcast, mpsc, oneshot},
task::JoinHandle,
task::{JoinHandle, JoinSet},
};
use tracing::{debug, error_span, trace, warn, Instrument};

Expand Down Expand Up @@ -82,7 +83,7 @@ impl Gossip {
/// Spawn a gossip actor and get a handle for it
pub fn from_endpoint(endpoint: Endpoint, config: proto::Config, my_addr: &AddrInfo) -> Self {
let peer_id = endpoint.node_id();
let dialer = Dialer::new(endpoint.clone());
let conn_manager = ConnManager::new(endpoint.clone(), GOSSIP_ALPN);
let state = proto::State::new(
peer_id,
encode_peer_data(my_addr).unwrap(),
Expand All @@ -97,12 +98,12 @@ impl Gossip {
let actor = Actor {
endpoint,
state,
dialer,
conn_manager,
conn_tasks: Default::default(),
to_actor_rx,
in_event_rx,
in_event_tx,
on_endpoints_rx,
conns: Default::default(),
conn_send_tx: Default::default(),
pending_sends: Default::default(),
timers: Timers::new(),
Expand Down Expand Up @@ -231,9 +232,7 @@ impl Gossip {
///
/// Make sure to check the ALPN protocol yourself before passing the connection.
pub async fn handle_connection(&self, conn: Connection) -> anyhow::Result<()> {
let peer_id = get_remote_node_id(&conn)?;
self.send(ToActor::ConnIncoming(peer_id, ConnOrigin::Accept, conn))
.await?;
self.send(ToActor::ConnIncoming(conn)).await?;
Ok(())
}

Expand Down Expand Up @@ -283,19 +282,11 @@ impl Future for JoinTopicFut {
}
}

/// Whether a connection is initiated by us (Dial) or by the remote peer (Accept)
#[derive(Debug)]
enum ConnOrigin {
Accept,
Dial,
}

/// Input messages for the gossip [`Actor`].
#[derive(derive_more::Debug)]
enum ToActor {
/// Handle a new QUIC connection, either from accept (external to the actor) or from connect
/// (happens internally in the actor).
ConnIncoming(PublicKey, ConnOrigin, #[debug(skip)] Connection),
/// Handle a new incoming QUIC connection.
ConnIncoming(iroh_net::endpoint::Connection),
/// Join a topic with a list of peers. Reply with oneshot once at least one peer joined.
Join(
TopicId,
Expand Down Expand Up @@ -329,8 +320,8 @@ struct Actor {
/// Protocol state
state: proto::State<PublicKey, StdRng>,
endpoint: Endpoint,
/// Dial machine to connect to peers
dialer: Dialer,
/// Connection manager to dial and accept connections.
conn_manager: ConnManager,
/// Input messages to the actor
to_actor_rx: mpsc::Receiver<ToActor>,
/// Sender for the state input (cloned into the connection loops)
Expand All @@ -341,10 +332,10 @@ struct Actor {
on_endpoints_rx: mpsc::Receiver<Vec<iroh_net::config::Endpoint>>,
/// Queued timers
timers: Timers<Timer>,
/// Currently opened quinn connections to peers
conns: HashMap<PublicKey, Connection>,
/// Channels to send outbound messages into the connection loops
conn_send_tx: HashMap<PublicKey, mpsc::Sender<ProtoMessage>>,
/// Connection loop tasks
conn_tasks: JoinSet<(PublicKey, anyhow::Result<()>)>,
/// Queued messages that were to be sent before a dial completed
pending_sends: HashMap<PublicKey, Vec<ProtoMessage>>,
/// Broadcast senders for active topic subscriptions from the application
Expand All @@ -353,6 +344,12 @@ struct Actor {
subscribers_all: Option<broadcast::Sender<(TopicId, Event)>>,
}

impl Drop for Actor {
fn drop(&mut self) {
self.conn_tasks.abort_all();
}
}

impl Actor {
pub async fn run(mut self) -> anyhow::Result<()> {
let mut i = 0;
Expand Down Expand Up @@ -384,15 +381,27 @@ impl Actor {
}
}
}
(peer_id, res) = self.dialer.next_conn() => {
trace!(?i, "tick: dialer");
Some(res) = self.conn_manager.next() => {
trace!(?i, "tick: conn_manager");
match res {
Ok(conn) => {
debug!(peer = ?peer_id, "dial successful");
self.handle_to_actor_msg(ToActor::ConnIncoming(peer_id, ConnOrigin::Dial, conn), Instant::now()).await.context("dialer.next -> conn -> handle_to_actor_msg")?;
}
Ok(conn) => self.handle_new_connection(conn).await,
Err(err) => {
warn!(peer = ?peer_id, "dial failed: {err}");
self.handle_in_event(InEvent::PeerDisconnected(err.node_id), Instant::now()).await?;
}
}
}
Some(res) = self.conn_tasks.join_next(), if !self.conn_tasks.is_empty() => {
match res {
Err(err) if !err.is_cancelled() => warn!(?err, "connection loop panicked"),
Err(_err) => {},
Ok((node_id, result)) => {
self.conn_manager.remove(&node_id);
self.conn_send_tx.remove(&node_id);
self.handle_in_event(InEvent::PeerDisconnected(node_id), Instant::now()).await?;
match result {
Ok(()) => debug!(peer=%node_id.fmt_short(), "connection closed without error"),
Err(err) => debug!(peer=%node_id.fmt_short(), "connection closed with error {err:?}"),
}
}
}
}
Expand Down Expand Up @@ -421,38 +430,9 @@ impl Actor {
async fn handle_to_actor_msg(&mut self, msg: ToActor, now: Instant) -> anyhow::Result<()> {
trace!("handle to_actor {msg:?}");
match msg {
ToActor::ConnIncoming(peer_id, origin, conn) => {
self.conns.insert(peer_id, conn.clone());
self.dialer.abort_dial(&peer_id);
let (send_tx, send_rx) = mpsc::channel(SEND_QUEUE_CAP);
self.conn_send_tx.insert(peer_id, send_tx.clone());

// Spawn a task for this connection
let in_event_tx = self.in_event_tx.clone();
tokio::spawn(
async move {
debug!("connection established");
match connection_loop(peer_id, conn, origin, send_rx, &in_event_tx).await {
Ok(()) => {
debug!("connection closed without error")
}
Err(err) => {
debug!("connection closed with error {err:?}")
}
}
in_event_tx
.send(InEvent::PeerDisconnected(peer_id))
.await
.ok();
}
.instrument(error_span!("gossip_conn", peer = %peer_id.fmt_short())),
);

// Forward queued pending sends
if let Some(send_queue) = self.pending_sends.remove(&peer_id) {
for msg in send_queue {
send_tx.send(msg).await?;
}
ToActor::ConnIncoming(conn) => {
if let Err(err) = self.conn_manager.accept(conn) {
warn!(?err, "failed to accept connection");
}
}
ToActor::Join(topic_id, peers, reply) => {
Expand Down Expand Up @@ -502,9 +482,6 @@ impl Actor {
} else {
debug!("handle in_event {event:?}");
};
if let InEvent::PeerDisconnected(peer) = &event {
self.conn_send_tx.remove(peer);
}
let out = self.state.handle(event, now);
for event in out {
if matches!(event, OutEvent::ScheduleTimer(_, _)) {
Expand All @@ -518,10 +495,13 @@ impl Actor {
if let Err(_err) = send.send(message).await {
warn!("conn receiver for {peer_id:?} dropped");
self.conn_send_tx.remove(&peer_id);
self.conn_manager.remove(&peer_id);
}
} else {
debug!(peer = ?peer_id, "dial");
self.dialer.queue_dial(peer_id, GOSSIP_ALPN);
if !self.conn_manager.is_pending(&peer_id) {
debug!(peer = ?peer_id, "dial");
self.conn_manager.dial(peer_id);
}
// TODO: Enforce max length
self.pending_sends.entry(peer_id).or_default().push(message);
}
Expand All @@ -544,12 +524,11 @@ impl Actor {
self.timers.insert(now + delay, timer);
}
OutEvent::DisconnectPeer(peer) => {
if let Some(conn) = self.conns.remove(&peer) {
conn.close(0u8.into(), b"close from disconnect");
}
self.conn_send_tx.remove(&peer);
self.pending_sends.remove(&peer);
self.dialer.abort_dial(&peer);
if let Some(conn) = self.conn_manager.remove(&peer) {
conn.close(0u8.into(), b"close from disconnect");
}
}
OutEvent::PeerData(node_id, data) => match decode_peer_data(&data) {
Err(err) => warn!("Failed to decode {data:?} from {node_id}: {err}"),
Expand All @@ -566,6 +545,33 @@ impl Actor {
Ok(())
}

async fn handle_new_connection(&mut self, new_conn: ConnInfo) {
let ConnInfo {
conn,
node_id,
direction,
} = new_conn;
let (send_tx, send_rx) = mpsc::channel(SEND_QUEUE_CAP);
self.conn_send_tx.insert(node_id, send_tx.clone());

// Spawn a task for this connection
let pending_sends = self.pending_sends.remove(&node_id);
let in_event_tx = self.in_event_tx.clone();
debug!(peer=%node_id.fmt_short(), ?direction, "connection established");
self.conn_tasks.spawn(
connection_loop(
node_id,
conn,
direction,
send_rx,
in_event_tx,
pending_sends,
)
.map(move |r| (node_id, r))
.instrument(error_span!("gossip_conn", peer = %node_id.fmt_short())),
);
}

fn subscribe_all(&mut self) -> broadcast::Receiver<(TopicId, Event)> {
if let Some(tx) = self.subscribers_all.as_mut() {
tx.subscribe()
Expand Down Expand Up @@ -602,16 +608,26 @@ async fn wait_for_neighbor_up(mut sub: broadcast::Receiver<Event>) -> anyhow::Re
async fn connection_loop(
from: PublicKey,
conn: Connection,
origin: ConnOrigin,
direction: ConnDirection,
mut send_rx: mpsc::Receiver<ProtoMessage>,
in_event_tx: &mpsc::Sender<InEvent>,
in_event_tx: mpsc::Sender<InEvent>,
mut pending_sends: Option<Vec<ProtoMessage>>,
) -> anyhow::Result<()> {
let (mut send, mut recv) = match origin {
ConnOrigin::Accept => conn.accept_bi().await?,
ConnOrigin::Dial => conn.open_bi().await?,
let (mut send, mut recv) = match direction {
ConnDirection::Accept => conn.accept_bi().await?,
ConnDirection::Dial => conn.open_bi().await?,
};
let mut send_buf = BytesMut::new();
let mut recv_buf = BytesMut::new();

// Forward queued pending sends
if let Some(mut send_queue) = pending_sends.take() {
for msg in send_queue.drain(..) {
write_message(&mut send, &mut send_buf, &msg).await?;
}
}

// loop over sending and receiving messages
loop {
tokio::select! {
biased;
Expand Down
5 changes: 3 additions & 2 deletions iroh-net/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,10 @@ quinn = { package = "iroh-quinn", version = "0.10.4" }
quinn-proto = { package = "iroh-quinn-proto", version = "0.10.7" }
quinn-udp = { package = "iroh-quinn-udp", version = "0.4" }
rand = "0.8"
rand_chacha = { version = "0.3.1", optional = true }
rand_core = "0.6.4"
rcgen = "0.11"
reqwest = { version = "0.12.4", default-features = false, features = ["rustls-tls"] }
reqwest = { version = "0.11.19", default-features = false, features = ["rustls-tls"] }
ring = "0.17"
rustls = { version = "0.21.11", default-features = false, features = ["dangerous_configuration"] }
serde = { version = "1", features = ["derive", "rc"] }
Expand Down Expand Up @@ -125,7 +126,7 @@ duct = "0.13.6"
default = ["metrics"]
iroh-relay = ["clap", "toml", "rustls-pemfile", "regex", "serde_with", "tracing-subscriber"]
metrics = ["iroh-metrics/metrics"]
test-utils = ["axum"]
test-utils = ["axum", "rand_chacha"]

[[bin]]
name = "iroh-relay"
Expand Down
Loading

0 comments on commit c8265ee

Please sign in to comment.