From b4506b2c4a288434ea55c36607f8fd839d58bf10 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=BCdiger=20Klaehn?= Date: Mon, 22 Jul 2024 17:02:53 +0300 Subject: [PATCH] fix(iroh-blobs): properly handle Drop in local pool during shutdown (#2517) ## Description The tokio_util LocalPoolHandle does not properly handle Drop during shutdown. Its threads are just spawned as detached. So any drop impl that runs in a local pool thread will be stopped as soon as the process terminates. This can have some bad consequences if that drop operation performs IO, like closing files and committing database transactions. Here is where the threads get spawned. The `std::thread::JoinHandle`s are just dropped. https://docs.rs/tokio-util/latest/src/tokio_util/task/spawn_pinned.rs.html#381 Here is some discussion of the observed effects: https://discord.com/channels/949724860232392765/1260571544414064670 LocalPoolHandle also, of course, is using an unbounded channel: https://docs.rs/tokio-util/latest/src/tokio_util/task/spawn_pinned.rs.html#372 ## Breaking Changes Public interfaces using tokio_util::task::LocalPoolHandle will now use our own LocalPool/LocalPoolHandle. ## Notes & open questions Should we use an unbounded channel like tokio::spawn or LocalPoolHandle::spawn_pinned? Seems like a big footgun. But if not, we need to somehow handle when the queue is full. ## Change checklist - [x] Self-review. - [x] Documentation updates following the [style guide](https://rust-lang.github.io/rfcs/1574-more-api-documentation-conventions.html#appendix-a-full-conventions-text), if relevant. - [x] Tests if relevant. - [x] All breaking changes documented. --- Cargo.lock | 3 +- iroh-blobs/Cargo.toml | 3 +- iroh-blobs/examples/provide-bytes.rs | 5 +- iroh-blobs/src/downloader.rs | 6 +- iroh-blobs/src/downloader/test.rs | 55 ++- iroh-blobs/src/provider.rs | 4 +- iroh-blobs/src/store/bao_file.rs | 7 +- iroh-blobs/src/store/traits.rs | 11 +- iroh-blobs/src/util.rs | 1 + iroh-blobs/src/util/local_pool.rs | 654 +++++++++++++++++++++++++++ iroh-cli/src/commands/start.rs | 2 +- iroh/src/node.rs | 17 +- iroh/src/node/builder.rs | 16 +- iroh/src/node/protocol.rs | 5 +- iroh/src/node/rpc.rs | 25 +- iroh/tests/provide.rs | 2 +- 16 files changed, 752 insertions(+), 64 deletions(-) create mode 100644 iroh-blobs/src/util/local_pool.rs diff --git a/Cargo.lock b/Cargo.lock index 04af2e4c277..63778db1c5f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2564,6 +2564,7 @@ dependencies = [ "iroh-test", "num_cpus", "parking_lot", + "pin-project", "postcard", "proptest", "rand", @@ -5953,8 +5954,6 @@ dependencies = [ "bytes", "futures-core", "futures-sink", - "futures-util", - "hashbrown 0.14.5", "pin-project-lite", "slab", "tokio", diff --git a/iroh-blobs/Cargo.toml b/iroh-blobs/Cargo.toml index 370a9929f37..a422d47351a 100644 --- a/iroh-blobs/Cargo.toml +++ b/iroh-blobs/Cargo.toml @@ -33,6 +33,7 @@ iroh-metrics = { version = "0.20.0", path = "../iroh-metrics", optional = true } iroh-net = { version = "0.20.0", path = "../iroh-net" } num_cpus = "1.15.0" parking_lot = { version = "0.12.1", optional = true } +pin-project = "1.1.5" postcard = { version = "1", default-features = false, features = ["alloc", "use-std", "experimental-derive"] } rand = "0.8" range-collections = "0.4.0" @@ -45,7 +46,7 @@ smallvec = { version = "1.10.0", features = ["serde", "const_new"] } tempfile = { version = "3.10.0", optional = true } thiserror = "1" tokio = { version = "1", features = ["fs"] } -tokio-util = { version = "0.7", features = ["io-util", "io", "rt"] } +tokio-util = { version = "0.7", features = ["io-util", "io"] } tracing = "0.1" tracing-futures = "0.2.5" diff --git a/iroh-blobs/examples/provide-bytes.rs b/iroh-blobs/examples/provide-bytes.rs index 73f7e6d8e39..eab9fddf5ab 100644 --- a/iroh-blobs/examples/provide-bytes.rs +++ b/iroh-blobs/examples/provide-bytes.rs @@ -10,11 +10,10 @@ //! cargo run --example provide-bytes collection //! To provide a collection (multiple blobs) use anyhow::Result; -use tokio_util::task::LocalPoolHandle; use tracing::warn; use tracing_subscriber::{prelude::*, EnvFilter}; -use iroh_blobs::{format::collection::Collection, Hash}; +use iroh_blobs::{format::collection::Collection, util::local_pool::LocalPool, Hash}; mod connect; use connect::{make_and_write_certs, make_server_endpoint, CERT_PATH}; @@ -82,7 +81,7 @@ async fn main() -> Result<()> { println!("\nfetch the content using a stream by running the following example:\n\ncargo run --example fetch-stream {hash} \"{addr}\" {format}\n"); // create a new local pool handle with 1 worker thread - let lp = LocalPoolHandle::new(1); + let lp = LocalPool::single(); let accept_task = tokio::spawn(async move { while let Some(incoming) = endpoint.accept().await { diff --git a/iroh-blobs/src/downloader.rs b/iroh-blobs/src/downloader.rs index 7d0eedd10b3..ca3d7a9b87c 100644 --- a/iroh-blobs/src/downloader.rs +++ b/iroh-blobs/src/downloader.rs @@ -45,13 +45,13 @@ use tokio::{ sync::{mpsc, oneshot}, task::JoinSet, }; -use tokio_util::{sync::CancellationToken, task::LocalPoolHandle, time::delay_queue}; +use tokio_util::{sync::CancellationToken, time::delay_queue}; use tracing::{debug, error_span, trace, warn, Instrument}; use crate::{ get::{db::DownloadProgress, Stats}, store::Store, - util::progress::ProgressSender, + util::{local_pool::LocalPoolHandle, progress::ProgressSender}, }; mod get; @@ -338,7 +338,7 @@ impl Downloader { service.run().instrument(error_span!("downloader", %me)) }; - rt.spawn_pinned(create_future); + rt.spawn_detached(create_future); Self { next_id: Arc::new(AtomicU64::new(0)), msg_tx, diff --git a/iroh-blobs/src/downloader/test.rs b/iroh-blobs/src/downloader/test.rs index ec54e0ef8c9..871b835ba7b 100644 --- a/iroh-blobs/src/downloader/test.rs +++ b/iroh-blobs/src/downloader/test.rs @@ -10,7 +10,10 @@ use iroh_net::key::SecretKey; use crate::{ get::{db::BlobId, progress::TransferState}, - util::progress::{FlumeProgressSender, IdGenerator}, + util::{ + local_pool::LocalPool, + progress::{FlumeProgressSender, IdGenerator}, + }, }; use super::*; @@ -23,7 +26,7 @@ impl Downloader { dialer: dialer::TestingDialer, getter: getter::TestingGetter, concurrency_limits: ConcurrencyLimits, - ) -> Self { + ) -> (Self, LocalPool) { Self::spawn_for_test_with_retry_config( dialer, getter, @@ -37,10 +40,11 @@ impl Downloader { getter: getter::TestingGetter, concurrency_limits: ConcurrencyLimits, retry_config: RetryConfig, - ) -> Self { + ) -> (Self, LocalPool) { let (msg_tx, msg_rx) = mpsc::channel(super::SERVICE_CHANNEL_CAPACITY); - LocalPoolHandle::new(1).spawn_pinned(move || async move { + let lp = LocalPool::default(); + lp.spawn_detached(move || async move { // we want to see the logs of the service let _guard = iroh_test::logging::setup(); @@ -48,10 +52,13 @@ impl Downloader { service.run().await }); - Downloader { - next_id: Arc::new(AtomicU64::new(0)), - msg_tx, - } + ( + Downloader { + next_id: Arc::new(AtomicU64::new(0)), + msg_tx, + }, + lp, + ) } } @@ -63,7 +70,8 @@ async fn smoke_test() { let getter = getter::TestingGetter::default(); let concurrency_limits = ConcurrencyLimits::default(); - let downloader = Downloader::spawn_for_test(dialer.clone(), getter.clone(), concurrency_limits); + let (downloader, _lp) = + Downloader::spawn_for_test(dialer.clone(), getter.clone(), concurrency_limits); // send a request and make sure the peer is requested the corresponding download let peer = SecretKey::generate().public(); @@ -88,7 +96,8 @@ async fn deduplication() { getter.set_request_duration(Duration::from_secs(1)); let concurrency_limits = ConcurrencyLimits::default(); - let downloader = Downloader::spawn_for_test(dialer.clone(), getter.clone(), concurrency_limits); + let (downloader, _lp) = + Downloader::spawn_for_test(dialer.clone(), getter.clone(), concurrency_limits); let peer = SecretKey::generate().public(); let kind: DownloadKind = HashAndFormat::raw(Hash::new([0u8; 32])).into(); @@ -119,7 +128,8 @@ async fn cancellation() { getter.set_request_duration(Duration::from_millis(500)); let concurrency_limits = ConcurrencyLimits::default(); - let downloader = Downloader::spawn_for_test(dialer.clone(), getter.clone(), concurrency_limits); + let (downloader, _lp) = + Downloader::spawn_for_test(dialer.clone(), getter.clone(), concurrency_limits); let peer = SecretKey::generate().public(); let kind_1: DownloadKind = HashAndFormat::raw(Hash::new([0u8; 32])).into(); @@ -158,7 +168,8 @@ async fn max_concurrent_requests_total() { ..Default::default() }; - let downloader = Downloader::spawn_for_test(dialer.clone(), getter.clone(), concurrency_limits); + let (downloader, _lp) = + Downloader::spawn_for_test(dialer.clone(), getter.clone(), concurrency_limits); // send the downloads let peer = SecretKey::generate().public(); @@ -201,7 +212,8 @@ async fn max_concurrent_requests_per_peer() { ..Default::default() }; - let downloader = Downloader::spawn_for_test(dialer.clone(), getter.clone(), concurrency_limits); + let (downloader, _lp) = + Downloader::spawn_for_test(dialer.clone(), getter.clone(), concurrency_limits); // send the downloads let peer = SecretKey::generate().public(); @@ -257,7 +269,8 @@ async fn concurrent_progress() { } .boxed() })); - let downloader = Downloader::spawn_for_test(dialer.clone(), getter.clone(), Default::default()); + let (downloader, _lp) = + Downloader::spawn_for_test(dialer.clone(), getter.clone(), Default::default()); let peer = SecretKey::generate().public(); let hash = Hash::new([0u8; 32]); @@ -341,7 +354,8 @@ async fn long_queue() { ..Default::default() }; - let downloader = Downloader::spawn_for_test(dialer.clone(), getter.clone(), concurrency_limits); + let (downloader, _lp) = + Downloader::spawn_for_test(dialer.clone(), getter.clone(), concurrency_limits); // send the downloads let nodes = [ SecretKey::generate().public(), @@ -370,7 +384,8 @@ async fn fail_while_running() { let _guard = iroh_test::logging::setup(); let dialer = dialer::TestingDialer::default(); let getter = getter::TestingGetter::default(); - let downloader = Downloader::spawn_for_test(dialer.clone(), getter.clone(), Default::default()); + let (downloader, _lp) = + Downloader::spawn_for_test(dialer.clone(), getter.clone(), Default::default()); let blob_fail = HashAndFormat::raw(Hash::new([1u8; 32])); let blob_success = HashAndFormat::raw(Hash::new([2u8; 32])); @@ -407,7 +422,8 @@ async fn retry_nodes_simple() { let _guard = iroh_test::logging::setup(); let dialer = dialer::TestingDialer::default(); let getter = getter::TestingGetter::default(); - let downloader = Downloader::spawn_for_test(dialer.clone(), getter.clone(), Default::default()); + let (downloader, _lp) = + Downloader::spawn_for_test(dialer.clone(), getter.clone(), Default::default()); let node = SecretKey::generate().public(); let dial_attempts = Arc::new(AtomicUsize::new(0)); let dial_attempts2 = dial_attempts.clone(); @@ -432,7 +448,7 @@ async fn retry_nodes_fail() { max_retries_per_node: 3, }; - let downloader = Downloader::spawn_for_test_with_retry_config( + let (downloader, _lp) = Downloader::spawn_for_test_with_retry_config( dialer.clone(), getter.clone(), Default::default(), @@ -472,7 +488,8 @@ async fn retry_nodes_jump_queue() { ..Default::default() }; - let downloader = Downloader::spawn_for_test(dialer.clone(), getter.clone(), concurrency_limits); + let (downloader, _lp) = + Downloader::spawn_for_test(dialer.clone(), getter.clone(), concurrency_limits); let good_node = SecretKey::generate().public(); let bad_node = SecretKey::generate().public(); diff --git a/iroh-blobs/src/provider.rs b/iroh-blobs/src/provider.rs index 7fe4e13004a..54b25151583 100644 --- a/iroh-blobs/src/provider.rs +++ b/iroh-blobs/src/provider.rs @@ -13,13 +13,13 @@ use iroh_io::stats::{ use iroh_io::{AsyncSliceReader, AsyncStreamWriter, TokioStreamWriter}; use iroh_net::endpoint::{self, RecvStream, SendStream}; use serde::{Deserialize, Serialize}; -use tokio_util::task::LocalPoolHandle; use tracing::{debug, debug_span, info, trace, warn}; use tracing_futures::Instrument; use crate::hashseq::parse_hash_seq; use crate::protocol::{GetRequest, RangeSpec, Request}; use crate::store::*; +use crate::util::local_pool::LocalPoolHandle; use crate::util::Tag; use crate::{BlobFormat, Hash}; @@ -302,7 +302,7 @@ pub async fn handle_connection( }; events.send(Event::ClientConnected { connection_id }).await; let db = db.clone(); - rt.spawn_pinned(|| { + rt.spawn_detached(|| { async move { if let Err(err) = handle_stream(db, reader, writer).await { warn!("error: {err:#?}",); diff --git a/iroh-blobs/src/store/bao_file.rs b/iroh-blobs/src/store/bao_file.rs index 962a7240b7e..b94c1960348 100644 --- a/iroh-blobs/src/store/bao_file.rs +++ b/iroh-blobs/src/store/bao_file.rs @@ -878,7 +878,8 @@ mod tests { decode_response_into_batch, local, make_wire_data, random_test_data, trickle, validate, }; use tokio::task::JoinSet; - use tokio_util::task::LocalPoolHandle; + + use crate::util::local_pool::LocalPool; use super::*; @@ -957,7 +958,7 @@ mod tests { )), hash.into(), ); - let local = LocalPoolHandle::new(4); + let local = LocalPool::default(); let mut tasks = Vec::new(); for i in 0..4 { let file = handle.writer(); @@ -968,7 +969,7 @@ mod tests { .map(io::Result::Ok) .boxed(); let trickle = TokioStreamReader::new(tokio_util::io::StreamReader::new(trickle)); - let task = local.spawn_pinned(move || async move { + let task = local.spawn(move || async move { decode_response_into_batch(hash, IROH_BLOCK_SIZE, chunk_ranges, trickle, file).await }); tasks.push(task); diff --git a/iroh-blobs/src/store/traits.rs b/iroh-blobs/src/store/traits.rs index 762e511cd37..4d5162ac04e 100644 --- a/iroh-blobs/src/store/traits.rs +++ b/iroh-blobs/src/store/traits.rs @@ -12,12 +12,12 @@ use iroh_base::rpc::RpcError; use iroh_io::AsyncSliceReader; use serde::{Deserialize, Serialize}; use tokio::io::AsyncRead; -use tokio_util::task::LocalPoolHandle; use crate::{ hashseq::parse_hash_seq, protocol::RangeSpec, util::{ + local_pool::{self, LocalPool}, progress::{BoxedProgressSender, IdGenerator, ProgressSender}, Tag, }, @@ -423,7 +423,10 @@ async fn validate_impl( use futures_buffered::BufferedStreamExt; let validate_parallelism: usize = num_cpus::get(); - let lp = LocalPoolHandle::new(validate_parallelism); + let lp = LocalPool::new(local_pool::Config { + threads: validate_parallelism, + ..Default::default() + }); let complete = store.blobs().await?.collect::>>()?; let partial = store .partial_blobs() @@ -437,7 +440,7 @@ async fn validate_impl( .map(|hash| { let store = store.clone(); let tx = tx.clone(); - lp.spawn_pinned(move || async move { + lp.spawn(move || async move { let entry = store .get(&hash) .await? @@ -486,7 +489,7 @@ async fn validate_impl( .map(|hash| { let store = store.clone(); let tx = tx.clone(); - lp.spawn_pinned(move || async move { + lp.spawn(move || async move { let entry = store .get(&hash) .await? diff --git a/iroh-blobs/src/util.rs b/iroh-blobs/src/util.rs index be43dfaaffd..6b70d24c9a5 100644 --- a/iroh-blobs/src/util.rs +++ b/iroh-blobs/src/util.rs @@ -19,6 +19,7 @@ pub mod progress; pub use mem_or_file::MemOrFile; mod sparse_mem_file; pub use sparse_mem_file::SparseMemFile; +pub mod local_pool; /// A tag #[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, From, Into)] diff --git a/iroh-blobs/src/util/local_pool.rs b/iroh-blobs/src/util/local_pool.rs new file mode 100644 index 00000000000..7473d70c872 --- /dev/null +++ b/iroh-blobs/src/util/local_pool.rs @@ -0,0 +1,654 @@ +//! A local task pool with proper shutdown +use futures_lite::FutureExt; +use std::{ + any::Any, + future::Future, + ops::Deref, + pin::Pin, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, +}; +use tokio::{ + sync::{Notify, Semaphore}, + task::{JoinError, JoinSet, LocalSet}, +}; + +type BoxedFut = Pin>>; +type SpawnFn = Box BoxedFut + Send + 'static>; + +enum Message { + /// Create a new task and execute it locally + Execute(SpawnFn), + /// Shutdown the thread after finishing all tasks + Finish, +} + +/// A local task pool with proper shutdown +/// +/// Unlike +/// [`LocalPoolHandle`](https://docs.rs/tokio-util/latest/tokio_util/task/struct.LocalPoolHandle.html), +/// this pool will join all its threads when dropped, ensuring that all Drop +/// implementations are run to completion. +/// +/// On drop, this pool will immediately cancel all *tasks* that are currently +/// being executed, and will wait for all threads to finish executing their +/// loops before returning. This means that all drop implementations will be +/// able to run to completion before drop exits. +/// +/// On [`LocalPool::finish`], this pool will notify all threads to shut down, +/// and then wait for all threads to finish executing their loops before +/// returning. This means that all currently executing tasks will be allowed to +/// run to completion. +#[derive(Debug)] +pub struct LocalPool { + threads: Vec>, + shutdown_sem: Arc, + cancel_token: CancellationToken, + handle: LocalPoolHandle, +} + +impl Deref for LocalPool { + type Target = LocalPoolHandle; + + fn deref(&self) -> &Self::Target { + &self.handle + } +} + +/// A handle to a [`LocalPool`] +#[derive(Debug, Clone)] +pub struct LocalPoolHandle { + /// The sender half of the channel used to send tasks to the pool + send: flume::Sender, +} + +/// What to do when a panic occurs in a pool thread +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum PanicMode { + /// Log the panic and continue + /// + /// The panic will be re-thrown when the pool is dropped. + LogAndContinue, + /// Log the panic and immediately shut down the pool. + /// + /// The panic will be re-thrown when the pool is dropped. + Shutdown, +} + +/// Local task pool configuration +#[derive(Clone, Debug)] +pub struct Config { + /// Number of threads in the pool + pub threads: usize, + /// Prefix for thread names + pub thread_name_prefix: &'static str, + /// Ignore panics in pool threads + pub panic_mode: PanicMode, +} + +impl Default for Config { + fn default() -> Self { + Self { + threads: num_cpus::get(), + thread_name_prefix: "local-pool", + panic_mode: PanicMode::Shutdown, + } + } +} + +impl Default for LocalPool { + fn default() -> Self { + Self::new(Default::default()) + } +} + +impl LocalPool { + /// Create a new local pool with a single std thread. + pub fn single() -> Self { + Self::new(Config { + threads: 1, + ..Default::default() + }) + } + + /// Create a new local pool with the given config. + /// + /// This will use the current tokio runtime handle, so it must be called + /// from within a tokio runtime. + pub fn new(config: Config) -> Self { + let Config { + threads, + thread_name_prefix, + panic_mode, + } = config; + let cancel_token = CancellationToken::new(); + let (send, recv) = flume::unbounded::(); + let shutdown_sem = Arc::new(Semaphore::new(0)); + let handle = tokio::runtime::Handle::current(); + let handles = (0..threads) + .map(|i| { + Self::spawn_pool_thread( + format!("{thread_name_prefix}-{i}"), + recv.clone(), + cancel_token.clone(), + panic_mode, + shutdown_sem.clone(), + handle.clone(), + ) + }) + .collect::>>() + .expect("invalid thread name"); + Self { + threads: handles, + handle: LocalPoolHandle { send }, + cancel_token, + shutdown_sem, + } + } + + /// Get a cheaply cloneable handle to the pool + /// + /// This is not strictly necessary since we implement deref for + /// LocalPoolHandle, but makes getting a handle more explicit. + pub fn handle(&self) -> &LocalPoolHandle { + &self.handle + } + + /// Spawn a new pool thread. + fn spawn_pool_thread( + thread_name: String, + recv: flume::Receiver, + cancel_token: CancellationToken, + panic_mode: PanicMode, + shutdown_sem: Arc, + handle: tokio::runtime::Handle, + ) -> std::io::Result> { + std::thread::Builder::new() + .name(thread_name) + .spawn(move || { + let mut s = JoinSet::new(); + let mut last_panic = None; + let mut handle_join = |res: Option>| -> bool { + if let Some(Err(e)) = res { + if let Ok(panic) = e.try_into_panic() { + let panic_info = get_panic_info(&panic); + let thread_name = get_thread_name(); + tracing::error!( + "Panic in local pool thread: {}\n{}", + thread_name, + panic_info + ); + last_panic = Some(panic); + } + } + panic_mode == PanicMode::LogAndContinue || last_panic.is_none() + }; + let ls = LocalSet::new(); + let shutdown_mode = handle.block_on(ls.run_until(async { + loop { + tokio::select! { + // poll the set of futures + res = s.join_next(), if !s.is_empty() => { + if !handle_join(res) { + break ShutdownMode::Stop; + } + }, + // if the cancel token is cancelled, break the loop immediately + _ = cancel_token.cancelled() => break ShutdownMode::Stop, + // if we receive a message, execute it + msg = recv.recv_async() => { + match msg { + // just push into the FuturesUnordered + Ok(Message::Execute(f)) => { + s.spawn_local((f)()); + } + // break with optional semaphore + Ok(Message::Finish) => break ShutdownMode::Finish, + // if the sender is dropped, break the loop immediately + Err(flume::RecvError::Disconnected) => break ShutdownMode::Stop, + } + } + } + } + })); + // soft shutdown mode is just like normal running, except that + // we don't add any more tasks and stop when there are no more + // tasks to run. + if shutdown_mode == ShutdownMode::Finish { + // somebody is asking for a clean shutdown, wait for all tasks to finish + handle.block_on(ls.run_until(async { + loop { + tokio::select! { + res = s.join_next() => { + if res.is_none() || !handle_join(res) { + break; + } + } + _ = cancel_token.cancelled() => break, + } + } + })); + } + // Always add the permit. If nobody is waiting for it, it does + // no harm. + shutdown_sem.add_permits(1); + if let Some(_panic) = last_panic { + // std::panic::resume_unwind(panic); + } + }) + } + + /// A future that resolves when the pool is cancelled + pub async fn cancelled(&self) { + self.cancel_token.cancelled().await + } + + /// Immediately stop polling all tasks and wait for all threads to finish. + /// + /// This is like droo, but waits for thread completion asynchronously. + /// + /// If there was a panic on any of the threads, it will be re-thrown here. + pub async fn shutdown(self) { + self.cancel_token.cancel(); + self.await_thread_completion().await; + // just make it explicit that this is where drop runs + drop(self); + } + + /// Gently shut down the pool + /// + /// Notifies all the pool threads to shut down and waits for them to finish. + /// + /// If you just want to drop the pool without giving the threads a chance to + /// process their remaining tasks, just use [`Self::shutdown`]. + /// + /// If you want to wait for only a limited time for the tasks to finish, + /// you can race this function with a timeout. + pub async fn finish(self) { + // we assume that there are exactly as many threads as there are handles. + // also, we assume that the threads are still running. + for _ in 0..self.threads_u32() { + println!("sending shutdown message"); + // send the shutdown message + // sending will fail if all threads are already finished, but + // in that case we don't need to do anything. + // + // Threads will add a permit in any case, so await_thread_completion + // will then immediately return. + self.send.send(Message::Finish).ok(); + } + self.await_thread_completion().await; + } + + fn threads_u32(&self) -> u32 { + self.threads + .len() + .try_into() + .expect("invalid number of threads") + } + + async fn await_thread_completion(&self) { + // wait for all threads to finish. + // Each thread will add a permit to the semaphore. + let wait_for_semaphore = async move { + let _ = self + .shutdown_sem + .acquire_many(self.threads_u32()) + .await + .expect("semaphore closed"); + }; + // race the semaphore wait with the cancel token in case somebody + // cancels the pool while we are waiting. + tokio::select! { + _ = wait_for_semaphore => {} + _ = self.cancel_token.cancelled() => {} + } + } +} + +impl Drop for LocalPool { + fn drop(&mut self) { + self.cancel_token.cancel(); + let current_thread_id = std::thread::current().id(); + for handle in self.threads.drain(..) { + // we have no control over from where Drop is called, especially + // if the pool ends up in an Arc. So we need to check if we are + // dropping from within a pool thread and skip it in that case. + if handle.thread().id() == current_thread_id { + tracing::error!("Dropping LocalPool from within a pool thread."); + continue; + } + // Log any panics and resume them + if let Err(panic) = handle.join() { + let panic_info = get_panic_info(&panic); + let thread_name = get_thread_name(); + tracing::error!("Error joining thread: {}\n{}", thread_name, panic_info); + // std::panic::resume_unwind(panic); + } + } + } +} + +/// Errors for spawn failures +#[derive(thiserror::Error, Debug)] +pub enum SpawnError { + /// Task was dropped, either due to a panic or because the pool was shut down. + #[error("cancelled")] + Cancelled, +} + +type SpawnResult = std::result::Result; + +/// Future returned by [`LocalPoolHandle::spawn`] and [`LocalPoolHandle::try_spawn`]. +/// +/// Dropping this future will immediately cancel the task. The task can fail if +/// the pool is shut down or if the task panics. In both cases the future will +/// resolve to [`SpawnError::Cancelled`]. +#[repr(transparent)] +#[derive(Debug)] +pub struct Run(tokio::sync::oneshot::Receiver); + +impl Run { + /// Abort the task + /// + /// Dropping the future will also abort the task. + pub fn abort(&mut self) { + self.0.close(); + } +} + +impl Future for Run { + type Output = std::result::Result; + + fn poll( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll { + // map a RecvError (other side was dropped) to a SpawnError::Shutdown + // + // The only way the receiver can be dropped is if the pool is shut down. + self.0.poll(cx).map_err(|_| SpawnError::Cancelled) + } +} + +impl From for std::io::Error { + fn from(e: SpawnError) -> Self { + std::io::Error::new(std::io::ErrorKind::Other, e) + } +} + +impl LocalPoolHandle { + /// Get the number of tasks in the queue + /// + /// This is *not* the number of tasks being executed, but the number of + /// tasks waiting to be scheduled for execution. If this number is high, + /// it indicates that the pool is very busy. + /// + /// You might want to use this to throttle or reject requests. + pub fn waiting_tasks(&self) -> usize { + self.send.len() + } + + /// Spawn a task in the pool and return a future that resolves when the task + /// is done. + /// + /// If you don't care about the result, prefer [`LocalPoolHandle::spawn_detached`] + /// since it is more efficient. + pub fn try_spawn(&self, gen: F) -> SpawnResult> + where + F: FnOnce() -> Fut + Send + 'static, + Fut: Future + 'static, + T: Send + 'static, + { + let (mut send_res, recv_res) = tokio::sync::oneshot::channel(); + let item = move || async move { + let fut = (gen)(); + tokio::select! { + // send the result to the receiver + res = fut => { send_res.send(res).ok(); } + // immediately stop the task if the receiver is dropped + _ = send_res.closed() => {} + } + }; + self.try_spawn_detached(item)?; + Ok(Run(recv_res)) + } + + /// Spawn a task in the pool. + /// + /// The task will run to completion unless the pool is shut down or the task + /// panics. In case of panic, the pool will either log the panic and continue + /// or immediately shut down, depending on the [`PanicMode`]. + pub fn try_spawn_detached(&self, gen: F) -> SpawnResult<()> + where + F: FnOnce() -> Fut + Send + 'static, + Fut: Future + 'static, + { + let gen: SpawnFn = Box::new(move || Box::pin(gen())); + self.try_spawn_detached_boxed(gen) + } + + /// Spawn a task in the pool and await the result. + /// + /// Like [`LocalPoolHandle::try_spawn`], but panics if the pool is shut down. + pub fn spawn(&self, gen: F) -> Run + where + F: FnOnce() -> Fut + Send + 'static, + Fut: Future + 'static, + T: Send + 'static, + { + self.try_spawn(gen).expect("pool is shut down") + } + + /// Spawn a task in the pool. + /// + /// Like [`LocalPoolHandle::try_spawn_detached`], but panics if the pool is shut down. + pub fn spawn_detached(&self, gen: F) + where + F: FnOnce() -> Fut + Send + 'static, + Fut: Future + 'static, + { + self.try_spawn_detached(gen).expect("pool is shut down") + } + + /// Spawn a task in the pool. + /// + /// This is like [`LocalPoolHandle::try_spawn_detached`], but assuming that the + /// generator function is already boxed. This is the lowest overhead way to + /// spawn a task in the pool. + pub fn try_spawn_detached_boxed(&self, gen: SpawnFn) -> SpawnResult<()> { + self.send + .send(Message::Execute(gen)) + .map_err(|_| SpawnError::Cancelled) + } +} + +/// Thread shutdown mode +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum ShutdownMode { + /// Finish all tasks and then stop + Finish, + /// Stop immediately + Stop, +} + +fn get_panic_info(panic: &Box) -> String { + if let Some(s) = panic.downcast_ref::<&str>() { + s.to_string() + } else if let Some(s) = panic.downcast_ref::() { + s.clone() + } else { + "Panic info unavailable".to_string() + } +} + +fn get_thread_name() -> String { + std::thread::current() + .name() + .unwrap_or("unnamed") + .to_string() +} + +/// A lightweight cancellation token +#[derive(Debug, Clone)] +struct CancellationToken { + inner: Arc, +} + +#[derive(Debug)] +struct CancellationTokenInner { + is_cancelled: AtomicBool, + notify: Notify, +} + +impl CancellationToken { + fn new() -> Self { + Self { + inner: Arc::new(CancellationTokenInner { + is_cancelled: AtomicBool::new(false), + notify: Notify::new(), + }), + } + } + + fn cancel(&self) { + if !self.inner.is_cancelled.swap(true, Ordering::SeqCst) { + self.inner.notify.notify_waiters(); + } + } + + async fn cancelled(&self) { + if self.is_cancelled() { + return; + } + + // Wait for notification if not cancelled + self.inner.notify.notified().await; + } + + fn is_cancelled(&self) -> bool { + self.inner.is_cancelled.load(Ordering::SeqCst) + } +} + +#[cfg(test)] +mod tests { + use std::{sync::atomic::AtomicU64, time::Duration}; + + use super::*; + + /// A struct that simulates a long running drop operation + #[derive(Debug)] + struct TestDrop(Option>); + + impl Drop for TestDrop { + fn drop(&mut self) { + // delay to make sure the drop is executed completely + std::thread::sleep(Duration::from_millis(100)); + // increment the drop counter + if let Some(counter) = self.0.take() { + counter.fetch_add(1, std::sync::atomic::Ordering::SeqCst); + } + } + } + + impl TestDrop { + fn new(counter: Arc) -> Self { + Self(Some(counter)) + } + + fn forget(mut self) { + self.0.take(); + } + } + + /// Create a non-send test future that captures a TestDrop instance + async fn delay_then_drop(x: TestDrop) { + tokio::time::sleep(Duration::from_millis(100)).await; + // drop x at the end. we will never get here when the future is + // no longer polled, but drop should still be called + drop(x); + } + + /// Use a TestDrop instance to test cancellation + async fn delay_then_forget(x: TestDrop, delay: Duration) { + tokio::time::sleep(delay).await; + x.forget(); + } + + #[tokio::test] + async fn test_drop() { + let _ = tracing_subscriber::fmt::try_init(); + let pool = LocalPool::new(Config::default()); + let counter = Arc::new(AtomicU64::new(0)); + let n = 4; + for _ in 0..n { + let td = TestDrop::new(counter.clone()); + pool.spawn_detached(move || delay_then_drop(td)); + } + drop(pool); + assert_eq!(counter.load(std::sync::atomic::Ordering::SeqCst), n); + } + + #[tokio::test] + async fn test_shutdown() { + let _ = tracing_subscriber::fmt::try_init(); + let pool = LocalPool::new(Config::default()); + let counter = Arc::new(AtomicU64::new(0)); + let n = 4; + for _ in 0..n { + let td = TestDrop::new(counter.clone()); + pool.spawn_detached(move || delay_then_drop(td)); + } + pool.finish().await; + assert_eq!(counter.load(std::sync::atomic::Ordering::SeqCst), n); + } + + #[tokio::test] + async fn test_cancel() { + let _ = tracing_subscriber::fmt::try_init(); + let pool = LocalPool::new(Config { + threads: 2, + ..Config::default() + }); + let c1 = Arc::new(AtomicU64::new(0)); + let td1 = TestDrop::new(c1.clone()); + let handle = pool.spawn(move || { + // this one will be aborted anyway, so use a long delay to make sure + // that it does not accidentally run to completion + delay_then_forget(td1, Duration::from_secs(10)) + }); + drop(handle); + let c2 = Arc::new(AtomicU64::new(0)); + let td2 = TestDrop::new(c2.clone()); + let _handle = pool.spawn(move || { + // this one will not be aborted, so use a short delay so the test + // does not take too long + delay_then_forget(td2, Duration::from_millis(100)) + }); + pool.finish().await; + // c1 will be aborted, so drop will run before forget, so the counter will be increased + assert_eq!(c1.load(std::sync::atomic::Ordering::SeqCst), 1); + // c2 will not be aborted, so drop will run after forget, so the counter will not be increased + assert_eq!(c2.load(std::sync::atomic::Ordering::SeqCst), 0); + } + + #[tokio::test] + #[should_panic] + #[ignore = "todo"] + async fn test_panic() { + let _ = tracing_subscriber::fmt::try_init(); + let pool = LocalPool::new(Config { + threads: 2, + ..Config::default() + }); + pool.spawn_detached(|| async { + panic!("test panic"); + }); + // we can't use shutdown here, because we need to allow time for the + // panic to happen. + pool.finish().await; + } +} diff --git a/iroh-cli/src/commands/start.rs b/iroh-cli/src/commands/start.rs index be61d611790..c6c91c49f89 100644 --- a/iroh-cli/src/commands/start.rs +++ b/iroh-cli/src/commands/start.rs @@ -83,7 +83,7 @@ where let client = node.client().clone(); - let mut command_task = node.local_pool_handle().spawn_pinned(move || { + let mut command_task = node.local_pool_handle().spawn(move || { async move { match command(client).await { Err(err) => Err(err), diff --git a/iroh/src/node.rs b/iroh/src/node.rs index f45e724eac5..65cf8b4c60f 100644 --- a/iroh/src/node.rs +++ b/iroh/src/node.rs @@ -44,6 +44,7 @@ use anyhow::{anyhow, Result}; use futures_lite::StreamExt; use iroh_base::key::PublicKey; use iroh_blobs::store::{GcMarkEvent, GcSweepEvent, Store as BaoStore}; +use iroh_blobs::util::local_pool::{LocalPool, LocalPoolHandle}; use iroh_blobs::{downloader::Downloader, protocol::Closed}; use iroh_gossip::dispatcher::GossipDispatcher; use iroh_gossip::net::Gossip; @@ -54,7 +55,6 @@ use quic_rpc::transport::ServerEndpoint as _; use quic_rpc::RpcServer; use tokio::task::JoinSet; use tokio_util::sync::CancellationToken; -use tokio_util::task::LocalPoolHandle; use tracing::{debug, error, info, warn}; use crate::node::{docs::DocsEngine, protocol::ProtocolMap}; @@ -107,10 +107,9 @@ struct NodeInner { secret_key: SecretKey, cancel_token: CancellationToken, client: crate::client::Iroh, - #[debug("rt")] - rt: LocalPoolHandle, downloader: Downloader, gossip_dispatcher: GossipDispatcher, + local_pool_handle: LocalPoolHandle, } /// In memory node. @@ -186,7 +185,7 @@ impl Node { /// Returns a reference to the used `LocalPoolHandle`. pub fn local_pool_handle(&self) -> &LocalPoolHandle { - &self.inner.rt + &self.inner.local_pool_handle } /// Get the relay server we are connected to. @@ -257,6 +256,7 @@ impl NodeInner { protocols: Arc, gc_policy: GcPolicy, gc_done_callback: Option>, + local_pool: LocalPool, ) { let (ipv4, ipv6) = self.endpoint.bound_sockets(); debug!( @@ -284,9 +284,7 @@ impl NodeInner { // Spawn a task for the garbage collection. if let GcPolicy::Interval(gc_period) = gc_policy { let inner = self.clone(); - let handle = self - .rt - .spawn_pinned(move || inner.run_gc_loop(gc_period, gc_done_callback)); + let handle = local_pool.spawn(move || inner.run_gc_loop(gc_period, gc_done_callback)); // We cannot spawn tasks that run on the local pool directly into the join set, // so instead we create a new task that supervises the local task. join_set.spawn({ @@ -377,6 +375,11 @@ impl NodeInner { // Abort remaining tasks. join_set.shutdown().await; + tracing::info!("Shutting down remaining tasks"); + + // Abort remaining local tasks. + tracing::info!("Shutting down local pool"); + local_pool.shutdown().await; } /// Shutdown the different parts of the node concurrently. diff --git a/iroh/src/node/builder.rs b/iroh/src/node/builder.rs index a1d5ee89aa9..2dd60a423f5 100644 --- a/iroh/src/node/builder.rs +++ b/iroh/src/node/builder.rs @@ -11,6 +11,7 @@ use iroh_base::key::SecretKey; use iroh_blobs::{ downloader::Downloader, store::{Map, Store as BaoStore}, + util::local_pool::{self, LocalPool, LocalPoolHandle, PanicMode}, }; use iroh_docs::engine::DefaultAuthorStorage; use iroh_docs::net::DOCS_ALPN; @@ -29,7 +30,7 @@ use iroh_net::{ use quic_rpc::transport::{boxed::BoxableServerEndpoint, quinn::QuinnServerEndpoint}; use serde::{Deserialize, Serialize}; -use tokio_util::{sync::CancellationToken, task::LocalPoolHandle}; +use tokio_util::sync::CancellationToken; use tracing::{debug, error_span, trace, Instrument}; use crate::{ @@ -454,7 +455,10 @@ where async fn build_inner(self) -> Result> { trace!("building node"); - let lp = LocalPoolHandle::new(num_cpus::get()); + let lp = LocalPool::new(local_pool::Config { + panic_mode: PanicMode::LogAndContinue, + ..Default::default() + }); let endpoint = { let mut transport_config = quinn::TransportConfig::default(); transport_config @@ -564,10 +568,10 @@ where secret_key: self.secret_key, client, cancel_token: CancellationToken::new(), - rt: lp, downloader, gossip, gossip_dispatcher, + local_pool_handle: lp.handle().clone(), }); let protocol_builder = ProtocolBuilder { @@ -577,6 +581,7 @@ where external_rpc: self.rpc_endpoint, gc_policy: self.gc_policy, gc_done_callback: self.gc_done_callback, + local_pool: lp, }; let protocol_builder = protocol_builder.register_iroh_protocols(); @@ -602,6 +607,7 @@ pub struct ProtocolBuilder { #[debug("callback")] gc_done_callback: Option>, gc_policy: GcPolicy, + local_pool: LocalPool, } impl ProtocolBuilder { @@ -678,7 +684,7 @@ impl ProtocolBuilder { /// Returns a reference to the used [`LocalPoolHandle`]. pub fn local_pool_handle(&self) -> &LocalPoolHandle { - &self.inner.rt + self.local_pool.handle() } /// Returns a reference to the [`Downloader`] used by the node. @@ -727,6 +733,7 @@ impl ProtocolBuilder { protocols, gc_done_callback, gc_policy, + local_pool: rt, } = self; let protocols = Arc::new(protocols); let node_id = inner.endpoint.node_id(); @@ -750,6 +757,7 @@ impl ProtocolBuilder { protocols.clone(), gc_policy, gc_done_callback, + rt, ) .instrument(error_span!("node", me=%node_id.fmt_short())); let task = tokio::task::spawn(fut); diff --git a/iroh/src/node/protocol.rs b/iroh/src/node/protocol.rs index a0f5b53be5c..ce342ab2497 100644 --- a/iroh/src/node/protocol.rs +++ b/iroh/src/node/protocol.rs @@ -3,6 +3,7 @@ use std::{any::Any, collections::BTreeMap, fmt, sync::Arc}; use anyhow::Result; use futures_lite::future::Boxed as BoxedFuture; use futures_util::future::join_all; +use iroh_blobs::util::local_pool::LocalPoolHandle; use iroh_net::endpoint::Connecting; /// Handler for incoming connections. @@ -78,12 +79,12 @@ impl ProtocolMap { #[derive(Debug)] pub(crate) struct BlobsProtocol { - rt: tokio_util::task::LocalPoolHandle, + rt: LocalPoolHandle, store: S, } impl BlobsProtocol { - pub fn new(store: S, rt: tokio_util::task::LocalPoolHandle) -> Self { + pub fn new(store: S, rt: LocalPoolHandle) -> Self { Self { rt, store } } } diff --git a/iroh/src/node/rpc.rs b/iroh/src/node/rpc.rs index f95a43ec1a9..0796a0d86eb 100644 --- a/iroh/src/node/rpc.rs +++ b/iroh/src/node/rpc.rs @@ -14,6 +14,7 @@ use iroh_blobs::format::collection::Collection; use iroh_blobs::get::db::DownloadProgress; use iroh_blobs::get::Stats; use iroh_blobs::store::{ConsistencyCheckProgress, ExportFormat, ImportProgress, MapEntry}; +use iroh_blobs::util::local_pool::LocalPoolHandle; use iroh_blobs::util::progress::ProgressSender; use iroh_blobs::util::SetTagOption; use iroh_blobs::BlobFormat; @@ -28,7 +29,7 @@ use iroh_net::relay::RelayUrl; use iroh_net::{Endpoint, NodeAddr, NodeId}; use quic_rpc::server::{RpcChannel, RpcServerError}; use tokio::task::JoinSet; -use tokio_util::{either::Either, task::LocalPoolHandle}; +use tokio_util::either::Either; use tracing::{debug, info, warn}; use crate::client::{ @@ -428,8 +429,8 @@ impl Handler { } } - fn rt(&self) -> LocalPoolHandle { - self.inner.rt.clone() + fn local_pool_handle(&self) -> LocalPoolHandle { + self.inner.local_pool_handle.clone() } async fn blob_list_impl(self, co: &Co>) -> io::Result<()> { @@ -565,7 +566,7 @@ impl Handler { // provide a little buffer so that we don't slow down the sender let (tx, rx) = flume::bounded(32); let tx2 = tx.clone(); - self.rt().spawn_pinned(|| async move { + self.local_pool_handle().spawn_detached(|| async move { if let Err(e) = self.blob_add_from_path0(msg, tx).await { tx2.send_async(AddProgress::Abort(e.into())).await.ok(); } @@ -577,7 +578,7 @@ impl Handler { // provide a little buffer so that we don't slow down the sender let (tx, rx) = flume::bounded(32); let tx2 = tx.clone(); - self.rt().spawn_pinned(|| async move { + self.local_pool_handle().spawn_detached(|| async move { if let Err(e) = self.doc_import_file0(msg, tx).await { tx2.send_async(crate::client::docs::ImportProgress::Abort(e.into())) .await @@ -661,7 +662,7 @@ impl Handler { fn doc_export_file(self, msg: ExportFileRequest) -> impl Stream { let (tx, rx) = flume::bounded(1024); let tx2 = tx.clone(); - self.rt().spawn_pinned(|| async move { + self.local_pool_handle().spawn_detached(|| async move { if let Err(e) = self.doc_export_file0(msg, tx).await { tx2.send_async(ExportProgress::Abort(e.into())).await.ok(); } @@ -704,7 +705,7 @@ impl Handler { let downloader = self.inner.downloader.clone(); let endpoint = self.inner.endpoint.clone(); let progress = FlumeProgressSender::new(sender); - self.inner.rt.spawn_pinned(move || async move { + self.local_pool_handle().spawn_detached(move || async move { if let Err(err) = download(&db, endpoint, &downloader, msg, progress.clone()).await { progress .send(DownloadProgress::Abort(err.into())) @@ -719,7 +720,7 @@ impl Handler { fn blob_export(self, msg: ExportRequest) -> impl Stream { let (tx, rx) = flume::bounded(1024); let progress = FlumeProgressSender::new(tx); - self.rt().spawn_pinned(move || async move { + self.local_pool_handle().spawn_detached(move || async move { let res = iroh_blobs::export::export( &self.inner.db, msg.hash, @@ -732,7 +733,7 @@ impl Handler { match res { Ok(()) => progress.send(ExportProgress::AllDone).await.ok(), Err(err) => progress.send(ExportProgress::Abort(err.into())).await.ok(), - } + }; }); rx.into_stream().map(ExportResponse) } @@ -925,7 +926,7 @@ impl Handler { let (tx, rx) = flume::bounded(32); let this = self.clone(); - self.rt().spawn_pinned(|| async move { + self.local_pool_handle().spawn_detached(|| async move { if let Err(err) = this.blob_add_stream0(msg, stream, tx.clone()).await { tx.send_async(AddProgress::Abort(err.into())).await.ok(); } @@ -994,7 +995,7 @@ impl Handler { ) -> impl Stream> + Send + 'static { let (tx, rx) = flume::bounded(RPC_BLOB_GET_CHANNEL_CAP); let db = self.inner.db.clone(); - self.inner.rt.spawn_pinned(move || async move { + self.local_pool_handle().spawn_detached(move || async move { if let Err(err) = read_loop(req, db, tx.clone(), RPC_BLOB_GET_CHUNK_SIZE).await { tx.send_async(RpcResult::Err(err.into())).await.ok(); } @@ -1058,7 +1059,7 @@ impl Handler { let (tx, rx) = flume::bounded(32); let mut conn_infos = self.inner.endpoint.connection_infos(); conn_infos.sort_by_key(|n| n.node_id.to_string()); - self.rt().spawn_pinned(|| async move { + self.local_pool_handle().spawn_detached(|| async move { for conn_info in conn_infos { tx.send_async(Ok(ConnectionsResponse { conn_info })) .await diff --git a/iroh/tests/provide.rs b/iroh/tests/provide.rs index 0a07b5d8e9b..461ad33e70e 100644 --- a/iroh/tests/provide.rs +++ b/iroh/tests/provide.rs @@ -153,7 +153,7 @@ async fn multiple_clients() -> Result<()> { let peer_id = node.node_id(); let content = content.to_vec(); - tasks.push(node.local_pool_handle().spawn_pinned(move || { + tasks.push(node.local_pool_handle().spawn(move || { async move { let (secret_key, peer) = get_options(peer_id, addrs); let expected_data = &content;