From 5039b5d70644bc06c98349090912c6e9066d3ea1 Mon Sep 17 00:00:00 2001 From: Fokko Driesprong Date: Mon, 8 Apr 2024 15:59:33 +0200 Subject: [PATCH] Change DataScan to accept Metadata and io (#581) * Change DataScan to accept Metadata and io For the partial deletes I want to do a scan on in memory metadata. Changing this API allows this. * fix name-mapping issue --------- Co-authored-by: HonahX --- pyiceberg/io/pyarrow.py | 26 ++++---- pyiceberg/table/__init__.py | 70 +++++++++------------- pyiceberg/table/metadata.py | 14 +++++ tests/integration/test_add_files.py | 9 +++ tests/io/test_pyarrow.py | 93 +++++++++++------------------ 5 files changed, 102 insertions(+), 110 deletions(-) diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index 1848fba787..74692f85b8 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -159,7 +159,7 @@ from pyiceberg.utils.truncate import truncate_upper_bound_binary_string, truncate_upper_bound_text_string if TYPE_CHECKING: - from pyiceberg.table import FileScanTask, Table + from pyiceberg.table import FileScanTask logger = logging.getLogger(__name__) @@ -1046,7 +1046,8 @@ def _read_all_delete_files(fs: FileSystem, tasks: Iterable[FileScanTask]) -> Dic def project_table( tasks: Iterable[FileScanTask], - table: Table, + table_metadata: TableMetadata, + io: FileIO, row_filter: BooleanExpression, projected_schema: Schema, case_sensitive: bool = True, @@ -1056,7 +1057,8 @@ def project_table( Args: tasks (Iterable[FileScanTask]): A URI or a path to a local file. - table (Table): The table that's being queried. + table_metadata (TableMetadata): The table metadata of the table that's being queried + io (FileIO): A FileIO to open streams to the object store row_filter (BooleanExpression): The expression for filtering rows. projected_schema (Schema): The output schema. case_sensitive (bool): Case sensitivity when looking up column names. @@ -1065,24 +1067,24 @@ def project_table( Raises: ResolveError: When an incompatible query is done. """ - scheme, netloc, _ = PyArrowFileIO.parse_location(table.location()) - if isinstance(table.io, PyArrowFileIO): - fs = table.io.fs_by_scheme(scheme, netloc) + scheme, netloc, _ = PyArrowFileIO.parse_location(table_metadata.location) + if isinstance(io, PyArrowFileIO): + fs = io.fs_by_scheme(scheme, netloc) else: try: from pyiceberg.io.fsspec import FsspecFileIO - if isinstance(table.io, FsspecFileIO): + if isinstance(io, FsspecFileIO): from pyarrow.fs import PyFileSystem - fs = PyFileSystem(FSSpecHandler(table.io.get_fs(scheme))) + fs = PyFileSystem(FSSpecHandler(io.get_fs(scheme))) else: - raise ValueError(f"Expected PyArrowFileIO or FsspecFileIO, got: {table.io}") + raise ValueError(f"Expected PyArrowFileIO or FsspecFileIO, got: {io}") except ModuleNotFoundError as e: # When FsSpec is not installed - raise ValueError(f"Expected PyArrowFileIO or FsspecFileIO, got: {table.io}") from e + raise ValueError(f"Expected PyArrowFileIO or FsspecFileIO, got: {io}") from e - bound_row_filter = bind(table.schema(), row_filter, case_sensitive=case_sensitive) + bound_row_filter = bind(table_metadata.schema(), row_filter, case_sensitive=case_sensitive) projected_field_ids = { id for id in projected_schema.field_ids if not isinstance(projected_schema.find_type(id), (MapType, ListType)) @@ -1101,7 +1103,7 @@ def project_table( deletes_per_file.get(task.file.file_path), case_sensitive, limit, - table.name_mapping(), + table_metadata.name_mapping(), ) for task in tasks ] diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index ac19c1a538..ea813176fc 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -103,7 +103,6 @@ ) from pyiceberg.table.name_mapping import ( NameMapping, - parse_mapping_from_json, update_mapping, ) from pyiceberg.table.refs import MAIN_BRANCH, SnapshotRef @@ -1215,7 +1214,8 @@ def scan( limit: Optional[int] = None, ) -> DataScan: return DataScan( - table=self, + table_metadata=self.metadata, + io=self.io, row_filter=row_filter, selected_fields=selected_fields, case_sensitive=case_sensitive, @@ -1312,10 +1312,7 @@ def update_schema(self, allow_incompatible_changes: bool = False, case_sensitive def name_mapping(self) -> Optional[NameMapping]: """Return the table's field-id NameMapping.""" - if name_mapping_json := self.properties.get(TableProperties.DEFAULT_NAME_MAPPING): - return parse_mapping_from_json(name_mapping_json) - else: - return None + return self.metadata.name_mapping() def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT) -> None: """ @@ -1468,7 +1465,8 @@ def _parse_row_filter(expr: Union[str, BooleanExpression]) -> BooleanExpression: class TableScan(ABC): - table: Table + table_metadata: TableMetadata + io: FileIO row_filter: BooleanExpression selected_fields: Tuple[str, ...] case_sensitive: bool @@ -1478,7 +1476,8 @@ class TableScan(ABC): def __init__( self, - table: Table, + table_metadata: TableMetadata, + io: FileIO, row_filter: Union[str, BooleanExpression] = ALWAYS_TRUE, selected_fields: Tuple[str, ...] = ("*",), case_sensitive: bool = True, @@ -1486,7 +1485,8 @@ def __init__( options: Properties = EMPTY_DICT, limit: Optional[int] = None, ): - self.table = table + self.table_metadata = table_metadata + self.io = io self.row_filter = _parse_row_filter(row_filter) self.selected_fields = selected_fields self.case_sensitive = case_sensitive @@ -1496,19 +1496,20 @@ def __init__( def snapshot(self) -> Optional[Snapshot]: if self.snapshot_id: - return self.table.snapshot_by_id(self.snapshot_id) - return self.table.current_snapshot() + return self.table_metadata.snapshot_by_id(self.snapshot_id) + return self.table_metadata.current_snapshot() def projection(self) -> Schema: - current_schema = self.table.schema() + current_schema = self.table_metadata.schema() if self.snapshot_id is not None: - snapshot = self.table.snapshot_by_id(self.snapshot_id) + snapshot = self.table_metadata.snapshot_by_id(self.snapshot_id) if snapshot is not None: if snapshot.schema_id is not None: - snapshot_schema = self.table.schemas().get(snapshot.schema_id) - if snapshot_schema is not None: - current_schema = snapshot_schema - else: + try: + current_schema = next( + schema for schema in self.table_metadata.schemas if schema.schema_id == snapshot.schema_id + ) + except StopIteration: warnings.warn(f"Metadata does not contain schema with id: {snapshot.schema_id}") else: raise ValueError(f"Snapshot not found: {self.snapshot_id}") @@ -1534,7 +1535,7 @@ def update(self: S, **overrides: Any) -> S: def use_ref(self: S, name: str) -> S: if self.snapshot_id: raise ValueError(f"Cannot override ref, already set snapshot id={self.snapshot_id}") - if snapshot := self.table.snapshot_by_name(name): + if snapshot := self.table_metadata.snapshot_by_name(name): return self.update(snapshot_id=snapshot.snapshot_id) raise ValueError(f"Cannot scan unknown ref={name}") @@ -1626,20 +1627,8 @@ def _match_deletes_to_data_file(data_entry: ManifestEntry, positional_delete_ent class DataScan(TableScan): - def __init__( - self, - table: Table, - row_filter: Union[str, BooleanExpression] = ALWAYS_TRUE, - selected_fields: Tuple[str, ...] = ("*",), - case_sensitive: bool = True, - snapshot_id: Optional[int] = None, - options: Properties = EMPTY_DICT, - limit: Optional[int] = None, - ): - super().__init__(table, row_filter, selected_fields, case_sensitive, snapshot_id, options, limit) - def _build_partition_projection(self, spec_id: int) -> BooleanExpression: - project = inclusive_projection(self.table.schema(), self.table.specs()[spec_id]) + project = inclusive_projection(self.table_metadata.schema(), self.table_metadata.specs()[spec_id]) return project(self.row_filter) @cached_property @@ -1647,12 +1636,12 @@ def partition_filters(self) -> KeyDefaultDict[int, BooleanExpression]: return KeyDefaultDict(self._build_partition_projection) def _build_manifest_evaluator(self, spec_id: int) -> Callable[[ManifestFile], bool]: - spec = self.table.specs()[spec_id] - return manifest_evaluator(spec, self.table.schema(), self.partition_filters[spec_id], self.case_sensitive) + spec = self.table_metadata.specs()[spec_id] + return manifest_evaluator(spec, self.table_metadata.schema(), self.partition_filters[spec_id], self.case_sensitive) def _build_partition_evaluator(self, spec_id: int) -> Callable[[DataFile], bool]: - spec = self.table.specs()[spec_id] - partition_type = spec.partition_type(self.table.schema()) + spec = self.table_metadata.specs()[spec_id] + partition_type = spec.partition_type(self.table_metadata.schema()) partition_schema = Schema(*partition_type.fields) partition_expr = self.partition_filters[spec_id] @@ -1687,8 +1676,6 @@ def plan_files(self) -> Iterable[FileScanTask]: if not snapshot: return iter([]) - io = self.table.io - # step 1: filter manifests using partition summaries # the filter depends on the partition spec used to write the manifest file, so create a cache of filters for each spec id @@ -1696,7 +1683,7 @@ def plan_files(self) -> Iterable[FileScanTask]: manifests = [ manifest_file - for manifest_file in snapshot.manifests(io) + for manifest_file in snapshot.manifests(self.io) if manifest_evaluators[manifest_file.partition_spec_id](manifest_file) ] @@ -1705,7 +1692,7 @@ def plan_files(self) -> Iterable[FileScanTask]: partition_evaluators: Dict[int, Callable[[DataFile], bool]] = KeyDefaultDict(self._build_partition_evaluator) metrics_evaluator = _InclusiveMetricsEvaluator( - self.table.schema(), self.row_filter, self.case_sensitive, self.options.get("include_empty_files") == "true" + self.table_metadata.schema(), self.row_filter, self.case_sensitive, self.options.get("include_empty_files") == "true" ).eval min_data_sequence_number = _min_data_file_sequence_number(manifests) @@ -1719,7 +1706,7 @@ def plan_files(self) -> Iterable[FileScanTask]: lambda args: _open_manifest(*args), [ ( - io, + self.io, manifest, partition_evaluators[manifest.partition_spec_id], metrics_evaluator, @@ -1755,7 +1742,8 @@ def to_arrow(self) -> pa.Table: return project_table( self.plan_files(), - self.table, + self.table_metadata, + self.io, self.row_filter, self.projection(), case_sensitive=self.case_sensitive, diff --git a/pyiceberg/table/metadata.py b/pyiceberg/table/metadata.py index 21ed144784..ba0c885758 100644 --- a/pyiceberg/table/metadata.py +++ b/pyiceberg/table/metadata.py @@ -35,6 +35,7 @@ from pyiceberg.exceptions import ValidationError from pyiceberg.partitioning import PARTITION_FIELD_ID_START, PartitionSpec, assign_fresh_partition_spec_ids from pyiceberg.schema import Schema, assign_fresh_schema_ids +from pyiceberg.table.name_mapping import NameMapping, parse_mapping_from_json from pyiceberg.table.refs import MAIN_BRANCH, SnapshotRef, SnapshotRefType from pyiceberg.table.snapshots import MetadataLogEntry, Snapshot, SnapshotLogEntry from pyiceberg.table.sorting import ( @@ -237,6 +238,13 @@ def schema(self) -> Schema: """Return the schema for this table.""" return next(schema for schema in self.schemas if schema.schema_id == self.current_schema_id) + def name_mapping(self) -> Optional[NameMapping]: + """Return the table's field-id NameMapping.""" + if name_mapping_json := self.properties.get("schema.name-mapping.default"): + return parse_mapping_from_json(name_mapping_json) + else: + return None + def spec(self) -> PartitionSpec: """Return the partition spec of this table.""" return next(spec for spec in self.partition_specs if spec.spec_id == self.default_spec_id) @@ -278,6 +286,12 @@ def new_snapshot_id(self) -> int: return snapshot_id + def snapshot_by_name(self, name: str) -> Optional[Snapshot]: + """Return the snapshot referenced by the given name or null if no such reference exists.""" + if ref := self.refs.get(name): + return self.snapshot_by_id(ref.snapshot_id) + return None + def current_snapshot(self) -> Optional[Snapshot]: """Get the current snapshot for this table, or None if there is no current snapshot.""" if self.current_snapshot_id is not None: diff --git a/tests/integration/test_add_files.py b/tests/integration/test_add_files.py index 7c17618280..0de5d5f4ce 100644 --- a/tests/integration/test_add_files.py +++ b/tests/integration/test_add_files.py @@ -158,6 +158,9 @@ def test_add_files_to_unpartitioned_table(spark: SparkSession, session_catalog: for col in df.columns: assert df.filter(df[col].isNotNull()).count() == 5, "Expected all 5 rows to be non-null" + # check that the table can be read by pyiceberg + assert len(tbl.scan().to_arrow()) == 5, "Expected 5 rows" + @pytest.mark.integration @pytest.mark.parametrize("format_version", [1, 2]) @@ -255,6 +258,9 @@ def test_add_files_to_unpartitioned_table_with_schema_updates( value_count = 1 if col == "quux" else 6 assert df.filter(df[col].isNotNull()).count() == value_count, f"Expected {value_count} rows to be non-null" + # check that the table can be read by pyiceberg + assert len(tbl.scan().to_arrow()) == 6, "Expected 6 rows" + @pytest.mark.integration @pytest.mark.parametrize("format_version", [1, 2]) @@ -324,6 +330,9 @@ def test_add_files_to_partitioned_table(spark: SparkSession, session_catalog: Ca assert [row.file_count for row in partition_rows] == [5] assert [(row.partition.baz, row.partition.qux_month) for row in partition_rows] == [(123, 650)] + # check that the table can be read by pyiceberg + assert len(tbl.scan().to_arrow()) == 5, "Expected 5 rows" + @pytest.mark.integration @pytest.mark.parametrize("format_version", [1, 2]) diff --git a/tests/io/test_pyarrow.py b/tests/io/test_pyarrow.py index b99febd6e2..46ece77880 100644 --- a/tests/io/test_pyarrow.py +++ b/tests/io/test_pyarrow.py @@ -28,7 +28,6 @@ import pytest from pyarrow.fs import FileType, LocalFileSystem -from pyiceberg.catalog.noop import NoopCatalog from pyiceberg.exceptions import ResolveError from pyiceberg.expressions import ( AlwaysFalse, @@ -72,7 +71,7 @@ from pyiceberg.manifest import DataFile, DataFileContent, FileFormat from pyiceberg.partitioning import PartitionSpec from pyiceberg.schema import Schema, make_compatible_name, visit -from pyiceberg.table import FileScanTask, Table, TableProperties +from pyiceberg.table import FileScanTask, TableProperties from pyiceberg.table.metadata import TableMetadataV2 from pyiceberg.typedef import UTF8 from pyiceberg.types import ( @@ -876,7 +875,7 @@ def project( schema: Schema, files: List[str], expr: Optional[BooleanExpression] = None, table_schema: Optional[Schema] = None ) -> pa.Table: return project_table( - [ + tasks=[ FileScanTask( DataFile( content=DataFileContent.DATA, @@ -889,21 +888,16 @@ def project( ) for file in files ], - Table( - ("namespace", "table"), - metadata=TableMetadataV2( - location="file://a/b/", - last_column_id=1, - format_version=2, - schemas=[table_schema or schema], - partition_specs=[PartitionSpec()], - ), - metadata_location="file://a/b/c.json", - io=PyArrowFileIO(), - catalog=NoopCatalog("NoopCatalog"), + table_metadata=TableMetadataV2( + location="file://a/b/", + last_column_id=1, + format_version=2, + schemas=[table_schema or schema], + partition_specs=[PartitionSpec()], ), - expr or AlwaysTrue(), - schema, + io=PyArrowFileIO(), + row_filter=expr or AlwaysTrue(), + projected_schema=schema, case_sensitive=True, ) @@ -1362,20 +1356,15 @@ def test_delete(deletes_file: str, example_task: FileScanTask, table_schema_simp with_deletes = project_table( tasks=[example_task_with_delete], - table=Table( - ("namespace", "table"), - metadata=TableMetadataV2( - location=metadata_location, - last_column_id=1, - format_version=2, - current_schema_id=1, - schemas=[table_schema_simple], - partition_specs=[PartitionSpec()], - ), - metadata_location=metadata_location, - io=load_file_io(), - catalog=NoopCatalog("noop"), + table_metadata=TableMetadataV2( + location=metadata_location, + last_column_id=1, + format_version=2, + current_schema_id=1, + schemas=[table_schema_simple], + partition_specs=[PartitionSpec()], ), + io=load_file_io(), row_filter=AlwaysTrue(), projected_schema=table_schema_simple, ) @@ -1405,20 +1394,15 @@ def test_delete_duplicates(deletes_file: str, example_task: FileScanTask, table_ with_deletes = project_table( tasks=[example_task_with_delete], - table=Table( - ("namespace", "table"), - metadata=TableMetadataV2( - location=metadata_location, - last_column_id=1, - format_version=2, - current_schema_id=1, - schemas=[table_schema_simple], - partition_specs=[PartitionSpec()], - ), - metadata_location=metadata_location, - io=load_file_io(), - catalog=NoopCatalog("noop"), + table_metadata=TableMetadataV2( + location=metadata_location, + last_column_id=1, + format_version=2, + current_schema_id=1, + schemas=[table_schema_simple], + partition_specs=[PartitionSpec()], ), + io=load_file_io(), row_filter=AlwaysTrue(), projected_schema=table_schema_simple, ) @@ -1439,21 +1423,16 @@ def test_delete_duplicates(deletes_file: str, example_task: FileScanTask, table_ def test_pyarrow_wrap_fsspec(example_task: FileScanTask, table_schema_simple: Schema) -> None: metadata_location = "file://a/b/c.json" projection = project_table( - [example_task], - Table( - ("namespace", "table"), - metadata=TableMetadataV2( - location=metadata_location, - last_column_id=1, - format_version=2, - current_schema_id=1, - schemas=[table_schema_simple], - partition_specs=[PartitionSpec()], - ), - metadata_location=metadata_location, - io=load_file_io(properties={"py-io-impl": "pyiceberg.io.fsspec.FsspecFileIO"}, location=metadata_location), - catalog=NoopCatalog("NoopCatalog"), + tasks=[example_task], + table_metadata=TableMetadataV2( + location=metadata_location, + last_column_id=1, + format_version=2, + current_schema_id=1, + schemas=[table_schema_simple], + partition_specs=[PartitionSpec()], ), + io=load_file_io(properties={"py-io-impl": "pyiceberg.io.fsspec.FsspecFileIO"}, location=metadata_location), case_sensitive=True, projected_schema=table_schema_simple, row_filter=AlwaysTrue(),