diff --git a/src/daft-micropartition/src/micropartition.rs b/src/daft-micropartition/src/micropartition.rs index b2510b5fb9..189263bea6 100644 --- a/src/daft-micropartition/src/micropartition.rs +++ b/src/daft-micropartition/src/micropartition.rs @@ -359,13 +359,18 @@ fn materialize_scan_task( impl MicroPartition { /// Create a new "unloaded" MicroPartition using an associated [`ScanTask`] /// - /// Schema invariants: + /// Invariants: /// 1. Each Loaded column statistic in `statistics` must be castable to the corresponding column in the MicroPartition's schema + /// 2. Creating a new MicroPartition with a ScanTask that has any filter predicates or limits is not allowed and will panic pub fn new_unloaded( scan_task: Arc, metadata: TableMetadata, statistics: TableStatistics, ) -> Self { + if scan_task.pushdowns.filters.is_some() { + panic!("Cannot create unloaded MicroPartition from a ScanTask with pushdowns that have filters"); + } + let schema = scan_task.materialized_schema(); let fill_map = scan_task.partition_spec().map(|pspec| pspec.to_fill_map()); let statistics = statistics @@ -424,11 +429,19 @@ impl MicroPartition { ) { // CASE: ScanTask provides all required metadata. // If the scan_task provides metadata (e.g. retrieved from a catalog) we can use it to create an unloaded MicroPartition - (Some(metadata), Some(statistics), _, _) => Ok(Self::new_unloaded( - scan_task.clone(), - metadata.clone(), - statistics.clone(), - )), + (Some(metadata), Some(statistics), _, _) if scan_task.pushdowns.filters.is_none() => { + Ok(Self::new_unloaded( + scan_task.clone(), + scan_task + .pushdowns + .limit + .map(|limit| TableMetadata { + length: metadata.length.min(limit), + }) + .unwrap_or_else(|| metadata.clone()), + statistics.clone(), + )) + } // CASE: ScanTask does not provide metadata, but the file format supports metadata retrieval // We can perform an eager **metadata** read to create an unloaded MicroPartition @@ -830,7 +843,8 @@ pub(crate) fn read_parquet_into_micropartition( let runtime_handle = daft_io::get_runtime(multithreaded_io)?; let io_client = daft_io::get_io_client(multithreaded_io, io_config.clone())?; - // If we have a predicate, perform an eager read only reading what row groups we need. + // If we have a predicate then we no longer have an accurate accounting of required metadata + // on the MicroPartition (e.g. its length). Hence we need to perform an eager read. if predicate.is_some() { return _read_parquet_into_loaded_micropartition( io_client, diff --git a/src/daft-scan/src/scan_task_iters.rs b/src/daft-scan/src/scan_task_iters.rs index 9c645bf8fc..e958680e32 100644 --- a/src/daft-scan/src/scan_task_iters.rs +++ b/src/daft-scan/src/scan_task_iters.rs @@ -169,13 +169,16 @@ pub fn split_by_row_groups( let mut new_tasks: Vec> = Vec::new(); let mut curr_row_groups = Vec::new(); let mut curr_size_bytes = 0; + let mut curr_num_rows = 0; for (i, rg) in file.row_groups.iter().enumerate() { curr_row_groups.push(i as i64); curr_size_bytes += rg.compressed_size(); + curr_num_rows += rg.num_rows(); if curr_size_bytes >= min_size_bytes || i == file.row_groups.len() - 1 { let mut new_source = source.clone(); + match &mut new_source { DataFileSource::AnonymousDataFile { chunk_spec, @@ -189,11 +192,26 @@ pub fn split_by_row_groups( } | DataFileSource::DatabaseDataSource { chunk_spec, size_bytes, .. } => { *chunk_spec = Some(ChunkSpec::Parquet(curr_row_groups)); *size_bytes = Some(curr_size_bytes as u64); - - curr_row_groups = Vec::new(); - curr_size_bytes = 0; } }; + match &mut new_source { + DataFileSource::AnonymousDataFile { + metadata: Some(metadata), + .. + } + | DataFileSource::CatalogDataFile { + metadata, + .. + } | DataFileSource::DatabaseDataSource { metadata: Some(metadata), .. } => { + metadata.length = curr_num_rows; + } + _ => (), + } + + // Reset accumulators + curr_row_groups = Vec::new(); + curr_size_bytes = 0; + curr_num_rows = 0; new_tasks.push(Ok(ScanTask::new( vec![new_source], diff --git a/tests/io/delta_lake/test_table_read.py b/tests/io/delta_lake/test_table_read.py index 3f84dac2ee..aec8029521 100644 --- a/tests/io/delta_lake/test_table_read.py +++ b/tests/io/delta_lake/test_table_read.py @@ -1,5 +1,7 @@ from __future__ import annotations +import contextlib + import pytest from daft.delta_lake.delta_lake_scan import _io_config_to_storage_options @@ -12,6 +14,19 @@ from daft.logical.schema import Schema from tests.utils import assert_pyarrow_tables_equal + +@contextlib.contextmanager +def split_small_pq_files(): + old_config = daft.context.get_context().daft_execution_config + daft.set_execution_config( + # Splits any parquet files >100 bytes in size + scan_tasks_min_size_bytes=1, + scan_tasks_max_size_bytes=100, + ) + yield + daft.set_execution_config(config=old_config) + + PYARROW_LE_8_0_0 = tuple(int(s) for s in pa.__version__.split(".") if s.isnumeric()) < (8, 0, 0) pytestmark = pytest.mark.skipif(PYARROW_LE_8_0_0, reason="deltalake only supported if pyarrow >= 8.0.0") @@ -38,3 +53,44 @@ def test_deltalake_read_show(deltalake_table): path, catalog_table, io_config, _ = deltalake_table df = daft.read_delta_lake(str(path) if catalog_table is None else catalog_table, io_config=io_config) df.show() + + +def test_deltalake_read_row_group_splits(tmp_path, base_table): + path = tmp_path / "some_table" + + # Force 2 rowgroups + deltalake.write_deltalake(path, base_table, min_rows_per_group=1, max_rows_per_group=2) + + # Force file splitting + with split_small_pq_files(): + df = daft.read_delta_lake(str(path)) + df.collect() + assert len(df) == 3, "Length of non-materialized data when read through deltalake should be correct" + + +def test_deltalake_read_row_group_splits_with_filter(tmp_path, base_table): + path = tmp_path / "some_table" + + # Force 2 rowgroups + deltalake.write_deltalake(path, base_table, min_rows_per_group=1, max_rows_per_group=2) + + # Force file splitting + with split_small_pq_files(): + df = daft.read_delta_lake(str(path)) + df = df.where(df["a"] > 1) + df.collect() + assert len(df) == 2, "Length of non-materialized data when read through deltalake should be correct" + + +def test_deltalake_read_row_group_splits_with_limit(tmp_path, base_table): + path = tmp_path / "some_table" + + # Force 2 rowgroups + deltalake.write_deltalake(path, base_table, min_rows_per_group=1, max_rows_per_group=2) + + # Force file splitting + with split_small_pq_files(): + df = daft.read_delta_lake(str(path)) + df = df.limit(2) + df.collect() + assert len(df) == 2, "Length of non-materialized data when read through deltalake should be correct"