diff --git a/python/pyiceberg/io/pyarrow.py b/python/pyiceberg/io/pyarrow.py index 83bd79d2fee0..9d6a11a9e09f 100644 --- a/python/pyiceberg/io/pyarrow.py +++ b/python/pyiceberg/io/pyarrow.py @@ -39,6 +39,7 @@ Iterable, Iterator, List, + Literal as _Literal, Optional, Set, Tuple, @@ -575,9 +576,16 @@ def _combine_positional_deletes(positional_deletes: List[pa.ChunkedArray], rows: return np.setdiff1d(np.arange(rows), all_chunks, assume_unique=False) -def pyarrow_to_schema(schema: pa.Schema) -> Schema: - visitor = _ConvertToIceberg() - return visit_pyarrow(schema, visitor) +def pyarrow_to_schema( + schema: pa.Schema, + projected_schema: Optional[Schema] = None, + match_with_field_name: bool = False, + ignore_unprojectable_fields: bool = False, +) -> Schema: + visitor = _ConvertToIceberg(projected_schema, match_with_field_name, ignore_unprojectable_fields) + ib_schema = visit_pyarrow(schema, visitor) + assert isinstance(ib_schema, StructType) + return Schema(*ib_schema.fields) @singledispatch @@ -675,32 +683,120 @@ def primitive(self, primitive: pa.DataType) -> Optional[T]: def _get_field_id(field: pa.Field) -> Optional[int]: - for pyarrow_field_id_key in PYARROW_FIELD_ID_KEYS: - if field_id_str := field.metadata.get(pyarrow_field_id_key): - return int(field_id_str.decode()) + if field.metadata is not None: + for pyarrow_field_id_key in PYARROW_FIELD_ID_KEYS: + if field_id_str := field.metadata.get(pyarrow_field_id_key): + return int(field_id_str.decode()) return None def _get_field_doc(field: pa.Field) -> Optional[str]: - for pyarrow_doc_key in PYARROW_FIELD_DOC_KEYS: - if doc_str := field.metadata.get(pyarrow_doc_key): - return doc_str.decode() + if field.metadata is not None: + for pyarrow_doc_key in PYARROW_FIELD_DOC_KEYS: + if doc_str := field.metadata.get(pyarrow_doc_key): + return doc_str.decode() return None -class _ConvertToIceberg(PyArrowSchemaVisitor[Union[IcebergType, Schema]]): +class _ConvertToIceberg(PyArrowSchemaVisitor[IcebergType]): + projected_schema: Union[Schema, ListType, MapType, None] + match_with_field_name: bool + ignore_unprojectable_fields: bool + projected_schema_stack: List[Tuple[Schema | ListType | MapType | None, Optional[_Literal["key", "value"]]]] + next: Optional[_Literal["key", "value"]] + def _convert_fields(self, arrow_fields: Iterable[pa.Field], field_results: List[Optional[IcebergType]]) -> List[NestedField]: fields = [] for i, field in enumerate(arrow_fields): field_id = _get_field_id(field) field_doc = _get_field_doc(field) field_type = field_results[i] - if field_type is not None and field_id is not None: - fields.append(NestedField(field_id, field.name, field_type, required=not field.nullable, doc=field_doc)) + ib_field: Optional[NestedField] = None + if field_type is not None: + if field_id is not None: + ib_field = NestedField(field_id, field.name, field_type, required=not field.nullable, doc=field_doc) + elif self.match_with_field_name: + if self.projected_schema is None and self.projected_schema_stack and self.ignore_unprojectable_fields: + continue + if not isinstance(self.projected_schema, Schema): + raise ValueError("projected_schema must be provided if match_with_field_name is set to True") + try: + projected_field = self.projected_schema.find_field(field.name) + except ValueError as e: + if self.ignore_unprojectable_fields: + continue + raise ValueError( + f"could not find a field that corresponds to {field.name} in projected schema {self.projected_schema}" + ) from e + ib_field = NestedField( + projected_field.field_id, field.name, field_type, required=not field.nullable, doc=field_doc + ) + if ib_field is not None: + fields.append(ib_field) return fields - def schema(self, schema: pa.Schema, field_results: List[Optional[IcebergType]]) -> Schema: - return Schema(*self._convert_fields(schema, field_results)) + def schema(self, schema: pa.Schema, field_results: List[Optional[IcebergType]]) -> StructType: + return StructType(*self._convert_fields(schema, field_results)) + + def before_field(self, field: pa.Field) -> None: + if not isinstance(field.type, (pa.StructType, pa.ListType, pa.MapType)): + return + + self.projected_schema_stack.append((self.projected_schema, self.next)) + + projected_field: Optional[NestedField] = None + + if isinstance(self.projected_schema, Schema): + field_id = _get_field_id(field) + if field_id is not None: + try: + projected_field = self.projected_schema.find_field(field_id) + except ValueError: + if not self.match_with_field_name: + raise + if projected_field is None and self.match_with_field_name: + try: + projected_field = self.projected_schema.find_field(field.name) + except ValueError: + if not self.ignore_unprojectable_fields: + raise + elif isinstance(self.projected_schema, ListType): + projected_field = self.projected_schema.element_field + elif isinstance(self.projected_schema, MapType): + if self.next == "key": + projected_field = self.projected_schema.key_field + elif self.next == "value": + projected_field = self.projected_schema.value_field + else: + raise AssertionError("should never get here") + + inner_schema: Schema | ListType | MapType | None = None + next_: Optional[str] = None + + if projected_field is not None: + field_type = projected_field.field_type + if isinstance(field_type, StructType): + inner_schema = Schema(*field_type.fields) + else: + if isinstance(field_type, ListType): + inner_schema = field_type + elif isinstance(field_type, MapType): + inner_schema = field_type + next_ = "key" + else: + if isinstance(field.type, pa.MapType): + next_ = "key" + + self.projected_schema = inner_schema + self.next = next_ + + def after_field(self, field: pa.Field) -> None: + if isinstance(field.type, (pa.StructType, pa.ListType, pa.MapType)): + (self.projected_schema, self.next) = self.projected_schema_stack.pop() + if self.next == "key": + self.next = "value" + elif self.next == "value": + self.next = None def struct(self, struct: pa.StructType, field_results: List[Optional[IcebergType]]) -> IcebergType: return StructType(*self._convert_fields(struct, field_results)) @@ -708,8 +804,20 @@ def struct(self, struct: pa.StructType, field_results: List[Optional[IcebergType def list(self, list_type: pa.ListType, element_result: Optional[IcebergType]) -> Optional[IcebergType]: element_field = list_type.value_field element_id = _get_field_id(element_field) - if element_result is not None and element_id is not None: - return ListType(element_id, element_result, element_required=not element_field.nullable) + if element_result is not None: + ib_type: ListType + if element_id is not None: + ib_type = ListType(element_id, element_result, element_required=not element_field.nullable) + elif self.match_with_field_name: + if self.projected_schema is None and self.projected_schema_stack and self.ignore_unprojectable_fields: + return None + if not isinstance(self.projected_schema, ListType): + raise ValueError("projected_schema must be provided if match_with_field_name is set to True") + ib_type = ListType(self.projected_schema.element_id, element_result, element_required=not element_field.nullable) + else: + raise ValueError("match_with_field_name is set to False and elemnt_id is unknown") + return ib_type + return None def map( @@ -719,8 +827,26 @@ def map( key_id = _get_field_id(key_field) value_field = map_type.item_field value_id = _get_field_id(value_field) - if key_result is not None and value_result is not None and key_id is not None and value_id is not None: - return MapType(key_id, key_result, value_id, value_result, value_required=not value_field.nullable) + if key_result is not None and value_result is not None: + ib_type: MapType + if key_id is not None and value_id is not None: + ib_type = MapType(key_id, key_result, value_id, value_result, value_required=not value_field.nullable) + elif self.match_with_field_name: + if self.projected_schema is None and self.projected_schema_stack and self.ignore_unprojectable_fields: + return None + if not isinstance(self.projected_schema, MapType): + raise ValueError("projected_schema must be provided if match_with_field_name is set to True") + ib_type = MapType( + self.projected_schema.key_id, + key_result, + self.projected_schema.value_id, + value_result, + value_required=not value_field.nullable, + ) + else: + raise ValueError("match_with_field_name is set to False and either key_id or value_id is unknown") + return ib_type + return None def primitive(self, primitive: pa.DataType) -> IcebergType: @@ -758,6 +884,15 @@ def primitive(self, primitive: pa.DataType) -> IcebergType: raise TypeError(f"Unsupported type: {primitive}") + def __init__( + self, projected_schema: Optional[Schema] = None, match_with_field_name: bool = False, ignore_unprojectable_fields: bool = False + ) -> None: + self.projected_schema = projected_schema + self.match_with_field_name = match_with_field_name + self.ignore_unprojectable_fields = ignore_unprojectable_fields + self.projected_schema_stack = [] + self.next = None + def _task_to_table( fs: FileSystem, @@ -769,6 +904,9 @@ def _task_to_table( case_sensitive: bool, row_counts: List[int], limit: Optional[int] = None, + *, + match_with_field_name: bool = False, + ignore_unprojectable_fields: bool = False, ) -> Optional[pa.Table]: if limit and sum(row_counts) >= limit: return None @@ -781,9 +919,16 @@ def _task_to_table( schema_raw = None if metadata := physical_schema.metadata: schema_raw = metadata.get(ICEBERG_SCHEMA) - # TODO: if field_ids are not present, Name Mapping should be implemented to look them up in the table schema, - # see https://github.com/apache/iceberg/issues/7451 - file_schema = Schema.model_validate_json(schema_raw) if schema_raw is not None else pyarrow_to_schema(physical_schema) + file_schema = ( + Schema.model_validate_json(schema_raw) + if schema_raw is not None + else pyarrow_to_schema( + physical_schema, + projected_schema, + match_with_field_name=match_with_field_name, + ignore_unprojectable_fields=ignore_unprojectable_fields, + ) + ) pyarrow_filter = None if bound_row_filter is not AlwaysTrue(): @@ -840,7 +985,9 @@ def _task_to_table( row_counts.append(len(arrow_table)) - return to_requested_schema(projected_schema, file_project_schema, arrow_table) + return to_requested_schema( + projected_schema, file_project_schema, arrow_table, match_with_field_name=match_with_field_name + ) def _read_all_delete_files(fs: FileSystem, tasks: Iterable[FileScanTask]) -> Dict[str, List[ChunkedArray]]: @@ -868,6 +1015,9 @@ def project_table( projected_schema: Schema, case_sensitive: bool = True, limit: Optional[int] = None, + *, + match_with_field_name: bool = False, + ignore_unprojectable_fields: bool = False, ) -> pa.Table: """Resolve the right columns based on the identifier. @@ -918,6 +1068,8 @@ def project_table( case_sensitive, row_counts, limit, + match_with_field_name=match_with_field_name, + ignore_unprojectable_fields=ignore_unprojectable_fields ) for task in tasks ] @@ -949,8 +1101,12 @@ def project_table( return result -def to_requested_schema(requested_schema: Schema, file_schema: Schema, table: pa.Table) -> pa.Table: - struct_array = visit_with_partner(requested_schema, table, ArrowProjectionVisitor(file_schema), ArrowAccessor(file_schema)) +def to_requested_schema( + requested_schema: Schema, file_schema: Schema, table: pa.Table, *, match_with_field_name: bool = False +) -> pa.Table: + struct_array = visit_with_partner( + requested_schema, table, ArrowProjectionVisitor(file_schema), ArrowAccessor(file_schema, match_with_field_name) + ) arrays = [] fields = [] @@ -1027,19 +1183,26 @@ def primitive(self, _: PrimitiveType, array: Optional[pa.Array]) -> Optional[pa. class ArrowAccessor(PartnerAccessor[pa.Array]): file_schema: Schema - def __init__(self, file_schema: Schema): + def __init__(self, file_schema: Schema, match_with_field_name: bool = False): self.file_schema = file_schema + self.match_with_field_name = match_with_field_name def schema_partner(self, partner: Optional[pa.Array]) -> Optional[pa.Array]: return partner - def field_partner(self, partner_struct: Optional[pa.Array], field_id: int, _: str) -> Optional[pa.Array]: + def field_partner(self, partner_struct: Optional[pa.Array], field_id: int, field_name: str) -> Optional[pa.Array]: if partner_struct: # use the field name from the file schema try: name = self.file_schema.find_field(field_id).name except ValueError: - return None + if self.match_with_field_name: + try: + name = self.file_schema.find_field(field_name).name + except ValueError: + return None + else: + return None if isinstance(partner_struct, pa.StructArray): return partner_struct.field(name) diff --git a/python/pyiceberg/table/__init__.py b/python/pyiceberg/table/__init__.py index d24550cb7e27..70da11ae3016 100644 --- a/python/pyiceberg/table/__init__.py +++ b/python/pyiceberg/table/__init__.py @@ -860,7 +860,7 @@ def plan_files(self) -> Iterable[FileScanTask]: for data_entry in data_entries ] - def to_arrow(self) -> pa.Table: + def to_arrow(self, *, match_with_field_name: bool = False, ignore_unprojectable_fields: bool = False) -> pa.Table: from pyiceberg.io.pyarrow import project_table return project_table( @@ -870,23 +870,25 @@ def to_arrow(self) -> pa.Table: self.projection(), case_sensitive=self.case_sensitive, limit=self.limit, + match_with_field_name=match_with_field_name, + ignore_unprojectable_fields=ignore_unprojectable_fields, ) - def to_pandas(self, **kwargs: Any) -> pd.DataFrame: - return self.to_arrow().to_pandas(**kwargs) + def to_pandas(self, *, match_with_field_name: bool = False, ignore_unprojectable_fields: bool = False, **kwargs: Any) -> pd.DataFrame: + return self.to_arrow(match_with_field_name=match_with_field_name, ignore_unprojectable_fields=ignore_unprojectable_fields).to_pandas(**kwargs) - def to_duckdb(self, table_name: str, connection: Optional[DuckDBPyConnection] = None) -> DuckDBPyConnection: + def to_duckdb(self, table_name: str, connection: Optional[DuckDBPyConnection] = None, *, match_with_field_name: bool = False) -> DuckDBPyConnection: import duckdb con = connection or duckdb.connect(database=":memory:") - con.register(table_name, self.to_arrow()) + con.register(table_name, self.to_arrow(match_with_field_name=match_with_field_name, ignore_unprojectable_fields=ignore_unprojectable_fields)) return con - def to_ray(self) -> ray.data.dataset.Dataset: + def to_ray(self, *, match_with_field_name: bool = False, ignore_unprojectable_fields: bool = False) -> ray.data.dataset.Dataset: import ray - return ray.data.from_arrow(self.to_arrow()) + return ray.data.from_arrow(self.to_arrow(match_with_field_name=match_with_field_name, ignore_unprojectable_fields=ignore_unprojectable_fields)) class UpdateSchema: diff --git a/python/tests/io/test_pyarrow.py b/python/tests/io/test_pyarrow.py index 366eda53f576..04a46a033938 100644 --- a/python/tests/io/test_pyarrow.py +++ b/python/tests/io/test_pyarrow.py @@ -62,6 +62,7 @@ _read_deletes, expression_to_pyarrow, project_table, + pyarrow_to_schema, schema_to_pyarrow, ) from pyiceberg.manifest import DataFile, DataFileContent, FileFormat @@ -376,6 +377,65 @@ def test_schema_to_pyarrow_schema(table_schema_nested: Schema) -> None: assert repr(actual) == expected +def test_pyarrow_to_schema(table_schema_simple: Schema, table_schema_nested: Schema) -> None: + pa_schema_simple = pa.schema([ + pa.field("foo", pa.string(), nullable=True, metadata={"field_id": "1"}), + pa.field("bar", pa.int32(), nullable=False, metadata={"field_id": "2"}), + pa.field("baz", pa.bool_(), nullable=True, metadata={"field_id": "3"}), + ]) + projected = pyarrow_to_schema(pa_schema_simple, table_schema_simple) + assert projected == table_schema_simple.copy(update=dict(schema_id=0, identifier_field_ids=[])) + + +def test_pyarrow_to_schema_match_with_field_name(table_schema_simple: Schema, table_schema_nested: Schema) -> None: + pa_schema_simple = pa.schema([ + pa.field("foo", pa.string(), nullable=True), + pa.field("bar", pa.int32(), nullable=False), + pa.field("baz", pa.bool_(), nullable=True), + ]) + projected = pyarrow_to_schema(pa_schema_simple, table_schema_simple) + assert repr(projected) == repr(Schema()) + + projected = pyarrow_to_schema(pa_schema_simple, table_schema_simple, match_with_field_name=True) + assert projected == table_schema_simple.copy(update=dict(schema_id=0, identifier_field_ids=[])) + + pa_schema_nested = pa.schema([ + pa.field("foo", pa.string(), nullable=True), + pa.field("bar", pa.int32(), nullable=False), + pa.field("baz", pa.bool_(), nullable=True), + pa.field("qux", pa.list_(pa.field("item", pa.string(), nullable=False)), nullable=False), + pa.field("quux", pa.map_(pa.field("key", pa.string(), nullable=False), pa.field("value", pa.map_(pa.string(), pa.field("value", pa.int32(), nullable=False)), nullable=False)), nullable=False), + pa.field( + "location", + pa.list_( + pa.field( + "item", + pa.struct( + [ + pa.field("latitude", pa.float32(), nullable=True), + pa.field("longitude", pa.float32(), nullable=True), + ], + ), + nullable=False, + ) + ), + nullable=False, + ), + pa.field( + "person", + pa.struct( + [ + pa.field("name", pa.string(), nullable=True), + pa.field("age", pa.int32(), nullable=False), + ], + ), + nullable=True, + ), + ]) + projected = pyarrow_to_schema(pa_schema_nested, table_schema_nested, match_with_field_name=True) + assert projected == table_schema_nested.copy(update=dict(schema_id=0, identifier_field_ids=[])) + + def test_fixed_type_to_pyarrow() -> None: length = 22 iceberg_type = FixedType(length)