From 2ca2ea19b17b3a75a6450c8a7cc8088569fd18e4 Mon Sep 17 00:00:00 2001 From: Freddie Witherden Date: Mon, 16 Dec 2024 10:07:57 -0600 Subject: [PATCH] Add bulk query APIs (#342) * Reduce overhead. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Reduce overhead. * Add support for array-based bulk insertion. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Consistency checks. * Add support for bulk-query APIs. * Formatting. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- rtree/core.py | 38 +++++++++++++++++ rtree/index.py | 109 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 147 insertions(+) diff --git a/rtree/core.py b/rtree/core.py index e9e7a355..c4ba5d0d 100644 --- a/rtree/core.py +++ b/rtree/core.py @@ -238,6 +238,44 @@ def free_error_msg_ptr(result, func, cargs): rt.Index_NearestNeighbors_id.restype = ctypes.c_int rt.Index_NearestNeighbors_id.errcheck = check_return # type: ignore +try: + rt.Index_NearestNeighbors_id_v.argtypes = [ + ctypes.c_void_p, + ctypes.c_int64, + ctypes.c_int64, + ctypes.c_uint32, + ctypes.c_uint64, + ctypes.c_uint64, + ctypes.c_uint64, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.POINTER(ctypes.c_int64), + ] + rt.Index_NearestNeighbors_id_v.restype = ctypes.c_int + rt.Index_NearestNeighbors_id_v.errcheck = check_return # type: ignore + + rt.Index_Intersects_id_v.argtypes = [ + ctypes.c_void_p, + ctypes.c_int64, + ctypes.c_uint32, + ctypes.c_uint64, + ctypes.c_uint64, + ctypes.c_uint64, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.POINTER(ctypes.c_int64), + ] + rt.Index_Intersects_id_v.restype = ctypes.c_int + rt.Index_Intersects_id_v.errcheck = check_return # type: ignore +except AttributeError: + pass + + rt.Index_GetLeaves.argtypes = [ ctypes.c_void_p, ctypes.POINTER(ctypes.c_uint32), diff --git a/rtree/index.py b/rtree/index.py index f363107f..df7ad7e0 100644 --- a/rtree/index.py +++ b/rtree/index.py @@ -1046,6 +1046,108 @@ def nearest( return self._get_ids(it, p_num_results.contents.value) + def intersection_v(self, mins, maxs): + import numpy as np + + assert mins.shape == maxs.shape + assert mins.strides == maxs.strides + + # Cast + mins = mins.astype(np.float64) + maxs = maxs.astype(np.float64) + + # Extract counts + n, d = mins.shape + + # Compute strides + d_i_stri = mins.strides[0] // mins.itemsize + d_j_stri = mins.strides[1] // mins.itemsize + + ids = np.empty(2 * n, dtype=np.int64) + counts = np.empty(n, dtype=np.uint64) + nr = ctypes.c_int64(0) + offn, offi = 0, 0 + + while True: + core.rt.Index_Intersects_id_v( + self.handle, + n - offn, + d, + len(ids), + d_i_stri, + d_j_stri, + mins[offn:].ctypes.data, + maxs[offn:].ctypes.data, + ids[offi:].ctypes.data, + counts[offn:].ctypes.data, + ctypes.byref(nr), + ) + + # If we got the expected nuber of results then return + if nr.value == n - offn: + return ids[: counts.sum()], counts + # Otherwise, if our array is too small then resize + else: + offi += counts[offn : offn + nr.value].sum() + offn += nr.value + + ids = ids.resize(2 * len(ids), refcheck=False) + + def nearest_v( + self, mins, maxs, num_results=1, strict=False, return_max_dists=False + ): + import numpy as np + + assert mins.shape == maxs.shape + assert mins.strides == maxs.strides + + # Cast + mins = mins.astype(np.float64) + maxs = maxs.astype(np.float64) + + # Extract counts + n, d = mins.shape + + # Compute strides + d_i_stri = mins.strides[0] // mins.itemsize + d_j_stri = mins.strides[1] // mins.itemsize + + ids = np.empty(n * num_results, dtype=np.int64) + counts = np.empty(n, dtype=np.uint64) + dists = np.empty(n) if return_max_dists else None + nr = ctypes.c_int64(0) + offn, offi = 0, 0 + + while True: + core.rt.Index_NearestNeighbors_id_v( + self.handle, + num_results if not strict else -num_results, + n - offn, + d, + len(ids), + d_i_stri, + d_j_stri, + mins[offn:].ctypes.data, + maxs[offn:].ctypes.data, + ids[offi:].ctypes.data, + counts[offn:].ctypes.data, + dists[offn:].ctypes.data if return_max_dists else None, + ctypes.byref(nr), + ) + + # If we got the expected nuber of results then return + if nr.value == n - offn: + if return_max_dists: + return ids[: counts.sum()], counts, dists + else: + return ids[: counts.sum()], counts + # Otherwise, if our array is too small then resize + else: + offi += counts[offn : offn + nr.value].sum() + offn += nr.value + + ids = ids.resize(2 * len(ids), refcheck=False) + def _nearestTP(self, coordinates, velocities, times, num_results=1, objects=False): p_mins, p_maxs = self.get_coordinate_pointers(coordinates) pv_mins, pv_maxs = self.get_coordinate_pointers(velocities) @@ -1538,6 +1640,13 @@ def initialize_from_dict(self, state: dict[str, Any]) -> None: if v is not None: setattr(self, k, v) + # Consistency checks + if "near_minimum_overlap_factor" not in state: + nmof = self.near_minimum_overlap_factor + ilc = min(self.index_capacity, self.leaf_capacity) + if nmof >= ilc: + self.near_minimum_overlap_factor = ilc // 3 + 1 + def __getstate__(self) -> dict[Any, Any]: return self.as_dict()