diff --git a/src/daft-micropartition/src/micropartition.rs b/src/daft-micropartition/src/micropartition.rs index 790f759059..8a8d9d97a4 100644 --- a/src/daft-micropartition/src/micropartition.rs +++ b/src/daft-micropartition/src/micropartition.rs @@ -338,7 +338,12 @@ impl MicroPartition { // Check and validate invariants with asserts for table in tables.iter() { assert!( - table.schema == schema, + table.schema.fields.len() == schema.fields.len() + && table.schema.fields.iter().zip(schema.fields.iter()).all( + |((s1, f1), (s2, f2))| s1 == s2 + && f1.name == f2.name + && f1.dtype == f2.dtype + ), "Loaded MicroPartition's tables' schema must match its own schema exactly" ); } @@ -692,6 +697,31 @@ pub(crate) fn read_json_into_micropartition( } } +// TODO: Deduplicate this with the other `rename_schema_recursively` function in file.rs +fn rename_schema_recursively( + daft_schema: Schema, + field_id_mapping: &BTreeMap, +) -> DaftResult { + Schema::new( + daft_schema + .fields + .into_iter() + .map(|(_, field)| { + if let Some(field_id) = field.metadata.get("field_id") { + let field_id = str::parse::(field_id).unwrap(); + let mapped_field = field_id_mapping.get(&field_id); + match mapped_field { + None => field, + Some(mapped_field) => field.rename(&mapped_field.name), + } + } else { + field + } + }) + .collect(), + ) +} + #[allow(clippy::too_many_arguments)] pub(crate) fn read_parquet_into_micropartition( uris: &[&str], @@ -762,7 +792,7 @@ pub(crate) fn read_parquet_into_micropartition( })?; // Deserialize and collect relevant TableStatistics - let schemas = metadata + let pq_file_schemas = metadata .iter() .map(|m| { let schema = infer_schema_with_options(m, &Some((*schema_infer_options).into()))?; @@ -778,11 +808,15 @@ pub(crate) fn read_parquet_into_micropartition( let stats = if any_stats_avail { let stat_per_table = metadata .iter() - .zip(schemas.iter()) - .flat_map(|(fm, schema)| { - fm.row_groups - .iter() - .map(|rgm| daft_parquet::row_group_metadata_to_table_stats(rgm, schema)) + .zip(pq_file_schemas.iter()) + .flat_map(|(fm, pq_file_schema)| { + fm.row_groups.iter().map(|rgm| { + daft_parquet::row_group_metadata_to_table_stats( + rgm, + pq_file_schema, + field_id_mapping, + ) + }) }) .collect::>>()?; stat_per_table.into_iter().try_reduce(|a, b| a.union(&b))? @@ -791,7 +825,19 @@ pub(crate) fn read_parquet_into_micropartition( }; // Union and prune the schema using the specified `columns` - let unioned_schema = schemas.into_iter().try_reduce(|l, r| l.union(&r))?; + let resolved_schemas = if let Some(field_id_mapping) = field_id_mapping { + pq_file_schemas + .into_iter() + .map(|pq_file_schema| { + rename_schema_recursively(pq_file_schema, field_id_mapping.as_ref()) + }) + .collect::>>()? + } else { + pq_file_schemas + }; + let unioned_schema = resolved_schemas + .into_iter() + .try_reduce(|l, r| l.union(&r))?; let full_daft_schema = unioned_schema.expect("we need at least 1 schema"); let pruned_daft_schema = prune_fields_from_schema(full_daft_schema, columns)?; @@ -868,6 +914,7 @@ pub(crate) fn read_parquet_into_micropartition( .keys() .map(|n| daft_dsl::col(n.as_str())) .collect::>(); + // use schema to update stats let stats = stats.eval_expression_list(exprs.as_slice(), daft_schema.as_ref())?; diff --git a/src/daft-parquet/src/file.rs b/src/daft-parquet/src/file.rs index b6b96d9227..9d631fd747 100644 --- a/src/daft-parquet/src/file.rs +++ b/src/daft-parquet/src/file.rs @@ -126,6 +126,7 @@ fn rename_schema_recursively( ) } +#[allow(clippy::too_many_arguments)] pub(crate) fn build_row_ranges( limit: Option, row_start_offset: usize, @@ -134,6 +135,7 @@ pub(crate) fn build_row_ranges( schema: &Schema, metadata: &parquet2::metadata::FileMetaData, uri: &str, + field_id_mapping: &Option>>, ) -> super::Result> { let limit = limit.map(|v| v as i64); let mut row_ranges = vec![]; @@ -155,10 +157,11 @@ pub(crate) fn build_row_ranges( } let rg = metadata.row_groups.get(i).unwrap(); if let Some(ref pred) = predicate { - let stats = statistics::row_group_metadata_to_table_stats(rg, schema) - .with_context(|_| UnableToConvertRowGroupMetadataToStatsSnafu { - path: uri.to_string(), - })?; + let stats = + statistics::row_group_metadata_to_table_stats(rg, schema, field_id_mapping) + .with_context(|_| UnableToConvertRowGroupMetadataToStatsSnafu { + path: uri.to_string(), + })?; let evaled = stats.eval_expression(pred).with_context(|_| { UnableToRunExpressionOnStatsSnafu { @@ -187,10 +190,11 @@ pub(crate) fn build_row_ranges( continue; } else if rows_to_add > 0 { if let Some(ref pred) = predicate { - let stats = statistics::row_group_metadata_to_table_stats(rg, schema) - .with_context(|_| UnableToConvertRowGroupMetadataToStatsSnafu { - path: uri.to_string(), - })?; + let stats = + statistics::row_group_metadata_to_table_stats(rg, schema, field_id_mapping) + .with_context(|_| UnableToConvertRowGroupMetadataToStatsSnafu { + path: uri.to_string(), + })?; let evaled = stats.eval_expression(pred).with_context(|_| { UnableToRunExpressionOnStatsSnafu { path: uri.to_string(), @@ -329,6 +333,7 @@ impl ParquetReaderBuilder { &daft_schema, &self.metadata, &self.uri, + &self.field_id_mapping, )?; ParquetFileReader::new( @@ -596,7 +601,7 @@ impl ParquetFileReader { rayon::spawn(move || { let concated = if series_to_concat.is_empty() { Ok(Series::empty( - owned_field.name.as_str(), + target_field_name.as_str(), &owned_field.data_type().into(), )) } else { diff --git a/src/daft-parquet/src/statistics/table_stats.rs b/src/daft-parquet/src/statistics/table_stats.rs index 25200fdfed..9dc67de712 100644 --- a/src/daft-parquet/src/statistics/table_stats.rs +++ b/src/daft-parquet/src/statistics/table_stats.rs @@ -1,5 +1,7 @@ +use std::{collections::BTreeMap, sync::Arc}; + use common_error::DaftResult; -use daft_core::schema::Schema; +use daft_core::{datatypes::Field, schema::Schema}; use daft_stats::{ColumnRangeStatistics, TableStatistics}; use snafu::ResultExt; @@ -9,7 +11,8 @@ use indexmap::IndexMap; pub fn row_group_metadata_to_table_stats( metadata: &crate::metadata::RowGroupMetaData, - schema: &Schema, + pq_file_schema: &Schema, + field_id_mapping: &Option>>, ) -> DaftResult { // Create a map from {field_name: statistics} from the RowGroupMetaData for easy access let mut parquet_column_metadata: IndexMap<_, _> = metadata @@ -26,26 +29,64 @@ pub fn row_group_metadata_to_table_stats( .collect(); // Iterate through the schema and construct ColumnRangeStatistics per field - let columns = schema - .fields - .iter() - .map(|(field_name, field)| { - if ColumnRangeStatistics::supports_dtype(&field.dtype) { - let stats: ColumnRangeStatistics = parquet_column_metadata - .remove(field_name) - .expect("Cannot find parsed Daft field in Parquet rowgroup metadata") - .transpose() - .context(super::UnableToParseParquetColumnStatisticsSnafu)? - .and_then(|v| { - parquet_statistics_to_column_range_statistics(v.as_ref(), &field.dtype).ok() - }) - .unwrap_or(ColumnRangeStatistics::Missing); - Ok((field_name.clone(), stats)) - } else { - Ok((field_name.clone(), ColumnRangeStatistics::Missing)) - } - }) - .collect::>>()?; + let columns = pq_file_schema.fields.iter().map(|(field_name, field)| { + if ColumnRangeStatistics::supports_dtype(&field.dtype) { + let stats: ColumnRangeStatistics = parquet_column_metadata + .remove(field_name) + .expect("Cannot find parsed Daft field in Parquet rowgroup metadata") + .transpose() + .context(super::UnableToParseParquetColumnStatisticsSnafu)? + .and_then(|v| { + parquet_statistics_to_column_range_statistics(v.as_ref(), &field.dtype).ok() + }) + .unwrap_or(ColumnRangeStatistics::Missing); + Ok((field_name.clone(), stats)) + } else { + Ok((field_name.clone(), ColumnRangeStatistics::Missing)) + } + }); + + // Apply `field_id_mapping` against parsed statistics to rename the columns (if provided) + let file_to_target_colname_mapping: Option> = + field_id_mapping.as_ref().map(|field_id_mapping| { + metadata + .columns() + .iter() + .filter_map(|col| { + if let Some(target_colname) = col + .descriptor() + .base_type + .get_field_info() + .id + .and_then(|field_id| field_id_mapping.get(&field_id)) + { + let top_level_column_name = col.descriptor().path_in_schema.first().expect( + "Parquet schema should have at least one entry in path_in_schema", + ); + Some((top_level_column_name, target_colname.name.clone())) + } else { + None + } + }) + .collect() + }); + let columns = columns.map(|result| { + if let Some(ref file_to_target_colname_mapping) = file_to_target_colname_mapping { + result.map(|(field_name, stats)| { + ( + file_to_target_colname_mapping + .get(&field_name) + .cloned() + .unwrap_or(field_name.clone()), + stats, + ) + }) + } else { + result + } + }); - Ok(TableStatistics { columns }) + Ok(TableStatistics { + columns: columns.collect::>>()?, + }) } diff --git a/src/daft-parquet/src/stream_reader.rs b/src/daft-parquet/src/stream_reader.rs index 982be9c2da..ba7fe07aa3 100644 --- a/src/daft-parquet/src/stream_reader.rs +++ b/src/daft-parquet/src/stream_reader.rs @@ -109,6 +109,7 @@ pub(crate) fn local_parquet_read_into_arrow( &daft_schema, &metadata, uri, + &None, )?; let columns_iters_per_rg = row_ranges diff --git a/tests/integration/iceberg/test_table_load.py b/tests/integration/iceberg/test_table_load.py index a5a2070b69..f62d055d5d 100644 --- a/tests/integration/iceberg/test_table_load.py +++ b/tests/integration/iceberg/test_table_load.py @@ -60,3 +60,14 @@ def test_daft_iceberg_table_collect_correct(table_name, local_iceberg_catalog): daft_pandas = df.to_pandas() iceberg_pandas = tab.scan().to_arrow().to_pandas() assert_df_equals(daft_pandas, iceberg_pandas, sort_key=[]) + + +@pytest.mark.integration() +def test_daft_iceberg_table_filtered_collect_correct(local_iceberg_catalog): + tab = local_iceberg_catalog.load_table(f"default.test_table_rename") + df = daft.read_iceberg(tab) + df = df.where(df["pos"] <= 1) + daft_pandas = df.to_pandas() + iceberg_pandas = tab.scan().to_arrow().to_pandas() + iceberg_pandas = daft_pandas.where(daft_pandas["pos"] <= 1) + assert_df_equals(daft_pandas, iceberg_pandas, sort_key=[])