diff --git a/src/daft-scan/src/scan_task_iters.rs b/src/daft-scan/src/scan_task_iters.rs index 3ae9f5e794..cd6230c166 100644 --- a/src/daft-scan/src/scan_task_iters.rs +++ b/src/daft-scan/src/scan_task_iters.rs @@ -178,125 +178,56 @@ impl<'a> Iterator for MergeByFileSize<'a> { } } -struct ParquetScanTaskRowGroupSplitter<'a> { - // Accumulators - curr_row_group_indices: Vec, - curr_size_bytes: f64, - curr_num_rows: usize, - - // Context variables - config: &'a DaftExecutionConfig, - file_metadata: FileMetaData, - scan_task_ref: ScanTaskRef, - scantask_estimated_size_bytes: Option, +#[derive(Default)] +struct SplitParquetByRowGroupsAccumulator { + row_group_indices: Vec, + size_bytes: f64, + num_rows: usize, } -impl<'a> ParquetScanTaskRowGroupSplitter<'a> { - pub fn new( - scan_task_ref: ScanTaskRef, - file_metadata: FileMetaData, - config: &'a DaftExecutionConfig, - ) -> Self { - let scantask_estimated_size_bytes = - scan_task_ref.estimate_in_memory_size_bytes(Some(config)); - Self { - curr_row_group_indices: Vec::new(), - curr_size_bytes: 0., - curr_num_rows: 0, - config, - scan_task_ref, - file_metadata, - scantask_estimated_size_bytes, - } - } - - pub fn split_by_row_groups(mut self) -> Vec { - let rg_indices = self.file_metadata.row_groups.keys().copied().collect_vec(); - let mut new_data_sources = Vec::new(); - for idx in rg_indices { - let maybe_accumulated = self.accumulate(idx); - if let Some(accumulated) = maybe_accumulated { - new_data_sources.push(accumulated); - } - } - if let Some(accumulated) = self.flush() { - new_data_sources.push(accumulated); - } - - // Construct ScanTasks with the new DataSources - new_data_sources - .into_iter() - .map(|data_source| { - ScanTask::new( - vec![data_source], - self.scan_task_ref.file_format_config.clone(), - self.scan_task_ref.schema.clone(), - self.scan_task_ref.storage_config.clone(), - self.scan_task_ref.pushdowns.clone(), - self.scan_task_ref.generated_fields.clone(), - ) - .into() - }) - .collect_vec() - } - - #[inline] - fn total_rg_compressed_size(&self) -> usize { - self.file_metadata - .row_groups - .iter() - .map(|rg| rg.1.compressed_size()) - .sum() - } - - #[inline] - fn data_source(&self) -> &DataSource { - match self.scan_task_ref.sources.as_slice() { - [source] => source, - _ => unreachable!( - "SplitByRowGroupsAccumulator should only have one DataSource in its ScanTask" - ), - } - } - - fn accumulate(&mut self, rg_idx: usize) -> Option { - let rg = self.file_metadata.row_groups.get(&rg_idx).unwrap(); +impl SplitParquetByRowGroupsAccumulator { + fn accumulate( + &mut self, + rg_idx: &usize, + ctx: &SplitParquetByRowGroupsContext, + ) -> Option { + let rg = ctx.file_metadata.row_groups.get(rg_idx).unwrap(); // Estimate the materialized size of this rowgroup and add it to curr_size_bytes - self.curr_size_bytes += - self.scantask_estimated_size_bytes - .map_or(0., |est_materialized_size| { - (rg.compressed_size() as f64 / self.total_rg_compressed_size() as f64) - * est_materialized_size as f64 - }); - self.curr_num_rows += rg.num_rows(); - self.curr_row_group_indices.push(rg_idx); + self.size_bytes += ctx + .scantask_estimated_size_bytes + .map_or(0., |est_materialized_size| { + (rg.compressed_size() as f64 / ctx.total_rg_compressed_size() as f64) + * est_materialized_size as f64 + }); + self.num_rows += rg.num_rows(); + self.row_group_indices.push(*rg_idx); // Flush the accumulator if necessary let reached_accumulator_limit = - self.curr_size_bytes >= self.config.scan_tasks_min_size_bytes as f64; + self.size_bytes >= ctx.config.scan_tasks_min_size_bytes as f64; let materialized_size_estimation_not_available = - self.scantask_estimated_size_bytes.is_none(); + ctx.scantask_estimated_size_bytes.is_none(); if materialized_size_estimation_not_available || reached_accumulator_limit { - self.flush() + self.flush(ctx) } else { None } } - fn flush(&mut self) -> Option { + fn flush(&mut self, ctx: &SplitParquetByRowGroupsContext) -> Option { // If nothing to flush, return None early - if self.curr_row_group_indices.is_empty() { + if self.row_group_indices.is_empty() { return None; } // Grab accumulated values and reset accumulators - let curr_row_group_indices = std::mem::take(&mut self.curr_row_group_indices); - let curr_size_bytes = std::mem::take(&mut self.curr_size_bytes); - let curr_num_rows = std::mem::take(&mut self.curr_num_rows); + let curr_row_group_indices = std::mem::take(&mut self.row_group_indices); + let curr_size_bytes = std::mem::take(&mut self.size_bytes); + let curr_num_rows = std::mem::take(&mut self.num_rows); // Create a new DataSource by mutating the old one - let mut new_source = self.data_source().clone(); + let mut new_source = ctx.data_source().clone(); if let DataSource::File { chunk_spec, @@ -307,14 +238,12 @@ impl<'a> ParquetScanTaskRowGroupSplitter<'a> { } = &mut new_source { // Create a new Parquet FileMetaData, only keeping the relevant row groups - let row_group_list = - RowGroupList::from_iter(curr_row_group_indices.iter().map(|idx| { - ( - *idx, - self.file_metadata.row_groups.get(idx).unwrap().clone(), - ) - })); - let new_metadata = self + let row_group_list = RowGroupList::from_iter( + curr_row_group_indices + .iter() + .map(|idx| (*idx, ctx.file_metadata.row_groups.get(idx).unwrap().clone())), + ); + let new_metadata = ctx .file_metadata .clone_with_row_groups(curr_num_rows, row_group_list); *parquet_metadata = Some(Arc::new(new_metadata)); @@ -336,6 +265,49 @@ impl<'a> ParquetScanTaskRowGroupSplitter<'a> { } } +struct SplitParquetByRowGroupsContext<'a> { + config: &'a DaftExecutionConfig, + file_metadata: FileMetaData, + scan_task_ref: ScanTaskRef, + scantask_estimated_size_bytes: Option, +} + +impl<'a> SplitParquetByRowGroupsContext<'a> { + pub fn new( + scan_task_ref: ScanTaskRef, + file_metadata: FileMetaData, + config: &'a DaftExecutionConfig, + ) -> Self { + let scantask_estimated_size_bytes = + scan_task_ref.estimate_in_memory_size_bytes(Some(config)); + Self { + config, + file_metadata, + scan_task_ref, + scantask_estimated_size_bytes, + } + } + + #[inline] + fn total_rg_compressed_size(&self) -> usize { + self.file_metadata + .row_groups + .iter() + .map(|rg| rg.1.compressed_size()) + .sum() + } + + #[inline] + fn data_source(&self) -> &DataSource { + match self.scan_task_ref.sources.as_slice() { + [source] => source, + _ => unreachable!( + "SplitByRowGroupsAccumulator should only have one DataSource in its ScanTask" + ), + } + } +} + #[must_use] pub(crate) fn split_by_row_groups<'a>( scan_tasks: BoxScanTaskIter<'a>, @@ -375,7 +347,7 @@ pub(crate) fn split_by_row_groups<'a>( .get_iceberg_delete_files() .map_or(true, std::vec::Vec::is_empty) { - // Retrieve Parquet FileMetaData and construct a ParquetScanTaskRowGroupSplitter + // Retrieve Parquet FileMetaData and construct SplitParquetByRowGroupsAccumulator let (io_runtime, io_client) = t.storage_config.get_io_client_and_runtime()?; let path = source.get_path(); let io_stats = @@ -387,11 +359,36 @@ pub(crate) fn split_by_row_groups<'a>( Some(io_stats), field_id_mapping.clone(), ))?; - let accumulator = - ParquetScanTaskRowGroupSplitter::new(t, file_metadata, config); - - // Materialize and convert into new ScanTasks - let new_scan_tasks = accumulator.split_by_row_groups(); + let mut accumulator = SplitParquetByRowGroupsAccumulator::default(); + let accumulator_ctx = + SplitParquetByRowGroupsContext::new(t, file_metadata, config); + + // Run accumulation over all rowgroups, make sure to flush last result + let mut new_data_sources = Vec::new(); + for idx in accumulator_ctx.file_metadata.row_groups.keys() { + if let Some(accumulated) = accumulator.accumulate(idx, &accumulator_ctx) { + new_data_sources.push(accumulated); + } + } + if let Some(accumulated) = accumulator.flush(&accumulator_ctx) { + new_data_sources.push(accumulated); + } + + // Construct ScanTasks with the new DataSources + let new_scan_tasks = new_data_sources + .into_iter() + .map(|data_source| { + ScanTask::new( + vec![data_source], + accumulator_ctx.scan_task_ref.file_format_config.clone(), + accumulator_ctx.scan_task_ref.schema.clone(), + accumulator_ctx.scan_task_ref.storage_config.clone(), + accumulator_ctx.scan_task_ref.pushdowns.clone(), + accumulator_ctx.scan_task_ref.generated_fields.clone(), + ) + .into() + }) + .collect_vec(); Ok(Box::new(new_scan_tasks.into_iter().map(Ok))) } else { Ok(Box::new(std::iter::once(Ok(t))))