Skip to content

Commit

Permalink
Python: allow projection of Iceberg fields to pyarrow table schema wi…
Browse files Browse the repository at this point in the history
…th field names when field ids are not available in data files.
  • Loading branch information
moriyoshi committed Aug 29, 2023
1 parent eb0a535 commit 9eeacbf
Show file tree
Hide file tree
Showing 3 changed files with 251 additions and 34 deletions.
209 changes: 182 additions & 27 deletions python/pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
Iterable,
Iterator,
List,
Literal as _Literal,
Optional,
Set,
Tuple,
Expand Down Expand Up @@ -575,9 +576,16 @@ def _combine_positional_deletes(positional_deletes: List[pa.ChunkedArray], rows:
return np.setdiff1d(np.arange(rows), all_chunks, assume_unique=False)


def pyarrow_to_schema(schema: pa.Schema) -> Schema:
visitor = _ConvertToIceberg()
return visit_pyarrow(schema, visitor)
def pyarrow_to_schema(
schema: pa.Schema,
projected_schema: Optional[Schema] = None,
match_with_field_name: bool = False,
ignore_unprojectable_fields: bool = False,
) -> Schema:
visitor = _ConvertToIceberg(projected_schema, match_with_field_name, ignore_unprojectable_fields)
ib_schema = visit_pyarrow(schema, visitor)
assert isinstance(ib_schema, StructType)
return Schema(*ib_schema.fields)


@singledispatch
Expand Down Expand Up @@ -675,41 +683,133 @@ def primitive(self, primitive: pa.DataType) -> Optional[T]:


def _get_field_id(field: pa.Field) -> Optional[int]:
for pyarrow_field_id_key in PYARROW_FIELD_ID_KEYS:
if field_id_str := field.metadata.get(pyarrow_field_id_key):
return int(field_id_str.decode())
if field.metadata is not None:
for pyarrow_field_id_key in PYARROW_FIELD_ID_KEYS:
if field_id_str := field.metadata.get(pyarrow_field_id_key):
return int(field_id_str.decode())
return None


def _get_field_doc(field: pa.Field) -> Optional[str]:
for pyarrow_doc_key in PYARROW_FIELD_DOC_KEYS:
if doc_str := field.metadata.get(pyarrow_doc_key):
return doc_str.decode()
if field.metadata is not None:
for pyarrow_doc_key in PYARROW_FIELD_DOC_KEYS:
if doc_str := field.metadata.get(pyarrow_doc_key):
return doc_str.decode()
return None


class _ConvertToIceberg(PyArrowSchemaVisitor[Union[IcebergType, Schema]]):
class _ConvertToIceberg(PyArrowSchemaVisitor[IcebergType]):
projected_schema: Union[Schema, ListType, MapType, None]
match_with_field_name: bool
ignore_unprojectable_fields: bool
projected_schema_stack: List[Tuple[Schema | ListType | MapType | None, Optional[_Literal["key", "value"]]]]
next: Optional[_Literal["key", "value"]]

def _convert_fields(self, arrow_fields: Iterable[pa.Field], field_results: List[Optional[IcebergType]]) -> List[NestedField]:
fields = []
for i, field in enumerate(arrow_fields):
field_id = _get_field_id(field)
field_doc = _get_field_doc(field)
field_type = field_results[i]
if field_type is not None and field_id is not None:
fields.append(NestedField(field_id, field.name, field_type, required=not field.nullable, doc=field_doc))
ib_field: Optional[NestedField] = None
if field_type is not None:
if field_id is not None:
ib_field = NestedField(field_id, field.name, field_type, required=not field.nullable, doc=field_doc)
elif self.match_with_field_name:
if self.projected_schema is None and self.projected_schema_stack and self.ignore_unprojectable_fields:
continue
if not isinstance(self.projected_schema, Schema):
raise ValueError("projected_schema must be provided if match_with_field_name is set to True")
try:
projected_field = self.projected_schema.find_field(field.name)
except ValueError as e:
if self.ignore_unprojectable_fields:
continue
raise ValueError(
f"could not find a field that corresponds to {field.name} in projected schema {self.projected_schema}"
) from e
ib_field = NestedField(
projected_field.field_id, field.name, field_type, required=not field.nullable, doc=field_doc
)
if ib_field is not None:
fields.append(ib_field)
return fields

def schema(self, schema: pa.Schema, field_results: List[Optional[IcebergType]]) -> Schema:
return Schema(*self._convert_fields(schema, field_results))
def schema(self, schema: pa.Schema, field_results: List[Optional[IcebergType]]) -> StructType:
return StructType(*self._convert_fields(schema, field_results))

def before_field(self, field: pa.Field) -> None:
if not isinstance(field.type, (pa.StructType, pa.ListType, pa.MapType)):
return

projected_field: Optional[NestedField] = None

if isinstance(self.projected_schema, Schema):
field_id = _get_field_id(field)
if field_id is not None:
try:
projected_field = self.projected_schema.find_field(field_id)
except ValueError:
if not self.match_with_field_name:
raise
if projected_field is None and self.match_with_field_name:
try:
projected_field = self.projected_schema.find_field(field.name)
except ValueError:
if not self.ignore_unprojectable_fields:
raise
elif isinstance(self.projected_schema, ListType):
projected_field = self.projected_schema.element_field
elif isinstance(self.projected_schema, MapType):
if self.next == "key":
projected_field = self.projected_schema.key_field
elif self.next == "value":
projected_field = self.projected_schema.value_field
else:
raise AssertionError("should never get here")

self.projected_schema_stack.append((self.projected_schema, self.next))
inner_schema: Schema | ListType | MapType | None = None
if projected_field is not None:
field_type = projected_field.field_type
if isinstance(field_type, StructType):
inner_schema = Schema(*field_type.fields)
else:
if isinstance(field_type, (ListType, MapType)):
inner_schema = field_type

self.projected_schema = inner_schema
self.next = "key" if isinstance(field.type, pa.MapType) else None

def after_field(self, field: pa.Field) -> None:
if self.next == "key":
self.next = "value"
elif self.next == "value":
self.next = None
if not isinstance(field.type, (pa.StructType, pa.ListType, pa.MapType)):
return
(self.projected_schema, self.next) = self.projected_schema_stack.pop()

def struct(self, struct: pa.StructType, field_results: List[Optional[IcebergType]]) -> IcebergType:
return StructType(*self._convert_fields(struct, field_results))

def list(self, list_type: pa.ListType, element_result: Optional[IcebergType]) -> Optional[IcebergType]:
element_field = list_type.value_field
element_id = _get_field_id(element_field)
if element_result is not None and element_id is not None:
return ListType(element_id, element_result, element_required=not element_field.nullable)
if element_result is not None:
ib_type: ListType
if element_id is not None:
ib_type = ListType(element_id, element_result, element_required=not element_field.nullable)
elif self.match_with_field_name:
if self.projected_schema is None and self.projected_schema_stack and self.ignore_unprojectable_fields:
return None
if not isinstance(self.projected_schema, ListType):
raise ValueError("projected_schema must be provided if match_with_field_name is set to True")
ib_type = ListType(self.projected_schema.element_id, element_result, element_required=not element_field.nullable)
else:
raise ValueError("match_with_field_name is set to False and elemnt_id is unknown")
return ib_type

return None

def map(
Expand All @@ -719,8 +819,26 @@ def map(
key_id = _get_field_id(key_field)
value_field = map_type.item_field
value_id = _get_field_id(value_field)
if key_result is not None and value_result is not None and key_id is not None and value_id is not None:
return MapType(key_id, key_result, value_id, value_result, value_required=not value_field.nullable)
if key_result is not None and value_result is not None:
ib_type: MapType
if key_id is not None and value_id is not None:
ib_type = MapType(key_id, key_result, value_id, value_result, value_required=not value_field.nullable)
elif self.match_with_field_name:
if self.projected_schema is None and self.projected_schema_stack and self.ignore_unprojectable_fields:
return None
if not isinstance(self.projected_schema, MapType):
raise ValueError("projected_schema must be provided if match_with_field_name is set to True")
ib_type = MapType(
self.projected_schema.key_id,
key_result,
self.projected_schema.value_id,
value_result,
value_required=not value_field.nullable,
)
else:
raise ValueError("match_with_field_name is set to False and either key_id or value_id is unknown")
return ib_type

return None

def primitive(self, primitive: pa.DataType) -> IcebergType:
Expand Down Expand Up @@ -758,6 +876,15 @@ def primitive(self, primitive: pa.DataType) -> IcebergType:

raise TypeError(f"Unsupported type: {primitive}")

def __init__(
self, projected_schema: Optional[Schema] = None, match_with_field_name: bool = False, ignore_unprojectable_fields: bool = False
) -> None:
self.projected_schema = projected_schema
self.match_with_field_name = match_with_field_name
self.ignore_unprojectable_fields = ignore_unprojectable_fields
self.projected_schema_stack = []
self.next = None


def _task_to_table(
fs: FileSystem,
Expand All @@ -769,6 +896,9 @@ def _task_to_table(
case_sensitive: bool,
row_counts: List[int],
limit: Optional[int] = None,
*,
match_with_field_name: bool = False,
ignore_unprojectable_fields: bool = False,
) -> Optional[pa.Table]:
if limit and sum(row_counts) >= limit:
return None
Expand All @@ -781,9 +911,16 @@ def _task_to_table(
schema_raw = None
if metadata := physical_schema.metadata:
schema_raw = metadata.get(ICEBERG_SCHEMA)
# TODO: if field_ids are not present, Name Mapping should be implemented to look them up in the table schema,
# see https://github.com/apache/iceberg/issues/7451
file_schema = Schema.model_validate_json(schema_raw) if schema_raw is not None else pyarrow_to_schema(physical_schema)
file_schema = (
Schema.model_validate_json(schema_raw)
if schema_raw is not None
else pyarrow_to_schema(
physical_schema,
projected_schema,
match_with_field_name=match_with_field_name,
ignore_unprojectable_fields=ignore_unprojectable_fields,
)
)

pyarrow_filter = None
if bound_row_filter is not AlwaysTrue():
Expand Down Expand Up @@ -840,7 +977,9 @@ def _task_to_table(

row_counts.append(len(arrow_table))

return to_requested_schema(projected_schema, file_project_schema, arrow_table)
return to_requested_schema(
projected_schema, file_project_schema, arrow_table, match_with_field_name=match_with_field_name
)


def _read_all_delete_files(fs: FileSystem, tasks: Iterable[FileScanTask]) -> Dict[str, List[ChunkedArray]]:
Expand Down Expand Up @@ -868,6 +1007,9 @@ def project_table(
projected_schema: Schema,
case_sensitive: bool = True,
limit: Optional[int] = None,
*,
match_with_field_name: bool = False,
ignore_unprojectable_fields: bool = False,
) -> pa.Table:
"""Resolve the right columns based on the identifier.
Expand Down Expand Up @@ -918,6 +1060,8 @@ def project_table(
case_sensitive,
row_counts,
limit,
match_with_field_name=match_with_field_name,
ignore_unprojectable_fields=ignore_unprojectable_fields
)
for task in tasks
]
Expand Down Expand Up @@ -949,8 +1093,12 @@ def project_table(
return result


def to_requested_schema(requested_schema: Schema, file_schema: Schema, table: pa.Table) -> pa.Table:
struct_array = visit_with_partner(requested_schema, table, ArrowProjectionVisitor(file_schema), ArrowAccessor(file_schema))
def to_requested_schema(
requested_schema: Schema, file_schema: Schema, table: pa.Table, *, match_with_field_name: bool = False
) -> pa.Table:
struct_array = visit_with_partner(
requested_schema, table, ArrowProjectionVisitor(file_schema), ArrowAccessor(file_schema, match_with_field_name)
)

arrays = []
fields = []
Expand Down Expand Up @@ -1027,19 +1175,26 @@ def primitive(self, _: PrimitiveType, array: Optional[pa.Array]) -> Optional[pa.
class ArrowAccessor(PartnerAccessor[pa.Array]):
file_schema: Schema

def __init__(self, file_schema: Schema):
def __init__(self, file_schema: Schema, match_with_field_name: bool = False):
self.file_schema = file_schema
self.match_with_field_name = match_with_field_name

def schema_partner(self, partner: Optional[pa.Array]) -> Optional[pa.Array]:
return partner

def field_partner(self, partner_struct: Optional[pa.Array], field_id: int, _: str) -> Optional[pa.Array]:
def field_partner(self, partner_struct: Optional[pa.Array], field_id: int, field_name: str) -> Optional[pa.Array]:
if partner_struct:
# use the field name from the file schema
try:
name = self.file_schema.find_field(field_id).name
except ValueError:
return None
if self.match_with_field_name:
try:
name = self.file_schema.find_field(field_name).name
except ValueError:
return None
else:
return None

if isinstance(partner_struct, pa.StructArray):
return partner_struct.field(name)
Expand Down
16 changes: 9 additions & 7 deletions python/pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -860,7 +860,7 @@ def plan_files(self) -> Iterable[FileScanTask]:
for data_entry in data_entries
]

def to_arrow(self) -> pa.Table:
def to_arrow(self, *, match_with_field_name: bool = False, ignore_unprojectable_fields: bool = False) -> pa.Table:
from pyiceberg.io.pyarrow import project_table

return project_table(
Expand All @@ -870,23 +870,25 @@ def to_arrow(self) -> pa.Table:
self.projection(),
case_sensitive=self.case_sensitive,
limit=self.limit,
match_with_field_name=match_with_field_name,
ignore_unprojectable_fields=ignore_unprojectable_fields,
)

def to_pandas(self, **kwargs: Any) -> pd.DataFrame:
return self.to_arrow().to_pandas(**kwargs)
def to_pandas(self, *, match_with_field_name: bool = False, ignore_unprojectable_fields: bool = False, **kwargs: Any) -> pd.DataFrame:
return self.to_arrow(match_with_field_name=match_with_field_name, ignore_unprojectable_fields=ignore_unprojectable_fields).to_pandas(**kwargs)

def to_duckdb(self, table_name: str, connection: Optional[DuckDBPyConnection] = None) -> DuckDBPyConnection:
def to_duckdb(self, table_name: str, connection: Optional[DuckDBPyConnection] = None, *, match_with_field_name: bool = False) -> DuckDBPyConnection:
import duckdb

con = connection or duckdb.connect(database=":memory:")
con.register(table_name, self.to_arrow())
con.register(table_name, self.to_arrow(match_with_field_name=match_with_field_name, ignore_unprojectable_fields=ignore_unprojectable_fields))

return con

def to_ray(self) -> ray.data.dataset.Dataset:
def to_ray(self, *, match_with_field_name: bool = False, ignore_unprojectable_fields: bool = False) -> ray.data.dataset.Dataset:
import ray

return ray.data.from_arrow(self.to_arrow())
return ray.data.from_arrow(self.to_arrow(match_with_field_name=match_with_field_name, ignore_unprojectable_fields=ignore_unprojectable_fields))


class UpdateSchema:
Expand Down
Loading

0 comments on commit 9eeacbf

Please sign in to comment.