Skip to content

Commit

Permalink
[FEAT][2/2] Support Iceberg renaming of **nested** columns (#1956)
Browse files Browse the repository at this point in the history
Adds support for renaming of nested columns (columns renamed under
structs and lists)

**Reviewers to note: this is a follow-on PR to #1937**

---------

Co-authored-by: Jay Chia <[email protected]@users.noreply.github.com>
  • Loading branch information
jaychia and Jay Chia authored Mar 5, 2024
1 parent 570764a commit 715a9e6
Show file tree
Hide file tree
Showing 6 changed files with 234 additions and 53 deletions.
15 changes: 1 addition & 14 deletions src/daft-core/src/datatypes/field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use serde::{Deserialize, Serialize};

pub type Metadata = std::collections::BTreeMap<String, String>;

#[derive(Clone, Debug, Eq, Deserialize, Serialize)]
#[derive(Clone, Debug, PartialEq, Eq, Deserialize, Serialize, Hash)]
pub struct Field {
pub name: String,
pub dtype: DataType,
Expand Down Expand Up @@ -129,16 +129,3 @@ impl Display for Field {
write!(f, "{}#{}", self.name, self.dtype)
}
}

impl PartialEq for Field {
fn eq(&self, other: &Self) -> bool {
self.dtype == other.dtype && self.name == other.name
}
}

impl std::hash::Hash for Field {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.name.hash(state);
self.dtype.hash(state);
}
}
1 change: 1 addition & 0 deletions src/daft-core/src/series/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ impl Series {
pub fn field(&self) -> &Field {
self.inner.field()
}

pub fn as_physical(&self) -> DaftResult<Series> {
let physical_dtype = self.data_type().to_physical();
if &physical_dtype == self.data_type() {
Expand Down
6 changes: 6 additions & 0 deletions src/daft-core/src/series/ops/downcast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ use crate::series::array_impl::ArrayWrapper;
use crate::series::Series;
use common_error::DaftResult;

use self::logical::MapArray;

impl Series {
pub fn downcast<Arr: DaftArrayType>(&self) -> DaftResult<&Arr> {
match self.inner.as_any().downcast_ref() {
Expand Down Expand Up @@ -81,6 +83,10 @@ impl Series {
self.downcast()
}

pub fn map(&self) -> DaftResult<&MapArray> {
self.downcast()
}

pub fn fixed_size_list(&self) -> DaftResult<&FixedSizeListArray> {
self.downcast()
}
Expand Down
233 changes: 209 additions & 24 deletions src/daft-parquet/src/file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ use std::{
use arrow2::io::parquet::read::schema::infer_schema_with_options;
use common_error::DaftResult;
use daft_core::{
datatypes::Field, schema::Schema, utils::arrow::cast_array_for_daft_if_needed, Series,
datatypes::Field, schema::Schema, utils::arrow::cast_array_for_daft_if_needed, DataType,
IntoSeries, Series,
};
use daft_dsl::ExprRef;
use daft_io::{IOClient, IOStatsRef};
Expand Down Expand Up @@ -101,27 +102,209 @@ where
}
}

fn rename_schema_recursively(
fn resolve_dtype_recursively(
dtype: &DataType,
field_id_mapping: Option<&BTreeMap<i32, Field>>,
) -> Option<DataType> {
match dtype {
// Ensure recursive renaming for nested types
DataType::List(child) => resolve_dtype_recursively(child.as_ref(), field_id_mapping)
.map(|new_dtype| DataType::List(Box::new(new_dtype))),
DataType::FixedSizeList(child, size) => {
resolve_dtype_recursively(child.as_ref(), field_id_mapping)
.map(|new_dtype| DataType::FixedSizeList(Box::new(new_dtype), *size))
}
DataType::Struct(original_children) => {
let new_fields = original_children
.iter()
.map(|field| resolve_field_recursively(field, field_id_mapping))
.collect::<Vec<_>>();
if new_fields.iter().all(|f| f.is_none()) {
None
} else {
Some(DataType::Struct(
new_fields
.into_iter()
.zip(original_children.iter())
.map(|(maybe_new_field, old_field)| {
maybe_new_field
.unwrap_or_else(|| old_field.clone().with_metadata(BTreeMap::new()))
})
.collect(),
))
}
}
DataType::Map(list_child) => resolve_dtype_recursively(list_child, field_id_mapping)
.map(|new_dtype| DataType::Map(Box::new(new_dtype))),
// All other types are renamed only at the top-level
_ => None,
}
}

fn resolve_field_recursively(
field: &Field,
field_id_mapping: Option<&BTreeMap<i32, Field>>,
) -> Option<Field> {
let new_name = if let (Some(field_id), Some(field_id_mapping)) =
(field.metadata.get("field_id"), field_id_mapping)
{
let field_id = str::parse::<i32>(field_id).unwrap();
let mapped_field = field_id_mapping.get(&field_id);
mapped_field.map(|mapped_field| &mapped_field.name)
} else {
None
};
let new_dtype = resolve_dtype_recursively(&field.dtype, field_id_mapping);

match (new_name, new_dtype, &field.metadata) {
(None, None, meta) => {
if meta.is_empty() {
None
} else {
Some(Field::new(&field.name, field.dtype.clone()))
}
}
(new_name, new_dtype, _) => Some(
Field::new(
new_name.unwrap_or(&field.name),
new_dtype.unwrap_or(field.dtype.clone()),
)
.with_metadata(BTreeMap::new()),
),
}
}

/// Resolves a Series that was retrieved from Parquet -> arrow2 -> Daft:
/// 1. Renames any Series' names using the provided `field_id_mapping`
/// 2. Sanitizes any Series' fields to remove any provided field metadata
fn resolve_series_recursively(
daft_series: &Series,
field_id_mapping: Option<&BTreeMap<i32, Field>>,
) -> Option<Series> {
let new_field = resolve_field_recursively(daft_series.field(), field_id_mapping);
match new_field {
Some(new_field) => {
// Rename the current series then recursively rename child Series objects if a nested dtype is detected
let new_series = daft_series.rename(&new_field.name);
match &new_field.dtype {
DataType::List(..) => {
use daft_core::array::ListArray;

let new_array = new_series.list().expect(
"Series renaming: Expected a ListArray for a Series with DataType::List",
);
let child = &new_array.flat_child;
resolve_series_recursively(child, field_id_mapping)
.map(|new_child| {
ListArray::new(
new_field,
new_child,
new_array.offsets().clone(),
new_array.validity().cloned(),
)
.into_series()
})
.or(Some(new_series))
}
DataType::FixedSizeList(..) => {
use daft_core::array::FixedSizeListArray;

let new_array = new_series.fixed_size_list().expect("Series renaming: Expected a FixedSizeListArray for a Series with DataType::FixedSizeList");
let child = &new_array.flat_child;
resolve_series_recursively(child, field_id_mapping)
.map(|new_child| {
FixedSizeListArray::new(
new_field,
new_child,
new_array.validity().cloned(),
)
.into_series()
})
.or(Some(new_series))
}
DataType::Struct(..) => {
use daft_core::array::StructArray;

let new_array = new_series.struct_().expect("Series renaming: Expected a StructArray for a Series with DataType::Struct");
let new_children = new_array
.children
.iter()
.map(|child| resolve_series_recursively(child, field_id_mapping))
.collect::<Vec<_>>();

if new_children.iter().all(|maybe_child| maybe_child.is_none()) {
Some(new_series)
} else {
Some(
StructArray::new(
new_field,
new_children
.into_iter()
.zip(new_array.children.iter())
.map(|(renamed, original)| renamed.unwrap_or(original.clone()))
.collect(),
new_array.validity().cloned(),
)
.into_series(),
)
}
}
DataType::Map(_) => {
use daft_core::array::ListArray;
use daft_core::datatypes::logical::MapArray;

let new_array = new_series.map().expect(
"Series renaming: Expected a MapArray for a Series with DataType::Map",
);
let new_array_child_flat_struct = resolve_series_recursively(
&new_array.physical.flat_child,
field_id_mapping,
);

match new_array_child_flat_struct {
Some(new_array_child_flat_struct) => Some(
MapArray::new(
new_field,
ListArray::new(
resolve_field_recursively(
&new_array.physical.field,
field_id_mapping,
)
.unwrap_or_else(|| new_array.physical.field.as_ref().clone()),
new_array_child_flat_struct,
new_array.physical.offsets().clone(),
new_array.physical.validity().cloned(),
),
)
.into_series(),
),
None => Some(new_series),
}
}
_ => Some(new_series),
}
}
_ => None,
}
}

/// Resolves a schema that was retrieved from Parquet -> arrow2 -> Daft:
/// 1. Renames any fields using the provided `field_id_mapping`
/// 2. Sanitizes the schema to remove any provided field metadata
fn resolve_schema_recursively(
daft_schema: Schema,
field_id_mapping: &BTreeMap<i32, Field>,
field_id_mapping: Option<&BTreeMap<i32, Field>>,
) -> DaftResult<Schema> {
// TODO: perform this recursively
Schema::new(
daft_schema
.fields
.into_iter()
.map(|(_, field)| {
if let Some(field_id) = field.metadata.get("field_id") {
let field_id = str::parse::<i32>(field_id).unwrap();
let mapped_field = field_id_mapping.get(&field_id);
match mapped_field {
None => field,
Some(mapped_field) => field.rename(&mapped_field.name),
}
} else {
field
}
})
.map(
|(_, field)| match resolve_field_recursively(&field, field_id_mapping) {
None => field,
Some(new_field) => new_field,
},
)
.collect(),
)
}
Expand Down Expand Up @@ -456,6 +639,10 @@ impl ParquetFileReader {
self,
ranges: Arc<RangesContainer>,
) -> DaftResult<Table> {
// Retrieve an Option<&BTreeMap> handle to the field_id_mapping
let field_id_mapping = self.field_id_mapping.clone();
let field_id_mapping = field_id_mapping.as_ref().map(|m| m.as_ref());

let metadata = self.metadata;
let all_handles = self
.arrow_schema_from_pq
Expand Down Expand Up @@ -586,8 +773,6 @@ impl ParquetFileReader {
all_arrays
.into_iter()
.map(|a| {
// TODO: Need to perform recursive renaming of Series here. Hopefully arrow array
// has the correct metadata and that was correctly transferred to the Series...
Series::try_from((
cloned_target_field_name.as_str(),
cast_array_for_daft_if_needed(a),
Expand Down Expand Up @@ -632,21 +817,21 @@ impl ParquetFileReader {
Ok(concated_handle)
})
.collect::<DaftResult<Vec<_>>>()?;

let all_series = try_join_all(all_handles)
.await
.context(JoinSnafu {
path: self.uri.to_string(),
})?
.into_iter()
.map(|series| {
series.map(|series| {
resolve_series_recursively(&series, field_id_mapping).unwrap_or(series)
})
})
.collect::<DaftResult<Vec<_>>>()?;

let daft_schema = daft_core::schema::Schema::try_from(self.arrow_schema_from_pq.as_ref())?;
let daft_schema = if let Some(field_id_mapping) = self.field_id_mapping {
rename_schema_recursively(daft_schema, field_id_mapping.as_ref())?
} else {
daft_schema
};
let daft_schema = resolve_schema_recursively(daft_schema, field_id_mapping)?;

Table::new(daft_schema, all_series)
}
Expand Down
28 changes: 15 additions & 13 deletions tests/integration/iceberg/docker-compose/provision.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from pyiceberg.schema import Schema
from pyiceberg.types import FixedType, NestedField, UUIDType
from pyspark.sql import SparkSession
from pyspark.sql.functions import current_date, date_add, expr
from pyspark.sql.functions import col, current_date, date_add, expr, struct

spark = SparkSession.builder.getOrCreate()

Expand Down Expand Up @@ -357,17 +357,19 @@

spark.sql("ALTER TABLE default.test_new_column_with_no_data ADD COLUMN name STRING")

spark.sql(
"""
CREATE OR REPLACE TABLE default.test_table_rename
USING iceberg
AS SELECT
1 AS idx, 10 AS data
UNION ALL SELECT
2 AS idx, 20 AS data
UNION ALL SELECT
3 AS idx, 30 AS data
"""
)

###
# Renaming columns test table
###

renaming_columns_dataframe = (
spark.range(1, 2, 3)
.withColumnRenamed("id", "idx")
.withColumn("data", col("idx") * 10)
.withColumn("structcol", struct("idx"))
.withColumn("structcol_oldname", struct("idx"))
)
renaming_columns_dataframe.writeTo("default.test_table_rename").tableProperty("format-version", "2").createOrReplace()
spark.sql("ALTER TABLE default.test_table_rename RENAME COLUMN idx TO pos")
spark.sql("ALTER TABLE default.test_table_rename RENAME COLUMN structcol.idx TO pos")
spark.sql("ALTER TABLE default.test_table_rename RENAME COLUMN structcol_oldname TO structcol_2")
4 changes: 2 additions & 2 deletions tests/integration/iceberg/test_table_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def test_daft_iceberg_table_collect_correct(table_name, local_iceberg_catalog):


@pytest.mark.integration()
def test_daft_iceberg_table_filtered_collect_correct(local_iceberg_catalog):
def test_daft_iceberg_table_renamed_filtered_collect_correct(local_iceberg_catalog):
tab = local_iceberg_catalog.load_table(f"default.test_table_rename")
df = daft.read_iceberg(tab)
df = df.where(df["pos"] <= 1)
Expand All @@ -74,7 +74,7 @@ def test_daft_iceberg_table_filtered_collect_correct(local_iceberg_catalog):


@pytest.mark.integration()
def test_daft_iceberg_table_column_pushdown_collect_correct(local_iceberg_catalog):
def test_daft_iceberg_table_renamed_column_pushdown_collect_correct(local_iceberg_catalog):
tab = local_iceberg_catalog.load_table(f"default.test_table_rename")
df = daft.read_iceberg(tab)
df = df.select("pos")
Expand Down

0 comments on commit 715a9e6

Please sign in to comment.