Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[wip] revisiting shutdown #2205

120 changes: 103 additions & 17 deletions iroh-net/src/magic_endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use anyhow::{anyhow, bail, ensure, Context, Result};
use derive_more::Debug;
use futures::StreamExt;
use quinn_proto::VarInt;
use tokio::sync::watch;
use tokio_util::sync::{CancellationToken, WaitForCancellationFuture};
use tracing::{debug, trace};

Expand All @@ -15,7 +16,7 @@ use crate::{
discovery::{Discovery, DiscoveryTask},
dns::{default_resolver, DnsResolver},
key::{PublicKey, SecretKey},
magicsock::{self, ConnectionTypeStream, MagicSock},
magicsock::{self, ConnectionTypeStream},
relay::{RelayMap, RelayMode, RelayUrl},
tls, NodeId,
};
Expand Down Expand Up @@ -226,14 +227,15 @@ pub fn make_server_config(
Ok(server_config)
}

/// An endpoint that leverages a [quinn::Endpoint] backed by a [magicsock::MagicSock].
/// An endpoint that leverages a [`quinn::Endpoint`] backed by a [`magicsock::MagicSock`].
#[derive(Clone, Debug)]
pub struct MagicEndpoint {
secret_key: Arc<SecretKey>,
msock: MagicSock,
msock: Arc<magicsock::MagicSockState>,
endpoint: quinn::Endpoint,
keylog: bool,
cancel_token: CancellationToken,
on_close_watcher: watch::Receiver<Result<()>>,
}

