Skip to content

Commit

Permalink
feat(iroh-net): allow the underlying UdpSockets to be rebound (#2946)
Browse files Browse the repository at this point in the history
## Description

In order to handle supsension and exits on mobile. we need to rebind our
UDP sockets when they break.

This PR adds the ability to rebind the socket on errors, and does so
automatically on known suspension errors for iOS.

When reviewing this, please specifically look at the duration of lock
holding, as this is the most sensitive part in this code.


Some references for these errors

- libevent/libevent#1031
- #2939

### TODOs

- [x] code cleanup
- [x] testing on actual ios apps, to see if this actually fixes the
issues
- [ ] potentially handle port still being in use? this needs some more
thoughts

Closes #2939

## Breaking Changes

The overall API for `netmon::UdpSocket` has changed entirely, everything
else is the same.

## Notes & open questions

- I have tried putting this logic higher in the stack, but unfortunately
that did not work out.
- We might not want to infinitely rebind a socket if the same error
happens over and over again, unclear how to handle this.


## Change checklist

- [ ] Self-review.
- [ ] Documentation updates following the [style
guide](https://rust-lang.github.io/rfcs/1574-more-api-documentation-conventions.html#appendix-a-full-conventions-text),
if relevant.
- [ ] Tests if relevant.
- [ ] All breaking changes documented.

---------

Co-authored-by: Philipp Krüger <[email protected]>
  • Loading branch information
dignifiedquire and matheus23 authored Nov 26, 2024
1 parent 4abfd61 commit cc9e4e6
Show file tree
Hide file tree
Showing 8 changed files with 953 additions and 187 deletions.
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion iroh-net-report/src/reportgen/hairpin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ impl Actor {
.context("net_report actor gone")?;
msg_response_rx.await.context("net_report actor died")?;

if let Err(err) = socket.send_to(&stun::request(txn), dst).await {
if let Err(err) = socket.send_to(&stun::request(txn), dst.into()).await {
warn!(%dst, "failed to send hairpin check");
return Err(err.into());
}
Expand Down
90 changes: 64 additions & 26 deletions iroh-net/src/magicsock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ use futures_util::stream::BoxStream;
use iroh_base::key::NodeId;
use iroh_metrics::{inc, inc_by};
use iroh_relay::protos::stun;
use netwatch::{interfaces, ip::LocalAddresses, netmon};
use netwatch::{interfaces, ip::LocalAddresses, netmon, UdpSocket};
use quinn::AsyncUdpSocket;
use rand::{seq::SliceRandom, Rng, SeedableRng};
use smallvec::{smallvec, SmallVec};
Expand Down Expand Up @@ -441,11 +441,8 @@ impl MagicSock {
// Right now however we have one single poller behaving the same for each
// connection. It checks all paths and returns Poll::Ready as soon as any path is
// ready.
let ipv4_poller = Arc::new(self.pconn4.clone()).create_io_poller();
let ipv6_poller = self
.pconn6
.as_ref()
.map(|sock| Arc::new(sock.clone()).create_io_poller());
let ipv4_poller = self.pconn4.create_io_poller();
let ipv6_poller = self.pconn6.as_ref().map(|sock| sock.create_io_poller());
let relay_sender = self.relay_actor_sender.clone();
Box::pin(IoPoller {
ipv4_poller,
Expand Down Expand Up @@ -1091,10 +1088,9 @@ impl MagicSock {
Err(err) if err.kind() == io::ErrorKind::WouldBlock => {
// This is the socket .try_send_disco_message_udp used.
let sock = self.conn_for_addr(dst)?;
let sock = Arc::new(sock.clone());
let mut poller = sock.create_io_poller();
match poller.as_mut().poll_writable(cx)? {
Poll::Ready(()) => continue,
match sock.as_socket_ref().poll_writable(cx) {
Poll::Ready(Ok(())) => continue,
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
Poll::Pending => return Poll::Pending,
}
}
Expand Down Expand Up @@ -1408,6 +1404,9 @@ impl Handle {
let net_reporter =
net_report::Client::new(Some(port_mapper.clone()), dns_resolver.clone())?;

let pconn4_sock = pconn4.as_socket();
let pconn6_sock = pconn6.as_ref().map(|p| p.as_socket());

let (actor_sender, actor_receiver) = mpsc::channel(256);
let (relay_actor_sender, relay_actor_receiver) = mpsc::channel(256);
let (udp_disco_sender, mut udp_disco_receiver) = mpsc::channel(256);
Expand All @@ -1431,9 +1430,9 @@ impl Handle {
ipv6_reported: Arc::new(AtomicBool::new(false)),
relay_map,
my_relay: Default::default(),
pconn4: pconn4.clone(),
pconn6: pconn6.clone(),
net_reporter: net_reporter.addr(),
pconn4,
pconn6,
disco_secrets: DiscoSecrets::default(),
node_map,
relay_actor_sender: relay_actor_sender.clone(),
Expand Down Expand Up @@ -1481,8 +1480,8 @@ impl Handle {
periodic_re_stun_timer: new_re_stun_timer(false),
net_info_last: None,
port_mapper,
pconn4,
pconn6,
pconn4: pconn4_sock,
pconn6: pconn6_sock,
no_v4_send: false,
net_reporter,
network_monitor,
Expand Down Expand Up @@ -1720,8 +1719,8 @@ struct Actor {
net_info_last: Option<NetInfo>,

// The underlying UDP sockets used to send/rcv packets.
pconn4: UdpConn,
pconn6: Option<UdpConn>,
pconn4: Arc<UdpSocket>,
pconn6: Option<Arc<UdpSocket>>,

/// The NAT-PMP/PCP/UPnP prober/client, for requesting port mappings from NAT devices.
port_mapper: portmapper::Client,
Expand Down Expand Up @@ -1861,6 +1860,14 @@ impl Actor {
debug!("link change detected: major? {}", is_major);

if is_major {
if let Err(err) = self.pconn4.rebind() {
warn!("failed to rebind Udp IPv4 socket: {:?}", err);
};
if let Some(ref pconn6) = self.pconn6 {
if let Err(err) = pconn6.rebind() {
warn!("failed to rebind Udp IPv6 socket: {:?}", err);
};
}
self.msock.dns_resolver.clear_cache();
self.msock.re_stun("link-change-major");
self.close_stale_relay_connections().await;
Expand Down Expand Up @@ -1893,14 +1900,6 @@ impl Actor {
self.port_mapper.deactivate();
self.relay_actor_cancel_token.cancel();

// Ignore errors from pconnN
// They will frequently have been closed already by a call to connBind.Close.
debug!("stopping connections");
if let Some(ref conn) = self.pconn6 {
conn.close().await.ok();
}
self.pconn4.close().await.ok();

debug!("shutdown complete");
return true;
}
Expand Down Expand Up @@ -2206,8 +2205,8 @@ impl Actor {
}

let relay_map = self.msock.relay_map.clone();
let pconn4 = Some(self.pconn4.as_socket());
let pconn6 = self.pconn6.as_ref().map(|p| p.as_socket());
let pconn4 = Some(self.pconn4.clone());
let pconn6 = self.pconn6.clone();

debug!("requesting net_report report");
match self
Expand Down Expand Up @@ -3099,6 +3098,45 @@ mod tests {
Ok(())
}

#[tokio::test]
async fn test_regression_network_change_rebind_wakes_connection_driver(
) -> testresult::TestResult {
let _ = iroh_test::logging::setup();
let m1 = MagicStack::new(RelayMode::Disabled).await?;
let m2 = MagicStack::new(RelayMode::Disabled).await?;

println!("Net change");
m1.endpoint.magic_sock().force_network_change(true).await;
tokio::time::sleep(Duration::from_secs(1)).await; // wait for socket rebinding

let _guard = mesh_stacks(vec![m1.clone(), m2.clone()]).await?;

let _handle = AbortOnDropHandle::new(tokio::spawn({
let endpoint = m2.endpoint.clone();
async move {
while let Some(incoming) = endpoint.accept().await {
println!("Incoming first conn!");
let conn = incoming.await?;
conn.closed().await;
}

testresult::TestResult::Ok(())
}
}));

println!("first conn!");
let conn = m1
.endpoint
.connect(m2.endpoint.node_addr().await?, ALPN)
.await?;
println!("Closing first conn");
conn.close(0u32.into(), b"bye lolz");
conn.closed().await;
println!("Closed first conn");

Ok(())
}

#[tokio::test(flavor = "multi_thread")]
async fn test_two_devices_roundtrip_network_change() -> Result<()> {
time::timeout(
Expand Down
114 changes: 25 additions & 89 deletions iroh-net/src/magicsock/udp_conn.rs
Original file line number Diff line number Diff line change
@@ -1,69 +1,57 @@
use std::{
fmt::Debug,
future::Future,
io,
net::SocketAddr,
pin::Pin,
sync::Arc,
task::{ready, Context, Poll},
task::{Context, Poll},
};

use anyhow::{bail, Context as _};
use netwatch::UdpSocket;
use quinn::AsyncUdpSocket;
use quinn_udp::{Transmit, UdpSockRef};
use tokio::io::Interest;
use tracing::{debug, trace};
use quinn_udp::Transmit;
use tracing::debug;

/// A UDP socket implementing Quinn's [`AsyncUdpSocket`].
#[derive(Clone, Debug)]
#[derive(Debug, Clone)]
pub struct UdpConn {
io: Arc<UdpSocket>,
inner: Arc<quinn_udp::UdpSocketState>,
}

impl UdpConn {
pub(super) fn as_socket(&self) -> Arc<UdpSocket> {
self.io.clone()
}

pub(super) fn as_socket_ref(&self) -> &UdpSocket {
&self.io
}

pub(super) fn bind(addr: SocketAddr) -> anyhow::Result<Self> {
let sock = bind(addr)?;
let state = quinn_udp::UdpSocketState::new(quinn_udp::UdpSockRef::from(&sock))?;
Ok(Self {
io: Arc::new(sock),
inner: Arc::new(state),
})

Ok(Self { io: Arc::new(sock) })
}

pub fn port(&self) -> u16 {
self.local_addr().map(|p| p.port()).unwrap_or_default()
}

#[allow(clippy::unused_async)]
pub async fn close(&self) -> Result<(), io::Error> {
// Nothing to do atm
Ok(())
pub(super) fn create_io_poller(&self) -> Pin<Box<dyn quinn::UdpPoller>> {
Box::pin(IoPoller {
io: self.io.clone(),
})
}
}

impl AsyncUdpSocket for UdpConn {
fn create_io_poller(self: Arc<Self>) -> Pin<Box<dyn quinn::UdpPoller>> {
let sock = self.io.clone();
Box::pin(IoPoller {
next_waiter: move || {
let sock = sock.clone();
async move { sock.writable().await }
},
waiter: None,
})
(*self).create_io_poller()
}

fn try_send(&self, transmit: &Transmit<'_>) -> io::Result<()> {
self.io.try_io(Interest::WRITABLE, || {
let sock_ref = UdpSockRef::from(&self.io);
self.inner.send(sock_ref, transmit)
})
self.io.try_send_quinn(transmit)
}

fn poll_recv(
Expand All @@ -72,40 +60,23 @@ impl AsyncUdpSocket for UdpConn {
bufs: &mut [io::IoSliceMut<'_>],
meta: &mut [quinn_udp::RecvMeta],
) -> Poll<io::Result<usize>> {
loop {
ready!(self.io.poll_recv_ready(cx))?;
if let Ok(res) = self.io.try_io(Interest::READABLE, || {
self.inner.recv(Arc::as_ref(&self.io).into(), bufs, meta)
}) {
for meta in meta.iter().take(res) {
trace!(
src = %meta.addr,
len = meta.len,
count = meta.len / meta.stride,
dst = %meta.dst_ip.map(|x| x.to_string()).unwrap_or_default(),
"UDP recv"
);
}

return Poll::Ready(Ok(res));
}
}
self.io.poll_recv_quinn(cx, bufs, meta)
}

fn local_addr(&self) -> io::Result<SocketAddr> {
self.io.local_addr()
}

fn may_fragment(&self) -> bool {
self.inner.may_fragment()
self.io.may_fragment()
}

fn max_transmit_segments(&self) -> usize {
self.inner.max_gso_segments()
self.io.max_gso_segments()
}

fn max_receive_segments(&self) -> usize {
self.inner.gro_segments()
self.io.gro_segments()
}
}

Expand Down Expand Up @@ -147,49 +118,14 @@ fn bind(mut addr: SocketAddr) -> anyhow::Result<UdpSocket> {
}

/// Poller for when the socket is writable.
///
/// The tricky part is that we only have `tokio::net::UdpSocket::writable()` to create the
/// waiter we need, which does not return a named future type. In order to be able to store
/// this waiter in a struct without boxing we need to specify the future itself as a type
/// parameter, which we can only do if we introduce a second type parameter which returns
/// the future. So we end up with a function which we do not need, but it makes the types
/// work.
#[derive(derive_more::Debug)]
#[pin_project::pin_project]
struct IoPoller<F, Fut>
where
F: Fn() -> Fut + Send + Sync + 'static,
Fut: Future<Output = io::Result<()>> + Send + Sync + 'static,
{
/// Function which can create a new waiter if there is none.
#[debug("next_waiter")]
next_waiter: F,
/// The waiter which tells us when the socket is writable.
#[debug("waiter")]
#[pin]
waiter: Option<Fut>,
#[derive(Debug)]
struct IoPoller {
io: Arc<UdpSocket>,
}

impl<F, Fut> quinn::UdpPoller for IoPoller<F, Fut>
where
F: Fn() -> Fut + Send + Sync + 'static,
Fut: Future<Output = io::Result<()>> + Send + Sync + 'static,
{
impl quinn::UdpPoller for IoPoller {
fn poll_writable(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
let mut this = self.project();
if this.waiter.is_none() {
this.waiter.set(Some((this.next_waiter)()));
}
let result = this
.waiter
.as_mut()
.as_pin_mut()
.expect("just set")
.poll(cx);
if result.is_ready() {
this.waiter.set(None);
}
result
self.io.poll_writable(cx)
}
}

Expand Down
Loading

0 comments on commit cc9e4e6

Please sign in to comment.