Skip to content
This repository has been archived by the owner on Dec 13, 2023. It is now read-only.

Commit

Permalink
feat(provider): graceful shutdown
Browse files Browse the repository at this point in the history
This stops accepting new connections, but does not currently abort transfers.

Closes #77
  • Loading branch information
dignifiedquire committed Feb 3, 2023
1 parent 01c711a commit 563671e
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 33 deletions.
29 changes: 29 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ blake3 = "1.3.3"
bytes = "1"
clap = { version = "4", features = ["derive"] }
console = "0.15.5"
ctrlc = "3.2.4"
der = { version = "0.6", features = ["alloc", "derive"] }
ed25519-dalek = { version = "1.0.1", features = ["serde"] }
futures = "0.3.25"
Expand Down
5 changes: 2 additions & 3 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -247,10 +247,9 @@ mod tests {
)
.await?;

provider.abort();
let _ = provider.join().await;
provider.shutdown().await?;

let events = events_task.await.unwrap();
let events = events_task.await?;
assert_eq!(events.len(), 3);

Ok(())
Expand Down
9 changes: 7 additions & 2 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use indicatif::{
use sendme::protocol::AuthToken;
use sendme::provider::Ticket;
use tokio::io::AsyncWriteExt;
use tokio::sync::Mutex;
use tokio::sync::{mpsc, Mutex};
use tracing_subscriber::{fmt, prelude::*, EnvFilter};

use sendme::{get, provider, Keypair, PeerId};
Expand Down Expand Up @@ -204,7 +204,12 @@ async fn main() -> Result<()> {
out_writer
.println(format!("All-in-one ticket: {}", provider.ticket(hash)))
.await;
provider.join().await?;

let (s, mut r) = mpsc::channel(1);
ctrlc::set_handler(move || s.try_send(()).expect("failed to send shutdown signal"))?;
r.recv().await;
out_writer.println("Shutting down...").await;
provider.shutdown().await?;

// Drop tempath to signal it can be destroyed
drop(tmp_path);
Expand Down
66 changes: 38 additions & 28 deletions src/provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use s2n_quic::stream::BidirectionalStream;
use s2n_quic::Server as QuicServer;
use serde::{Deserialize, Serialize};
use tokio::io::{AsyncWrite, AsyncWriteExt};
use tokio::sync::broadcast;
use tokio::sync::{broadcast, oneshot};
use tokio::task::{JoinError, JoinHandle};
use tokio_util::io::SyncIoBridge;
use tracing::{debug, warn};
Expand Down Expand Up @@ -102,17 +102,25 @@ impl Builder {
let db2 = self.db.clone();
let (events_sender, _events_receiver) = broadcast::channel(8);
let events = events_sender.clone();
let task =
tokio::spawn(
async move { Self::run(server, db2, self.auth_token, events_sender).await },
);
let (shutdown_sender, shutdown_receiver) = oneshot::channel();
let task = tokio::spawn(async move {
Self::run(
server,
db2,
self.auth_token,
events_sender,
shutdown_receiver,
)
.await
});

Ok(Provider {
listen_addr,
keypair: self.keypair,
auth_token: self.auth_token,
task,
events,
shutdown: shutdown_sender,
})
}

Expand All @@ -121,28 +129,36 @@ impl Builder {
db: Database,
token: AuthToken,
events: broadcast::Sender<Event>,
mut shutdown: oneshot::Receiver<()>,
) {
debug!("\nlistening at: {:#?}", server.local_addr().unwrap());

while let Some(mut connection) = server.accept().await {
let db = db.clone();
let events = events.clone();
tokio::spawn(async move {
debug!("connection accepted from {:?}", connection.remote_addr());
while let Ok(Some(stream)) = connection.accept_bidirectional_stream().await {
let _ = events.send(Event::ClientConnected {
connection_id: connection.id(),
});
loop {
tokio::select! {
Some(mut connection) = server.accept() => {
let db = db.clone();
let events = events.clone();
tokio::spawn(async move {
if let Err(err) = handle_stream(db, token, stream, events).await {
warn!("error: {:#?}", err);
debug!("connection accepted from {:?}", connection.remote_addr());
while let Ok(Some(stream)) = connection.accept_bidirectional_stream().await {
let _ = events.send(Event::ClientConnected {
connection_id: connection.id(),
});
let db = db.clone();
let events = events.clone();
tokio::spawn(async move {
if let Err(err) = handle_stream(db, token, stream, events).await {
warn!("error: {:#?}", err);
}
debug!("disconnected");
});
}
debug!("disconnected");
});
}
});
_ = &mut shutdown => {
break;
}
}
}
}
}
Expand All @@ -162,6 +178,7 @@ pub struct Provider {
auth_token: AuthToken,
task: JoinHandle<()>,
events: broadcast::Sender<Event>,
shutdown: oneshot::Sender<()>,
}

/// Events emitted by the [`Provider`] informing about the current status.
Expand Down Expand Up @@ -227,18 +244,11 @@ impl Provider {
}
}

/// Blocks until the provider task completes.
// TODO: Maybe implement Future directly?
pub async fn join(self) -> Result<(), JoinError> {
/// Gracefully shuts down the provider.
pub async fn shutdown(self) -> Result<(), JoinError> {
let _ = self.shutdown.send(());
self.task.await
}

/// Aborts the provider.
///
/// TODO: temporary, do graceful shutdown instead.
pub fn abort(&self) {
self.task.abort();
}
}

async fn handle_stream(
Expand Down

0 comments on commit 563671e

Please sign in to comment.