Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Python: allow projection of Iceberg fields to pyarrow table schema with names #8144

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
217 changes: 190 additions & 27 deletions python/pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
Iterable,
Iterator,
List,
Literal as _Literal,
Optional,
Set,
Tuple,
Expand Down Expand Up @@ -575,9 +576,16 @@ def _combine_positional_deletes(positional_deletes: List[pa.ChunkedArray], rows:
return np.setdiff1d(np.arange(rows), all_chunks, assume_unique=False)


def pyarrow_to_schema(schema: pa.Schema) -> Schema:
visitor = _ConvertToIceberg()
return visit_pyarrow(schema, visitor)
def pyarrow_to_schema(
schema: pa.Schema,
projected_schema: Optional[Schema] = None,
match_with_field_name: bool = False,
ignore_unprojectable_fields: bool = False,
) -> Schema:
visitor = _ConvertToIceberg(projected_schema, match_with_field_name, ignore_unprojectable_fields)
ib_schema = visit_pyarrow(schema, visitor)
assert isinstance(ib_schema, StructType)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please replace it with a ValueError or similar things? According to comments in other PRs, we try to avoid assert outside tests/

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not intended to warn the user for wrong usage. This is a type guard for mypy and I believe it's valid.

return Schema(*ib_schema.fields)


@singledispatch
Expand Down Expand Up @@ -675,41 +683,141 @@ def primitive(self, primitive: pa.DataType) -> Optional[T]:


def _get_field_id(field: pa.Field) -> Optional[int]:
for pyarrow_field_id_key in PYARROW_FIELD_ID_KEYS:
if field_id_str := field.metadata.get(pyarrow_field_id_key):
return int(field_id_str.decode())
if field.metadata is not None:
for pyarrow_field_id_key in PYARROW_FIELD_ID_KEYS:
if field_id_str := field.metadata.get(pyarrow_field_id_key):
return int(field_id_str.decode())
return None


def _get_field_doc(field: pa.Field) -> Optional[str]:
for pyarrow_doc_key in PYARROW_FIELD_DOC_KEYS:
if doc_str := field.metadata.get(pyarrow_doc_key):
return doc_str.decode()
if field.metadata is not None:
for pyarrow_doc_key in PYARROW_FIELD_DOC_KEYS:
if doc_str := field.metadata.get(pyarrow_doc_key):
return doc_str.decode()
return None


class _ConvertToIceberg(PyArrowSchemaVisitor[Union[IcebergType, Schema]]):
class _ConvertToIceberg(PyArrowSchemaVisitor[IcebergType]):
projected_schema: Union[Schema, ListType, MapType, None]
match_with_field_name: bool
ignore_unprojectable_fields: bool
projected_schema_stack: List[Tuple[Schema | ListType | MapType | None, Optional[_Literal["key", "value"]]]]
next: Optional[_Literal["key", "value"]]

def _convert_fields(self, arrow_fields: Iterable[pa.Field], field_results: List[Optional[IcebergType]]) -> List[NestedField]:
fields = []
for i, field in enumerate(arrow_fields):
field_id = _get_field_id(field)
field_doc = _get_field_doc(field)
field_type = field_results[i]
if field_type is not None and field_id is not None:
fields.append(NestedField(field_id, field.name, field_type, required=not field.nullable, doc=field_doc))
ib_field: Optional[NestedField] = None
if field_type is not None:
if field_id is not None:
ib_field = NestedField(field_id, field.name, field_type, required=not field.nullable, doc=field_doc)
elif self.match_with_field_name:
if self.projected_schema is None and self.projected_schema_stack and self.ignore_unprojectable_fields:
continue
if not isinstance(self.projected_schema, Schema):
raise ValueError("projected_schema must be provided if match_with_field_name is set to True")
try:
projected_field = self.projected_schema.find_field(field.name)
except ValueError as e:
if self.ignore_unprojectable_fields:
continue
raise ValueError(
f"could not find a field that corresponds to {field.name} in projected schema {self.projected_schema}"
) from e
ib_field = NestedField(
projected_field.field_id, field.name, field_type, required=not field.nullable, doc=field_doc
)
if ib_field is not None:
fields.append(ib_field)
return fields

def schema(self, schema: pa.Schema, field_results: List[Optional[IcebergType]]) -> Schema:
return Schema(*self._convert_fields(schema, field_results))
def schema(self, schema: pa.Schema, field_results: List[Optional[IcebergType]]) -> StructType:
return StructType(*self._convert_fields(schema, field_results))

def before_field(self, field: pa.Field) -> None:
if not isinstance(field.type, (pa.StructType, pa.ListType, pa.MapType)):
return

self.projected_schema_stack.append((self.projected_schema, self.next))

projected_field: Optional[NestedField] = None

if isinstance(self.projected_schema, Schema):
field_id = _get_field_id(field)
if field_id is not None:
try:
projected_field = self.projected_schema.find_field(field_id)
except ValueError:
if not self.match_with_field_name:
raise
if projected_field is None and self.match_with_field_name:
try:
projected_field = self.projected_schema.find_field(field.name)
except ValueError:
if not self.ignore_unprojectable_fields:
raise
elif isinstance(self.projected_schema, ListType):
projected_field = self.projected_schema.element_field
elif isinstance(self.projected_schema, MapType):
if self.next == "key":
projected_field = self.projected_schema.key_field
elif self.next == "value":
projected_field = self.projected_schema.value_field
else:
raise AssertionError("should never get here")

inner_schema: Schema | ListType | MapType | None = None
next_: Optional[str] = None

if projected_field is not None:
field_type = projected_field.field_type
if isinstance(field_type, StructType):
inner_schema = Schema(*field_type.fields)
else:
if isinstance(field_type, ListType):
inner_schema = field_type
elif isinstance(field_type, MapType):
inner_schema = field_type
next_ = "key"
else:
if isinstance(field.type, pa.MapType):
next_ = "key"

self.projected_schema = inner_schema
self.next = next_

def after_field(self, field: pa.Field) -> None:
if isinstance(field.type, (pa.StructType, pa.ListType, pa.MapType)):
(self.projected_schema, self.next) = self.projected_schema_stack.pop()
if self.next == "key":
self.next = "value"
elif self.next == "value":
self.next = None

def struct(self, struct: pa.StructType, field_results: List[Optional[IcebergType]]) -> IcebergType:
return StructType(*self._convert_fields(struct, field_results))

def list(self, list_type: pa.ListType, element_result: Optional[IcebergType]) -> Optional[IcebergType]:
element_field = list_type.value_field
element_id = _get_field_id(element_field)
if element_result is not None and element_id is not None:
return ListType(element_id, element_result, element_required=not element_field.nullable)
if element_result is not None:
ib_type: ListType
if element_id is not None:
ib_type = ListType(element_id, element_result, element_required=not element_field.nullable)
elif self.match_with_field_name:
if self.projected_schema is None and self.projected_schema_stack and self.ignore_unprojectable_fields:
return None
if not isinstance(self.projected_schema, ListType):
raise ValueError("projected_schema must be provided if match_with_field_name is set to True")
ib_type = ListType(self.projected_schema.element_id, element_result, element_required=not element_field.nullable)
else:
raise ValueError("match_with_field_name is set to False and elemnt_id is unknown")
return ib_type

return None

def map(
Expand All @@ -719,8 +827,26 @@ def map(
key_id = _get_field_id(key_field)
value_field = map_type.item_field
value_id = _get_field_id(value_field)
if key_result is not None and value_result is not None and key_id is not None and value_id is not None:
return MapType(key_id, key_result, value_id, value_result, value_required=not value_field.nullable)
if key_result is not None and value_result is not None:
ib_type: MapType
if key_id is not None and value_id is not None:
ib_type = MapType(key_id, key_result, value_id, value_result, value_required=not value_field.nullable)
elif self.match_with_field_name:
if self.projected_schema is None and self.projected_schema_stack and self.ignore_unprojectable_fields:
return None
if not isinstance(self.projected_schema, MapType):
raise ValueError("projected_schema must be provided if match_with_field_name is set to True")
ib_type = MapType(
self.projected_schema.key_id,
key_result,
self.projected_schema.value_id,
value_result,
value_required=not value_field.nullable,
)
else:
raise ValueError("match_with_field_name is set to False and either key_id or value_id is unknown")
return ib_type

return None

def primitive(self, primitive: pa.DataType) -> IcebergType:
Expand Down Expand Up @@ -758,6 +884,15 @@ def primitive(self, primitive: pa.DataType) -> IcebergType:

raise TypeError(f"Unsupported type: {primitive}")

def __init__(
self, projected_schema: Optional[Schema] = None, match_with_field_name: bool = False, ignore_unprojectable_fields: bool = False
) -> None:
self.projected_schema = projected_schema
self.match_with_field_name = match_with_field_name
self.ignore_unprojectable_fields = ignore_unprojectable_fields
self.projected_schema_stack = []
self.next = None


def _task_to_table(
fs: FileSystem,
Expand All @@ -769,6 +904,9 @@ def _task_to_table(
case_sensitive: bool,
row_counts: List[int],
limit: Optional[int] = None,
*,
match_with_field_name: bool = False,
ignore_unprojectable_fields: bool = False,
) -> Optional[pa.Table]:
if limit and sum(row_counts) >= limit:
return None
Expand All @@ -781,9 +919,16 @@ def _task_to_table(
schema_raw = None
if metadata := physical_schema.metadata:
schema_raw = metadata.get(ICEBERG_SCHEMA)
# TODO: if field_ids are not present, Name Mapping should be implemented to look them up in the table schema,
# see https://github.com/apache/iceberg/issues/7451
file_schema = Schema.model_validate_json(schema_raw) if schema_raw is not None else pyarrow_to_schema(physical_schema)
file_schema = (
Schema.model_validate_json(schema_raw)
if schema_raw is not None
else pyarrow_to_schema(
physical_schema,
projected_schema,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we may want full table schema here. pyarrow_to_schema supposed to simply convert the physical schema to the iceberg schema without handling column pruning. However, projected_schema only contains selected columns in a table scan. If we use it during the conversion, we will have to ignore unselected columns, which I think is unnecessary and tricky to implement. (also inconsistent with the behavior when field_id is present)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't have much to do with the pruning. What we need to acheive with ignore_unprojected_fields here is to simply ignore redundant columns in the actual data, and the purpose of pruning is to take away the fields that are already known according to the catalog. Those are similar, but have different semantics.

match_with_field_name=match_with_field_name,
ignore_unprojectable_fields=ignore_unprojectable_fields,
)
)

pyarrow_filter = None
if bound_row_filter is not AlwaysTrue():
Expand Down Expand Up @@ -840,7 +985,9 @@ def _task_to_table(

row_counts.append(len(arrow_table))

return to_requested_schema(projected_schema, file_project_schema, arrow_table)
return to_requested_schema(
projected_schema, file_project_schema, arrow_table, match_with_field_name=match_with_field_name
)


def _read_all_delete_files(fs: FileSystem, tasks: Iterable[FileScanTask]) -> Dict[str, List[ChunkedArray]]:
Expand Down Expand Up @@ -868,6 +1015,9 @@ def project_table(
projected_schema: Schema,
case_sensitive: bool = True,
limit: Optional[int] = None,
*,
match_with_field_name: bool = False,
ignore_unprojectable_fields: bool = False,
) -> pa.Table:
"""Resolve the right columns based on the identifier.

Expand Down Expand Up @@ -918,6 +1068,8 @@ def project_table(
case_sensitive,
row_counts,
limit,
match_with_field_name=match_with_field_name,
ignore_unprojectable_fields=ignore_unprojectable_fields
)
for task in tasks
]
Expand Down Expand Up @@ -949,8 +1101,12 @@ def project_table(
return result


def to_requested_schema(requested_schema: Schema, file_schema: Schema, table: pa.Table) -> pa.Table:
struct_array = visit_with_partner(requested_schema, table, ArrowProjectionVisitor(file_schema), ArrowAccessor(file_schema))
def to_requested_schema(
requested_schema: Schema, file_schema: Schema, table: pa.Table, *, match_with_field_name: bool = False
) -> pa.Table:
struct_array = visit_with_partner(
requested_schema, table, ArrowProjectionVisitor(file_schema), ArrowAccessor(file_schema, match_with_field_name)
)

arrays = []
fields = []
Expand Down Expand Up @@ -1027,19 +1183,26 @@ def primitive(self, _: PrimitiveType, array: Optional[pa.Array]) -> Optional[pa.
class ArrowAccessor(PartnerAccessor[pa.Array]):
file_schema: Schema

def __init__(self, file_schema: Schema):
def __init__(self, file_schema: Schema, match_with_field_name: bool = False):
self.file_schema = file_schema
self.match_with_field_name = match_with_field_name

def schema_partner(self, partner: Optional[pa.Array]) -> Optional[pa.Array]:
return partner

def field_partner(self, partner_struct: Optional[pa.Array], field_id: int, _: str) -> Optional[pa.Array]:
def field_partner(self, partner_struct: Optional[pa.Array], field_id: int, field_name: str) -> Optional[pa.Array]:
if partner_struct:
# use the field name from the file schema
try:
name = self.file_schema.find_field(field_id).name
except ValueError:
return None
if self.match_with_field_name:
try:
name = self.file_schema.find_field(field_name).name
except ValueError:
return None
else:
return None

if isinstance(partner_struct, pa.StructArray):
return partner_struct.field(name)
Expand Down
16 changes: 9 additions & 7 deletions python/pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -860,7 +860,7 @@ def plan_files(self) -> Iterable[FileScanTask]:
for data_entry in data_entries
]

def to_arrow(self) -> pa.Table:
def to_arrow(self, *, match_with_field_name: bool = False, ignore_unprojectable_fields: bool = False) -> pa.Table:
from pyiceberg.io.pyarrow import project_table

return project_table(
Expand All @@ -870,23 +870,25 @@ def to_arrow(self) -> pa.Table:
self.projection(),
case_sensitive=self.case_sensitive,
limit=self.limit,
match_with_field_name=match_with_field_name,
ignore_unprojectable_fields=ignore_unprojectable_fields,
)

def to_pandas(self, **kwargs: Any) -> pd.DataFrame:
return self.to_arrow().to_pandas(**kwargs)
def to_pandas(self, *, match_with_field_name: bool = False, ignore_unprojectable_fields: bool = False, **kwargs: Any) -> pd.DataFrame:
return self.to_arrow(match_with_field_name=match_with_field_name, ignore_unprojectable_fields=ignore_unprojectable_fields).to_pandas(**kwargs)

def to_duckdb(self, table_name: str, connection: Optional[DuckDBPyConnection] = None) -> DuckDBPyConnection:
def to_duckdb(self, table_name: str, connection: Optional[DuckDBPyConnection] = None, *, match_with_field_name: bool = False) -> DuckDBPyConnection:
import duckdb

con = connection or duckdb.connect(database=":memory:")
con.register(table_name, self.to_arrow())
con.register(table_name, self.to_arrow(match_with_field_name=match_with_field_name, ignore_unprojectable_fields=ignore_unprojectable_fields))

return con

def to_ray(self) -> ray.data.dataset.Dataset:
def to_ray(self, *, match_with_field_name: bool = False, ignore_unprojectable_fields: bool = False) -> ray.data.dataset.Dataset:
import ray

return ray.data.from_arrow(self.to_arrow())
return ray.data.from_arrow(self.to_arrow(match_with_field_name=match_with_field_name, ignore_unprojectable_fields=ignore_unprojectable_fields))


class UpdateSchema:
Expand Down
Loading
Loading