Skip to content

Commit

Permalink
more comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Colin Ho authored and Colin Ho committed Nov 18, 2024
1 parent 1fc2019 commit 460b060
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 14 deletions.
7 changes: 7 additions & 0 deletions src/daft-parquet/src/semaphore.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ use std::{

use tokio::sync::Semaphore;

/// A semaphore that dynamically adjusts the number of permits based on the
/// observed compute and IO times.
/// Used to control the number of concurrent parquet deserialization tasks.
pub(crate) struct DynamicParquetReadingSemaphore {
semaphore: Arc<Semaphore>,
timings: Mutex<RunningTimings>,
Expand Down Expand Up @@ -45,7 +48,9 @@ impl RunningAverage {
}

impl DynamicParquetReadingSemaphore {
/// The ratio of compute time to IO time that allows for permit increase. This is a minimum value.
const COMPUTE_THRESHOLD: f64 = 1.2;
/// The ratio of waiting time to compute time that allows for permit increase. This is a maximum value.
const WAIT_THRESHOLD: f64 = 0.5;

pub(crate) fn new(max_permits: usize) -> Arc<Self> {
Expand Down Expand Up @@ -108,6 +113,8 @@ impl DynamicParquetReadingSemaphore {
let wait_ratio =
wait_avg.average.as_millis() as f64 / compute_avg.average.as_millis() as f64;

// Only increase permits if compute time is significantly higher than IO time,
// and waiting time is not too high.
if compute_ratio > Self::COMPUTE_THRESHOLD && wait_ratio < Self::WAIT_THRESHOLD {
let current_permits = self.current_permits.load(Ordering::Relaxed);
let optimal_permits = (compute_ratio.ceil() as usize).min(self.max_permits);
Expand Down
37 changes: 23 additions & 14 deletions src/daft-parquet/src/stream_reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ fn arrow_chunk_to_table(
Ok(table)
}

/// Spawns a task that reads the column iterators and converts them into a table.
#[allow(clippy::too_many_arguments)]
pub fn spawn_column_iters_to_table_task(
arr_iters: ArrowChunkIters,
Expand All @@ -203,18 +204,18 @@ pub fn spawn_column_iters_to_table_task(
delete_rows: Option<Vec<i64>>,
output_sender: tokio::sync::mpsc::Sender<DaftResult<Table>>,
permit: tokio::sync::OwnedSemaphorePermit,
semaphore: Arc<DynamicParquetReadingSemaphore>,
record_compute_times_fn: impl Fn(Duration, Duration) + Send + Sync + 'static,
) {
let (senders, mut receivers): (Vec<_>, Vec<_>) = arr_iters
let (arrow_chunk_senders, mut arrow_chunk_receivers): (Vec<_>, Vec<_>) = arr_iters
.iter()
.map(|_| tokio::sync::mpsc::channel(1))
.unzip();

let compute_runtime = get_compute_runtime();

let mut deserializer_handles = Vec::with_capacity(arr_iters.len());
for (sender, arr_iter) in senders.into_iter().zip(arr_iters.into_iter()) {
deserializer_handles.push(compute_runtime.spawn(async move {
for (sender, arr_iter) in arrow_chunk_senders.into_iter().zip(arr_iters.into_iter()) {
let deserialization_task = async move {
let mut total_deserialization_time = std::time::Duration::ZERO;
let mut deserialization_iteration_time = std::time::Instant::now();
for arr in arr_iter {
Expand All @@ -226,7 +227,8 @@ pub fn spawn_column_iters_to_table_task(
deserialization_iteration_time = std::time::Instant::now();
}
total_deserialization_time
}));
};
deserializer_handles.push(compute_runtime.spawn(deserialization_task));
}

compute_runtime.spawn_detached(async move {
Expand All @@ -244,7 +246,7 @@ pub fn spawn_column_iters_to_table_task(
let mut total_waiting_time = std::time::Duration::ZERO;
loop {
let arr_results =
futures::future::join_all(receivers.iter_mut().map(|s| s.recv())).await;
futures::future::join_all(arrow_chunk_receivers.iter_mut().map(|s| s.recv())).await;
if arr_results.iter().any(|r| r.is_none()) {
break;
}
Expand Down Expand Up @@ -280,13 +282,14 @@ pub fn spawn_column_iters_to_table_task(
let waiting_elapsed = waiting_start_time.elapsed();
total_waiting_time += waiting_elapsed;
}
let total_deserialization_time = futures::future::join_all(deserializer_handles)
let average_deserialization_time = futures::future::join_all(deserializer_handles)
.await
.into_iter()
.filter_map(|r| r.ok())
.sum::<Duration>();
total_compute_time += total_deserialization_time;
semaphore.record_compute_times(total_compute_time, total_waiting_time);
.sum::<Duration>()
.div_f32(arrow_chunk_receivers.len() as f32);
total_compute_time += average_deserialization_time;
record_compute_times_fn(total_compute_time, total_waiting_time);
drop(permit);
});
}
Expand Down Expand Up @@ -674,6 +677,8 @@ pub async fn local_parquet_stream(
BoxStream<'static, DaftResult<Table>>,
)> {
let chunk_size = 128 * 1024;
// We use a semaphore to limit the number of concurrent row group deserialization tasks.
// Set the maximum number of concurrent tasks to 2 * number of available threads.
let semaphore = DynamicParquetReadingSemaphore::new(
std::thread::available_parallelism()
.unwrap_or(NonZeroUsize::new(2).unwrap())
Expand All @@ -694,21 +699,22 @@ pub async fn local_parquet_stream(
semaphore.clone(),
)?;

let (senders, receivers) = row_ranges
let (output_senders, output_receivers) = row_ranges
.iter()
.map(|_| tokio::sync::mpsc::channel(1))
.unzip::<_, _, Vec<_>, Vec<_>>();

let owned_uri = uri.to_string();
let compute_runtime = get_compute_runtime();
compute_runtime.spawn_detached(async move {
for ((column_iters, sender), rg_range) in column_iters.zip(senders).zip(row_ranges) {
for ((column_iters, sender), rg_range) in column_iters.zip(output_senders).zip(row_ranges) {
if let Err(e) = column_iters {
let _ = sender.send(Err(e.into())).await;
break;

Check warning on line 713 in src/daft-parquet/src/stream_reader.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-parquet/src/stream_reader.rs#L712-L713

Added lines #L712 - L713 were not covered by tests
}

let permit = semaphore.acquire().await;
let semaphore_ref = semaphore.clone();
spawn_column_iters_to_table_task(
column_iters.unwrap(),
rg_range,
Expand All @@ -720,12 +726,15 @@ pub async fn local_parquet_stream(
delete_rows.clone(),
sender,
permit,
semaphore.clone(),
move |compute_time, waiting_time| {
semaphore_ref.record_compute_times(compute_time, waiting_time);
},
);
}
});

let result_stream = futures::stream::iter(receivers.into_iter().map(ReceiverStream::new));
let result_stream =
futures::stream::iter(output_receivers.into_iter().map(ReceiverStream::new));

match maintain_order {
true => Ok((metadata, Box::pin(result_stream.flatten()))),
Expand Down

0 comments on commit 460b060

Please sign in to comment.