Skip to content

Commit

Permalink
Add bulk query APIs (#342)
Browse files Browse the repository at this point in the history
* Reduce overhead.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Reduce overhead.

* Add support for array-based bulk insertion.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Consistency checks.

* Add support for bulk-query APIs.

* Formatting.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
FreddieWitherden and pre-commit-ci[bot] authored Dec 16, 2024
1 parent 7775fdd commit 2ca2ea1
Show file tree
Hide file tree
Showing 2 changed files with 147 additions and 0 deletions.
38 changes: 38 additions & 0 deletions rtree/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,44 @@ def free_error_msg_ptr(result, func, cargs):
rt.Index_NearestNeighbors_id.restype = ctypes.c_int
rt.Index_NearestNeighbors_id.errcheck = check_return # type: ignore

try:
rt.Index_NearestNeighbors_id_v.argtypes = [
ctypes.c_void_p,
ctypes.c_int64,
ctypes.c_int64,
ctypes.c_uint32,
ctypes.c_uint64,
ctypes.c_uint64,
ctypes.c_uint64,
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.POINTER(ctypes.c_int64),
]
rt.Index_NearestNeighbors_id_v.restype = ctypes.c_int
rt.Index_NearestNeighbors_id_v.errcheck = check_return # type: ignore

rt.Index_Intersects_id_v.argtypes = [
ctypes.c_void_p,
ctypes.c_int64,
ctypes.c_uint32,
ctypes.c_uint64,
ctypes.c_uint64,
ctypes.c_uint64,
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.POINTER(ctypes.c_int64),
]
rt.Index_Intersects_id_v.restype = ctypes.c_int
rt.Index_Intersects_id_v.errcheck = check_return # type: ignore
except AttributeError:
pass


rt.Index_GetLeaves.argtypes = [
ctypes.c_void_p,
ctypes.POINTER(ctypes.c_uint32),
Expand Down
109 changes: 109 additions & 0 deletions rtree/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -1046,6 +1046,108 @@ def nearest(

return self._get_ids(it, p_num_results.contents.value)

def intersection_v(self, mins, maxs):
import numpy as np

assert mins.shape == maxs.shape
assert mins.strides == maxs.strides

# Cast
mins = mins.astype(np.float64)
maxs = maxs.astype(np.float64)

# Extract counts
n, d = mins.shape

# Compute strides
d_i_stri = mins.strides[0] // mins.itemsize
d_j_stri = mins.strides[1] // mins.itemsize

ids = np.empty(2 * n, dtype=np.int64)
counts = np.empty(n, dtype=np.uint64)
nr = ctypes.c_int64(0)
offn, offi = 0, 0

while True:
core.rt.Index_Intersects_id_v(
self.handle,
n - offn,
d,
len(ids),
d_i_stri,
d_j_stri,
mins[offn:].ctypes.data,
maxs[offn:].ctypes.data,
ids[offi:].ctypes.data,
counts[offn:].ctypes.data,
ctypes.byref(nr),
)

# If we got the expected nuber of results then return
if nr.value == n - offn:
return ids[: counts.sum()], counts
# Otherwise, if our array is too small then resize
else:
offi += counts[offn : offn + nr.value].sum()
offn += nr.value

ids = ids.resize(2 * len(ids), refcheck=False)

def nearest_v(
self, mins, maxs, num_results=1, strict=False, return_max_dists=False
):
import numpy as np

assert mins.shape == maxs.shape
assert mins.strides == maxs.strides

# Cast
mins = mins.astype(np.float64)
maxs = maxs.astype(np.float64)

# Extract counts
n, d = mins.shape

# Compute strides
d_i_stri = mins.strides[0] // mins.itemsize
d_j_stri = mins.strides[1] // mins.itemsize

ids = np.empty(n * num_results, dtype=np.int64)
counts = np.empty(n, dtype=np.uint64)
dists = np.empty(n) if return_max_dists else None
nr = ctypes.c_int64(0)
offn, offi = 0, 0

while True:
core.rt.Index_NearestNeighbors_id_v(
self.handle,
num_results if not strict else -num_results,
n - offn,
d,
len(ids),
d_i_stri,
d_j_stri,
mins[offn:].ctypes.data,
maxs[offn:].ctypes.data,
ids[offi:].ctypes.data,
counts[offn:].ctypes.data,
dists[offn:].ctypes.data if return_max_dists else None,
ctypes.byref(nr),
)

# If we got the expected nuber of results then return
if nr.value == n - offn:
if return_max_dists:
return ids[: counts.sum()], counts, dists
else:
return ids[: counts.sum()], counts
# Otherwise, if our array is too small then resize
else:
offi += counts[offn : offn + nr.value].sum()
offn += nr.value

ids = ids.resize(2 * len(ids), refcheck=False)

def _nearestTP(self, coordinates, velocities, times, num_results=1, objects=False):
p_mins, p_maxs = self.get_coordinate_pointers(coordinates)
pv_mins, pv_maxs = self.get_coordinate_pointers(velocities)
Expand Down Expand Up @@ -1538,6 +1640,13 @@ def initialize_from_dict(self, state: dict[str, Any]) -> None:
if v is not None:
setattr(self, k, v)

# Consistency checks
if "near_minimum_overlap_factor" not in state:
nmof = self.near_minimum_overlap_factor
ilc = min(self.index_capacity, self.leaf_capacity)
if nmof >= ilc:
self.near_minimum_overlap_factor = ilc // 3 + 1

def __getstate__(self) -> dict[Any, Any]:
return self.as_dict()

Expand Down

0 comments on commit 2ca2ea1

Please sign in to comment.