diff --git a/Cargo.lock b/Cargo.lock index bcd232d..4cbbe8e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -451,6 +451,17 @@ dependencies = [ "syn", ] +[[package]] +name = "derivative" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fcc3dd5e9e9c0b295d6e1e4d811fb6f157d5ffd784b8d202fc62eac8035a770b" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "digest" version = "0.9.0" @@ -1747,6 +1758,7 @@ dependencies = [ "clap", "console", "der", + "derivative", "ed25519-dalek", "futures", "hex", @@ -1765,6 +1777,7 @@ dependencies = [ "testdir", "thiserror", "tokio", + "tokio-context", "tokio-util", "tracing", "tracing-subscriber", @@ -2103,6 +2116,15 @@ dependencies = [ "windows-sys 0.42.0", ] +[[package]] +name = "tokio-context" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eaf0b8394dd5ca9a1b726c629390154c19222dfd7467a4b56f1ced90adee3958" +dependencies = [ + "tokio", +] + [[package]] name = "tokio-macros" version = "1.8.2" diff --git a/Cargo.toml b/Cargo.toml index 966f72e..8614f79 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,6 +19,7 @@ bytes = "1" clap = { version = "4", features = ["derive"], optional = true } console = { version = "0.15.5", optional = true } der = { version = "0.6", features = ["alloc", "derive"] } +derivative = "2.2.0" ed25519-dalek = { version = "1.0.1", features = ["serde"] } futures = "0.3.25" indicatif = { version = "0.17", features = ["tokio"], optional = true } @@ -33,6 +34,7 @@ ssh-key = { version = "0.5.1", features = ["ed25519", "std", "rand_core"] } tempfile = "3" thiserror = "1" tokio = { version = "1", features = ["full"] } +tokio-context = "0.1.3" tokio-util = { version = "0.7", features = ["io-util", "io"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/src/lib.rs b/src/lib.rs index e847b27..8c310c9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -240,10 +240,9 @@ mod tests { ) .await?; - provider.abort(); - let _ = provider.join().await; + provider.shutdown().await?; - let events = events_task.await.unwrap(); + let events = events_task.await?; assert_eq!(events.len(), 3); Ok(()) diff --git a/src/main.rs b/src/main.rs index 428b3f8..50203cd 100644 --- a/src/main.rs +++ b/src/main.rs @@ -204,7 +204,10 @@ async fn main() -> Result<()> { out_writer .println(format!("All-in-one ticket: {}", provider.ticket(hash))) .await; - provider.join().await?; + + tokio::signal::ctrl_c().await?; + out_writer.println("Shutting down...").await; + provider.shutdown().await?; // Drop tempath to signal it can be destroyed drop(tmp_path); diff --git a/src/provider.rs b/src/provider.rs index 868ec22..5aa438c 100644 --- a/src/provider.rs +++ b/src/provider.rs @@ -5,15 +5,17 @@ use std::path::PathBuf; use std::str::FromStr; use std::{collections::HashMap, sync::Arc}; -use anyhow::{anyhow, bail, ensure, Context, Result}; +use anyhow::{anyhow, bail, ensure, Context as _, Result}; use bao::encode::SliceExtractor; use bytes::{Bytes, BytesMut}; +use derivative::Derivative; use s2n_quic::stream::BidirectionalStream; -use s2n_quic::Server as QuicServer; +use s2n_quic::{Connection, Server as QuicServer}; use serde::{Deserialize, Serialize}; use tokio::io::{AsyncWrite, AsyncWriteExt}; use tokio::sync::broadcast; use tokio::task::{JoinError, JoinHandle}; +use tokio_context::context::{Context, Handle as ContextHandle, RefContext}; use tokio_util::io::SyncIoBridge; use tracing::{debug, warn}; @@ -103,10 +105,10 @@ impl Builder { let db2 = self.db.clone(); let (events_sender, _events_receiver) = broadcast::channel(8); let events = events_sender.clone(); - let task = - tokio::spawn( - async move { Self::run(server, db2, self.auth_token, events_sender).await }, - ); + let (ctx, handle) = RefContext::new(); + let task = tokio::task::spawn(async move { + Self::run(ctx, server, db2, self.auth_token, events_sender).await + }); Ok(Provider { listen_addr, @@ -114,36 +116,77 @@ impl Builder { auth_token: self.auth_token, task, events, + handle, }) } async fn run( + ctx: RefContext, mut server: s2n_quic::server::Server, db: Database, token: AuthToken, events: broadcast::Sender, ) { debug!("\nlistening at: {:#?}", server.local_addr().unwrap()); + let (mut current_ctx, _handle) = Context::with_parent(&ctx, None); + + loop { + tokio::select! { + _ = current_ctx.done() => { + return; + } - while let Some(mut connection) = server.accept().await { - let db = db.clone(); - let events = events.clone(); - tokio::spawn(async move { - debug!("connection accepted from {:?}", connection.remote_addr()); - while let Ok(Some(stream)) = connection.accept_bidirectional_stream().await { - let _ = events.send(Event::ClientConnected { - connection_id: connection.id(), - }); + Some(connection) = server.accept() => { let db = db.clone(); let events = events.clone(); - tokio::spawn(async move { - if let Err(err) = handle_stream(db, token, stream, events).await { - warn!("error: {:#?}", err); - } - debug!("disconnected"); - }); + let (current_ctx, _handle) = RefContext::with_parent(&ctx, None); + tokio::spawn(async move { handle_connection(current_ctx, connection, db, token, events).await }); } - }); + } + } + } +} + +async fn handle_connection( + ctx: RefContext, + mut connection: Connection, + db: Database, + token: AuthToken, + events: broadcast::Sender, +) { + debug!("connection accepted from {:?}", connection.remote_addr()); + let (mut current_ctx, _handle) = Context::with_parent(&ctx, None); + loop { + tokio::select! { + biased; + + _ = current_ctx.done() => { + break; + } + Ok(Some(stream)) = connection.accept_bidirectional_stream() => { + let _ = events.send(Event::ClientConnected { + connection_id: connection.id(), + }); + let db = db.clone(); + let events = events.clone(); + let (mut current_ctx, _handle) = Context::with_parent(&ctx, None); + + tokio::spawn(async move { + tokio::select! { + biased; + + _ = current_ctx.done() => { + return; + } + res = handle_stream(db, token, stream, events) => { + if let Err(err) = res { + warn!("error: {:#?}", err); + } + } + } + debug!("disconnected"); + }); + } } } } @@ -156,13 +199,16 @@ impl Builder { /// is a shorthand to create a suitable [`Builder`]. /// /// This runs a tokio task which can be aborted and joined if desired. -#[derive(Debug)] +#[derive(Derivative)] +#[derivative(Debug)] pub struct Provider { listen_addr: SocketAddr, keypair: Keypair, auth_token: AuthToken, task: JoinHandle<()>, events: broadcast::Sender, + #[derivative(Debug = "ignore")] + handle: ContextHandle, } /// Events emitted by the [`Provider`] informing about the current status. @@ -228,18 +274,11 @@ impl Provider { } } - /// Blocks until the provider task completes. - // TODO: Maybe implement Future directly? - pub async fn join(self) -> Result<(), JoinError> { + /// Gracefully shuts down the provider. + pub async fn shutdown(self) -> Result<(), JoinError> { + self.handle.cancel(); self.task.await } - - /// Aborts the provider. - /// - /// TODO: temporary, do graceful shutdown instead. - pub fn abort(&self) { - self.task.abort(); - } } async fn handle_stream(