From 715a9e688f172bf7dca0fa7cb5519bb74a0c037f Mon Sep 17 00:00:00 2001 From: Jay Chia <17691182+jaychia@users.noreply.github.com> Date: Mon, 4 Mar 2024 17:29:00 -0800 Subject: [PATCH] [FEAT][2/2] Support Iceberg renaming of **nested** columns (#1956) Adds support for renaming of nested columns (columns renamed under structs and lists) **Reviewers to note: this is a follow-on PR to #1937** --------- Co-authored-by: Jay Chia --- src/daft-core/src/datatypes/field.rs | 15 +- src/daft-core/src/series/mod.rs | 1 + src/daft-core/src/series/ops/downcast.rs | 6 + src/daft-parquet/src/file.rs | 233 ++++++++++++++++-- .../iceberg/docker-compose/provision.py | 28 ++- tests/integration/iceberg/test_table_load.py | 4 +- 6 files changed, 234 insertions(+), 53 deletions(-) diff --git a/src/daft-core/src/datatypes/field.rs b/src/daft-core/src/datatypes/field.rs index 99b2d3a537..fca15f89aa 100644 --- a/src/daft-core/src/datatypes/field.rs +++ b/src/daft-core/src/datatypes/field.rs @@ -10,7 +10,7 @@ use serde::{Deserialize, Serialize}; pub type Metadata = std::collections::BTreeMap; -#[derive(Clone, Debug, Eq, Deserialize, Serialize)] +#[derive(Clone, Debug, PartialEq, Eq, Deserialize, Serialize, Hash)] pub struct Field { pub name: String, pub dtype: DataType, @@ -129,16 +129,3 @@ impl Display for Field { write!(f, "{}#{}", self.name, self.dtype) } } - -impl PartialEq for Field { - fn eq(&self, other: &Self) -> bool { - self.dtype == other.dtype && self.name == other.name - } -} - -impl std::hash::Hash for Field { - fn hash(&self, state: &mut H) { - self.name.hash(state); - self.dtype.hash(state); - } -} diff --git a/src/daft-core/src/series/mod.rs b/src/daft-core/src/series/mod.rs index 1722014e0c..07f6d2b501 100644 --- a/src/daft-core/src/series/mod.rs +++ b/src/daft-core/src/series/mod.rs @@ -82,6 +82,7 @@ impl Series { pub fn field(&self) -> &Field { self.inner.field() } + pub fn as_physical(&self) -> DaftResult { let physical_dtype = self.data_type().to_physical(); if &physical_dtype == self.data_type() { diff --git a/src/daft-core/src/series/ops/downcast.rs b/src/daft-core/src/series/ops/downcast.rs index a9c5780fee..c4bb562c8a 100644 --- a/src/daft-core/src/series/ops/downcast.rs +++ b/src/daft-core/src/series/ops/downcast.rs @@ -7,6 +7,8 @@ use crate::series::array_impl::ArrayWrapper; use crate::series::Series; use common_error::DaftResult; +use self::logical::MapArray; + impl Series { pub fn downcast(&self) -> DaftResult<&Arr> { match self.inner.as_any().downcast_ref() { @@ -81,6 +83,10 @@ impl Series { self.downcast() } + pub fn map(&self) -> DaftResult<&MapArray> { + self.downcast() + } + pub fn fixed_size_list(&self) -> DaftResult<&FixedSizeListArray> { self.downcast() } diff --git a/src/daft-parquet/src/file.rs b/src/daft-parquet/src/file.rs index 782d9e81cb..0ab2e50c12 100644 --- a/src/daft-parquet/src/file.rs +++ b/src/daft-parquet/src/file.rs @@ -6,7 +6,8 @@ use std::{ use arrow2::io::parquet::read::schema::infer_schema_with_options; use common_error::DaftResult; use daft_core::{ - datatypes::Field, schema::Schema, utils::arrow::cast_array_for_daft_if_needed, Series, + datatypes::Field, schema::Schema, utils::arrow::cast_array_for_daft_if_needed, DataType, + IntoSeries, Series, }; use daft_dsl::ExprRef; use daft_io::{IOClient, IOStatsRef}; @@ -101,27 +102,209 @@ where } } -fn rename_schema_recursively( +fn resolve_dtype_recursively( + dtype: &DataType, + field_id_mapping: Option<&BTreeMap>, +) -> Option { + match dtype { + // Ensure recursive renaming for nested types + DataType::List(child) => resolve_dtype_recursively(child.as_ref(), field_id_mapping) + .map(|new_dtype| DataType::List(Box::new(new_dtype))), + DataType::FixedSizeList(child, size) => { + resolve_dtype_recursively(child.as_ref(), field_id_mapping) + .map(|new_dtype| DataType::FixedSizeList(Box::new(new_dtype), *size)) + } + DataType::Struct(original_children) => { + let new_fields = original_children + .iter() + .map(|field| resolve_field_recursively(field, field_id_mapping)) + .collect::>(); + if new_fields.iter().all(|f| f.is_none()) { + None + } else { + Some(DataType::Struct( + new_fields + .into_iter() + .zip(original_children.iter()) + .map(|(maybe_new_field, old_field)| { + maybe_new_field + .unwrap_or_else(|| old_field.clone().with_metadata(BTreeMap::new())) + }) + .collect(), + )) + } + } + DataType::Map(list_child) => resolve_dtype_recursively(list_child, field_id_mapping) + .map(|new_dtype| DataType::Map(Box::new(new_dtype))), + // All other types are renamed only at the top-level + _ => None, + } +} + +fn resolve_field_recursively( + field: &Field, + field_id_mapping: Option<&BTreeMap>, +) -> Option { + let new_name = if let (Some(field_id), Some(field_id_mapping)) = + (field.metadata.get("field_id"), field_id_mapping) + { + let field_id = str::parse::(field_id).unwrap(); + let mapped_field = field_id_mapping.get(&field_id); + mapped_field.map(|mapped_field| &mapped_field.name) + } else { + None + }; + let new_dtype = resolve_dtype_recursively(&field.dtype, field_id_mapping); + + match (new_name, new_dtype, &field.metadata) { + (None, None, meta) => { + if meta.is_empty() { + None + } else { + Some(Field::new(&field.name, field.dtype.clone())) + } + } + (new_name, new_dtype, _) => Some( + Field::new( + new_name.unwrap_or(&field.name), + new_dtype.unwrap_or(field.dtype.clone()), + ) + .with_metadata(BTreeMap::new()), + ), + } +} + +/// Resolves a Series that was retrieved from Parquet -> arrow2 -> Daft: +/// 1. Renames any Series' names using the provided `field_id_mapping` +/// 2. Sanitizes any Series' fields to remove any provided field metadata +fn resolve_series_recursively( + daft_series: &Series, + field_id_mapping: Option<&BTreeMap>, +) -> Option { + let new_field = resolve_field_recursively(daft_series.field(), field_id_mapping); + match new_field { + Some(new_field) => { + // Rename the current series then recursively rename child Series objects if a nested dtype is detected + let new_series = daft_series.rename(&new_field.name); + match &new_field.dtype { + DataType::List(..) => { + use daft_core::array::ListArray; + + let new_array = new_series.list().expect( + "Series renaming: Expected a ListArray for a Series with DataType::List", + ); + let child = &new_array.flat_child; + resolve_series_recursively(child, field_id_mapping) + .map(|new_child| { + ListArray::new( + new_field, + new_child, + new_array.offsets().clone(), + new_array.validity().cloned(), + ) + .into_series() + }) + .or(Some(new_series)) + } + DataType::FixedSizeList(..) => { + use daft_core::array::FixedSizeListArray; + + let new_array = new_series.fixed_size_list().expect("Series renaming: Expected a FixedSizeListArray for a Series with DataType::FixedSizeList"); + let child = &new_array.flat_child; + resolve_series_recursively(child, field_id_mapping) + .map(|new_child| { + FixedSizeListArray::new( + new_field, + new_child, + new_array.validity().cloned(), + ) + .into_series() + }) + .or(Some(new_series)) + } + DataType::Struct(..) => { + use daft_core::array::StructArray; + + let new_array = new_series.struct_().expect("Series renaming: Expected a StructArray for a Series with DataType::Struct"); + let new_children = new_array + .children + .iter() + .map(|child| resolve_series_recursively(child, field_id_mapping)) + .collect::>(); + + if new_children.iter().all(|maybe_child| maybe_child.is_none()) { + Some(new_series) + } else { + Some( + StructArray::new( + new_field, + new_children + .into_iter() + .zip(new_array.children.iter()) + .map(|(renamed, original)| renamed.unwrap_or(original.clone())) + .collect(), + new_array.validity().cloned(), + ) + .into_series(), + ) + } + } + DataType::Map(_) => { + use daft_core::array::ListArray; + use daft_core::datatypes::logical::MapArray; + + let new_array = new_series.map().expect( + "Series renaming: Expected a MapArray for a Series with DataType::Map", + ); + let new_array_child_flat_struct = resolve_series_recursively( + &new_array.physical.flat_child, + field_id_mapping, + ); + + match new_array_child_flat_struct { + Some(new_array_child_flat_struct) => Some( + MapArray::new( + new_field, + ListArray::new( + resolve_field_recursively( + &new_array.physical.field, + field_id_mapping, + ) + .unwrap_or_else(|| new_array.physical.field.as_ref().clone()), + new_array_child_flat_struct, + new_array.physical.offsets().clone(), + new_array.physical.validity().cloned(), + ), + ) + .into_series(), + ), + None => Some(new_series), + } + } + _ => Some(new_series), + } + } + _ => None, + } +} + +/// Resolves a schema that was retrieved from Parquet -> arrow2 -> Daft: +/// 1. Renames any fields using the provided `field_id_mapping` +/// 2. Sanitizes the schema to remove any provided field metadata +fn resolve_schema_recursively( daft_schema: Schema, - field_id_mapping: &BTreeMap, + field_id_mapping: Option<&BTreeMap>, ) -> DaftResult { - // TODO: perform this recursively Schema::new( daft_schema .fields .into_iter() - .map(|(_, field)| { - if let Some(field_id) = field.metadata.get("field_id") { - let field_id = str::parse::(field_id).unwrap(); - let mapped_field = field_id_mapping.get(&field_id); - match mapped_field { - None => field, - Some(mapped_field) => field.rename(&mapped_field.name), - } - } else { - field - } - }) + .map( + |(_, field)| match resolve_field_recursively(&field, field_id_mapping) { + None => field, + Some(new_field) => new_field, + }, + ) .collect(), ) } @@ -456,6 +639,10 @@ impl ParquetFileReader { self, ranges: Arc, ) -> DaftResult { + // Retrieve an Option<&BTreeMap> handle to the field_id_mapping + let field_id_mapping = self.field_id_mapping.clone(); + let field_id_mapping = field_id_mapping.as_ref().map(|m| m.as_ref()); + let metadata = self.metadata; let all_handles = self .arrow_schema_from_pq @@ -586,8 +773,6 @@ impl ParquetFileReader { all_arrays .into_iter() .map(|a| { - // TODO: Need to perform recursive renaming of Series here. Hopefully arrow array - // has the correct metadata and that was correctly transferred to the Series... Series::try_from(( cloned_target_field_name.as_str(), cast_array_for_daft_if_needed(a), @@ -632,21 +817,21 @@ impl ParquetFileReader { Ok(concated_handle) }) .collect::>>()?; - let all_series = try_join_all(all_handles) .await .context(JoinSnafu { path: self.uri.to_string(), })? .into_iter() + .map(|series| { + series.map(|series| { + resolve_series_recursively(&series, field_id_mapping).unwrap_or(series) + }) + }) .collect::>>()?; let daft_schema = daft_core::schema::Schema::try_from(self.arrow_schema_from_pq.as_ref())?; - let daft_schema = if let Some(field_id_mapping) = self.field_id_mapping { - rename_schema_recursively(daft_schema, field_id_mapping.as_ref())? - } else { - daft_schema - }; + let daft_schema = resolve_schema_recursively(daft_schema, field_id_mapping)?; Table::new(daft_schema, all_series) } diff --git a/tests/integration/iceberg/docker-compose/provision.py b/tests/integration/iceberg/docker-compose/provision.py index db67b3be27..c8589621a3 100644 --- a/tests/integration/iceberg/docker-compose/provision.py +++ b/tests/integration/iceberg/docker-compose/provision.py @@ -21,7 +21,7 @@ from pyiceberg.schema import Schema from pyiceberg.types import FixedType, NestedField, UUIDType from pyspark.sql import SparkSession -from pyspark.sql.functions import current_date, date_add, expr +from pyspark.sql.functions import col, current_date, date_add, expr, struct spark = SparkSession.builder.getOrCreate() @@ -357,17 +357,19 @@ spark.sql("ALTER TABLE default.test_new_column_with_no_data ADD COLUMN name STRING") -spark.sql( - """ - CREATE OR REPLACE TABLE default.test_table_rename - USING iceberg - AS SELECT - 1 AS idx, 10 AS data - UNION ALL SELECT - 2 AS idx, 20 AS data - UNION ALL SELECT - 3 AS idx, 30 AS data -""" -) +### +# Renaming columns test table +### + +renaming_columns_dataframe = ( + spark.range(1, 2, 3) + .withColumnRenamed("id", "idx") + .withColumn("data", col("idx") * 10) + .withColumn("structcol", struct("idx")) + .withColumn("structcol_oldname", struct("idx")) +) +renaming_columns_dataframe.writeTo("default.test_table_rename").tableProperty("format-version", "2").createOrReplace() spark.sql("ALTER TABLE default.test_table_rename RENAME COLUMN idx TO pos") +spark.sql("ALTER TABLE default.test_table_rename RENAME COLUMN structcol.idx TO pos") +spark.sql("ALTER TABLE default.test_table_rename RENAME COLUMN structcol_oldname TO structcol_2") diff --git a/tests/integration/iceberg/test_table_load.py b/tests/integration/iceberg/test_table_load.py index 225f1ecb8d..1d1032d041 100644 --- a/tests/integration/iceberg/test_table_load.py +++ b/tests/integration/iceberg/test_table_load.py @@ -63,7 +63,7 @@ def test_daft_iceberg_table_collect_correct(table_name, local_iceberg_catalog): @pytest.mark.integration() -def test_daft_iceberg_table_filtered_collect_correct(local_iceberg_catalog): +def test_daft_iceberg_table_renamed_filtered_collect_correct(local_iceberg_catalog): tab = local_iceberg_catalog.load_table(f"default.test_table_rename") df = daft.read_iceberg(tab) df = df.where(df["pos"] <= 1) @@ -74,7 +74,7 @@ def test_daft_iceberg_table_filtered_collect_correct(local_iceberg_catalog): @pytest.mark.integration() -def test_daft_iceberg_table_column_pushdown_collect_correct(local_iceberg_catalog): +def test_daft_iceberg_table_renamed_column_pushdown_collect_correct(local_iceberg_catalog): tab = local_iceberg_catalog.load_table(f"default.test_table_rename") df = daft.read_iceberg(tab) df = df.select("pos")