impl MagicEndpoint {
Expand All @@ -252,7 +254,7 @@ impl MagicEndpoint {
keylog: bool,
) -> Result<Self> {
let secret_key = msock_opts.secret_key.clone();
let msock = magicsock::MagicSock::new(msock_opts).await?;
let (quinn_sock, on_close_watcher) = QuinnSock::new(msock_opts).await?;
trace!("created magicsock");

let mut endpoint_config = quinn::EndpointConfig::default();
Expand All @@ -263,10 +265,12 @@ impl MagicEndpoint {
// the packet if grease_quic_bit is set to false.
endpoint_config.grease_quic_bit(false);

let msock = quinn_sock.msock.state();

let endpoint = quinn::Endpoint::new_with_abstract_socket(
endpoint_config,
server_config,
msock.clone(),
quinn_sock,
Arc::new(quinn::TokioRuntime),
)?;
trace!("created quinn endpoint");
Expand All @@ -277,6 +281,7 @@ impl MagicEndpoint {
endpoint,
keylog,
cancel_token: CancellationToken::new(),
on_close_watcher,
})
}

Expand Down Expand Up @@ -304,7 +309,7 @@ impl MagicEndpoint {
///
/// Returns a tuple of the IPv4 and the optional IPv6 address.
pub fn local_addr(&self) -> Result<(SocketAddr, Option<SocketAddr>)> {
self.msock.local_addr()
Ok(self.msock.local_addr())
}

/// Returns the local endpoints as a stream.
Expand All @@ -313,11 +318,11 @@ impl MagicEndpoint {
/// addresses it can listen on, for changes. Whenever changes are detected this stream
/// will yield a new list of endpoints.
///
/// Upon the first creation on the [`MagicSock`] it may not yet have completed a first
/// local endpoint discovery, in this case the first item of the stream will not be
/// immediately available. Once this first set of local endpoints are discovered the
/// stream will always return the first set of endpoints immediately, which are the most
/// recently discovered endpoints.
/// Upon the first creation on the [`magicsock::MagicSock`] it may not yet have completed a
/// first local endpoint discovery, in this case the first item of the stream will not be
/// immediately available. Once this first set of local endpoints are discovered the stream
/// will always return the first set of endpoints immediately, which are the most recently
/// discovered endpoints.
///
/// The list of endpoints yielded contains both the locally-bound addresses and the
/// endpoint's publicly-reachable addresses, if they could be discovered through STUN or
Expand Down Expand Up @@ -560,10 +565,28 @@ impl MagicEndpoint {
///
/// Returns an error if closing the magic socket failed.
/// TODO: Document error cases.
pub async fn close(&self, error_code: VarInt, reason: &[u8]) -> Result<()> {
self.cancel_token.cancel();
self.endpoint.close(error_code, reason);
self.msock.close().await?;
pub async fn close(self, error_code: VarInt, reason: &[u8]) -> Result<()> {
let Self {
endpoint,
cancel_token,
mut on_close_watcher,
..
} = self;
cancel_token.cancel();
endpoint.close(error_code, reason);
// this is necessary to make the quinn::Endpoint drop the underlying socket (MagicSock)
drop(endpoint);
on_close_watcher.changed().await.map_err(|_| {
// `changed` docs state:
// > This method returns an error if and only if the Sender is dropped.
// This would somehow mean drop was performed without performing close operations.
// Better to bubble up in case we see it
anyhow!("close triggered but result unknown")
})?;
on_close_watcher
.borrow()
.as_ref()
.map_err(|e| anyhow!("Closing MagicSock failed: {e}"))?;
Ok(())
}

Expand All @@ -582,15 +605,78 @@ impl MagicEndpoint {
}

#[cfg(test)]
pub(crate) fn magic_sock(&self) -> &MagicSock {
&self.msock
pub(crate) fn magic_sock_state(&self) -> Arc<magicsock::MagicSockState> {
self.msock.clone()
}
#[cfg(test)]
pub(crate) fn endpoint(&self) -> &quinn::Endpoint {
&self.endpoint
}
}

/// A [`magicsock::MagicSock`] handle for [`quinn::Endpoint`]
#[derive(Debug)]
struct QuinnSock {
msock: magicsock::MagicSock,
on_close_sender: Option<watch::Sender<Result<()>>>,
}

impl QuinnSock {
async fn new(
msock_opts: magicsock::Options,
) -> Result<(Self, watch::Receiver<anyhow::Result<()>>)> {
let msock = magicsock::MagicSock::new(msock_opts).await?;
let (tx, rx) = watch::channel(Ok(()));
let qs = QuinnSock {
msock,
on_close_sender: Some(tx),
};
Ok((qs, rx))
}
}

impl Drop for QuinnSock {
fn drop(&mut self) {
let on_close_sender = self
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you could use a CancelToken for this, avoiding the Option dance

.on_close_sender
.take()
.expect("only taken on drop; not yet dropped");
if let Ok(rt) = tokio::runtime::Handle::try_current() {
let msock = self.msock.clone();
rt.spawn(async move {
let close_result = msock.close().await;
let _ = on_close_sender.send(close_result);
});
} else {
tracing::warn!("dropping Magisock outside an active tokio runtime");
}
}
}

impl quinn::AsyncUdpSocket for QuinnSock {
fn poll_send(
&self,
state: &quinn_udp::UdpState,
cx: &mut std::task::Context,
transmits: &[quinn_udp::Transmit],
) -> std::task::Poll<std::prelude::v1::Result<usize, std::io::Error>> {
self.msock.poll_send(state, cx, transmits)
}

fn poll_recv(
&self,
cx: &mut std::task::Context,
bufs: &mut [std::io::IoSliceMut<'_>],
meta: &mut [quinn_udp::RecvMeta],
) -> std::task::Poll<std::io::Result<usize>> {
self.msock.poll_recv(cx, bufs, meta)
}

fn local_addr(&self) -> std::io::Result<SocketAddr> {
quinn::AsyncUdpSocket::local_addr(&self.msock)
}
}

/// Accept an incoming connection and extract the client-provided [`PublicKey`] and ALPN protocol.
pub async fn accept_conn(
mut conn: quinn::Connecting,
Expand Down
Loading
Loading