From 91d2ab937b4c221f5665b4088aa26dda64d1a15d Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Mon, 21 Aug 2023 22:35:16 -0300 Subject: [PATCH] Add `GeoDataset` interface (#43) * some initial failing tests * actually working * use row groups * error * subclass for parquet rowgroup dataset * keep row group ids * move indexing methods into the GeoDataset class * factor out dataset constructor * prepare for row group metadata * scaffold the row group metadata scan * theoretical row group stats working * handle parquet files with no stats * basic docs * reorganize * maybe docs * maybe fix docstrings * fix a type-o (har) * test + fix parquet field counting * format, remove unneeded change * document + test filtering on multiple geometry columns * more explicit name + type tests * fix type test * theoretically pass index through filter * test wrapping * fix filtering to empty * fix docs --- docs/source/python/pyarrow.rst | 11 + python/geoarrow/pyarrow/__init__.py | 31 ++ python/geoarrow/pyarrow/_dataset.py | 465 ++++++++++++++++++++++++++ python/tests/test_geoarrow_dataset.py | 212 ++++++++++++ 4 files changed, 719 insertions(+) create mode 100644 python/geoarrow/pyarrow/_dataset.py create mode 100644 python/tests/test_geoarrow_dataset.py diff --git a/docs/source/python/pyarrow.rst b/docs/source/python/pyarrow.rst index 6905ca7d..64b09259 100644 --- a/docs/source/python/pyarrow.rst +++ b/docs/source/python/pyarrow.rst @@ -9,6 +9,11 @@ Integration with pyarrow .. autofunction:: array + Dataset constructors + -------------------- + + .. autofunction:: dataset + Type Constructors ----------------- @@ -94,3 +99,9 @@ Integration with pyarrow .. autoclass:: MultiPolygonType :members: + + .. autoclass:: geoarrow.pyarrow._dataset.GeoDataset + :members: + + .. autoclass:: geoarrow.pyarrow._dataset.ParquetRowGroupGeoDataset + :members: diff --git a/python/geoarrow/pyarrow/__init__.py b/python/geoarrow/pyarrow/__init__.py index bb06861a..08380ba9 100644 --- a/python/geoarrow/pyarrow/__init__.py +++ b/python/geoarrow/pyarrow/__init__.py @@ -60,4 +60,35 @@ point_coords, ) + +# Use a lazy import here to avoid requiring pyarrow.dataset +def dataset(*args, geometry_columns=None, use_row_groups=None, **kwargs): + """Construct a GeoDataset + + This constructor is intended to mirror `pyarrow.dataset()`, adding + geo-specific arguments. See :class:`geoarrow.pyarrow._dataset.GeoDataset` for + details. + + >>> import geoarrow.pyarrow as ga + >>> import pyarrow as pa + >>> table = pa.table([ga.array(["POINT (0.5 1.5)"])], ["geometry"]) + >>> dataset = ga.dataset(table) + """ + from pyarrow import dataset as _ds + from ._dataset import GeoDataset, ParquetRowGroupGeoDataset + + parent = _ds.dataset(*args, **kwargs) + + if use_row_groups is None: + use_row_groups = isinstance(parent, _ds.FileSystemDataset) and isinstance( + parent.format, _ds.ParquetFileFormat + ) + if use_row_groups: + return ParquetRowGroupGeoDataset.create( + parent, geometry_columns=geometry_columns + ) + else: + return GeoDataset(parent, geometry_columns=geometry_columns) + + register_extension_types() diff --git a/python/geoarrow/pyarrow/_dataset.py b/python/geoarrow/pyarrow/_dataset.py new file mode 100644 index 00000000..c1efe2f0 --- /dev/null +++ b/python/geoarrow/pyarrow/_dataset.py @@ -0,0 +1,465 @@ +from concurrent.futures import ThreadPoolExecutor, wait + +import pyarrow as _pa +import pyarrow.types as _types +import pyarrow.dataset as _ds +import pyarrow.compute as _compute +import pyarrow.parquet as _pq +from ..lib import CoordType +from ._type import wkt, wkb, VectorType +from ._kernel import Kernel + + +class GeoDataset: + """Geospatial-augmented Dataset + + The GeoDataset wraps a pyarrow.Dataset containing one or more geometry columns + and provides indexing and IO capability. If `geometry_columns` is `None`, + it will include all columns that inherit from `geoarrow.pyarrow.VectorType`. + The `geometry_columns` are not required to be geoarrow extension type columns: + text columns will be parsed as WKT; binary columns will be parsed as WKB + (but are not detected automatically). + """ + + def __init__(self, parent, geometry_columns=None): + self._index = None + self._geometry_columns = geometry_columns + self._geometry_types = None + self._fragments = None + + if not isinstance(parent, _ds.Dataset): + raise TypeError("parent must be a pyarrow.dataset.Dataset") + self._parent = parent + + @property + def parent(self): + """Get the parent Dataset + + Returns the (non geo-aware) parent pyarrow.Dataset. + + >>> import geoarrow.pyarrow as ga + >>> import pyarrow as pa + >>> table = pa.table([ga.array(["POINT (0.5 1.5)"])], ["geometry"]) + >>> dataset = ga.dataset(table) + >>> type(dataset.parent) + + """ + return self._parent + + def to_table(self): + return self.parent.to_table() + + @property + def schema(self): + """Get the dataset schema + + The schema of a GeoDataset is identical to that of its parent. + + >>> import geoarrow.pyarrow as ga + >>> import pyarrow as pa + >>> table = pa.table([ga.array(["POINT (0.5 1.5)"])], ["geometry"]) + >>> dataset = ga.dataset(table) + >>> dataset.schema + geometry: extension> + """ + return self._parent.schema + + def get_fragments(self): + """Resolve the list of fragments in the dataset + + This is identical to the list of fragments of its parent.""" + if self._fragments is None: + self._fragments = tuple(self._parent.get_fragments()) + + return self._fragments + + @property + def geometry_columns(self): + """Get a tuple of geometry column names + + >>> import geoarrow.pyarrow as ga + >>> import pyarrow as pa + >>> table = pa.table([ga.array(["POINT (0.5 1.5)"])], ["geometry"]) + >>> dataset = ga.dataset(table) + >>> dataset.geometry_columns + ('geometry',) + """ + if self._geometry_columns is None: + schema = self.schema + geometry_columns = [] + for name, type in zip(schema.names, schema.types): + if isinstance(type, VectorType): + geometry_columns.append(name) + self._geometry_columns = tuple(geometry_columns) + + return self._geometry_columns + + @property + def geometry_types(self): + """Resolve a tuple of geometry column types + + This will convert any primitive types to the corresponding + geo-enabled type (e.g., binary to wkb) and check that geometry + columns actually refer a field that can be interpreted as + geometry. + + >>> import geoarrow.pyarrow as ga + >>> import pyarrow as pa + >>> table = pa.table([ga.array(["POINT (0.5 1.5)"])], ["geometry"]) + >>> dataset = ga.dataset(table) + >>> dataset.geometry_types + (WktType(geoarrow.wkt),) + """ + if self._geometry_types is None: + geometry_types = [] + for col in self.geometry_columns: + type = self.schema.field(col).type + if isinstance(type, VectorType): + geometry_types.append(type) + elif _types.is_binary(type): + geometry_types.append(wkb()) + elif _types.is_string(type): + geometry_types.append(wkt()) + else: + raise TypeError(f"Unsupported type for geometry column: {type}") + + self._geometry_types = tuple(geometry_types) + + return self._geometry_types + + def index_fragments(self, num_threads=None): + """Resolve a simplified geometry for each fragment + + Currently the simplified geometry is a box in the form of a + struct array with fields xmin, xmax, ymin, and ymax. The + fragment index is curently a table whose first column is the fragment + index and whose subsequent columns are named with the geometry column + name. A future implementation may handle spherical edges using a type + of simplified geometry more suitable to a spherical comparison. + + >>> import geoarrow.pyarrow as ga + >>> import pyarrow as pa + >>> table = pa.table([ga.array(["POINT (0.5 1.5)"])], ["geometry"]) + >>> dataset = ga.dataset(table) + >>> dataset.index_fragments().to_pylist() + [{'_fragment_index': 0, 'geometry': {'xmin': 0.5, 'xmax': 0.5, 'ymin': 1.5, 'ymax': 1.5}}] + """ + if self._index is None: + self._index = self._build_index(self.geometry_columns, num_threads) + + return self._index + + def _build_index(self, geometry_columns, num_threads=None): + return GeoDataset._index_fragments( + self.get_fragments(), geometry_columns, num_threads=num_threads + ) + + def filter_fragments(self, target): + """Push down a spatial query into a GeoDataset + + Returns a potentially simplified dataset based on the geometry of + target. Currently this uses `geoarrow.pyarrow.box_agg()` on `target` + and performs a simple envelope comparison with each fragment. A future + implementation may handle spherical edges using a type of simplified + geometry more suitable to a spherical comparison. For datasets with + more than one geometry column, the filter will be applied to all columns + and include fragments that intersect the simplified geometry from any + of the columns. + + >>> import geoarrow.pyarrow as ga + >>> import pyarrow as pa + >>> table = pa.table([ga.array(["POINT (0.5 1.5)"])], ["geometry"]) + >>> dataset = ga.dataset(table) + >>> dataset.filter_fragments("POLYGON ((0 0, 0 1, 1 1, 1 0, 0 0))").to_table() + pyarrow.Table + geometry: extension> + ---- + geometry: [] + >>> dataset.filter_fragments("POLYGON ((0 1, 0 2, 1 2, 1 1, 0 1))").to_table() + pyarrow.Table + geometry: extension> + ---- + geometry: [["POINT (0.5 1.5)"]] + """ + from ._compute import box_agg + + if isinstance(target, str): + target = [target] + target_box = box_agg(target) + maybe_intersects = GeoDataset._index_box_intersects( + self.index_fragments(), target_box, self.geometry_columns + ) + fragment_indices = [scalar.as_py() for scalar in maybe_intersects] + filtered_parent = self._filter_parent_fragment_indices(fragment_indices) + return self._wrap_parent(filtered_parent, fragment_indices) + + def _wrap_parent(self, filtered_parent, fragment_indices): + new_wrapped = GeoDataset( + filtered_parent, geometry_columns=self._geometry_columns + ) + new_wrapped._geometry_types = self.geometry_types + + new_index = self.index_fragments().take( + _pa.array(fragment_indices, type=_pa.int64()) + ) + new_wrapped._index = new_index.set_column( + 0, "_fragment_index", _pa.array(range(new_index.num_rows)) + ) + + return new_wrapped + + def _filter_parent_fragment_indices(self, fragment_indices): + fragments = self.get_fragments() + fragments_filtered = [fragments[i] for i in fragment_indices] + + if isinstance(self._parent, _ds.FileSystemDataset): + return _ds.FileSystemDataset( + fragments_filtered, self.schema, self._parent.format + ) + else: + tables = [fragment.to_table() for fragment in fragments_filtered] + return _ds.InMemoryDataset(tables, schema=self.schema) + + @staticmethod + def _index_fragment(fragment, column): + scanner = fragment.scanner(columns=[column]) + reader = scanner.to_reader() + kernel = Kernel.box_agg(reader.schema.types[0]) + for batch in reader: + kernel.push(batch.column(0)) + return kernel.finish() + + @staticmethod + def _index_fragments(fragments, columns, num_threads=None): + columns = list(columns) + if num_threads is None: + num_threads = _pa.cpu_count() + + num_fragments = len(fragments) + metadata = [_pa.array(range(num_fragments))] + if not columns: + return _pa.table(metadata, names=["_fragment_index"]) + + with ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [] + for column in columns: + for fragment in fragments: + future = executor.submit( + GeoDataset._index_fragment, fragment, column + ) + futures.append(future) + + wait(futures) + + results = [] + for i, column in enumerate(columns): + results.append( + [ + futures[i * num_fragments + j].result() + for j in range(num_fragments) + ] + ) + + result_arrays = [_pa.concat_arrays(result) for result in results] + return _pa.table( + metadata + result_arrays, names=["_fragment_index"] + columns + ) + + @staticmethod + def _index_box_intersects(index, box, columns): + xmin, xmax, ymin, ymax = box.as_py().values() + expressions = [] + for col in columns: + expr = ( + (_compute.field(col, "xmin") <= xmax) + & (_compute.field(col, "xmax") >= xmin) + & (_compute.field(col, "ymin") <= ymax) + & (_compute.field(col, "ymax") >= ymin) + ) + expressions.append(expr) + + expr = expressions[0] + for i in range(1, len(expressions)): + expr = expr | expressions[i] + + result = _ds.dataset(index).filter(expr).to_table() + return result.column(0) + + +class ParquetRowGroupGeoDataset(GeoDataset): + """Geospatial-augmented Parquet dataset using row groups + + An implementation of the GeoDataset that can leverage potentially + more efficient indexing and more specific filtering. Notably, this + implementation can (1) split a Parquet dataset into potentially more + smaller fragments and (2) use column statistics added by most Parquet + writers to more efficiently build the fragment index for types that support + this capability. + """ + + def __init__( + self, + parent, + row_group_fragments, + row_group_ids, + geometry_columns=None, + use_column_statistics=True, + ): + super().__init__(parent, geometry_columns=geometry_columns) + self._fragments = row_group_fragments + self._row_group_ids = row_group_ids + self._use_column_statistics = use_column_statistics + + @staticmethod + def create(parent, geometry_columns=None, use_column_statistics=True): + if not isinstance(parent, _ds.FileSystemDataset) or not isinstance( + parent.format, _ds.ParquetFileFormat + ): + raise TypeError( + "ParquetRowGroupGeoDataset() is only supported for Parquet datasets" + ) + + row_group_fragments = [] + row_group_ids = [] + + for file_fragment in parent.get_fragments(): + for i, row_group_fragment in enumerate(file_fragment.split_by_row_group()): + row_group_fragments.append(row_group_fragment) + # Keep track of the row group IDs so we can accellerate + # building an index later where column statistics are supported + row_group_ids.append(i) + + parent = _ds.FileSystemDataset( + row_group_fragments, parent.schema, parent.format + ) + return ParquetRowGroupGeoDataset( + parent, + row_group_fragments, + row_group_ids, + geometry_columns=geometry_columns, + use_column_statistics=use_column_statistics, + ) + + def _wrap_parent(self, filtered_parent, fragment_indices): + base_wrapped = super()._wrap_parent(filtered_parent, fragment_indices) + + new_row_group_fragments = [self._fragments[i] for i in fragment_indices] + new_row_group_ids = [self._row_group_ids[i] for i in fragment_indices] + new_wrapped = ParquetRowGroupGeoDataset( + base_wrapped._parent, + new_row_group_fragments, + new_row_group_ids, + geometry_columns=base_wrapped.geometry_columns, + use_column_statistics=self._use_column_statistics, + ) + new_wrapped._index = base_wrapped._index + return new_wrapped + + def _build_index(self, geometry_columns, num_threads=None): + can_use_statistics = [ + type.coord_type == CoordType.SEPARATE for type in self.geometry_types + ] + + if not self._use_column_statistics or not any(can_use_statistics): + return super()._build_index(geometry_columns, num_threads) + + # Build a list of columns that will work with column stats + bbox_stats_cols = [] + for col, use_stats in zip(geometry_columns, can_use_statistics): + if use_stats: + bbox_stats_cols.append(col) + + # Compute the column stats + bbox_stats = self._build_index_using_stats(bbox_stats_cols) + normal_stats_cols = list(geometry_columns) + stats_by_name = {} + for col, stat in zip(bbox_stats_cols, bbox_stats): + # stat will contain nulls if any statistics were missing: + if stat.null_count == 0: + stats_by_name[col] = stat + normal_stats_cols.remove(col) + + # Compute any remaining statistics + normal_stats = super()._build_index(normal_stats_cols, num_threads) + for col in normal_stats_cols: + stats_by_name[col] = normal_stats.column(col) + + # Reorder stats to match the order of geometry_columns + stat_cols = [stats_by_name[col] for col in geometry_columns] + return _pa.table( + [normal_stats.column(0)] + stat_cols, + names=["_fragment_index"] + list(geometry_columns), + ) + + def _build_index_using_stats(self, geometry_columns): + parquet_fields_before = ParquetRowGroupGeoDataset._count_fields_before( + self.schema + ) + parquet_fields_before = {k: v for k, v in parquet_fields_before} + parquet_fields_before = [ + parquet_fields_before[(col,)] for col in geometry_columns + ] + return self._parquet_field_boxes(parquet_fields_before) + + def _parquet_field_boxes(self, parquet_indices): + boxes = [[]] * len(parquet_indices) + pq_file = None + last_row_group = None + + # Note: probably worth parallelizing by file + for row_group, fragment in zip(self._row_group_ids, self.get_fragments()): + if pq_file is None or row_group < last_row_group: + pq_file = _pq.ParquetFile( + fragment.path, filesystem=self._parent.filesystem + ) + + metadata = pq_file.metadata.row_group(row_group) + for i, parquet_index in enumerate(parquet_indices): + stats_x = metadata.column(parquet_index).statistics + stats_y = metadata.column(parquet_index + 1).statistics + + if stats_x is None or stats_y is None: + boxes[i].append(None) + else: + boxes[i].append( + { + "xmin": stats_x.min, + "xmax": stats_x.max, + "ymin": stats_y.min, + "ymax": stats_y.max, + } + ) + + last_row_group = row_group + + type_field_names = ["xmin", "xmax", "ymin", "ymax"] + type_fields = [_pa.field(name, _pa.float64()) for name in type_field_names] + type = _pa.struct(type_fields) + return [_pa.array(box, type=type) for box in boxes] + + @staticmethod + def _count_fields_before(field, fields_before=None, path=(), count=0): + """Helper to find the parquet column index of a given field path""" + + if isinstance(field, _pa.Schema): + fields_before = [] + for i in range(len(field.types)): + count = ParquetRowGroupGeoDataset._count_fields_before( + field.field(i), fields_before, path, count + ) + return fields_before + + if isinstance(field.type, _pa.ExtensionType): + field = _pa.field(field.name, field.type.storage_type) + + if _types.is_nested(field.type): + path = path + (field.name,) + fields_before.append((path, count)) + for i in range(field.type.num_fields): + count = ParquetRowGroupGeoDataset._count_fields_before( + field.type.field(i), fields_before, path, count + ) + return count + else: + fields_before.append((path + (field.name,), count)) + return count + 1 diff --git a/python/tests/test_geoarrow_dataset.py b/python/tests/test_geoarrow_dataset.py new file mode 100644 index 00000000..35a22cce --- /dev/null +++ b/python/tests/test_geoarrow_dataset.py @@ -0,0 +1,212 @@ +from tempfile import TemporaryDirectory + +import pyarrow as pa +import pyarrow.dataset as ds +import pyarrow.parquet as pq +import pytest + +import geoarrow.pyarrow as ga +from geoarrow.pyarrow._dataset import GeoDataset, ParquetRowGroupGeoDataset + + +def test_geodataset_column_name_guessing(): + table = pa.table([ga.array(["POINT (0.5 1.5)"])], ["geometry"]) + geods = ga.dataset(table) + assert geods.geometry_columns == ("geometry",) + + +def test_geodataset_column_type_guessing(): + # Already a geoarrow type + table = pa.table([ga.array(["POINT (0.5 1.5)"])], ["geometry"]) + geods = ga.dataset(table, geometry_columns=["geometry"]) + assert geods.geometry_types == (ga.wkt(),) + + # utf8 maps to wkt + table = pa.table([ga.array(["POINT (0.5 1.5)"]).storage], ["geometry"]) + geods = ga.dataset(table, geometry_columns=["geometry"]) + assert geods.geometry_types == (ga.wkt(),) + + # binary maps to wkb + table = pa.table([ga.as_wkb(["POINT (0.5 1.5)"]).storage], ["geometry"]) + geods = ga.dataset(table, geometry_columns=["geometry"]) + assert geods.geometry_types == (ga.wkb(),) + + # Error for other types + with pytest.raises(TypeError): + table = pa.table([[123]], ["geometry"]) + geods = ga.dataset(table, geometry_columns=["geometry"]) + geods.geometry_types + + +def test_geodataset_in_memory(): + table1 = pa.table([ga.array(["POINT (0.5 1.5)"])], ["geometry"]) + table2 = pa.table([ga.array(["POINT (2.5 3.5)"])], ["geometry"]) + + geods = ga.dataset([table1, table2]) + assert isinstance(geods._parent, ds.InMemoryDataset) + assert len(list(geods._parent.get_fragments())) == 2 + + filtered1 = geods.filter_fragments("POLYGON ((2 3, 3 3, 3 4, 2 4, 2 3))") + assert isinstance(filtered1, GeoDataset) + assert filtered1.to_table().num_rows == 1 + assert filtered1._index.column("_fragment_index") == pa.chunked_array([[0]]) + assert filtered1._index.column("geometry") == geods._index.column("geometry").take( + [1] + ) + + # Make sure we can filter to empty + filtered0 = geods.filter_fragments("POLYGON ((0 0, 0 1, 1 1, 1 0, 0 0))") + assert filtered0.to_table().num_rows == 0 + + with pytest.raises(TypeError): + ga.dataset([table1], use_row_groups=True) + + +def test_geodataset_multiple_geometry_columns(): + table1 = pa.table( + [ga.array(["POINT (0.5 1.5)"]), ga.array(["POINT (2.5 3.5)"])], + ["geometry1", "geometry2"], + ) + table2 = pa.table( + [ga.array(["POINT (4.5 5.5)"]), ga.array(["POINT (6.5 7.5)"])], + ["geometry1", "geometry2"], + ) + + geods = ga.dataset([table1, table2]) + assert isinstance(geods._parent, ds.InMemoryDataset) + assert len(list(geods._parent.get_fragments())) == 2 + + filtered1 = geods.filter_fragments("POLYGON ((0 1, 1 1, 1 2, 0 2, 0 1))").to_table() + assert filtered1.num_rows == 1 + + filtered2 = geods.filter_fragments("POLYGON ((2 3, 3 3, 3 4, 2 4, 2 3))").to_table() + assert filtered2.num_rows == 1 + + +def test_geodataset_parquet(): + table1 = pa.table([ga.array(["POINT (0.5 1.5)"])], ["geometry"]) + table2 = pa.table([ga.array(["POINT (2.5 3.5)"])], ["geometry"]) + with TemporaryDirectory() as td: + pq.write_table(table1, f"{td}/table1.parquet") + pq.write_table(table2, f"{td}/table2.parquet") + geods = ga.dataset( + [f"{td}/table1.parquet", f"{td}/table2.parquet"], use_row_groups=False + ) + + filtered1 = geods.filter_fragments( + "POLYGON ((0 1, 1 1, 1 2, 0 2, 0 1))" + ).to_table() + assert filtered1.num_rows == 1 + + +def test_geodataset_parquet_rowgroups(): + table = pa.table([ga.array(["POINT (0.5 1.5)", "POINT (2.5 3.5)"])], ["geometry"]) + with TemporaryDirectory() as td: + pq.write_table(table, f"{td}/table.parquet", row_group_size=1) + + geods = ga.dataset(f"{td}/table.parquet") + assert isinstance(geods, ParquetRowGroupGeoDataset) + assert len(geods.get_fragments()) == 2 + + filtered1 = geods.filter_fragments("POLYGON ((2 3, 3 3, 3 4, 2 4, 2 3))") + assert isinstance(filtered1, ParquetRowGroupGeoDataset) + assert filtered1.to_table().num_rows == 1 + assert filtered1._index.column("_fragment_index") == pa.chunked_array([[0]]) + assert filtered1._index.column("geometry") == geods._index.column( + "geometry" + ).take([1]) + + assert filtered1._row_group_ids == [1] + + +def test_geodataset_parquet_index_rowgroups(): + array_wkt = ga.array( + ["LINESTRING (0.5 1.5, 2.5 3.5)", "LINESTRING (4.5 5.5, 6.5 7.5)"] + ) + array_geoarrow = ga.as_geoarrow( + ["LINESTRING (8.5 9.5, 10.5 11.5)", "LINESTRING (12.5 13.5, 14.5 15.5)"] + ) + + table_wkt = pa.table([array_wkt], ["geometry"]) + table_geoarrow = pa.table([array_geoarrow], ["geometry"]) + table_both = pa.table( + [array_wkt, array_geoarrow], ["geometry_wkt", "geometry_geoarrow"] + ) + + with TemporaryDirectory() as td: + pq.write_table(table_wkt, f"{td}/table_wkt.parquet", row_group_size=1) + pq.write_table(table_geoarrow, f"{td}/table_geoarrow.parquet", row_group_size=1) + pq.write_table( + table_geoarrow, + f"{td}/table_geoarrow_nostats.parquet", + row_group_size=1, + write_statistics=False, + ) + pq.write_table(table_both, f"{td}/table_both.parquet", row_group_size=1) + + ds_wkt = ga.dataset(f"{td}/table_wkt.parquet") + ds_geoarrow = ga.dataset(f"{td}/table_geoarrow.parquet") + ds_geoarrow_nostats = ga.dataset(f"{td}/table_geoarrow_nostats.parquet") + ds_both = ga.dataset(f"{td}/table_both.parquet") + + index_wkt = ds_wkt.index_fragments() + index_geoarrow = ds_geoarrow.index_fragments() + index_geoarrow_nostats = ds_geoarrow_nostats.index_fragments() + index_both = ds_both.index_fragments() + + # All the fragment indices should be the same + assert index_geoarrow.column(0) == index_wkt.column(0) + assert index_geoarrow_nostats.column(0) == index_wkt.column(0) + assert index_both.column(0) == index_wkt.column(0) + + # The wkt index should be the same in index_both and index_wkt + assert index_both.column("geometry_wkt") == index_wkt.column("geometry") + + # The geoarrow index should be the same everywhere + assert index_geoarrow_nostats.column("geometry") == index_geoarrow.column( + "geometry" + ) + assert index_both.column("geometry_geoarrow") == index_geoarrow.column( + "geometry" + ) + + +def test_geodataset_parquet_filter_rowgroups_with_stats(): + arr = ga.as_geoarrow(["POINT (0.5 1.5)", "POINT (2.5 3.5)"]) + table = pa.table([arr], ["geometry"]) + with TemporaryDirectory() as td: + pq.write_table(table, f"{td}/table.parquet", row_group_size=1) + + geods = ga.dataset(f"{td}/table.parquet") + assert len(geods.get_fragments()) == 2 + + geods._build_index_using_stats(["geometry"]) + + filtered1 = geods.filter_fragments( + "POLYGON ((0 1, 1 1, 1 2, 0 2, 0 1))" + ).to_table() + assert filtered1.num_rows == 1 + + +def test_parquet_fields_before(): + schema = pa.schema([pa.field("col1", pa.int32()), pa.field("col2", pa.int32())]) + fields_before = ParquetRowGroupGeoDataset._count_fields_before(schema) + assert fields_before == [(("col1",), 0), (("col2",), 1)] + + schema = pa.schema( + [pa.field("col1", pa.list_(pa.int32())), pa.field("col2", pa.int32())] + ) + fields_before = ParquetRowGroupGeoDataset._count_fields_before(schema) + assert fields_before == [(("col1",), 0), (("col1", "item"), 0), (("col2",), 1)] + + schema = pa.schema( + [pa.field("col1", ga.linestring()), pa.field("col2", pa.int32())] + ) + fields_before = ParquetRowGroupGeoDataset._count_fields_before(schema) + assert fields_before == [ + (("col1",), 0), + (("col1", "vertices"), 0), + (("col1", "vertices", "x"), 0), + (("col1", "vertices", "y"), 1), + (("col2",), 2), + ]