Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix GeoDataset for non-extension type column (e.g., overature) #44

Merged
merged 2 commits into from
Aug 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 18 additions & 11 deletions python/geoarrow/pyarrow/_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,13 +145,18 @@ def index_fragments(self, num_threads=None):
[{'_fragment_index': 0, 'geometry': {'xmin': 0.5, 'xmax': 0.5, 'ymin': 1.5, 'ymax': 1.5}}]
"""
if self._index is None:
self._index = self._build_index(self.geometry_columns, num_threads)
self._index = self._build_index(
self.geometry_columns, self.geometry_types, num_threads
)

return self._index

def _build_index(self, geometry_columns, num_threads=None):
def _build_index(self, geometry_columns, geometry_types, num_threads=None):
return GeoDataset._index_fragments(
self.get_fragments(), geometry_columns, num_threads=num_threads
self.get_fragments(),
geometry_columns,
geometry_types,
num_threads=num_threads,
)

def filter_fragments(self, target):
Expand Down Expand Up @@ -221,16 +226,16 @@ def _filter_parent_fragment_indices(self, fragment_indices):
return _ds.InMemoryDataset(tables, schema=self.schema)

@staticmethod
def _index_fragment(fragment, column):
def _index_fragment(fragment, column, type):
scanner = fragment.scanner(columns=[column])
reader = scanner.to_reader()
kernel = Kernel.box_agg(reader.schema.types[0])
kernel = Kernel.box_agg(type)
for batch in reader:
kernel.push(batch.column(0))
return kernel.finish()

@staticmethod
def _index_fragments(fragments, columns, num_threads=None):
def _index_fragments(fragments, columns, types, num_threads=None):
columns = list(columns)
if num_threads is None:
num_threads = _pa.cpu_count()
Expand All @@ -242,10 +247,10 @@ def _index_fragments(fragments, columns, num_threads=None):

with ThreadPoolExecutor(max_workers=num_threads) as executor:
futures = []
for column in columns:
for column, type in zip(columns, types):
for fragment in fragments:
future = executor.submit(
GeoDataset._index_fragment, fragment, column
GeoDataset._index_fragment, fragment, column, type
)
futures.append(future)

Expand Down Expand Up @@ -355,13 +360,13 @@ def _wrap_parent(self, filtered_parent, fragment_indices):
new_wrapped._index = base_wrapped._index
return new_wrapped

def _build_index(self, geometry_columns, num_threads=None):
def _build_index(self, geometry_columns, geometry_types, num_threads=None):
can_use_statistics = [
type.coord_type == CoordType.SEPARATE for type in self.geometry_types
]

if not self._use_column_statistics or not any(can_use_statistics):
return super()._build_index(geometry_columns, num_threads)
return super()._build_index(geometry_columns, geometry_types, num_threads)

# Build a list of columns that will work with column stats
bbox_stats_cols = []
Expand All @@ -380,7 +385,9 @@ def _build_index(self, geometry_columns, num_threads=None):
normal_stats_cols.remove(col)

# Compute any remaining statistics
normal_stats = super()._build_index(normal_stats_cols, num_threads)
normal_stats = super()._build_index(
normal_stats_cols, geometry_types, num_threads
)
for col in normal_stats_cols:
stats_by_name[col] = normal_stats.column(col)

Expand Down
9 changes: 9 additions & 0 deletions python/tests/test_geoarrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,15 @@ def test_geodataset_in_memory():
ga.dataset([table1], use_row_groups=True)


def test_geodataset_in_memory_guessed_type():
table1 = pa.table([ga.array(["POINT (0.5 1.5)"]).storage], ["geometry"])
table2 = pa.table([ga.array(["POINT (2.5 3.5)"]).storage], ["geometry"])
geods = ga.dataset([table1, table2], geometry_columns=["geometry"])

filtered1 = geods.filter_fragments("POLYGON ((2 3, 3 3, 3 4, 2 4, 2 3))")
assert filtered1.to_table().num_rows == 1


def test_geodataset_multiple_geometry_columns():
table1 = pa.table(
[ga.array(["POINT (0.5 1.5)"]), ga.array(["POINT (2.5 3.5)"])],
Expand Down
Loading