Skip to content

Commit

Permalink
Cleanup into separate accumulator and accumulator context
Browse files Browse the repository at this point in the history
  • Loading branch information
Jay Chia committed Nov 30, 2024
1 parent a36fffb commit d17e91d
Showing 1 changed file with 109 additions and 112 deletions.
221 changes: 109 additions & 112 deletions src/daft-scan/src/scan_task_iters.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,125 +178,56 @@ impl<'a> Iterator for MergeByFileSize<'a> {
}
}

struct ParquetScanTaskRowGroupSplitter<'a> {
// Accumulators
curr_row_group_indices: Vec<usize>,
curr_size_bytes: f64,
curr_num_rows: usize,

// Context variables
config: &'a DaftExecutionConfig,
file_metadata: FileMetaData,
scan_task_ref: ScanTaskRef,
scantask_estimated_size_bytes: Option<usize>,
#[derive(Default)]
struct SplitParquetByRowGroupsAccumulator {
row_group_indices: Vec<usize>,
size_bytes: f64,
num_rows: usize,
}

impl<'a> ParquetScanTaskRowGroupSplitter<'a> {
pub fn new(
scan_task_ref: ScanTaskRef,
file_metadata: FileMetaData,
config: &'a DaftExecutionConfig,
) -> Self {
let scantask_estimated_size_bytes =
scan_task_ref.estimate_in_memory_size_bytes(Some(config));
Self {
curr_row_group_indices: Vec::new(),
curr_size_bytes: 0.,
curr_num_rows: 0,
config,
scan_task_ref,
file_metadata,
scantask_estimated_size_bytes,
}
}

pub fn split_by_row_groups(mut self) -> Vec<ScanTaskRef> {
let rg_indices = self.file_metadata.row_groups.keys().copied().collect_vec();
let mut new_data_sources = Vec::new();
for idx in rg_indices {
let maybe_accumulated = self.accumulate(idx);
if let Some(accumulated) = maybe_accumulated {
new_data_sources.push(accumulated);
}
}
if let Some(accumulated) = self.flush() {
new_data_sources.push(accumulated);
}

// Construct ScanTasks with the new DataSources
new_data_sources
.into_iter()
.map(|data_source| {
ScanTask::new(
vec![data_source],
self.scan_task_ref.file_format_config.clone(),
self.scan_task_ref.schema.clone(),
self.scan_task_ref.storage_config.clone(),
self.scan_task_ref.pushdowns.clone(),
self.scan_task_ref.generated_fields.clone(),
)
.into()
})
.collect_vec()
}

#[inline]
fn total_rg_compressed_size(&self) -> usize {
self.file_metadata
.row_groups
.iter()
.map(|rg| rg.1.compressed_size())
.sum()
}

#[inline]
fn data_source(&self) -> &DataSource {
match self.scan_task_ref.sources.as_slice() {
[source] => source,
_ => unreachable!(
"SplitByRowGroupsAccumulator should only have one DataSource in its ScanTask"
),
}
}

fn accumulate(&mut self, rg_idx: usize) -> Option<DataSource> {
let rg = self.file_metadata.row_groups.get(&rg_idx).unwrap();
impl SplitParquetByRowGroupsAccumulator {
fn accumulate(
&mut self,
rg_idx: &usize,
ctx: &SplitParquetByRowGroupsContext,
) -> Option<DataSource> {
let rg = ctx.file_metadata.row_groups.get(rg_idx).unwrap();

// Estimate the materialized size of this rowgroup and add it to curr_size_bytes
self.curr_size_bytes +=
self.scantask_estimated_size_bytes
.map_or(0., |est_materialized_size| {
(rg.compressed_size() as f64 / self.total_rg_compressed_size() as f64)
* est_materialized_size as f64
});
self.curr_num_rows += rg.num_rows();
self.curr_row_group_indices.push(rg_idx);
self.size_bytes += ctx
.scantask_estimated_size_bytes
.map_or(0., |est_materialized_size| {
(rg.compressed_size() as f64 / ctx.total_rg_compressed_size() as f64)
* est_materialized_size as f64
});
self.num_rows += rg.num_rows();
self.row_group_indices.push(*rg_idx);

// Flush the accumulator if necessary
let reached_accumulator_limit =
self.curr_size_bytes >= self.config.scan_tasks_min_size_bytes as f64;
self.size_bytes >= ctx.config.scan_tasks_min_size_bytes as f64;
let materialized_size_estimation_not_available =
self.scantask_estimated_size_bytes.is_none();
ctx.scantask_estimated_size_bytes.is_none();
if materialized_size_estimation_not_available || reached_accumulator_limit {
self.flush()
self.flush(ctx)
} else {
None
}
}

fn flush(&mut self) -> Option<DataSource> {
fn flush(&mut self, ctx: &SplitParquetByRowGroupsContext) -> Option<DataSource> {
// If nothing to flush, return None early
if self.curr_row_group_indices.is_empty() {
if self.row_group_indices.is_empty() {
return None;
}

// Grab accumulated values and reset accumulators
let curr_row_group_indices = std::mem::take(&mut self.curr_row_group_indices);
let curr_size_bytes = std::mem::take(&mut self.curr_size_bytes);
let curr_num_rows = std::mem::take(&mut self.curr_num_rows);
let curr_row_group_indices = std::mem::take(&mut self.row_group_indices);
let curr_size_bytes = std::mem::take(&mut self.size_bytes);
let curr_num_rows = std::mem::take(&mut self.num_rows);

// Create a new DataSource by mutating the old one
let mut new_source = self.data_source().clone();
let mut new_source = ctx.data_source().clone();

if let DataSource::File {
chunk_spec,
Expand All @@ -307,14 +238,12 @@ impl<'a> ParquetScanTaskRowGroupSplitter<'a> {
} = &mut new_source
{
// Create a new Parquet FileMetaData, only keeping the relevant row groups
let row_group_list =
RowGroupList::from_iter(curr_row_group_indices.iter().map(|idx| {
(
*idx,
self.file_metadata.row_groups.get(idx).unwrap().clone(),
)
}));
let new_metadata = self
let row_group_list = RowGroupList::from_iter(
curr_row_group_indices
.iter()
.map(|idx| (*idx, ctx.file_metadata.row_groups.get(idx).unwrap().clone())),
);
let new_metadata = ctx
.file_metadata
.clone_with_row_groups(curr_num_rows, row_group_list);
*parquet_metadata = Some(Arc::new(new_metadata));
Expand All @@ -336,6 +265,49 @@ impl<'a> ParquetScanTaskRowGroupSplitter<'a> {
}
}

struct SplitParquetByRowGroupsContext<'a> {
config: &'a DaftExecutionConfig,
file_metadata: FileMetaData,
scan_task_ref: ScanTaskRef,
scantask_estimated_size_bytes: Option<usize>,
}

impl<'a> SplitParquetByRowGroupsContext<'a> {
pub fn new(
scan_task_ref: ScanTaskRef,
file_metadata: FileMetaData,
config: &'a DaftExecutionConfig,
) -> Self {
let scantask_estimated_size_bytes =
scan_task_ref.estimate_in_memory_size_bytes(Some(config));
Self {
config,
file_metadata,
scan_task_ref,
scantask_estimated_size_bytes,
}
}

#[inline]
fn total_rg_compressed_size(&self) -> usize {
self.file_metadata
.row_groups
.iter()
.map(|rg| rg.1.compressed_size())
.sum()
}

#[inline]
fn data_source(&self) -> &DataSource {
match self.scan_task_ref.sources.as_slice() {
[source] => source,
_ => unreachable!(
"SplitByRowGroupsAccumulator should only have one DataSource in its ScanTask"
),
}
}
}

#[must_use]
pub(crate) fn split_by_row_groups<'a>(
scan_tasks: BoxScanTaskIter<'a>,
Expand Down Expand Up @@ -375,7 +347,7 @@ pub(crate) fn split_by_row_groups<'a>(
.get_iceberg_delete_files()
.map_or(true, std::vec::Vec::is_empty)
{
// Retrieve Parquet FileMetaData and construct a ParquetScanTaskRowGroupSplitter
// Retrieve Parquet FileMetaData and construct SplitParquetByRowGroupsAccumulator
let (io_runtime, io_client) = t.storage_config.get_io_client_and_runtime()?;
let path = source.get_path();
let io_stats =
Expand All @@ -387,11 +359,36 @@ pub(crate) fn split_by_row_groups<'a>(
Some(io_stats),
field_id_mapping.clone(),
))?;
let accumulator =
ParquetScanTaskRowGroupSplitter::new(t, file_metadata, config);

// Materialize and convert into new ScanTasks
let new_scan_tasks = accumulator.split_by_row_groups();
let mut accumulator = SplitParquetByRowGroupsAccumulator::default();
let accumulator_ctx =
SplitParquetByRowGroupsContext::new(t, file_metadata, config);

// Run accumulation over all rowgroups, make sure to flush last result
let mut new_data_sources = Vec::new();
for idx in accumulator_ctx.file_metadata.row_groups.keys() {
if let Some(accumulated) = accumulator.accumulate(idx, &accumulator_ctx) {
new_data_sources.push(accumulated);
}
}
if let Some(accumulated) = accumulator.flush(&accumulator_ctx) {
new_data_sources.push(accumulated);
}

// Construct ScanTasks with the new DataSources
let new_scan_tasks = new_data_sources
.into_iter()
.map(|data_source| {
ScanTask::new(
vec![data_source],
accumulator_ctx.scan_task_ref.file_format_config.clone(),
accumulator_ctx.scan_task_ref.schema.clone(),
accumulator_ctx.scan_task_ref.storage_config.clone(),
accumulator_ctx.scan_task_ref.pushdowns.clone(),
accumulator_ctx.scan_task_ref.generated_fields.clone(),
)
.into()
})
.collect_vec();
Ok(Box::new(new_scan_tasks.into_iter().map(Ok)))
} else {
Ok(Box::new(std::iter::once(Ok(t))))
Expand Down

0 comments on commit d17e91d

Please sign in to comment.