From 460b060830413e353bdb1c789dd5d5a490c7a9c5 Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Tue, 19 Nov 2024 07:51:14 +0800 Subject: [PATCH] more comments --- src/daft-parquet/src/semaphore.rs | 7 +++++ src/daft-parquet/src/stream_reader.rs | 37 +++++++++++++++++---------- 2 files changed, 30 insertions(+), 14 deletions(-) diff --git a/src/daft-parquet/src/semaphore.rs b/src/daft-parquet/src/semaphore.rs index e645984159..4fb9754763 100644 --- a/src/daft-parquet/src/semaphore.rs +++ b/src/daft-parquet/src/semaphore.rs @@ -8,6 +8,9 @@ use std::{ use tokio::sync::Semaphore; +/// A semaphore that dynamically adjusts the number of permits based on the +/// observed compute and IO times. +/// Used to control the number of concurrent parquet deserialization tasks. pub(crate) struct DynamicParquetReadingSemaphore { semaphore: Arc, timings: Mutex, @@ -45,7 +48,9 @@ impl RunningAverage { } impl DynamicParquetReadingSemaphore { + /// The ratio of compute time to IO time that allows for permit increase. This is a minimum value. const COMPUTE_THRESHOLD: f64 = 1.2; + /// The ratio of waiting time to compute time that allows for permit increase. This is a maximum value. const WAIT_THRESHOLD: f64 = 0.5; pub(crate) fn new(max_permits: usize) -> Arc { @@ -108,6 +113,8 @@ impl DynamicParquetReadingSemaphore { let wait_ratio = wait_avg.average.as_millis() as f64 / compute_avg.average.as_millis() as f64; + // Only increase permits if compute time is significantly higher than IO time, + // and waiting time is not too high. if compute_ratio > Self::COMPUTE_THRESHOLD && wait_ratio < Self::WAIT_THRESHOLD { let current_permits = self.current_permits.load(Ordering::Relaxed); let optimal_permits = (compute_ratio.ceil() as usize).min(self.max_permits); diff --git a/src/daft-parquet/src/stream_reader.rs b/src/daft-parquet/src/stream_reader.rs index ae9f8ae0ce..feacd42b28 100644 --- a/src/daft-parquet/src/stream_reader.rs +++ b/src/daft-parquet/src/stream_reader.rs @@ -191,6 +191,7 @@ fn arrow_chunk_to_table( Ok(table) } +/// Spawns a task that reads the column iterators and converts them into a table. #[allow(clippy::too_many_arguments)] pub fn spawn_column_iters_to_table_task( arr_iters: ArrowChunkIters, @@ -203,9 +204,9 @@ pub fn spawn_column_iters_to_table_task( delete_rows: Option>, output_sender: tokio::sync::mpsc::Sender>, permit: tokio::sync::OwnedSemaphorePermit, - semaphore: Arc, + record_compute_times_fn: impl Fn(Duration, Duration) + Send + Sync + 'static, ) { - let (senders, mut receivers): (Vec<_>, Vec<_>) = arr_iters + let (arrow_chunk_senders, mut arrow_chunk_receivers): (Vec<_>, Vec<_>) = arr_iters .iter() .map(|_| tokio::sync::mpsc::channel(1)) .unzip(); @@ -213,8 +214,8 @@ pub fn spawn_column_iters_to_table_task( let compute_runtime = get_compute_runtime(); let mut deserializer_handles = Vec::with_capacity(arr_iters.len()); - for (sender, arr_iter) in senders.into_iter().zip(arr_iters.into_iter()) { - deserializer_handles.push(compute_runtime.spawn(async move { + for (sender, arr_iter) in arrow_chunk_senders.into_iter().zip(arr_iters.into_iter()) { + let deserialization_task = async move { let mut total_deserialization_time = std::time::Duration::ZERO; let mut deserialization_iteration_time = std::time::Instant::now(); for arr in arr_iter { @@ -226,7 +227,8 @@ pub fn spawn_column_iters_to_table_task( deserialization_iteration_time = std::time::Instant::now(); } total_deserialization_time - })); + }; + deserializer_handles.push(compute_runtime.spawn(deserialization_task)); } compute_runtime.spawn_detached(async move { @@ -244,7 +246,7 @@ pub fn spawn_column_iters_to_table_task( let mut total_waiting_time = std::time::Duration::ZERO; loop { let arr_results = - futures::future::join_all(receivers.iter_mut().map(|s| s.recv())).await; + futures::future::join_all(arrow_chunk_receivers.iter_mut().map(|s| s.recv())).await; if arr_results.iter().any(|r| r.is_none()) { break; } @@ -280,13 +282,14 @@ pub fn spawn_column_iters_to_table_task( let waiting_elapsed = waiting_start_time.elapsed(); total_waiting_time += waiting_elapsed; } - let total_deserialization_time = futures::future::join_all(deserializer_handles) + let average_deserialization_time = futures::future::join_all(deserializer_handles) .await .into_iter() .filter_map(|r| r.ok()) - .sum::(); - total_compute_time += total_deserialization_time; - semaphore.record_compute_times(total_compute_time, total_waiting_time); + .sum::() + .div_f32(arrow_chunk_receivers.len() as f32); + total_compute_time += average_deserialization_time; + record_compute_times_fn(total_compute_time, total_waiting_time); drop(permit); }); } @@ -674,6 +677,8 @@ pub async fn local_parquet_stream( BoxStream<'static, DaftResult>, )> { let chunk_size = 128 * 1024; + // We use a semaphore to limit the number of concurrent row group deserialization tasks. + // Set the maximum number of concurrent tasks to 2 * number of available threads. let semaphore = DynamicParquetReadingSemaphore::new( std::thread::available_parallelism() .unwrap_or(NonZeroUsize::new(2).unwrap()) @@ -694,7 +699,7 @@ pub async fn local_parquet_stream( semaphore.clone(), )?; - let (senders, receivers) = row_ranges + let (output_senders, output_receivers) = row_ranges .iter() .map(|_| tokio::sync::mpsc::channel(1)) .unzip::<_, _, Vec<_>, Vec<_>>(); @@ -702,13 +707,14 @@ pub async fn local_parquet_stream( let owned_uri = uri.to_string(); let compute_runtime = get_compute_runtime(); compute_runtime.spawn_detached(async move { - for ((column_iters, sender), rg_range) in column_iters.zip(senders).zip(row_ranges) { + for ((column_iters, sender), rg_range) in column_iters.zip(output_senders).zip(row_ranges) { if let Err(e) = column_iters { let _ = sender.send(Err(e.into())).await; break; } let permit = semaphore.acquire().await; + let semaphore_ref = semaphore.clone(); spawn_column_iters_to_table_task( column_iters.unwrap(), rg_range, @@ -720,12 +726,15 @@ pub async fn local_parquet_stream( delete_rows.clone(), sender, permit, - semaphore.clone(), + move |compute_time, waiting_time| { + semaphore_ref.record_compute_times(compute_time, waiting_time); + }, ); } }); - let result_stream = futures::stream::iter(receivers.into_iter().map(ReceiverStream::new)); + let result_stream = + futures::stream::iter(output_receivers.into_iter().map(ReceiverStream::new)); match maintain_order { true => Ok((metadata, Box::pin(result_stream.flatten()))),