From 58564501b93ddf9a658570ff93078630a09fb381 Mon Sep 17 00:00:00 2001 From: Jay Chia Date: Mon, 26 Feb 2024 18:21:05 -0800 Subject: [PATCH] Add recursive renaming to Series --- src/daft-parquet/src/file.rs | 210 +++++++++++++++++++++++++---------- 1 file changed, 149 insertions(+), 61 deletions(-) diff --git a/src/daft-parquet/src/file.rs b/src/daft-parquet/src/file.rs index 6cda2804c8..210f7f8192 100644 --- a/src/daft-parquet/src/file.rs +++ b/src/daft-parquet/src/file.rs @@ -6,7 +6,8 @@ use std::{ use arrow2::io::parquet::read::schema::infer_schema_with_options; use common_error::DaftResult; use daft_core::{ - datatypes::Field, schema::Schema, utils::arrow::cast_array_for_daft_if_needed, DataType, Series, + datatypes::Field, schema::Schema, utils::arrow::cast_array_for_daft_if_needed, DataType, + IntoSeries, Series, }; use daft_dsl::ExprRef; use daft_io::{IOClient, IOStatsRef}; @@ -101,71 +102,150 @@ where } } -fn rename_schema_recursively( - daft_schema: Schema, +fn rename_dtype_recursively( + dtype: &DataType, field_id_mapping: &BTreeMap, -) -> DaftResult { - fn rename_dtype_recursively( - dtype: &DataType, - field_id_mapping: &BTreeMap, - ) -> Option { - match dtype { - // Ensure recursive renaming for nested types - DataType::List(child) => rename_dtype_recursively(child.as_ref(), field_id_mapping) - .map(|new_dtype| DataType::List(Box::new(new_dtype))), - DataType::FixedSizeList(child, size) => { - rename_dtype_recursively(child.as_ref(), field_id_mapping) - .map(|new_dtype| DataType::FixedSizeList(Box::new(new_dtype), *size)) - } - DataType::Struct(original_children) => { - let new_fields = original_children - .iter() - .map(|field| rename_field_recursively(field, field_id_mapping)) - .collect::>(); - if new_fields.iter().all(|f| f.is_none()) { - None - } else { - Some(DataType::Struct( - new_fields - .into_iter() - .zip(original_children.iter()) - .map(|(maybe_new_field, old_field)| { - maybe_new_field.unwrap_or_else(|| old_field.clone()) - }) - .collect(), - )) - } +) -> Option { + match dtype { + // Ensure recursive renaming for nested types + DataType::List(child) => rename_dtype_recursively(child.as_ref(), field_id_mapping) + .map(|new_dtype| DataType::List(Box::new(new_dtype))), + DataType::FixedSizeList(child, size) => { + rename_dtype_recursively(child.as_ref(), field_id_mapping) + .map(|new_dtype| DataType::FixedSizeList(Box::new(new_dtype), *size)) + } + DataType::Struct(original_children) => { + let new_fields = original_children + .iter() + .map(|field| rename_field_recursively(field, field_id_mapping)) + .collect::>(); + if new_fields.iter().all(|f| f.is_none()) { + None + } else { + Some(DataType::Struct( + new_fields + .into_iter() + .zip(original_children.iter()) + .map(|(maybe_new_field, old_field)| { + maybe_new_field.unwrap_or_else(|| old_field.clone()) + }) + .collect(), + )) } - // All other types are renamed only at the top-level - _ => None, } + // All other types are renamed only at the top-level + _ => None, } +} - fn rename_field_recursively( - field: &Field, - field_id_mapping: &BTreeMap, - ) -> Option { - let new_name = if let Some(field_id) = field.metadata.get("field_id") { - let field_id = str::parse::(field_id).unwrap(); - let mapped_field = field_id_mapping.get(&field_id); - mapped_field.map(|mapped_field| &mapped_field.name) - } else { - None - }; - let new_dtype = rename_dtype_recursively(&field.dtype, field_id_mapping); - - match (new_name, new_dtype) { - (None, None) => None, - (new_name, new_dtype) => Some( - Field::new( - new_name.unwrap_or(&field.name), - new_dtype.unwrap_or(field.dtype.clone()), - ) - .with_metadata(field.metadata.clone()), - ), +fn rename_field_recursively( + field: &Field, + field_id_mapping: &BTreeMap, +) -> Option { + let new_name = if let Some(field_id) = field.metadata.get("field_id") { + let field_id = str::parse::(field_id).unwrap(); + let mapped_field = field_id_mapping.get(&field_id); + mapped_field.map(|mapped_field| &mapped_field.name) + } else { + None + }; + let new_dtype = rename_dtype_recursively(&field.dtype, field_id_mapping); + + match (new_name, new_dtype) { + (None, None) => None, + (new_name, new_dtype) => Some( + Field::new( + new_name.unwrap_or(&field.name), + new_dtype.unwrap_or(field.dtype.clone()), + ) + .with_metadata(field.metadata.clone()), + ), + } +} + +fn rename_series_recursively( + daft_series: &Series, + field_id_mapping: &BTreeMap, +) -> Option { + let new_field = rename_field_recursively(daft_series.field(), field_id_mapping); + match new_field { + Some(new_field) => { + // Rename the current series then recursively rename child Series objects if a nested dtype is detected + let new_series = daft_series.rename(&new_field.name); + match &new_field.dtype { + DataType::List(..) => { + use daft_core::array::ListArray; + + let new_array = new_series.list().expect( + "Series renaming: Expected a ListArray for a Series with DataType::List", + ); + let child = &new_array.flat_child; + rename_series_recursively(child, field_id_mapping) + .map(|new_child| { + ListArray::new( + new_field, + new_child, + new_array.offsets().clone(), + new_array.validity().cloned(), + ) + .into_series() + }) + .or(Some(new_series)) + } + DataType::FixedSizeList(..) => { + use daft_core::array::FixedSizeListArray; + + let new_array = new_series.fixed_size_list().expect("Series renaming: Expected a FixedSizeListArray for a Series with DataType::FixedSizeList"); + let child = &new_array.flat_child; + rename_series_recursively(child, field_id_mapping) + .map(|new_child| { + FixedSizeListArray::new( + new_field, + new_child, + new_array.validity().cloned(), + ) + .into_series() + }) + .or(Some(new_series)) + } + DataType::Struct(..) => { + use daft_core::array::StructArray; + + let new_array = new_series.struct_().expect("Series renaming: Expected a StructArray for a Series with DataType::Struct"); + let new_children = new_array + .children + .iter() + .map(|child| rename_series_recursively(child, field_id_mapping)) + .collect::>(); + + if new_children.iter().all(|maybe_child| maybe_child.is_none()) { + Some(new_series) + } else { + Some( + StructArray::new( + new_field, + new_children + .into_iter() + .zip(new_array.children.iter()) + .map(|(renamed, original)| renamed.unwrap_or(original.clone())) + .collect(), + new_array.validity().cloned(), + ) + .into_series(), + ) + } + } + _ => Some(new_series), + } } + _ => None, } +} +fn rename_schema_recursively( + daft_schema: Schema, + field_id_mapping: &BTreeMap, +) -> DaftResult { Schema::new( daft_schema .fields @@ -640,8 +720,6 @@ impl ParquetFileReader { all_arrays .into_iter() .map(|a| { - // TODO: Need to perform recursive renaming of Series here. Hopefully arrow array - // has the correct metadata and that was correctly transferred to the Series... Series::try_from(( cloned_target_field_name.as_str(), cast_array_for_daft_if_needed(a), @@ -693,10 +771,20 @@ impl ParquetFileReader { path: self.uri.to_string(), })? .into_iter() + .map(|series| { + series.map(|series| { + if let Some(field_id_mapping) = self.field_id_mapping.as_ref() { + rename_series_recursively(&series, field_id_mapping.as_ref()) + .unwrap_or(series) + } else { + series + } + }) + }) .collect::>>()?; let daft_schema = daft_core::schema::Schema::try_from(self.arrow_schema_from_pq.as_ref())?; - let daft_schema = if let Some(field_id_mapping) = self.field_id_mapping { + let daft_schema = if let Some(field_id_mapping) = self.field_id_mapping.as_ref() { rename_schema_recursively(daft_schema, field_id_mapping.as_ref())? } else { daft_schema