Skip to content
This repository has been archived by the owner on Dec 13, 2023. It is now read-only.

Commit

Permalink
refactor: fully shutdown nested tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
dignifiedquire committed Feb 3, 2023
1 parent 06963a8 commit e7be7ec
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 39 deletions.
22 changes: 22 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ bytes = "1"
clap = { version = "4", features = ["derive"], optional = true }
console = { version = "0.15.5", optional = true }
der = { version = "0.6", features = ["alloc", "derive"] }
derivative = "2.2.0"
ed25519-dalek = { version = "1.0.1", features = ["serde"] }
futures = "0.3.25"
indicatif = { version = "0.17", features = ["tokio"], optional = true }
Expand All @@ -33,6 +34,7 @@ ssh-key = { version = "0.5.1", features = ["ed25519", "std", "rand_core"] }
tempfile = "3"
thiserror = "1"
tokio = { version = "1", features = ["full"] }
tokio-context = "0.1.3"
tokio-util = { version = "0.7", features = ["io-util", "io"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
Expand Down
104 changes: 65 additions & 39 deletions src/provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,17 @@ use std::path::PathBuf;
use std::str::FromStr;
use std::{collections::HashMap, sync::Arc};

use anyhow::{anyhow, bail, ensure, Context, Result};
use anyhow::{anyhow, bail, ensure, Context as _, Result};
use bao::encode::SliceExtractor;
use bytes::{Bytes, BytesMut};
use derivative::Derivative;
use s2n_quic::stream::BidirectionalStream;
use s2n_quic::Server as QuicServer;
use s2n_quic::{Connection, Server as QuicServer};
use serde::{Deserialize, Serialize};
use tokio::io::{AsyncWrite, AsyncWriteExt};
use tokio::sync::{broadcast, oneshot};
use tokio::sync::broadcast;
use tokio::task::{JoinError, JoinHandle};
use tokio_context::context::{Context, Handle as ContextHandle, RefContext};
use tokio_util::io::SyncIoBridge;
use tracing::{debug, warn};

Expand Down Expand Up @@ -103,16 +105,9 @@ impl Builder {
let db2 = self.db.clone();
let (events_sender, _events_receiver) = broadcast::channel(8);
let events = events_sender.clone();
let (shutdown_sender, shutdown_receiver) = oneshot::channel();
let task = tokio::spawn(async move {
Self::run(
server,
db2,
self.auth_token,
events_sender,
shutdown_receiver,
)
.await
let (ctx, handle) = RefContext::new();
let task = tokio::task::spawn(async move {
Self::run(ctx, server, db2, self.auth_token, events_sender).await
});

Ok(Provider {
Expand All @@ -121,52 +116,81 @@ impl Builder {
auth_token: self.auth_token,
task,
events,
shutdown: shutdown_sender,
handle,
})
}

async fn run(
ctx: RefContext,
mut server: s2n_quic::server::Server,
db: Database,
token: AuthToken,
events: broadcast::Sender<Event>,
mut shutdown: oneshot::Receiver<()>,
) {
debug!("\nlistening at: {:#?}", server.local_addr().unwrap());
let (mut current_ctx, _handle) = Context::with_parent(&ctx, None);

loop {
tokio::select! {
biased;

_ = &mut shutdown => {
break;
_ = current_ctx.done() => {
return;
}

Some(mut connection) = server.accept() => {
Some(connection) = server.accept() => {
let db = db.clone();
let events = events.clone();
tokio::spawn(async move {
debug!("connection accepted from {:?}", connection.remote_addr());
while let Ok(Some(stream)) = connection.accept_bidirectional_stream().await {
let _ = events.send(Event::ClientConnected {
connection_id: connection.id(),
});
let db = db.clone();
let events = events.clone();
tokio::spawn(async move {
if let Err(err) = handle_stream(db, token, stream, events).await {
warn!("error: {:#?}", err);
}
debug!("disconnected");
});
}
});
let (current_ctx, _handle) = RefContext::with_parent(&ctx, None);
tokio::spawn(async move { handle_connection(current_ctx, connection, db, token, events).await });
}
}
}
}
}

async fn handle_connection(
ctx: RefContext,
mut connection: Connection,
db: Database,
token: AuthToken,
events: broadcast::Sender<Event>,
) {
debug!("connection accepted from {:?}", connection.remote_addr());
let (mut current_ctx, _handle) = Context::with_parent(&ctx, None);
loop {
tokio::select! {
biased;

_ = current_ctx.done() => {
break;
}
Ok(Some(stream)) = connection.accept_bidirectional_stream() => {
let _ = events.send(Event::ClientConnected {
connection_id: connection.id(),
});
let db = db.clone();
let events = events.clone();
let (mut current_ctx, _handle) = Context::with_parent(&ctx, None);

tokio::spawn(async move {
tokio::select! {
biased;

_ = current_ctx.done() => {
return;
}
res = handle_stream(db, token, stream, events) => {
if let Err(err) = res {
warn!("error: {:#?}", err);
}
}
}
debug!("disconnected");
});
}
}
}
}

/// A server which implements the sendme provider.
///
/// Clients can connect to this server and requests hashes from it.
Expand All @@ -175,14 +199,16 @@ impl Builder {
/// is a shorthand to create a suitable [`Builder`].
///
/// This runs a tokio task which can be aborted and joined if desired.
#[derive(Debug)]
#[derive(Derivative)]
#[derivative(Debug)]
pub struct Provider {
listen_addr: SocketAddr,
keypair: Keypair,
auth_token: AuthToken,
task: JoinHandle<()>,
events: broadcast::Sender<Event>,
shutdown: oneshot::Sender<()>,
#[derivative(Debug = "ignore")]
handle: ContextHandle,
}

/// Events emitted by the [`Provider`] informing about the current status.
Expand Down Expand Up @@ -250,7 +276,7 @@ impl Provider {

/// Gracefully shuts down the provider.
pub async fn shutdown(self) -> Result<(), JoinError> {
let _ = self.shutdown.send(());
self.handle.cancel();
self.task.await
}
}
Expand Down

0 comments on commit e7be7ec

Please sign in to comment.