diff --git a/Cargo.lock b/Cargo.lock index 3e519a9aa26..18c031120b8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2564,6 +2564,7 @@ dependencies = [ "iroh-test", "num_cpus", "parking_lot", + "pin-project", "postcard", "proptest", "rand", @@ -5954,8 +5955,6 @@ dependencies = [ "bytes", "futures-core", "futures-sink", - "futures-util", - "hashbrown 0.14.5", "pin-project-lite", "slab", "tokio", diff --git a/iroh-blobs/Cargo.toml b/iroh-blobs/Cargo.toml index 46d6d317709..44df64134d0 100644 --- a/iroh-blobs/Cargo.toml +++ b/iroh-blobs/Cargo.toml @@ -33,6 +33,7 @@ iroh-metrics = { version = "0.20.0", path = "../iroh-metrics", default-features iroh-net = { version = "0.20.0", path = "../iroh-net" } num_cpus = "1.15.0" parking_lot = { version = "0.12.1", optional = true } +pin-project = "1.1.5" postcard = { version = "1", default-features = false, features = ["alloc", "use-std", "experimental-derive"] } rand = "0.8" range-collections = "0.4.0" @@ -45,7 +46,7 @@ smallvec = { version = "1.10.0", features = ["serde", "const_new"] } tempfile = { version = "3.10.0", optional = true } thiserror = "1" tokio = { version = "1", features = ["fs"] } -tokio-util = { version = "0.7", features = ["io-util", "io", "rt"] } +tokio-util = { version = "0.7", features = ["io-util", "io"] } tracing = "0.1" tracing-futures = "0.2.5" diff --git a/iroh-blobs/examples/provide-bytes.rs b/iroh-blobs/examples/provide-bytes.rs index 73f7e6d8e39..eab9fddf5ab 100644 --- a/iroh-blobs/examples/provide-bytes.rs +++ b/iroh-blobs/examples/provide-bytes.rs @@ -10,11 +10,10 @@ //! cargo run --example provide-bytes collection //! To provide a collection (multiple blobs) use anyhow::Result; -use tokio_util::task::LocalPoolHandle; use tracing::warn; use tracing_subscriber::{prelude::*, EnvFilter}; -use iroh_blobs::{format::collection::Collection, Hash}; +use iroh_blobs::{format::collection::Collection, util::local_pool::LocalPool, Hash}; mod connect; use connect::{make_and_write_certs, make_server_endpoint, CERT_PATH}; @@ -82,7 +81,7 @@ async fn main() -> Result<()> { println!("\nfetch the content using a stream by running the following example:\n\ncargo run --example fetch-stream {hash} \"{addr}\" {format}\n"); // create a new local pool handle with 1 worker thread - let lp = LocalPoolHandle::new(1); + let lp = LocalPool::single(); let accept_task = tokio::spawn(async move { while let Some(incoming) = endpoint.accept().await { diff --git a/iroh-blobs/src/downloader.rs b/iroh-blobs/src/downloader.rs index 4e02c947f4c..dd26a8bc6de 100644 --- a/iroh-blobs/src/downloader.rs +++ b/iroh-blobs/src/downloader.rs @@ -46,14 +46,14 @@ use tokio::{ sync::{mpsc, oneshot}, task::JoinSet, }; -use tokio_util::{sync::CancellationToken, task::LocalPoolHandle, time::delay_queue}; +use tokio_util::{sync::CancellationToken, time::delay_queue}; use tracing::{debug, error_span, trace, warn, Instrument}; use crate::{ get::{db::DownloadProgress, Stats}, metrics::Metrics, store::Store, - util::progress::ProgressSender, + util::{local_pool::LocalPoolHandle, progress::ProgressSender}, }; mod get; @@ -340,7 +340,7 @@ impl Downloader { service.run().instrument(error_span!("downloader", %me)) }; - rt.spawn_pinned(create_future); + rt.spawn_detached(create_future); Self { next_id: Arc::new(AtomicU64::new(0)), msg_tx, diff --git a/iroh-blobs/src/downloader/test.rs b/iroh-blobs/src/downloader/test.rs index ec54e0ef8c9..871b835ba7b 100644 --- a/iroh-blobs/src/downloader/test.rs +++ b/iroh-blobs/src/downloader/test.rs @@ -10,7 +10,10 @@ use iroh_net::key::SecretKey; use crate::{ get::{db::BlobId, progress::TransferState}, - util::progress::{FlumeProgressSender, IdGenerator}, + util::{ + local_pool::LocalPool, + progress::{FlumeProgressSender, IdGenerator}, + }, }; use super::*; @@ -23,7 +26,7 @@ impl Downloader { dialer: dialer::TestingDialer, getter: getter::TestingGetter, concurrency_limits: ConcurrencyLimits, - ) -> Self { + ) -> (Self, LocalPool) { Self::spawn_for_test_with_retry_config( dialer, getter, @@ -37,10 +40,11 @@ impl Downloader { getter: getter::TestingGetter, concurrency_limits: ConcurrencyLimits, retry_config: RetryConfig, - ) -> Self { + ) -> (Self, LocalPool) { let (msg_tx, msg_rx) = mpsc::channel(super::SERVICE_CHANNEL_CAPACITY); - LocalPoolHandle::new(1).spawn_pinned(move || async move { + let lp = LocalPool::default(); + lp.spawn_detached(move || async move { // we want to see the logs of the service let _guard = iroh_test::logging::setup(); @@ -48,10 +52,13 @@ impl Downloader { service.run().await }); - Downloader { - next_id: Arc::new(AtomicU64::new(0)), - msg_tx, - } + ( + Downloader { + next_id: Arc::new(AtomicU64::new(0)), + msg_tx, + }, + lp, + ) } } @@ -63,7 +70,8 @@ async fn smoke_test() { let getter = getter::TestingGetter::default(); let concurrency_limits = ConcurrencyLimits::default(); - let downloader = Downloader::spawn_for_test(dialer.clone(), getter.clone(), concurrency_limits); + let (downloader, _lp) = + Downloader::spawn_for_test(dialer.clone(), getter.clone(), concurrency_limits); // send a request and make sure the peer is requested the corresponding download let peer = SecretKey::generate().public(); @@ -88,7 +96,8 @@ async fn deduplication() { getter.set_request_duration(Duration::from_secs(1)); let concurrency_limits = ConcurrencyLimits::default(); - let downloader = Downloader::spawn_for_test(dialer.clone(), getter.clone(), concurrency_limits); + let (downloader, _lp) = + Downloader::spawn_for_test(dialer.clone(), getter.clone(), concurrency_limits); let peer = SecretKey::generate().public(); let kind: DownloadKind = HashAndFormat::raw(Hash::new([0u8; 32])).into(); @@ -119,7 +128,8 @@ async fn cancellation() { getter.set_request_duration(Duration::from_millis(500)); let concurrency_limits = ConcurrencyLimits::default(); - let downloader = Downloader::spawn_for_test(dialer.clone(), getter.clone(), concurrency_limits); + let (downloader, _lp) = + Downloader::spawn_for_test(dialer.clone(), getter.clone(), concurrency_limits); let peer = SecretKey::generate().public(); let kind_1: DownloadKind = HashAndFormat::raw(Hash::new([0u8; 32])).into(); @@ -158,7 +168,8 @@ async fn max_concurrent_requests_total() { ..Default::default() }; - let downloader = Downloader::spawn_for_test(dialer.clone(), getter.clone(), concurrency_limits); + let (downloader, _lp) = + Downloader::spawn_for_test(dialer.clone(), getter.clone(), concurrency_limits); // send the downloads let peer = SecretKey::generate().public(); @@ -201,7 +212,8 @@ async fn max_concurrent_requests_per_peer() { ..Default::default() }; - let downloader = Downloader::spawn_for_test(dialer.clone(), getter.clone(), concurrency_limits); + let (downloader, _lp) = + Downloader::spawn_for_test(dialer.clone(), getter.clone(), concurrency_limits); // send the downloads let peer = SecretKey::generate().public(); @@ -257,7 +269,8 @@ async fn concurrent_progress() { } .boxed() })); - let downloader = Downloader::spawn_for_test(dialer.clone(), getter.clone(), Default::default()); + let (downloader, _lp) = + Downloader::spawn_for_test(dialer.clone(), getter.clone(), Default::default()); let peer = SecretKey::generate().public(); let hash = Hash::new([0u8; 32]); @@ -341,7 +354,8 @@ async fn long_queue() { ..Default::default() }; - let downloader = Downloader::spawn_for_test(dialer.clone(), getter.clone(), concurrency_limits); + let (downloader, _lp) = + Downloader::spawn_for_test(dialer.clone(), getter.clone(), concurrency_limits); // send the downloads let nodes = [ SecretKey::generate().public(), @@ -370,7 +384,8 @@ async fn fail_while_running() { let _guard = iroh_test::logging::setup(); let dialer = dialer::TestingDialer::default(); let getter = getter::TestingGetter::default(); - let downloader = Downloader::spawn_for_test(dialer.clone(), getter.clone(), Default::default()); + let (downloader, _lp) = + Downloader::spawn_for_test(dialer.clone(), getter.clone(), Default::default()); let blob_fail = HashAndFormat::raw(Hash::new([1u8; 32])); let blob_success = HashAndFormat::raw(Hash::new([2u8; 32])); @@ -407,7 +422,8 @@ async fn retry_nodes_simple() { let _guard = iroh_test::logging::setup(); let dialer = dialer::TestingDialer::default(); let getter = getter::TestingGetter::default(); - let downloader = Downloader::spawn_for_test(dialer.clone(), getter.clone(), Default::default()); + let (downloader, _lp) = + Downloader::spawn_for_test(dialer.clone(), getter.clone(), Default::default()); let node = SecretKey::generate().public(); let dial_attempts = Arc::new(AtomicUsize::new(0)); let dial_attempts2 = dial_attempts.clone(); @@ -432,7 +448,7 @@ async fn retry_nodes_fail() { max_retries_per_node: 3, }; - let downloader = Downloader::spawn_for_test_with_retry_config( + let (downloader, _lp) = Downloader::spawn_for_test_with_retry_config( dialer.clone(), getter.clone(), Default::default(), @@ -472,7 +488,8 @@ async fn retry_nodes_jump_queue() { ..Default::default() }; - let downloader = Downloader::spawn_for_test(dialer.clone(), getter.clone(), concurrency_limits); + let (downloader, _lp) = + Downloader::spawn_for_test(dialer.clone(), getter.clone(), concurrency_limits); let good_node = SecretKey::generate().public(); let bad_node = SecretKey::generate().public(); diff --git a/iroh-blobs/src/provider.rs b/iroh-blobs/src/provider.rs index 7fe4e13004a..54b25151583 100644 --- a/iroh-blobs/src/provider.rs +++ b/iroh-blobs/src/provider.rs @@ -13,13 +13,13 @@ use iroh_io::stats::{ use iroh_io::{AsyncSliceReader, AsyncStreamWriter, TokioStreamWriter}; use iroh_net::endpoint::{self, RecvStream, SendStream}; use serde::{Deserialize, Serialize}; -use tokio_util::task::LocalPoolHandle; use tracing::{debug, debug_span, info, trace, warn}; use tracing_futures::Instrument; use crate::hashseq::parse_hash_seq; use crate::protocol::{GetRequest, RangeSpec, Request}; use crate::store::*; +use crate::util::local_pool::LocalPoolHandle; use crate::util::Tag; use crate::{BlobFormat, Hash}; @@ -302,7 +302,7 @@ pub async fn handle_connection( }; events.send(Event::ClientConnected { connection_id }).await; let db = db.clone(); - rt.spawn_pinned(|| { + rt.spawn_detached(|| { async move { if let Err(err) = handle_stream(db, reader, writer).await { warn!("error: {err:#?}",); diff --git a/iroh-blobs/src/store/bao_file.rs b/iroh-blobs/src/store/bao_file.rs index 962a7240b7e..b94c1960348 100644 --- a/iroh-blobs/src/store/bao_file.rs +++ b/iroh-blobs/src/store/bao_file.rs @@ -878,7 +878,8 @@ mod tests { decode_response_into_batch, local, make_wire_data, random_test_data, trickle, validate, }; use tokio::task::JoinSet; - use tokio_util::task::LocalPoolHandle; + + use crate::util::local_pool::LocalPool; use super::*; @@ -957,7 +958,7 @@ mod tests { )), hash.into(), ); - let local = LocalPoolHandle::new(4); + let local = LocalPool::default(); let mut tasks = Vec::new(); for i in 0..4 { let file = handle.writer(); @@ -968,7 +969,7 @@ mod tests { .map(io::Result::Ok) .boxed(); let trickle = TokioStreamReader::new(tokio_util::io::StreamReader::new(trickle)); - let task = local.spawn_pinned(move || async move { + let task = local.spawn(move || async move { decode_response_into_batch(hash, IROH_BLOCK_SIZE, chunk_ranges, trickle, file).await }); tasks.push(task); diff --git a/iroh-blobs/src/store/traits.rs b/iroh-blobs/src/store/traits.rs index 762e511cd37..4d5162ac04e 100644 --- a/iroh-blobs/src/store/traits.rs +++ b/iroh-blobs/src/store/traits.rs @@ -12,12 +12,12 @@ use iroh_base::rpc::RpcError; use iroh_io::AsyncSliceReader; use serde::{Deserialize, Serialize}; use tokio::io::AsyncRead; -use tokio_util::task::LocalPoolHandle; use crate::{ hashseq::parse_hash_seq, protocol::RangeSpec, util::{ + local_pool::{self, LocalPool}, progress::{BoxedProgressSender, IdGenerator, ProgressSender}, Tag, }, @@ -423,7 +423,10 @@ async fn validate_impl( use futures_buffered::BufferedStreamExt; let validate_parallelism: usize = num_cpus::get(); - let lp = LocalPoolHandle::new(validate_parallelism); + let lp = LocalPool::new(local_pool::Config { + threads: validate_parallelism, + ..Default::default() + }); let complete = store.blobs().await?.collect::>>()?; let partial = store .partial_blobs() @@ -437,7 +440,7 @@ async fn validate_impl( .map(|hash| { let store = store.clone(); let tx = tx.clone(); - lp.spawn_pinned(move || async move { + lp.spawn(move || async move { let entry = store .get(&hash) .await? @@ -486,7 +489,7 @@ async fn validate_impl( .map(|hash| { let store = store.clone(); let tx = tx.clone(); - lp.spawn_pinned(move || async move { + lp.spawn(move || async move { let entry = store .get(&hash) .await? diff --git a/iroh-blobs/src/util.rs b/iroh-blobs/src/util.rs index be43dfaaffd..6b70d24c9a5 100644 --- a/iroh-blobs/src/util.rs +++ b/iroh-blobs/src/util.rs @@ -19,6 +19,7 @@ pub mod progress; pub use mem_or_file::MemOrFile; mod sparse_mem_file; pub use sparse_mem_file::SparseMemFile; +pub mod local_pool; /// A tag #[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, From, Into)] diff --git a/iroh-blobs/src/util/local_pool.rs b/iroh-blobs/src/util/local_pool.rs new file mode 100644 index 00000000000..7473d70c872 --- /dev/null +++ b/iroh-blobs/src/util/local_pool.rs @@ -0,0 +1,654 @@ +//! A local task pool with proper shutdown +use futures_lite::FutureExt; +use std::{ + any::Any, + future::Future, + ops::Deref, + pin::Pin, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, +}; +use tokio::{ + sync::{Notify, Semaphore}, + task::{JoinError, JoinSet, LocalSet}, +}; + +type BoxedFut = Pin>>; +type SpawnFn = Box BoxedFut + Send + 'static>; + +enum Message { + /// Create a new task and execute it locally + Execute(SpawnFn), + /// Shutdown the thread after finishing all tasks + Finish, +} + +/// A local task pool with proper shutdown +/// +/// Unlike +/// [`LocalPoolHandle`](https://docs.rs/tokio-util/latest/tokio_util/task/struct.LocalPoolHandle.html), +/// this pool will join all its threads when dropped, ensuring that all Drop +/// implementations are run to completion. +/// +/// On drop, this pool will immediately cancel all *tasks* that are currently +/// being executed, and will wait for all threads to finish executing their +/// loops before returning. This means that all drop implementations will be +/// able to run to completion before drop exits. +/// +/// On [`LocalPool::finish`], this pool will notify all threads to shut down, +/// and then wait for all threads to finish executing their loops before +/// returning. This means that all currently executing tasks will be allowed to +/// run to completion. +#[derive(Debug)] +pub struct LocalPool { + threads: Vec>, + shutdown_sem: Arc, + cancel_token: CancellationToken, + handle: LocalPoolHandle, +} + +impl Deref for LocalPool { + type Target = LocalPoolHandle; + + fn deref(&self) -> &Self::Target { + &self.handle + } +} + +/// A handle to a [`LocalPool`] +#[derive(Debug, Clone)] +pub struct LocalPoolHandle { + /// The sender half of the channel used to send tasks to the pool + send: flume::Sender, +} + +/// What to do when a panic occurs in a pool thread +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum PanicMode { + /// Log the panic and continue + /// + /// The panic will be re-thrown when the pool is dropped. + LogAndContinue, + /// Log the panic and immediately shut down the pool. + /// + /// The panic will be re-thrown when the pool is dropped. + Shutdown, +} + +/// Local task pool configuration +#[derive(Clone, Debug)] +pub struct Config { + /// Number of threads in the pool + pub threads: usize, + /// Prefix for thread names + pub thread_name_prefix: &'static str, + /// Ignore panics in pool threads + pub panic_mode: PanicMode, +} + +impl Default for Config { + fn default() -> Self { + Self { + threads: num_cpus::get(), + thread_name_prefix: "local-pool", + panic_mode: PanicMode::Shutdown, + } + } +} + +impl Default for LocalPool { + fn default() -> Self { + Self::new(Default::default()) + } +} + +impl LocalPool { + /// Create a new local pool with a single std thread. + pub fn single() -> Self { + Self::new(Config { + threads: 1, + ..Default::default() + }) + } + + /// Create a new local pool with the given config. + /// + /// This will use the current tokio runtime handle, so it must be called + /// from within a tokio runtime. + pub fn new(config: Config) -> Self { + let Config { + threads, + thread_name_prefix, + panic_mode, + } = config; + let cancel_token = CancellationToken::new(); + let (send, recv) = flume::unbounded::(); + let shutdown_sem = Arc::new(Semaphore::new(0)); + let handle = tokio::runtime::Handle::current(); + let handles = (0..threads) + .map(|i| { + Self::spawn_pool_thread( + format!("{thread_name_prefix}-{i}"), + recv.clone(), + cancel_token.clone(), + panic_mode, + shutdown_sem.clone(), + handle.clone(), + ) + }) + .collect::>>() + .expect("invalid thread name"); + Self { + threads: handles, + handle: LocalPoolHandle { send }, + cancel_token, + shutdown_sem, + } + } + + /// Get a cheaply cloneable handle to the pool + /// + /// This is not strictly necessary since we implement deref for + /// LocalPoolHandle, but makes getting a handle more explicit. + pub fn handle(&self) -> &LocalPoolHandle { + &self.handle + } + + /// Spawn a new pool thread. + fn spawn_pool_thread( + thread_name: String, + recv: flume::Receiver, + cancel_token: CancellationToken, + panic_mode: PanicMode, + shutdown_sem: Arc, + handle: tokio::runtime::Handle, + ) -> std::io::Result> { + std::thread::Builder::new() + .name(thread_name) + .spawn(move || { + let mut s = JoinSet::new(); + let mut last_panic = None; + let mut handle_join = |res: Option>| -> bool { + if let Some(Err(e)) = res { + if let Ok(panic) = e.try_into_panic() { + let panic_info = get_panic_info(&panic); + let thread_name = get_thread_name(); + tracing::error!( + "Panic in local pool thread: {}\n{}", + thread_name, + panic_info + ); + last_panic = Some(panic); + } + } + panic_mode == PanicMode::LogAndContinue || last_panic.is_none() + }; + let ls = LocalSet::new(); + let shutdown_mode = handle.block_on(ls.run_until(async { + loop { + tokio::select! { + // poll the set of futures + res = s.join_next(), if !s.is_empty() => { + if !handle_join(res) { + break ShutdownMode::Stop; + } + }, + // if the cancel token is cancelled, break the loop immediately + _ = cancel_token.cancelled() => break ShutdownMode::Stop, + // if we receive a message, execute it + msg = recv.recv_async() => { + match msg { + // just push into the FuturesUnordered + Ok(Message::Execute(f)) => { + s.spawn_local((f)()); + } + // break with optional semaphore + Ok(Message::Finish) => break ShutdownMode::Finish, + // if the sender is dropped, break the loop immediately + Err(flume::RecvError::Disconnected) => break ShutdownMode::Stop, + } + } + } + } + })); + // soft shutdown mode is just like normal running, except that + // we don't add any more tasks and stop when there are no more + // tasks to run. + if shutdown_mode == ShutdownMode::Finish { + // somebody is asking for a clean shutdown, wait for all tasks to finish + handle.block_on(ls.run_until(async { + loop { + tokio::select! { + res = s.join_next() => { + if res.is_none() || !handle_join(res) { + break; + } + } + _ = cancel_token.cancelled() => break, + } + } + })); + } + // Always add the permit. If nobody is waiting for it, it does + // no harm. + shutdown_sem.add_permits(1); + if let Some(_panic) = last_panic { + // std::panic::resume_unwind(panic); + } + }) + } + + /// A future that resolves when the pool is cancelled + pub async fn cancelled(&self) { + self.cancel_token.cancelled().await + } + + /// Immediately stop polling all tasks and wait for all threads to finish. + /// + /// This is like droo, but waits for thread completion asynchronously. + /// + /// If there was a panic on any of the threads, it will be re-thrown here. + pub async fn shutdown(self) { + self.cancel_token.cancel(); + self.await_thread_completion().await; + // just make it explicit that this is where drop runs + drop(self); + } + + /// Gently shut down the pool + /// + /// Notifies all the pool threads to shut down and waits for them to finish. + /// + /// If you just want to drop the pool without giving the threads a chance to + /// process their remaining tasks, just use [`Self::shutdown`]. + /// + /// If you want to wait for only a limited time for the tasks to finish, + /// you can race this function with a timeout. + pub async fn finish(self) { + // we assume that there are exactly as many threads as there are handles. + // also, we assume that the threads are still running. + for _ in 0..self.threads_u32() { + println!("sending shutdown message"); + // send the shutdown message + // sending will fail if all threads are already finished, but + // in that case we don't need to do anything. + // + // Threads will add a permit in any case, so await_thread_completion + // will then immediately return. + self.send.send(Message::Finish).ok(); + } + self.await_thread_completion().await; + } + + fn threads_u32(&self) -> u32 { + self.threads + .len() + .try_into() + .expect("invalid number of threads") + } + + async fn await_thread_completion(&self) { + // wait for all threads to finish. + // Each thread will add a permit to the semaphore. + let wait_for_semaphore = async move { + let _ = self + .shutdown_sem + .acquire_many(self.threads_u32()) + .await + .expect("semaphore closed"); + }; + // race the semaphore wait with the cancel token in case somebody + // cancels the pool while we are waiting. + tokio::select! { + _ = wait_for_semaphore => {} + _ = self.cancel_token.cancelled() => {} + } + } +} + +impl Drop for LocalPool { + fn drop(&mut self) { + self.cancel_token.cancel(); + let current_thread_id = std::thread::current().id(); + for handle in self.threads.drain(..) { + // we have no control over from where Drop is called, especially + // if the pool ends up in an Arc. So we need to check if we are + // dropping from within a pool thread and skip it in that case. + if handle.thread().id() == current_thread_id { + tracing::error!("Dropping LocalPool from within a pool thread."); + continue; + } + // Log any panics and resume them + if let Err(panic) = handle.join() { + let panic_info = get_panic_info(&panic); + let thread_name = get_thread_name(); + tracing::error!("Error joining thread: {}\n{}", thread_name, panic_info); + // std::panic::resume_unwind(panic); + } + } + } +} + +/// Errors for spawn failures +#[derive(thiserror::Error, Debug)] +pub enum SpawnError { + /// Task was dropped, either due to a panic or because the pool was shut down. + #[error("cancelled")] + Cancelled, +} + +type SpawnResult = std::result::Result; + +/// Future returned by [`LocalPoolHandle::spawn`] and [`LocalPoolHandle::try_spawn`]. +/// +/// Dropping this future will immediately cancel the task. The task can fail if +/// the pool is shut down or if the task panics. In both cases the future will +/// resolve to [`SpawnError::Cancelled`]. +#[repr(transparent)] +#[derive(Debug)] +pub struct Run(tokio::sync::oneshot::Receiver); + +impl Run { + /// Abort the task + /// + /// Dropping the future will also abort the task. + pub fn abort(&mut self) { + self.0.close(); + } +} + +impl Future for Run { + type Output = std::result::Result; + + fn poll( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll { + // map a RecvError (other side was dropped) to a SpawnError::Shutdown + // + // The only way the receiver can be dropped is if the pool is shut down. + self.0.poll(cx).map_err(|_| SpawnError::Cancelled) + } +} + +impl From for std::io::Error { + fn from(e: SpawnError) -> Self { + std::io::Error::new(std::io::ErrorKind::Other, e) + } +} + +impl LocalPoolHandle { + /// Get the number of tasks in the queue + /// + /// This is *not* the number of tasks being executed, but the number of + /// tasks waiting to be scheduled for execution. If this number is high, + /// it indicates that the pool is very busy. + /// + /// You might want to use this to throttle or reject requests. + pub fn waiting_tasks(&self) -> usize { + self.send.len() + } + + /// Spawn a task in the pool and return a future that resolves when the task + /// is done. + /// + /// If you don't care about the result, prefer [`LocalPoolHandle::spawn_detached`] + /// since it is more efficient. + pub fn try_spawn(&self, gen: F) -> SpawnResult> + where + F: FnOnce() -> Fut + Send + 'static, + Fut: Future + 'static, + T: Send + 'static, + { + let (mut send_res, recv_res) = tokio::sync::oneshot::channel(); + let item = move || async move { + let fut = (gen)(); + tokio::select! { + // send the result to the receiver + res = fut => { send_res.send(res).ok(); } + // immediately stop the task if the receiver is dropped + _ = send_res.closed() => {} + } + }; + self.try_spawn_detached(item)?; + Ok(Run(recv_res)) + } + + /// Spawn a task in the pool. + /// + /// The task will run to completion unless the pool is shut down or the task + /// panics. In case of panic, the pool will either log the panic and continue + /// or immediately shut down, depending on the [`PanicMode`]. + pub fn try_spawn_detached(&self, gen: F) -> SpawnResult<()> + where + F: FnOnce() -> Fut + Send + 'static, + Fut: Future + 'static, + { + let gen: SpawnFn = Box::new(move || Box::pin(gen())); + self.try_spawn_detached_boxed(gen) + } + + /// Spawn a task in the pool and await the result. + /// + /// Like [`LocalPoolHandle::try_spawn`], but panics if the pool is shut down. + pub fn spawn(&self, gen: F) -> Run + where + F: FnOnce() -> Fut + Send + 'static, + Fut: Future + 'static, + T: Send + 'static, + { + self.try_spawn(gen).expect("pool is shut down") + } + + /// Spawn a task in the pool. + /// + /// Like [`LocalPoolHandle::try_spawn_detached`], but panics if the pool is shut down. + pub fn spawn_detached(&self, gen: F) + where + F: FnOnce() -> Fut + Send + 'static, + Fut: Future + 'static, + { + self.try_spawn_detached(gen).expect("pool is shut down") + } + + /// Spawn a task in the pool. + /// + /// This is like [`LocalPoolHandle::try_spawn_detached`], but assuming that the + /// generator function is already boxed. This is the lowest overhead way to + /// spawn a task in the pool. + pub fn try_spawn_detached_boxed(&self, gen: SpawnFn) -> SpawnResult<()> { + self.send + .send(Message::Execute(gen)) + .map_err(|_| SpawnError::Cancelled) + } +} + +/// Thread shutdown mode +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum ShutdownMode { + /// Finish all tasks and then stop + Finish, + /// Stop immediately + Stop, +} + +fn get_panic_info(panic: &Box) -> String { + if let Some(s) = panic.downcast_ref::<&str>() { + s.to_string() + } else if let Some(s) = panic.downcast_ref::() { + s.clone() + } else { + "Panic info unavailable".to_string() + } +} + +fn get_thread_name() -> String { + std::thread::current() + .name() + .unwrap_or("unnamed") + .to_string() +} + +/// A lightweight cancellation token +#[derive(Debug, Clone)] +struct CancellationToken { + inner: Arc, +} + +#[derive(Debug)] +struct CancellationTokenInner { + is_cancelled: AtomicBool, + notify: Notify, +} + +impl CancellationToken { + fn new() -> Self { + Self { + inner: Arc::new(CancellationTokenInner { + is_cancelled: AtomicBool::new(false), + notify: Notify::new(), + }), + } + } + + fn cancel(&self) { + if !self.inner.is_cancelled.swap(true, Ordering::SeqCst) { + self.inner.notify.notify_waiters(); + } + } + + async fn cancelled(&self) { + if self.is_cancelled() { + return; + } + + // Wait for notification if not cancelled + self.inner.notify.notified().await; + } + + fn is_cancelled(&self) -> bool { + self.inner.is_cancelled.load(Ordering::SeqCst) + } +} + +#[cfg(test)] +mod tests { + use std::{sync::atomic::AtomicU64, time::Duration}; + + use super::*; + + /// A struct that simulates a long running drop operation + #[derive(Debug)] + struct TestDrop(Option>); + + impl Drop for TestDrop { + fn drop(&mut self) { + // delay to make sure the drop is executed completely + std::thread::sleep(Duration::from_millis(100)); + // increment the drop counter + if let Some(counter) = self.0.take() { + counter.fetch_add(1, std::sync::atomic::Ordering::SeqCst); + } + } + } + + impl TestDrop { + fn new(counter: Arc) -> Self { + Self(Some(counter)) + } + + fn forget(mut self) { + self.0.take(); + } + } + + /// Create a non-send test future that captures a TestDrop instance + async fn delay_then_drop(x: TestDrop) { + tokio::time::sleep(Duration::from_millis(100)).await; + // drop x at the end. we will never get here when the future is + // no longer polled, but drop should still be called + drop(x); + } + + /// Use a TestDrop instance to test cancellation + async fn delay_then_forget(x: TestDrop, delay: Duration) { + tokio::time::sleep(delay).await; + x.forget(); + } + + #[tokio::test] + async fn test_drop() { + let _ = tracing_subscriber::fmt::try_init(); + let pool = LocalPool::new(Config::default()); + let counter = Arc::new(AtomicU64::new(0)); + let n = 4; + for _ in 0..n { + let td = TestDrop::new(counter.clone()); + pool.spawn_detached(move || delay_then_drop(td)); + } + drop(pool); + assert_eq!(counter.load(std::sync::atomic::Ordering::SeqCst), n); + } + + #[tokio::test] + async fn test_shutdown() { + let _ = tracing_subscriber::fmt::try_init(); + let pool = LocalPool::new(Config::default()); + let counter = Arc::new(AtomicU64::new(0)); + let n = 4; + for _ in 0..n { + let td = TestDrop::new(counter.clone()); + pool.spawn_detached(move || delay_then_drop(td)); + } + pool.finish().await; + assert_eq!(counter.load(std::sync::atomic::Ordering::SeqCst), n); + } + + #[tokio::test] + async fn test_cancel() { + let _ = tracing_subscriber::fmt::try_init(); + let pool = LocalPool::new(Config { + threads: 2, + ..Config::default() + }); + let c1 = Arc::new(AtomicU64::new(0)); + let td1 = TestDrop::new(c1.clone()); + let handle = pool.spawn(move || { + // this one will be aborted anyway, so use a long delay to make sure + // that it does not accidentally run to completion + delay_then_forget(td1, Duration::from_secs(10)) + }); + drop(handle); + let c2 = Arc::new(AtomicU64::new(0)); + let td2 = TestDrop::new(c2.clone()); + let _handle = pool.spawn(move || { + // this one will not be aborted, so use a short delay so the test + // does not take too long + delay_then_forget(td2, Duration::from_millis(100)) + }); + pool.finish().await; + // c1 will be aborted, so drop will run before forget, so the counter will be increased + assert_eq!(c1.load(std::sync::atomic::Ordering::SeqCst), 1); + // c2 will not be aborted, so drop will run after forget, so the counter will not be increased + assert_eq!(c2.load(std::sync::atomic::Ordering::SeqCst), 0); + } + + #[tokio::test] + #[should_panic] + #[ignore = "todo"] + async fn test_panic() { + let _ = tracing_subscriber::fmt::try_init(); + let pool = LocalPool::new(Config { + threads: 2, + ..Config::default() + }); + pool.spawn_detached(|| async { + panic!("test panic"); + }); + // we can't use shutdown here, because we need to allow time for the + // panic to happen. + pool.finish().await; + } +} diff --git a/iroh-cli/src/commands/start.rs b/iroh-cli/src/commands/start.rs index 39d5e44e50a..39449fd15cf 100644 --- a/iroh-cli/src/commands/start.rs +++ b/iroh-cli/src/commands/start.rs @@ -89,7 +89,7 @@ where let client = node.client().clone(); - let mut command_task = node.local_pool_handle().spawn_pinned(move || { + let mut command_task = node.local_pool_handle().spawn(move || { async move { match command(client).await { Err(err) => Err(err), diff --git a/iroh/src/node.rs b/iroh/src/node.rs index f45e724eac5..65cf8b4c60f 100644 --- a/iroh/src/node.rs +++ b/iroh/src/node.rs @@ -44,6 +44,7 @@ use anyhow::{anyhow, Result}; use futures_lite::StreamExt; use iroh_base::key::PublicKey; use iroh_blobs::store::{GcMarkEvent, GcSweepEvent, Store as BaoStore}; +use iroh_blobs::util::local_pool::{LocalPool, LocalPoolHandle}; use iroh_blobs::{downloader::Downloader, protocol::Closed}; use iroh_gossip::dispatcher::GossipDispatcher; use iroh_gossip::net::Gossip; @@ -54,7 +55,6 @@ use quic_rpc::transport::ServerEndpoint as _; use quic_rpc::RpcServer; use tokio::task::JoinSet; use tokio_util::sync::CancellationToken; -use tokio_util::task::LocalPoolHandle; use tracing::{debug, error, info, warn}; use crate::node::{docs::DocsEngine, protocol::ProtocolMap}; @@ -107,10 +107,9 @@ struct NodeInner { secret_key: SecretKey, cancel_token: CancellationToken, client: crate::client::Iroh, - #[debug("rt")] - rt: LocalPoolHandle, downloader: Downloader, gossip_dispatcher: GossipDispatcher, + local_pool_handle: LocalPoolHandle, } /// In memory node. @@ -186,7 +185,7 @@ impl Node { /// Returns a reference to the used `LocalPoolHandle`. pub fn local_pool_handle(&self) -> &LocalPoolHandle { - &self.inner.rt + &self.inner.local_pool_handle } /// Get the relay server we are connected to. @@ -257,6 +256,7 @@ impl NodeInner { protocols: Arc, gc_policy: GcPolicy, gc_done_callback: Option>, + local_pool: LocalPool, ) { let (ipv4, ipv6) = self.endpoint.bound_sockets(); debug!( @@ -284,9 +284,7 @@ impl NodeInner { // Spawn a task for the garbage collection. if let GcPolicy::Interval(gc_period) = gc_policy { let inner = self.clone(); - let handle = self - .rt - .spawn_pinned(move || inner.run_gc_loop(gc_period, gc_done_callback)); + let handle = local_pool.spawn(move || inner.run_gc_loop(gc_period, gc_done_callback)); // We cannot spawn tasks that run on the local pool directly into the join set, // so instead we create a new task that supervises the local task. join_set.spawn({ @@ -377,6 +375,11 @@ impl NodeInner { // Abort remaining tasks. join_set.shutdown().await; + tracing::info!("Shutting down remaining tasks"); + + // Abort remaining local tasks. + tracing::info!("Shutting down local pool"); + local_pool.shutdown().await; } /// Shutdown the different parts of the node concurrently. diff --git a/iroh/src/node/builder.rs b/iroh/src/node/builder.rs index a1d5ee89aa9..2dd60a423f5 100644 --- a/iroh/src/node/builder.rs +++ b/iroh/src/node/builder.rs @@ -11,6 +11,7 @@ use iroh_base::key::SecretKey; use iroh_blobs::{ downloader::Downloader, store::{Map, Store as BaoStore}, + util::local_pool::{self, LocalPool, LocalPoolHandle, PanicMode}, }; use iroh_docs::engine::DefaultAuthorStorage; use iroh_docs::net::DOCS_ALPN; @@ -29,7 +30,7 @@ use iroh_net::{ use quic_rpc::transport::{boxed::BoxableServerEndpoint, quinn::QuinnServerEndpoint}; use serde::{Deserialize, Serialize}; -use tokio_util::{sync::CancellationToken, task::LocalPoolHandle}; +use tokio_util::sync::CancellationToken; use tracing::{debug, error_span, trace, Instrument}; use crate::{ @@ -454,7 +455,10 @@ where async fn build_inner(self) -> Result> { trace!("building node"); - let lp = LocalPoolHandle::new(num_cpus::get()); + let lp = LocalPool::new(local_pool::Config { + panic_mode: PanicMode::LogAndContinue, + ..Default::default() + }); let endpoint = { let mut transport_config = quinn::TransportConfig::default(); transport_config @@ -564,10 +568,10 @@ where secret_key: self.secret_key, client, cancel_token: CancellationToken::new(), - rt: lp, downloader, gossip, gossip_dispatcher, + local_pool_handle: lp.handle().clone(), }); let protocol_builder = ProtocolBuilder { @@ -577,6 +581,7 @@ where external_rpc: self.rpc_endpoint, gc_policy: self.gc_policy, gc_done_callback: self.gc_done_callback, + local_pool: lp, }; let protocol_builder = protocol_builder.register_iroh_protocols(); @@ -602,6 +607,7 @@ pub struct ProtocolBuilder { #[debug("callback")] gc_done_callback: Option>, gc_policy: GcPolicy, + local_pool: LocalPool, } impl ProtocolBuilder { @@ -678,7 +684,7 @@ impl ProtocolBuilder { /// Returns a reference to the used [`LocalPoolHandle`]. pub fn local_pool_handle(&self) -> &LocalPoolHandle { - &self.inner.rt + self.local_pool.handle() } /// Returns a reference to the [`Downloader`] used by the node. @@ -727,6 +733,7 @@ impl ProtocolBuilder { protocols, gc_done_callback, gc_policy, + local_pool: rt, } = self; let protocols = Arc::new(protocols); let node_id = inner.endpoint.node_id(); @@ -750,6 +757,7 @@ impl ProtocolBuilder { protocols.clone(), gc_policy, gc_done_callback, + rt, ) .instrument(error_span!("node", me=%node_id.fmt_short())); let task = tokio::task::spawn(fut); diff --git a/iroh/src/node/protocol.rs b/iroh/src/node/protocol.rs index a0f5b53be5c..ce342ab2497 100644 --- a/iroh/src/node/protocol.rs +++ b/iroh/src/node/protocol.rs @@ -3,6 +3,7 @@ use std::{any::Any, collections::BTreeMap, fmt, sync::Arc}; use anyhow::Result; use futures_lite::future::Boxed as BoxedFuture; use futures_util::future::join_all; +use iroh_blobs::util::local_pool::LocalPoolHandle; use iroh_net::endpoint::Connecting; /// Handler for incoming connections. @@ -78,12 +79,12 @@ impl ProtocolMap { #[derive(Debug)] pub(crate) struct BlobsProtocol { - rt: tokio_util::task::LocalPoolHandle, + rt: LocalPoolHandle, store: S, } impl BlobsProtocol { - pub fn new(store: S, rt: tokio_util::task::LocalPoolHandle) -> Self { + pub fn new(store: S, rt: LocalPoolHandle) -> Self { Self { rt, store } } } diff --git a/iroh/src/node/rpc.rs b/iroh/src/node/rpc.rs index f95a43ec1a9..0796a0d86eb 100644 --- a/iroh/src/node/rpc.rs +++ b/iroh/src/node/rpc.rs @@ -14,6 +14,7 @@ use iroh_blobs::format::collection::Collection; use iroh_blobs::get::db::DownloadProgress; use iroh_blobs::get::Stats; use iroh_blobs::store::{ConsistencyCheckProgress, ExportFormat, ImportProgress, MapEntry}; +use iroh_blobs::util::local_pool::LocalPoolHandle; use iroh_blobs::util::progress::ProgressSender; use iroh_blobs::util::SetTagOption; use iroh_blobs::BlobFormat; @@ -28,7 +29,7 @@ use iroh_net::relay::RelayUrl; use iroh_net::{Endpoint, NodeAddr, NodeId}; use quic_rpc::server::{RpcChannel, RpcServerError}; use tokio::task::JoinSet; -use tokio_util::{either::Either, task::LocalPoolHandle}; +use tokio_util::either::Either; use tracing::{debug, info, warn}; use crate::client::{ @@ -428,8 +429,8 @@ impl Handler { } } - fn rt(&self) -> LocalPoolHandle { - self.inner.rt.clone() + fn local_pool_handle(&self) -> LocalPoolHandle { + self.inner.local_pool_handle.clone() } async fn blob_list_impl(self, co: &Co>) -> io::Result<()> { @@ -565,7 +566,7 @@ impl Handler { // provide a little buffer so that we don't slow down the sender let (tx, rx) = flume::bounded(32); let tx2 = tx.clone(); - self.rt().spawn_pinned(|| async move { + self.local_pool_handle().spawn_detached(|| async move { if let Err(e) = self.blob_add_from_path0(msg, tx).await { tx2.send_async(AddProgress::Abort(e.into())).await.ok(); } @@ -577,7 +578,7 @@ impl Handler { // provide a little buffer so that we don't slow down the sender let (tx, rx) = flume::bounded(32); let tx2 = tx.clone(); - self.rt().spawn_pinned(|| async move { + self.local_pool_handle().spawn_detached(|| async move { if let Err(e) = self.doc_import_file0(msg, tx).await { tx2.send_async(crate::client::docs::ImportProgress::Abort(e.into())) .await @@ -661,7 +662,7 @@ impl Handler { fn doc_export_file(self, msg: ExportFileRequest) -> impl Stream { let (tx, rx) = flume::bounded(1024); let tx2 = tx.clone(); - self.rt().spawn_pinned(|| async move { + self.local_pool_handle().spawn_detached(|| async move { if let Err(e) = self.doc_export_file0(msg, tx).await { tx2.send_async(ExportProgress::Abort(e.into())).await.ok(); } @@ -704,7 +705,7 @@ impl Handler { let downloader = self.inner.downloader.clone(); let endpoint = self.inner.endpoint.clone(); let progress = FlumeProgressSender::new(sender); - self.inner.rt.spawn_pinned(move || async move { + self.local_pool_handle().spawn_detached(move || async move { if let Err(err) = download(&db, endpoint, &downloader, msg, progress.clone()).await { progress .send(DownloadProgress::Abort(err.into())) @@ -719,7 +720,7 @@ impl Handler { fn blob_export(self, msg: ExportRequest) -> impl Stream { let (tx, rx) = flume::bounded(1024); let progress = FlumeProgressSender::new(tx); - self.rt().spawn_pinned(move || async move { + self.local_pool_handle().spawn_detached(move || async move { let res = iroh_blobs::export::export( &self.inner.db, msg.hash, @@ -732,7 +733,7 @@ impl Handler { match res { Ok(()) => progress.send(ExportProgress::AllDone).await.ok(), Err(err) => progress.send(ExportProgress::Abort(err.into())).await.ok(), - } + }; }); rx.into_stream().map(ExportResponse) } @@ -925,7 +926,7 @@ impl Handler { let (tx, rx) = flume::bounded(32); let this = self.clone(); - self.rt().spawn_pinned(|| async move { + self.local_pool_handle().spawn_detached(|| async move { if let Err(err) = this.blob_add_stream0(msg, stream, tx.clone()).await { tx.send_async(AddProgress::Abort(err.into())).await.ok(); } @@ -994,7 +995,7 @@ impl Handler { ) -> impl Stream> + Send + 'static { let (tx, rx) = flume::bounded(RPC_BLOB_GET_CHANNEL_CAP); let db = self.inner.db.clone(); - self.inner.rt.spawn_pinned(move || async move { + self.local_pool_handle().spawn_detached(move || async move { if let Err(err) = read_loop(req, db, tx.clone(), RPC_BLOB_GET_CHUNK_SIZE).await { tx.send_async(RpcResult::Err(err.into())).await.ok(); } @@ -1058,7 +1059,7 @@ impl Handler { let (tx, rx) = flume::bounded(32); let mut conn_infos = self.inner.endpoint.connection_infos(); conn_infos.sort_by_key(|n| n.node_id.to_string()); - self.rt().spawn_pinned(|| async move { + self.local_pool_handle().spawn_detached(|| async move { for conn_info in conn_infos { tx.send_async(Ok(ConnectionsResponse { conn_info })) .await diff --git a/iroh/tests/provide.rs b/iroh/tests/provide.rs index 0a07b5d8e9b..461ad33e70e 100644 --- a/iroh/tests/provide.rs +++ b/iroh/tests/provide.rs @@ -153,7 +153,7 @@ async fn multiple_clients() -> Result<()> { let peer_id = node.node_id(); let content = content.to_vec(); - tasks.push(node.local_pool_handle().spawn_pinned(move || { + tasks.push(node.local_pool_handle().spawn(move || { async move { let (secret_key, peer) = get_options(peer_id, addrs); let expected_data = &content;