diff --git a/Cargo.lock b/Cargo.lock index bcd232d..bce8f20 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -383,6 +383,16 @@ dependencies = [ "typenum", ] +[[package]] +name = "ctrlc" +version = "3.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1631ca6e3c59112501a9d87fd86f21591ff77acd31331e8a73f8d80a65bbdd71" +dependencies = [ + "nix", + "windows-sys", +] + [[package]] name = "cuckoofilter" version = "0.5.0" @@ -945,6 +955,18 @@ dependencies = [ "windows-sys 0.42.0", ] +[[package]] +name = "nix" +version = "0.26.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfdda3d196821d6af13126e40375cdf7da646a96114af134d5f417a9a1dc8e1a" +dependencies = [ + "bitflags", + "cfg-if", + "libc", + "static_assertions", +] + [[package]] name = "nom" version = "7.1.3" @@ -1746,6 +1768,7 @@ dependencies = [ "bytes", "clap", "console", + "ctrlc", "der", "ed25519-dalek", "futures", @@ -1941,6 +1964,12 @@ dependencies = [ "zeroize", ] +[[package]] +name = "static_assertions" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" + [[package]] name = "strsim" version = "0.10.0" diff --git a/Cargo.toml b/Cargo.toml index 966f72e..90b98b3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,6 +18,7 @@ blake3 = "1.3.3" bytes = "1" clap = { version = "4", features = ["derive"], optional = true } console = { version = "0.15.5", optional = true } +ctrlc = "3.2.4" der = { version = "0.6", features = ["alloc", "derive"] } ed25519-dalek = { version = "1.0.1", features = ["serde"] } futures = "0.3.25" diff --git a/src/lib.rs b/src/lib.rs index e847b27..8c310c9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -240,10 +240,9 @@ mod tests { ) .await?; - provider.abort(); - let _ = provider.join().await; + provider.shutdown().await?; - let events = events_task.await.unwrap(); + let events = events_task.await?; assert_eq!(events.len(), 3); Ok(()) diff --git a/src/main.rs b/src/main.rs index 428b3f8..e043881 100644 --- a/src/main.rs +++ b/src/main.rs @@ -9,7 +9,7 @@ use indicatif::{ use sendme::protocol::AuthToken; use sendme::provider::Ticket; use tokio::io::AsyncWriteExt; -use tokio::sync::Mutex; +use tokio::sync::{mpsc, Mutex}; use tracing_subscriber::{fmt, prelude::*, EnvFilter}; use sendme::{get, provider, util, Keypair, PeerId}; @@ -204,7 +204,12 @@ async fn main() -> Result<()> { out_writer .println(format!("All-in-one ticket: {}", provider.ticket(hash))) .await; - provider.join().await?; + + let (s, mut r) = mpsc::channel(1); + ctrlc::set_handler(move || s.try_send(()).expect("failed to send shutdown signal"))?; + r.recv().await; + out_writer.println("Shutting down...").await; + provider.shutdown().await?; // Drop tempath to signal it can be destroyed drop(tmp_path); diff --git a/src/provider.rs b/src/provider.rs index 868ec22..3034c28 100644 --- a/src/provider.rs +++ b/src/provider.rs @@ -12,7 +12,7 @@ use s2n_quic::stream::BidirectionalStream; use s2n_quic::Server as QuicServer; use serde::{Deserialize, Serialize}; use tokio::io::{AsyncWrite, AsyncWriteExt}; -use tokio::sync::broadcast; +use tokio::sync::{broadcast, oneshot}; use tokio::task::{JoinError, JoinHandle}; use tokio_util::io::SyncIoBridge; use tracing::{debug, warn}; @@ -103,10 +103,17 @@ impl Builder { let db2 = self.db.clone(); let (events_sender, _events_receiver) = broadcast::channel(8); let events = events_sender.clone(); - let task = - tokio::spawn( - async move { Self::run(server, db2, self.auth_token, events_sender).await }, - ); + let (shutdown_sender, shutdown_receiver) = oneshot::channel(); + let task = tokio::spawn(async move { + Self::run( + server, + db2, + self.auth_token, + events_sender, + shutdown_receiver, + ) + .await + }); Ok(Provider { listen_addr, @@ -114,6 +121,7 @@ impl Builder { auth_token: self.auth_token, task, events, + shutdown: shutdown_sender, }) } @@ -122,28 +130,36 @@ impl Builder { db: Database, token: AuthToken, events: broadcast::Sender, + mut shutdown: oneshot::Receiver<()>, ) { debug!("\nlistening at: {:#?}", server.local_addr().unwrap()); - while let Some(mut connection) = server.accept().await { - let db = db.clone(); - let events = events.clone(); - tokio::spawn(async move { - debug!("connection accepted from {:?}", connection.remote_addr()); - while let Ok(Some(stream)) = connection.accept_bidirectional_stream().await { - let _ = events.send(Event::ClientConnected { - connection_id: connection.id(), - }); + loop { + tokio::select! { + Some(mut connection) = server.accept() => { let db = db.clone(); let events = events.clone(); tokio::spawn(async move { - if let Err(err) = handle_stream(db, token, stream, events).await { - warn!("error: {:#?}", err); + debug!("connection accepted from {:?}", connection.remote_addr()); + while let Ok(Some(stream)) = connection.accept_bidirectional_stream().await { + let _ = events.send(Event::ClientConnected { + connection_id: connection.id(), + }); + let db = db.clone(); + let events = events.clone(); + tokio::spawn(async move { + if let Err(err) = handle_stream(db, token, stream, events).await { + warn!("error: {:#?}", err); + } + debug!("disconnected"); + }); } - debug!("disconnected"); }); } - }); + _ = &mut shutdown => { + break; + } + } } } } @@ -163,6 +179,7 @@ pub struct Provider { auth_token: AuthToken, task: JoinHandle<()>, events: broadcast::Sender, + shutdown: oneshot::Sender<()>, } /// Events emitted by the [`Provider`] informing about the current status. @@ -228,18 +245,11 @@ impl Provider { } } - /// Blocks until the provider task completes. - // TODO: Maybe implement Future directly? - pub async fn join(self) -> Result<(), JoinError> { + /// Gracefully shuts down the provider. + pub async fn shutdown(self) -> Result<(), JoinError> { + let _ = self.shutdown.send(()); self.task.await } - - /// Aborts the provider. - /// - /// TODO: temporary, do graceful shutdown instead. - pub fn abort(&self) { - self.task.abort(); - } } async fn handle_stream(