From e97ee780d7482c5b8bdc981d622bb105ab245a38 Mon Sep 17 00:00:00 2001 From: dignifiedquire Date: Thu, 2 Feb 2023 18:52:08 +0100 Subject: [PATCH 1/3] feat(provider): graceful shutdown This stops accepting new connections, but does not currently abort transfers. Closes #77 --- Cargo.lock | 29 ++++++++++++++++++++++ Cargo.toml | 1 + src/lib.rs | 5 ++-- src/main.rs | 9 +++++-- src/provider.rs | 66 ++++++++++++++++++++++++++++--------------------- 5 files changed, 77 insertions(+), 33 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index bcd232d..bce8f20 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -383,6 +383,16 @@ dependencies = [ "typenum", ] +[[package]] +name = "ctrlc" +version = "3.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1631ca6e3c59112501a9d87fd86f21591ff77acd31331e8a73f8d80a65bbdd71" +dependencies = [ + "nix", + "windows-sys", +] + [[package]] name = "cuckoofilter" version = "0.5.0" @@ -945,6 +955,18 @@ dependencies = [ "windows-sys 0.42.0", ] +[[package]] +name = "nix" +version = "0.26.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfdda3d196821d6af13126e40375cdf7da646a96114af134d5f417a9a1dc8e1a" +dependencies = [ + "bitflags", + "cfg-if", + "libc", + "static_assertions", +] + [[package]] name = "nom" version = "7.1.3" @@ -1746,6 +1768,7 @@ dependencies = [ "bytes", "clap", "console", + "ctrlc", "der", "ed25519-dalek", "futures", @@ -1941,6 +1964,12 @@ dependencies = [ "zeroize", ] +[[package]] +name = "static_assertions" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" + [[package]] name = "strsim" version = "0.10.0" diff --git a/Cargo.toml b/Cargo.toml index 966f72e..90b98b3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,6 +18,7 @@ blake3 = "1.3.3" bytes = "1" clap = { version = "4", features = ["derive"], optional = true } console = { version = "0.15.5", optional = true } +ctrlc = "3.2.4" der = { version = "0.6", features = ["alloc", "derive"] } ed25519-dalek = { version = "1.0.1", features = ["serde"] } futures = "0.3.25" diff --git a/src/lib.rs b/src/lib.rs index e847b27..8c310c9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -240,10 +240,9 @@ mod tests { ) .await?; - provider.abort(); - let _ = provider.join().await; + provider.shutdown().await?; - let events = events_task.await.unwrap(); + let events = events_task.await?; assert_eq!(events.len(), 3); Ok(()) diff --git a/src/main.rs b/src/main.rs index 428b3f8..e043881 100644 --- a/src/main.rs +++ b/src/main.rs @@ -9,7 +9,7 @@ use indicatif::{ use sendme::protocol::AuthToken; use sendme::provider::Ticket; use tokio::io::AsyncWriteExt; -use tokio::sync::Mutex; +use tokio::sync::{mpsc, Mutex}; use tracing_subscriber::{fmt, prelude::*, EnvFilter}; use sendme::{get, provider, util, Keypair, PeerId}; @@ -204,7 +204,12 @@ async fn main() -> Result<()> { out_writer .println(format!("All-in-one ticket: {}", provider.ticket(hash))) .await; - provider.join().await?; + + let (s, mut r) = mpsc::channel(1); + ctrlc::set_handler(move || s.try_send(()).expect("failed to send shutdown signal"))?; + r.recv().await; + out_writer.println("Shutting down...").await; + provider.shutdown().await?; // Drop tempath to signal it can be destroyed drop(tmp_path); diff --git a/src/provider.rs b/src/provider.rs index 868ec22..3034c28 100644 --- a/src/provider.rs +++ b/src/provider.rs @@ -12,7 +12,7 @@ use s2n_quic::stream::BidirectionalStream; use s2n_quic::Server as QuicServer; use serde::{Deserialize, Serialize}; use tokio::io::{AsyncWrite, AsyncWriteExt}; -use tokio::sync::broadcast; +use tokio::sync::{broadcast, oneshot}; use tokio::task::{JoinError, JoinHandle}; use tokio_util::io::SyncIoBridge; use tracing::{debug, warn}; @@ -103,10 +103,17 @@ impl Builder { let db2 = self.db.clone(); let (events_sender, _events_receiver) = broadcast::channel(8); let events = events_sender.clone(); - let task = - tokio::spawn( - async move { Self::run(server, db2, self.auth_token, events_sender).await }, - ); + let (shutdown_sender, shutdown_receiver) = oneshot::channel(); + let task = tokio::spawn(async move { + Self::run( + server, + db2, + self.auth_token, + events_sender, + shutdown_receiver, + ) + .await + }); Ok(Provider { listen_addr, @@ -114,6 +121,7 @@ impl Builder { auth_token: self.auth_token, task, events, + shutdown: shutdown_sender, }) } @@ -122,28 +130,36 @@ impl Builder { db: Database, token: AuthToken, events: broadcast::Sender, + mut shutdown: oneshot::Receiver<()>, ) { debug!("\nlistening at: {:#?}", server.local_addr().unwrap()); - while let Some(mut connection) = server.accept().await { - let db = db.clone(); - let events = events.clone(); - tokio::spawn(async move { - debug!("connection accepted from {:?}", connection.remote_addr()); - while let Ok(Some(stream)) = connection.accept_bidirectional_stream().await { - let _ = events.send(Event::ClientConnected { - connection_id: connection.id(), - }); + loop { + tokio::select! { + Some(mut connection) = server.accept() => { let db = db.clone(); let events = events.clone(); tokio::spawn(async move { - if let Err(err) = handle_stream(db, token, stream, events).await { - warn!("error: {:#?}", err); + debug!("connection accepted from {:?}", connection.remote_addr()); + while let Ok(Some(stream)) = connection.accept_bidirectional_stream().await { + let _ = events.send(Event::ClientConnected { + connection_id: connection.id(), + }); + let db = db.clone(); + let events = events.clone(); + tokio::spawn(async move { + if let Err(err) = handle_stream(db, token, stream, events).await { + warn!("error: {:#?}", err); + } + debug!("disconnected"); + }); } - debug!("disconnected"); }); } - }); + _ = &mut shutdown => { + break; + } + } } } } @@ -163,6 +179,7 @@ pub struct Provider { auth_token: AuthToken, task: JoinHandle<()>, events: broadcast::Sender, + shutdown: oneshot::Sender<()>, } /// Events emitted by the [`Provider`] informing about the current status. @@ -228,18 +245,11 @@ impl Provider { } } - /// Blocks until the provider task completes. - // TODO: Maybe implement Future directly? - pub async fn join(self) -> Result<(), JoinError> { + /// Gracefully shuts down the provider. + pub async fn shutdown(self) -> Result<(), JoinError> { + let _ = self.shutdown.send(()); self.task.await } - - /// Aborts the provider. - /// - /// TODO: temporary, do graceful shutdown instead. - pub fn abort(&self) { - self.task.abort(); - } } async fn handle_stream( From 06963a8a2861926c2e7d6560a5812d660805dd9f Mon Sep 17 00:00:00 2001 From: dignifiedquire Date: Fri, 3 Feb 2023 11:44:18 +0100 Subject: [PATCH 2/3] apply some CR --- Cargo.lock | 29 ----------------------------- Cargo.toml | 1 - src/main.rs | 6 ++---- src/provider.rs | 9 ++++++--- 4 files changed, 8 insertions(+), 37 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index bce8f20..bcd232d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -383,16 +383,6 @@ dependencies = [ "typenum", ] -[[package]] -name = "ctrlc" -version = "3.2.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1631ca6e3c59112501a9d87fd86f21591ff77acd31331e8a73f8d80a65bbdd71" -dependencies = [ - "nix", - "windows-sys", -] - [[package]] name = "cuckoofilter" version = "0.5.0" @@ -955,18 +945,6 @@ dependencies = [ "windows-sys 0.42.0", ] -[[package]] -name = "nix" -version = "0.26.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bfdda3d196821d6af13126e40375cdf7da646a96114af134d5f417a9a1dc8e1a" -dependencies = [ - "bitflags", - "cfg-if", - "libc", - "static_assertions", -] - [[package]] name = "nom" version = "7.1.3" @@ -1768,7 +1746,6 @@ dependencies = [ "bytes", "clap", "console", - "ctrlc", "der", "ed25519-dalek", "futures", @@ -1964,12 +1941,6 @@ dependencies = [ "zeroize", ] -[[package]] -name = "static_assertions" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" - [[package]] name = "strsim" version = "0.10.0" diff --git a/Cargo.toml b/Cargo.toml index 90b98b3..966f72e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,7 +18,6 @@ blake3 = "1.3.3" bytes = "1" clap = { version = "4", features = ["derive"], optional = true } console = { version = "0.15.5", optional = true } -ctrlc = "3.2.4" der = { version = "0.6", features = ["alloc", "derive"] } ed25519-dalek = { version = "1.0.1", features = ["serde"] } futures = "0.3.25" diff --git a/src/main.rs b/src/main.rs index e043881..50203cd 100644 --- a/src/main.rs +++ b/src/main.rs @@ -9,7 +9,7 @@ use indicatif::{ use sendme::protocol::AuthToken; use sendme::provider::Ticket; use tokio::io::AsyncWriteExt; -use tokio::sync::{mpsc, Mutex}; +use tokio::sync::Mutex; use tracing_subscriber::{fmt, prelude::*, EnvFilter}; use sendme::{get, provider, util, Keypair, PeerId}; @@ -205,9 +205,7 @@ async fn main() -> Result<()> { .println(format!("All-in-one ticket: {}", provider.ticket(hash))) .await; - let (s, mut r) = mpsc::channel(1); - ctrlc::set_handler(move || s.try_send(()).expect("failed to send shutdown signal"))?; - r.recv().await; + tokio::signal::ctrl_c().await?; out_writer.println("Shutting down...").await; provider.shutdown().await?; diff --git a/src/provider.rs b/src/provider.rs index 3034c28..40ed762 100644 --- a/src/provider.rs +++ b/src/provider.rs @@ -136,6 +136,12 @@ impl Builder { loop { tokio::select! { + biased; + + _ = &mut shutdown => { + break; + } + Some(mut connection) = server.accept() => { let db = db.clone(); let events = events.clone(); @@ -156,9 +162,6 @@ impl Builder { } }); } - _ = &mut shutdown => { - break; - } } } } From e7be7eccf07adf6870c3a0427ee764e352719b70 Mon Sep 17 00:00:00 2001 From: dignifiedquire Date: Fri, 3 Feb 2023 12:27:40 +0100 Subject: [PATCH 3/3] refactor: fully shutdown nested tasks --- Cargo.lock | 22 ++++++++++ Cargo.toml | 2 + src/provider.rs | 104 ++++++++++++++++++++++++++++++------------------ 3 files changed, 89 insertions(+), 39 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index bcd232d..4cbbe8e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -451,6 +451,17 @@ dependencies = [ "syn", ] +[[package]] +name = "derivative" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fcc3dd5e9e9c0b295d6e1e4d811fb6f157d5ffd784b8d202fc62eac8035a770b" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "digest" version = "0.9.0" @@ -1747,6 +1758,7 @@ dependencies = [ "clap", "console", "der", + "derivative", "ed25519-dalek", "futures", "hex", @@ -1765,6 +1777,7 @@ dependencies = [ "testdir", "thiserror", "tokio", + "tokio-context", "tokio-util", "tracing", "tracing-subscriber", @@ -2103,6 +2116,15 @@ dependencies = [ "windows-sys 0.42.0", ] +[[package]] +name = "tokio-context" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eaf0b8394dd5ca9a1b726c629390154c19222dfd7467a4b56f1ced90adee3958" +dependencies = [ + "tokio", +] + [[package]] name = "tokio-macros" version = "1.8.2" diff --git a/Cargo.toml b/Cargo.toml index 966f72e..8614f79 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,6 +19,7 @@ bytes = "1" clap = { version = "4", features = ["derive"], optional = true } console = { version = "0.15.5", optional = true } der = { version = "0.6", features = ["alloc", "derive"] } +derivative = "2.2.0" ed25519-dalek = { version = "1.0.1", features = ["serde"] } futures = "0.3.25" indicatif = { version = "0.17", features = ["tokio"], optional = true } @@ -33,6 +34,7 @@ ssh-key = { version = "0.5.1", features = ["ed25519", "std", "rand_core"] } tempfile = "3" thiserror = "1" tokio = { version = "1", features = ["full"] } +tokio-context = "0.1.3" tokio-util = { version = "0.7", features = ["io-util", "io"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/src/provider.rs b/src/provider.rs index 40ed762..5aa438c 100644 --- a/src/provider.rs +++ b/src/provider.rs @@ -5,15 +5,17 @@ use std::path::PathBuf; use std::str::FromStr; use std::{collections::HashMap, sync::Arc}; -use anyhow::{anyhow, bail, ensure, Context, Result}; +use anyhow::{anyhow, bail, ensure, Context as _, Result}; use bao::encode::SliceExtractor; use bytes::{Bytes, BytesMut}; +use derivative::Derivative; use s2n_quic::stream::BidirectionalStream; -use s2n_quic::Server as QuicServer; +use s2n_quic::{Connection, Server as QuicServer}; use serde::{Deserialize, Serialize}; use tokio::io::{AsyncWrite, AsyncWriteExt}; -use tokio::sync::{broadcast, oneshot}; +use tokio::sync::broadcast; use tokio::task::{JoinError, JoinHandle}; +use tokio_context::context::{Context, Handle as ContextHandle, RefContext}; use tokio_util::io::SyncIoBridge; use tracing::{debug, warn}; @@ -103,16 +105,9 @@ impl Builder { let db2 = self.db.clone(); let (events_sender, _events_receiver) = broadcast::channel(8); let events = events_sender.clone(); - let (shutdown_sender, shutdown_receiver) = oneshot::channel(); - let task = tokio::spawn(async move { - Self::run( - server, - db2, - self.auth_token, - events_sender, - shutdown_receiver, - ) - .await + let (ctx, handle) = RefContext::new(); + let task = tokio::task::spawn(async move { + Self::run(ctx, server, db2, self.auth_token, events_sender).await }); Ok(Provider { @@ -121,52 +116,81 @@ impl Builder { auth_token: self.auth_token, task, events, - shutdown: shutdown_sender, + handle, }) } async fn run( + ctx: RefContext, mut server: s2n_quic::server::Server, db: Database, token: AuthToken, events: broadcast::Sender, - mut shutdown: oneshot::Receiver<()>, ) { debug!("\nlistening at: {:#?}", server.local_addr().unwrap()); + let (mut current_ctx, _handle) = Context::with_parent(&ctx, None); loop { tokio::select! { - biased; - - _ = &mut shutdown => { - break; + _ = current_ctx.done() => { + return; } - Some(mut connection) = server.accept() => { + Some(connection) = server.accept() => { let db = db.clone(); let events = events.clone(); - tokio::spawn(async move { - debug!("connection accepted from {:?}", connection.remote_addr()); - while let Ok(Some(stream)) = connection.accept_bidirectional_stream().await { - let _ = events.send(Event::ClientConnected { - connection_id: connection.id(), - }); - let db = db.clone(); - let events = events.clone(); - tokio::spawn(async move { - if let Err(err) = handle_stream(db, token, stream, events).await { - warn!("error: {:#?}", err); - } - debug!("disconnected"); - }); - } - }); + let (current_ctx, _handle) = RefContext::with_parent(&ctx, None); + tokio::spawn(async move { handle_connection(current_ctx, connection, db, token, events).await }); } } } } } +async fn handle_connection( + ctx: RefContext, + mut connection: Connection, + db: Database, + token: AuthToken, + events: broadcast::Sender, +) { + debug!("connection accepted from {:?}", connection.remote_addr()); + let (mut current_ctx, _handle) = Context::with_parent(&ctx, None); + loop { + tokio::select! { + biased; + + _ = current_ctx.done() => { + break; + } + Ok(Some(stream)) = connection.accept_bidirectional_stream() => { + let _ = events.send(Event::ClientConnected { + connection_id: connection.id(), + }); + let db = db.clone(); + let events = events.clone(); + let (mut current_ctx, _handle) = Context::with_parent(&ctx, None); + + tokio::spawn(async move { + tokio::select! { + biased; + + _ = current_ctx.done() => { + return; + } + res = handle_stream(db, token, stream, events) => { + if let Err(err) = res { + warn!("error: {:#?}", err); + } + } + } + debug!("disconnected"); + }); + } + } + } +} + /// A server which implements the sendme provider. /// /// Clients can connect to this server and requests hashes from it. @@ -175,14 +199,16 @@ impl Builder { /// is a shorthand to create a suitable [`Builder`]. /// /// This runs a tokio task which can be aborted and joined if desired. -#[derive(Debug)] +#[derive(Derivative)] +#[derivative(Debug)] pub struct Provider { listen_addr: SocketAddr, keypair: Keypair, auth_token: AuthToken, task: JoinHandle<()>, events: broadcast::Sender, - shutdown: oneshot::Sender<()>, + #[derivative(Debug = "ignore")] + handle: ContextHandle, } /// Events emitted by the [`Provider`] informing about the current status. @@ -250,7 +276,7 @@ impl Provider { /// Gracefully shuts down the provider. pub async fn shutdown(self) -> Result<(), JoinError> { - let _ = self.shutdown.send(()); + self.handle.cancel(); self.task.await } }