-
Notifications
You must be signed in to change notification settings - Fork 2.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Python: allow projection of Iceberg fields to pyarrow table schema with names #8144
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -39,6 +39,7 @@ | |
Iterable, | ||
Iterator, | ||
List, | ||
Literal as _Literal, | ||
Optional, | ||
Set, | ||
Tuple, | ||
|
@@ -575,9 +576,16 @@ def _combine_positional_deletes(positional_deletes: List[pa.ChunkedArray], rows: | |
return np.setdiff1d(np.arange(rows), all_chunks, assume_unique=False) | ||
|
||
|
||
def pyarrow_to_schema(schema: pa.Schema) -> Schema: | ||
visitor = _ConvertToIceberg() | ||
return visit_pyarrow(schema, visitor) | ||
def pyarrow_to_schema( | ||
schema: pa.Schema, | ||
projected_schema: Optional[Schema] = None, | ||
match_with_field_name: bool = False, | ||
ignore_unprojectable_fields: bool = False, | ||
) -> Schema: | ||
visitor = _ConvertToIceberg(projected_schema, match_with_field_name, ignore_unprojectable_fields) | ||
ib_schema = visit_pyarrow(schema, visitor) | ||
assert isinstance(ib_schema, StructType) | ||
return Schema(*ib_schema.fields) | ||
|
||
|
||
@singledispatch | ||
|
@@ -675,41 +683,141 @@ def primitive(self, primitive: pa.DataType) -> Optional[T]: | |
|
||
|
||
def _get_field_id(field: pa.Field) -> Optional[int]: | ||
for pyarrow_field_id_key in PYARROW_FIELD_ID_KEYS: | ||
if field_id_str := field.metadata.get(pyarrow_field_id_key): | ||
return int(field_id_str.decode()) | ||
if field.metadata is not None: | ||
for pyarrow_field_id_key in PYARROW_FIELD_ID_KEYS: | ||
if field_id_str := field.metadata.get(pyarrow_field_id_key): | ||
return int(field_id_str.decode()) | ||
return None | ||
|
||
|
||
def _get_field_doc(field: pa.Field) -> Optional[str]: | ||
for pyarrow_doc_key in PYARROW_FIELD_DOC_KEYS: | ||
if doc_str := field.metadata.get(pyarrow_doc_key): | ||
return doc_str.decode() | ||
if field.metadata is not None: | ||
for pyarrow_doc_key in PYARROW_FIELD_DOC_KEYS: | ||
if doc_str := field.metadata.get(pyarrow_doc_key): | ||
return doc_str.decode() | ||
return None | ||
|
||
|
||
class _ConvertToIceberg(PyArrowSchemaVisitor[Union[IcebergType, Schema]]): | ||
class _ConvertToIceberg(PyArrowSchemaVisitor[IcebergType]): | ||
projected_schema: Union[Schema, ListType, MapType, None] | ||
match_with_field_name: bool | ||
ignore_unprojectable_fields: bool | ||
projected_schema_stack: List[Tuple[Schema | ListType | MapType | None, Optional[_Literal["key", "value"]]]] | ||
next: Optional[_Literal["key", "value"]] | ||
|
||
def _convert_fields(self, arrow_fields: Iterable[pa.Field], field_results: List[Optional[IcebergType]]) -> List[NestedField]: | ||
fields = [] | ||
for i, field in enumerate(arrow_fields): | ||
field_id = _get_field_id(field) | ||
field_doc = _get_field_doc(field) | ||
field_type = field_results[i] | ||
if field_type is not None and field_id is not None: | ||
fields.append(NestedField(field_id, field.name, field_type, required=not field.nullable, doc=field_doc)) | ||
ib_field: Optional[NestedField] = None | ||
if field_type is not None: | ||
if field_id is not None: | ||
ib_field = NestedField(field_id, field.name, field_type, required=not field.nullable, doc=field_doc) | ||
elif self.match_with_field_name: | ||
if self.projected_schema is None and self.projected_schema_stack and self.ignore_unprojectable_fields: | ||
continue | ||
if not isinstance(self.projected_schema, Schema): | ||
raise ValueError("projected_schema must be provided if match_with_field_name is set to True") | ||
try: | ||
projected_field = self.projected_schema.find_field(field.name) | ||
except ValueError as e: | ||
if self.ignore_unprojectable_fields: | ||
continue | ||
raise ValueError( | ||
f"could not find a field that corresponds to {field.name} in projected schema {self.projected_schema}" | ||
) from e | ||
ib_field = NestedField( | ||
projected_field.field_id, field.name, field_type, required=not field.nullable, doc=field_doc | ||
) | ||
if ib_field is not None: | ||
fields.append(ib_field) | ||
return fields | ||
|
||
def schema(self, schema: pa.Schema, field_results: List[Optional[IcebergType]]) -> Schema: | ||
return Schema(*self._convert_fields(schema, field_results)) | ||
def schema(self, schema: pa.Schema, field_results: List[Optional[IcebergType]]) -> StructType: | ||
return StructType(*self._convert_fields(schema, field_results)) | ||
|
||
def before_field(self, field: pa.Field) -> None: | ||
if not isinstance(field.type, (pa.StructType, pa.ListType, pa.MapType)): | ||
return | ||
|
||
self.projected_schema_stack.append((self.projected_schema, self.next)) | ||
|
||
projected_field: Optional[NestedField] = None | ||
|
||
if isinstance(self.projected_schema, Schema): | ||
field_id = _get_field_id(field) | ||
if field_id is not None: | ||
try: | ||
projected_field = self.projected_schema.find_field(field_id) | ||
except ValueError: | ||
if not self.match_with_field_name: | ||
raise | ||
if projected_field is None and self.match_with_field_name: | ||
try: | ||
projected_field = self.projected_schema.find_field(field.name) | ||
except ValueError: | ||
if not self.ignore_unprojectable_fields: | ||
raise | ||
elif isinstance(self.projected_schema, ListType): | ||
projected_field = self.projected_schema.element_field | ||
elif isinstance(self.projected_schema, MapType): | ||
if self.next == "key": | ||
projected_field = self.projected_schema.key_field | ||
elif self.next == "value": | ||
projected_field = self.projected_schema.value_field | ||
else: | ||
raise AssertionError("should never get here") | ||
|
||
inner_schema: Schema | ListType | MapType | None = None | ||
next_: Optional[str] = None | ||
|
||
if projected_field is not None: | ||
field_type = projected_field.field_type | ||
if isinstance(field_type, StructType): | ||
inner_schema = Schema(*field_type.fields) | ||
else: | ||
if isinstance(field_type, ListType): | ||
inner_schema = field_type | ||
elif isinstance(field_type, MapType): | ||
inner_schema = field_type | ||
next_ = "key" | ||
else: | ||
if isinstance(field.type, pa.MapType): | ||
next_ = "key" | ||
|
||
self.projected_schema = inner_schema | ||
self.next = next_ | ||
|
||
def after_field(self, field: pa.Field) -> None: | ||
if isinstance(field.type, (pa.StructType, pa.ListType, pa.MapType)): | ||
(self.projected_schema, self.next) = self.projected_schema_stack.pop() | ||
if self.next == "key": | ||
self.next = "value" | ||
elif self.next == "value": | ||
self.next = None | ||
|
||
def struct(self, struct: pa.StructType, field_results: List[Optional[IcebergType]]) -> IcebergType: | ||
return StructType(*self._convert_fields(struct, field_results)) | ||
|
||
def list(self, list_type: pa.ListType, element_result: Optional[IcebergType]) -> Optional[IcebergType]: | ||
element_field = list_type.value_field | ||
element_id = _get_field_id(element_field) | ||
if element_result is not None and element_id is not None: | ||
return ListType(element_id, element_result, element_required=not element_field.nullable) | ||
if element_result is not None: | ||
ib_type: ListType | ||
if element_id is not None: | ||
ib_type = ListType(element_id, element_result, element_required=not element_field.nullable) | ||
elif self.match_with_field_name: | ||
if self.projected_schema is None and self.projected_schema_stack and self.ignore_unprojectable_fields: | ||
return None | ||
if not isinstance(self.projected_schema, ListType): | ||
raise ValueError("projected_schema must be provided if match_with_field_name is set to True") | ||
ib_type = ListType(self.projected_schema.element_id, element_result, element_required=not element_field.nullable) | ||
else: | ||
raise ValueError("match_with_field_name is set to False and elemnt_id is unknown") | ||
return ib_type | ||
|
||
return None | ||
|
||
def map( | ||
|
@@ -719,8 +827,26 @@ def map( | |
key_id = _get_field_id(key_field) | ||
value_field = map_type.item_field | ||
value_id = _get_field_id(value_field) | ||
if key_result is not None and value_result is not None and key_id is not None and value_id is not None: | ||
return MapType(key_id, key_result, value_id, value_result, value_required=not value_field.nullable) | ||
if key_result is not None and value_result is not None: | ||
ib_type: MapType | ||
if key_id is not None and value_id is not None: | ||
ib_type = MapType(key_id, key_result, value_id, value_result, value_required=not value_field.nullable) | ||
elif self.match_with_field_name: | ||
if self.projected_schema is None and self.projected_schema_stack and self.ignore_unprojectable_fields: | ||
return None | ||
if not isinstance(self.projected_schema, MapType): | ||
raise ValueError("projected_schema must be provided if match_with_field_name is set to True") | ||
ib_type = MapType( | ||
self.projected_schema.key_id, | ||
key_result, | ||
self.projected_schema.value_id, | ||
value_result, | ||
value_required=not value_field.nullable, | ||
) | ||
else: | ||
raise ValueError("match_with_field_name is set to False and either key_id or value_id is unknown") | ||
return ib_type | ||
|
||
return None | ||
|
||
def primitive(self, primitive: pa.DataType) -> IcebergType: | ||
|
@@ -758,6 +884,15 @@ def primitive(self, primitive: pa.DataType) -> IcebergType: | |
|
||
raise TypeError(f"Unsupported type: {primitive}") | ||
|
||
def __init__( | ||
self, projected_schema: Optional[Schema] = None, match_with_field_name: bool = False, ignore_unprojectable_fields: bool = False | ||
) -> None: | ||
self.projected_schema = projected_schema | ||
self.match_with_field_name = match_with_field_name | ||
self.ignore_unprojectable_fields = ignore_unprojectable_fields | ||
self.projected_schema_stack = [] | ||
self.next = None | ||
|
||
|
||
def _task_to_table( | ||
fs: FileSystem, | ||
|
@@ -769,6 +904,9 @@ def _task_to_table( | |
case_sensitive: bool, | ||
row_counts: List[int], | ||
limit: Optional[int] = None, | ||
*, | ||
match_with_field_name: bool = False, | ||
ignore_unprojectable_fields: bool = False, | ||
) -> Optional[pa.Table]: | ||
if limit and sum(row_counts) >= limit: | ||
return None | ||
|
@@ -781,9 +919,16 @@ def _task_to_table( | |
schema_raw = None | ||
if metadata := physical_schema.metadata: | ||
schema_raw = metadata.get(ICEBERG_SCHEMA) | ||
# TODO: if field_ids are not present, Name Mapping should be implemented to look them up in the table schema, | ||
# see https://github.com/apache/iceberg/issues/7451 | ||
file_schema = Schema.model_validate_json(schema_raw) if schema_raw is not None else pyarrow_to_schema(physical_schema) | ||
file_schema = ( | ||
Schema.model_validate_json(schema_raw) | ||
if schema_raw is not None | ||
else pyarrow_to_schema( | ||
physical_schema, | ||
projected_schema, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we may want full table schema here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It doesn't have much to do with the pruning. What we need to acheive with |
||
match_with_field_name=match_with_field_name, | ||
ignore_unprojectable_fields=ignore_unprojectable_fields, | ||
) | ||
) | ||
|
||
pyarrow_filter = None | ||
if bound_row_filter is not AlwaysTrue(): | ||
|
@@ -840,7 +985,9 @@ def _task_to_table( | |
|
||
row_counts.append(len(arrow_table)) | ||
|
||
return to_requested_schema(projected_schema, file_project_schema, arrow_table) | ||
return to_requested_schema( | ||
projected_schema, file_project_schema, arrow_table, match_with_field_name=match_with_field_name | ||
) | ||
|
||
|
||
def _read_all_delete_files(fs: FileSystem, tasks: Iterable[FileScanTask]) -> Dict[str, List[ChunkedArray]]: | ||
|
@@ -868,6 +1015,9 @@ def project_table( | |
projected_schema: Schema, | ||
case_sensitive: bool = True, | ||
limit: Optional[int] = None, | ||
*, | ||
match_with_field_name: bool = False, | ||
ignore_unprojectable_fields: bool = False, | ||
) -> pa.Table: | ||
"""Resolve the right columns based on the identifier. | ||
|
||
|
@@ -918,6 +1068,8 @@ def project_table( | |
case_sensitive, | ||
row_counts, | ||
limit, | ||
match_with_field_name=match_with_field_name, | ||
ignore_unprojectable_fields=ignore_unprojectable_fields | ||
) | ||
for task in tasks | ||
] | ||
|
@@ -949,8 +1101,12 @@ def project_table( | |
return result | ||
|
||
|
||
def to_requested_schema(requested_schema: Schema, file_schema: Schema, table: pa.Table) -> pa.Table: | ||
struct_array = visit_with_partner(requested_schema, table, ArrowProjectionVisitor(file_schema), ArrowAccessor(file_schema)) | ||
def to_requested_schema( | ||
requested_schema: Schema, file_schema: Schema, table: pa.Table, *, match_with_field_name: bool = False | ||
) -> pa.Table: | ||
struct_array = visit_with_partner( | ||
requested_schema, table, ArrowProjectionVisitor(file_schema), ArrowAccessor(file_schema, match_with_field_name) | ||
) | ||
|
||
arrays = [] | ||
fields = [] | ||
|
@@ -1027,19 +1183,26 @@ def primitive(self, _: PrimitiveType, array: Optional[pa.Array]) -> Optional[pa. | |
class ArrowAccessor(PartnerAccessor[pa.Array]): | ||
file_schema: Schema | ||
|
||
def __init__(self, file_schema: Schema): | ||
def __init__(self, file_schema: Schema, match_with_field_name: bool = False): | ||
self.file_schema = file_schema | ||
self.match_with_field_name = match_with_field_name | ||
|
||
def schema_partner(self, partner: Optional[pa.Array]) -> Optional[pa.Array]: | ||
return partner | ||
|
||
def field_partner(self, partner_struct: Optional[pa.Array], field_id: int, _: str) -> Optional[pa.Array]: | ||
def field_partner(self, partner_struct: Optional[pa.Array], field_id: int, field_name: str) -> Optional[pa.Array]: | ||
if partner_struct: | ||
# use the field name from the file schema | ||
try: | ||
name = self.file_schema.find_field(field_id).name | ||
except ValueError: | ||
return None | ||
if self.match_with_field_name: | ||
try: | ||
name = self.file_schema.find_field(field_name).name | ||
except ValueError: | ||
return None | ||
else: | ||
return None | ||
|
||
if isinstance(partner_struct, pa.StructArray): | ||
return partner_struct.field(name) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please replace it with a
ValueError
or similar things? According to comments in other PRs, we try to avoidassert
outsidetests/
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not intended to warn the user for wrong usage. This is a type guard for mypy and I believe it's valid.