From e7be7eccf07adf6870c3a0427ee764e352719b70 Mon Sep 17 00:00:00 2001 From: dignifiedquire Date: Fri, 3 Feb 2023 12:27:40 +0100 Subject: [PATCH] refactor: fully shutdown nested tasks --- Cargo.lock | 22 ++++++++++ Cargo.toml | 2 + src/provider.rs | 104 ++++++++++++++++++++++++++++++------------------ 3 files changed, 89 insertions(+), 39 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index bcd232d..4cbbe8e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -451,6 +451,17 @@ dependencies = [ "syn", ] +[[package]] +name = "derivative" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fcc3dd5e9e9c0b295d6e1e4d811fb6f157d5ffd784b8d202fc62eac8035a770b" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "digest" version = "0.9.0" @@ -1747,6 +1758,7 @@ dependencies = [ "clap", "console", "der", + "derivative", "ed25519-dalek", "futures", "hex", @@ -1765,6 +1777,7 @@ dependencies = [ "testdir", "thiserror", "tokio", + "tokio-context", "tokio-util", "tracing", "tracing-subscriber", @@ -2103,6 +2116,15 @@ dependencies = [ "windows-sys 0.42.0", ] +[[package]] +name = "tokio-context" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eaf0b8394dd5ca9a1b726c629390154c19222dfd7467a4b56f1ced90adee3958" +dependencies = [ + "tokio", +] + [[package]] name = "tokio-macros" version = "1.8.2" diff --git a/Cargo.toml b/Cargo.toml index 966f72e..8614f79 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,6 +19,7 @@ bytes = "1" clap = { version = "4", features = ["derive"], optional = true } console = { version = "0.15.5", optional = true } der = { version = "0.6", features = ["alloc", "derive"] } +derivative = "2.2.0" ed25519-dalek = { version = "1.0.1", features = ["serde"] } futures = "0.3.25" indicatif = { version = "0.17", features = ["tokio"], optional = true } @@ -33,6 +34,7 @@ ssh-key = { version = "0.5.1", features = ["ed25519", "std", "rand_core"] } tempfile = "3" thiserror = "1" tokio = { version = "1", features = ["full"] } +tokio-context = "0.1.3" tokio-util = { version = "0.7", features = ["io-util", "io"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/src/provider.rs b/src/provider.rs index 40ed762..5aa438c 100644 --- a/src/provider.rs +++ b/src/provider.rs @@ -5,15 +5,17 @@ use std::path::PathBuf; use std::str::FromStr; use std::{collections::HashMap, sync::Arc}; -use anyhow::{anyhow, bail, ensure, Context, Result}; +use anyhow::{anyhow, bail, ensure, Context as _, Result}; use bao::encode::SliceExtractor; use bytes::{Bytes, BytesMut}; +use derivative::Derivative; use s2n_quic::stream::BidirectionalStream; -use s2n_quic::Server as QuicServer; +use s2n_quic::{Connection, Server as QuicServer}; use serde::{Deserialize, Serialize}; use tokio::io::{AsyncWrite, AsyncWriteExt}; -use tokio::sync::{broadcast, oneshot}; +use tokio::sync::broadcast; use tokio::task::{JoinError, JoinHandle}; +use tokio_context::context::{Context, Handle as ContextHandle, RefContext}; use tokio_util::io::SyncIoBridge; use tracing::{debug, warn}; @@ -103,16 +105,9 @@ impl Builder { let db2 = self.db.clone(); let (events_sender, _events_receiver) = broadcast::channel(8); let events = events_sender.clone(); - let (shutdown_sender, shutdown_receiver) = oneshot::channel(); - let task = tokio::spawn(async move { - Self::run( - server, - db2, - self.auth_token, - events_sender, - shutdown_receiver, - ) - .await + let (ctx, handle) = RefContext::new(); + let task = tokio::task::spawn(async move { + Self::run(ctx, server, db2, self.auth_token, events_sender).await }); Ok(Provider { @@ -121,52 +116,81 @@ impl Builder { auth_token: self.auth_token, task, events, - shutdown: shutdown_sender, + handle, }) } async fn run( + ctx: RefContext, mut server: s2n_quic::server::Server, db: Database, token: AuthToken, events: broadcast::Sender, - mut shutdown: oneshot::Receiver<()>, ) { debug!("\nlistening at: {:#?}", server.local_addr().unwrap()); + let (mut current_ctx, _handle) = Context::with_parent(&ctx, None); loop { tokio::select! { - biased; - - _ = &mut shutdown => { - break; + _ = current_ctx.done() => { + return; } - Some(mut connection) = server.accept() => { + Some(connection) = server.accept() => { let db = db.clone(); let events = events.clone(); - tokio::spawn(async move { - debug!("connection accepted from {:?}", connection.remote_addr()); - while let Ok(Some(stream)) = connection.accept_bidirectional_stream().await { - let _ = events.send(Event::ClientConnected { - connection_id: connection.id(), - }); - let db = db.clone(); - let events = events.clone(); - tokio::spawn(async move { - if let Err(err) = handle_stream(db, token, stream, events).await { - warn!("error: {:#?}", err); - } - debug!("disconnected"); - }); - } - }); + let (current_ctx, _handle) = RefContext::with_parent(&ctx, None); + tokio::spawn(async move { handle_connection(current_ctx, connection, db, token, events).await }); } } } } } +async fn handle_connection( + ctx: RefContext, + mut connection: Connection, + db: Database, + token: AuthToken, + events: broadcast::Sender, +) { + debug!("connection accepted from {:?}", connection.remote_addr()); + let (mut current_ctx, _handle) = Context::with_parent(&ctx, None); + loop { + tokio::select! { + biased; + + _ = current_ctx.done() => { + break; + } + Ok(Some(stream)) = connection.accept_bidirectional_stream() => { + let _ = events.send(Event::ClientConnected { + connection_id: connection.id(), + }); + let db = db.clone(); + let events = events.clone(); + let (mut current_ctx, _handle) = Context::with_parent(&ctx, None); + + tokio::spawn(async move { + tokio::select! { + biased; + + _ = current_ctx.done() => { + return; + } + res = handle_stream(db, token, stream, events) => { + if let Err(err) = res { + warn!("error: {:#?}", err); + } + } + } + debug!("disconnected"); + }); + } + } + } +} + /// A server which implements the sendme provider. /// /// Clients can connect to this server and requests hashes from it. @@ -175,14 +199,16 @@ impl Builder { /// is a shorthand to create a suitable [`Builder`]. /// /// This runs a tokio task which can be aborted and joined if desired. -#[derive(Debug)] +#[derive(Derivative)] +#[derivative(Debug)] pub struct Provider { listen_addr: SocketAddr, keypair: Keypair, auth_token: AuthToken, task: JoinHandle<()>, events: broadcast::Sender, - shutdown: oneshot::Sender<()>, + #[derivative(Debug = "ignore")] + handle: ContextHandle, } /// Events emitted by the [`Provider`] informing about the current status. @@ -250,7 +276,7 @@ impl Provider { /// Gracefully shuts down the provider. pub async fn shutdown(self) -> Result<(), JoinError> { - let _ = self.shutdown.send(()); + self.handle.cancel(); self.task.await } }