From ef59855ebe9af8fc0309d1e77b752f64ee2d98b2 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Mon, 21 Aug 2023 23:06:37 -0300 Subject: [PATCH 1/2] make sure a dataset with a plain col actually works --- python/geoarrow/pyarrow/_dataset.py | 34 +++++++++++++++++---------- python/tests/test_geoarrow_dataset.py | 10 ++++++++ 2 files changed, 32 insertions(+), 12 deletions(-) diff --git a/python/geoarrow/pyarrow/_dataset.py b/python/geoarrow/pyarrow/_dataset.py index c1efe2f0..ab9c2252 100644 --- a/python/geoarrow/pyarrow/_dataset.py +++ b/python/geoarrow/pyarrow/_dataset.py @@ -145,13 +145,18 @@ def index_fragments(self, num_threads=None): [{'_fragment_index': 0, 'geometry': {'xmin': 0.5, 'xmax': 0.5, 'ymin': 1.5, 'ymax': 1.5}}] """ if self._index is None: - self._index = self._build_index(self.geometry_columns, num_threads) + self._index = self._build_index( + self.geometry_columns, self.geometry_types, num_threads + ) return self._index - def _build_index(self, geometry_columns, num_threads=None): + def _build_index(self, geometry_columns, geometry_types, num_threads=None): return GeoDataset._index_fragments( - self.get_fragments(), geometry_columns, num_threads=num_threads + self.get_fragments(), + geometry_columns, + geometry_types, + num_threads=num_threads, ) def filter_fragments(self, target): @@ -221,16 +226,19 @@ def _filter_parent_fragment_indices(self, fragment_indices): return _ds.InMemoryDataset(tables, schema=self.schema) @staticmethod - def _index_fragment(fragment, column): + def _index_fragment(fragment, column, type): scanner = fragment.scanner(columns=[column]) reader = scanner.to_reader() - kernel = Kernel.box_agg(reader.schema.types[0]) + kernel = Kernel.box_agg(type) for batch in reader: - kernel.push(batch.column(0)) + if isinstance(type, VectorType): + kernel.push(batch.column(0)) + else: + kernel.push(type.wrap_array(batch.column(0))) return kernel.finish() @staticmethod - def _index_fragments(fragments, columns, num_threads=None): + def _index_fragments(fragments, columns, types, num_threads=None): columns = list(columns) if num_threads is None: num_threads = _pa.cpu_count() @@ -242,10 +250,10 @@ def _index_fragments(fragments, columns, num_threads=None): with ThreadPoolExecutor(max_workers=num_threads) as executor: futures = [] - for column in columns: + for column, type in zip(columns, types): for fragment in fragments: future = executor.submit( - GeoDataset._index_fragment, fragment, column + GeoDataset._index_fragment, fragment, column, type ) futures.append(future) @@ -355,13 +363,13 @@ def _wrap_parent(self, filtered_parent, fragment_indices): new_wrapped._index = base_wrapped._index return new_wrapped - def _build_index(self, geometry_columns, num_threads=None): + def _build_index(self, geometry_columns, geometry_types, num_threads=None): can_use_statistics = [ type.coord_type == CoordType.SEPARATE for type in self.geometry_types ] if not self._use_column_statistics or not any(can_use_statistics): - return super()._build_index(geometry_columns, num_threads) + return super()._build_index(geometry_columns, geometry_types, num_threads) # Build a list of columns that will work with column stats bbox_stats_cols = [] @@ -380,7 +388,9 @@ def _build_index(self, geometry_columns, num_threads=None): normal_stats_cols.remove(col) # Compute any remaining statistics - normal_stats = super()._build_index(normal_stats_cols, num_threads) + normal_stats = super()._build_index( + normal_stats_cols, geometry_types, num_threads + ) for col in normal_stats_cols: stats_by_name[col] = normal_stats.column(col) diff --git a/python/tests/test_geoarrow_dataset.py b/python/tests/test_geoarrow_dataset.py index 35a22cce..f76dec99 100644 --- a/python/tests/test_geoarrow_dataset.py +++ b/python/tests/test_geoarrow_dataset.py @@ -62,6 +62,16 @@ def test_geodataset_in_memory(): ga.dataset([table1], use_row_groups=True) +def test_geodataset_in_memory_guessed_type(): + table1 = pa.table([ga.array(["POINT (0.5 1.5)"]).storage], ["geometry"]) + table2 = pa.table([ga.array(["POINT (2.5 3.5)"]).storage], ["geometry"]) + geods = ga.dataset([table1, table2], geometry_columns=["geometry"]) + + filtered1 = geods.filter_fragments("POLYGON ((2 3, 3 3, 3 4, 2 4, 2 3))") + assert isinstance(filtered1, GeoDataset) + assert filtered1.to_table().num_rows == 1 + + def test_geodataset_multiple_geometry_columns(): table1 = pa.table( [ga.array(["POINT (0.5 1.5)"]), ga.array(["POINT (2.5 3.5)"])], From 1f8f952175bbca8c117840fbf38635b65e8b28c8 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Mon, 21 Aug 2023 23:40:13 -0300 Subject: [PATCH 2/2] remove unused branch --- python/geoarrow/pyarrow/_dataset.py | 5 +---- python/tests/test_geoarrow_dataset.py | 1 - 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/python/geoarrow/pyarrow/_dataset.py b/python/geoarrow/pyarrow/_dataset.py index ab9c2252..50a6e6cc 100644 --- a/python/geoarrow/pyarrow/_dataset.py +++ b/python/geoarrow/pyarrow/_dataset.py @@ -231,10 +231,7 @@ def _index_fragment(fragment, column, type): reader = scanner.to_reader() kernel = Kernel.box_agg(type) for batch in reader: - if isinstance(type, VectorType): - kernel.push(batch.column(0)) - else: - kernel.push(type.wrap_array(batch.column(0))) + kernel.push(batch.column(0)) return kernel.finish() @staticmethod diff --git a/python/tests/test_geoarrow_dataset.py b/python/tests/test_geoarrow_dataset.py index f76dec99..70ef03aa 100644 --- a/python/tests/test_geoarrow_dataset.py +++ b/python/tests/test_geoarrow_dataset.py @@ -68,7 +68,6 @@ def test_geodataset_in_memory_guessed_type(): geods = ga.dataset([table1, table2], geometry_columns=["geometry"]) filtered1 = geods.filter_fragments("POLYGON ((2 3, 3 3, 3 4, 2 4, 2 3))") - assert isinstance(filtered1, GeoDataset) assert filtered1.to_table().num_rows == 1