From 0c68607fa9d51b9a7e9e8eb3e281c3de7560072a Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Sat, 5 Feb 2022 11:40:44 -0600 Subject: [PATCH 1/4] Index/Item: add intersection/union support --- rtree/index.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/rtree/index.py b/rtree/index.py index dc614cb8..8c55e1f6 100644 --- a/rtree/index.py +++ b/rtree/index.py @@ -663,6 +663,31 @@ def contains( ) return self._get_ids(it, p_num_results.value) + def __and__(self, other: Index) -> Index: + """Take the intersection of two Index objects. + + :param other: another index + :return: a new index + """ + new_idx = Index(properties=self.properties) + for item1 in self.intersection(self.bounds, objects=True): + for item2 in other.intersection(item1.bounds, objects=True): + item3 = item1 & item2 + new_idx.insert(item3.id, item3.bounds, item3.object) + return new_idx + + def __or__(self, other: Index) -> Index: + """Take the union of two Index objects. + + :param other: another index + :return: a new index + """ + new_idx = Index(properties=self.properties) + for old_idx in [self, other]: + for item in old_idx.intersection(old_idx.bounds, objects=True): + new_idx.insert(item.id, item.bounds, item.object) + return new_idx + @overload def intersection(self, coordinates: Any, objects: Literal[True]) -> Iterator[Item]: ... From c27cef21bfd00380fd24fb0267de5916d4b80246 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Wed, 23 Feb 2022 16:30:33 -0600 Subject: [PATCH 2/4] Compute intersection bounding box --- rtree/index.py | 39 +++++++++++++++++++++++++++++++++++---- 1 file changed, 35 insertions(+), 4 deletions(-) diff --git a/rtree/index.py b/rtree/index.py index 8c55e1f6..c124a51a 100644 --- a/rtree/index.py +++ b/rtree/index.py @@ -668,12 +668,35 @@ def __and__(self, other: Index) -> Index: :param other: another index :return: a new index + :raises AssertionError: if self and other have different interleave or dimension """ - new_idx = Index(properties=self.properties) + assert self.interleaved == other.interleaved + assert self.properties.dimension == other.properties.dimension + + i = 0 + new_idx = Index(interleaved=self.interleaved, properties=self.properties) + + # For each Item in self... for item1 in self.intersection(self.bounds, objects=True): + # For each Item in other that intersects... for item2 in other.intersection(item1.bounds, objects=True): - item3 = item1 & item2 - new_idx.insert(item3.id, item3.bounds, item3.object) + # Compute the intersection bounding box + bounds = [] + for j in range(len(item1.bounds)): + if self.interleaved: + if j < len(item1.bounds) // 2: + bounds.append(max(item1.bounds[j], item2.bounds[j])) + else: + bounds.append(min(item1.bounds[j], item2.bounds[j])) + else: + if j % 2 == 0: + bounds.append(max(item1.bounds[j], item2.bounds[j])) + else: + bounds.append(min(item1.bounds[j], item2.bounds[j])) + + new_idx.insert(i, bounds, (item1.object, item2.object)) + i += 1 + return new_idx def __or__(self, other: Index) -> Index: @@ -681,11 +704,19 @@ def __or__(self, other: Index) -> Index: :param other: another index :return: a new index + :raises AssertionError: if self and other have different interleave or dimension """ - new_idx = Index(properties=self.properties) + assert self.interleaved == other.interleaved + assert self.properties.dimension == other.properties.dimension + + new_idx = Index(interleaved=self.interleaved, properties=self.properties) + + # For each index... for old_idx in [self, other]: + # For each item... for item in old_idx.intersection(old_idx.bounds, objects=True): new_idx.insert(item.id, item.bounds, item.object) + return new_idx @overload From d3991655b56bff3a07185d7a3dada9ffaf8f60a9 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Fri, 4 Mar 2022 15:57:27 -0600 Subject: [PATCH 3/4] Add tests, fix bugs --- rtree/index.py | 44 +++++++++++------ tests/test_index.py | 114 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 143 insertions(+), 15 deletions(-) diff --git a/rtree/index.py b/rtree/index.py index c124a51a..f96fb8e5 100644 --- a/rtree/index.py +++ b/rtree/index.py @@ -676,26 +676,37 @@ def __and__(self, other: Index) -> Index: i = 0 new_idx = Index(interleaved=self.interleaved, properties=self.properties) - # For each Item in self... - for item1 in self.intersection(self.bounds, objects=True): - # For each Item in other that intersects... - for item2 in other.intersection(item1.bounds, objects=True): - # Compute the intersection bounding box - bounds = [] - for j in range(len(item1.bounds)): - if self.interleaved: - if j < len(item1.bounds) // 2: - bounds.append(max(item1.bounds[j], item2.bounds[j])) + if self.interleaved: + # For each Item in self... + for item1 in self.intersection(self.bounds, objects=True): + # For each Item in other that intersects... + for item2 in other.intersection(item1.bbox, objects=True): + # Compute the intersection bounding box + bbox = [] + for j in range(len(item1.bbox)): + if j < len(item1.bbox) // 2: + bbox.append(max(item1.bbox[j], item2.bbox[j])) else: - bounds.append(min(item1.bounds[j], item2.bounds[j])) - else: + bbox.append(min(item1.bbox[j], item2.bbox[j])) + + new_idx.insert(i, bbox, (item1.object, item2.object)) + i += 1 + + else: + # For each Item in self... + for item1 in self.intersection(self.bounds, objects=True): + # For each Item in other that intersects... + for item2 in other.intersection(item1.bounds, objects=True): + # Compute the intersection bounding box + bounds = [] + for j in range(len(item1.bounds)): if j % 2 == 0: bounds.append(max(item1.bounds[j], item2.bounds[j])) else: bounds.append(min(item1.bounds[j], item2.bounds[j])) - new_idx.insert(i, bounds, (item1.object, item2.object)) - i += 1 + new_idx.insert(i, bounds, (item1.object, item2.object)) + i += 1 return new_idx @@ -715,7 +726,10 @@ def __or__(self, other: Index) -> Index: for old_idx in [self, other]: # For each item... for item in old_idx.intersection(old_idx.bounds, objects=True): - new_idx.insert(item.id, item.bounds, item.object) + if self.interleaved: + new_idx.insert(item.id, item.bbox, item.object) + else: + new_idx.insert(item.id, item.bounds, item.object) return new_idx diff --git a/tests/test_index.py b/tests/test_index.py index 943b6828..e89d7b08 100644 --- a/tests/test_index.py +++ b/tests/test_index.py @@ -264,6 +264,120 @@ def test_double_insertion(self) -> None: self.assertEqual([1, 1], list(idx.intersection((0, 0, 5, 5)))) +class TestIndexIntersectionUnion: + @pytest.fixture(scope="class") + def index_a_interleaved(self) -> index.Index: + idx = index.Index(interleaved=True) + idx.insert(1, (3, 3, 5, 5), "a_1") + idx.insert(2, (4, 2, 6, 4), "a_2") + return idx + + @pytest.fixture(scope="class") + def index_a_uninterleaved(self) -> index.Index: + idx = index.Index(interleaved=False) + idx.insert(1, (3, 5, 3, 5), "a_1") + idx.insert(2, (4, 6, 2, 4), "a_2") + return idx + + @pytest.fixture(scope="class") + def index_b_interleaved(self) -> index.Index: + idx = index.Index(interleaved=True) + idx.insert(3, (2, 1, 7, 6), "b_3") + idx.insert(4, (8, 7, 9, 8), "b_4") + return idx + + @pytest.fixture(scope="class") + def index_b_uninterleaved(self) -> index.Index: + idx = index.Index(interleaved=False) + idx.insert(3, (2, 7, 1, 6), "b_3") + idx.insert(4, (8, 9, 7, 8), "b_4") + return idx + + def test_intersection_interleaved( + self, index_a_interleaved: index.Index, index_b_interleaved: index.Index + ) -> None: + index_c_interleaved = index_a_interleaved & index_b_interleaved + assert index_c_interleaved.interleaved + assert len(index_c_interleaved) == 2 + for hit in index_c_interleaved.intersection( + index_c_interleaved.bounds, objects=True + ): + if hit.bbox == [3.0, 3.0, 5.0, 5.0]: + assert hit.object == ("a_1", "b_3") + elif hit.bbox == [4.0, 2.0, 6.0, 4.0]: + assert hit.object == ("a_2", "b_3") + else: + assert False + + def test_intersection_uninterleaved( + self, index_a_uninterleaved: index.Index, index_b_uninterleaved: index.Index + ) -> None: + index_c_uninterleaved = index_a_uninterleaved & index_b_uninterleaved + assert not index_c_uninterleaved.interleaved + assert len(index_c_uninterleaved) == 2 + for hit in index_c_uninterleaved.intersection( + index_c_uninterleaved.bounds, objects=True + ): + if hit.bounds == [3.0, 5.0, 3.0, 5.0]: + assert hit.object == ("a_1", "b_3") + elif hit.bounds == [4.0, 6.0, 2.0, 4.0]: + assert hit.object == ("a_2", "b_3") + else: + assert False + + def test_intersection_mismatch( + self, index_a_interleaved: index.Index, index_b_uninterleaved: index.Index + ) -> None: + with pytest.raises(AssertionError): + index_a_interleaved & index_b_uninterleaved + + def test_union_interleaved( + self, index_a_interleaved: index.Index, index_b_interleaved: index.Index + ) -> None: + index_c_interleaved = index_a_interleaved | index_b_interleaved + assert index_c_interleaved.interleaved + assert len(index_c_interleaved) == 4 + for hit in index_c_interleaved.intersection( + index_c_interleaved.bounds, objects=True + ): + if hit.bbox == [3.0, 3.0, 5.0, 5.0]: + assert hit.object == "a_1" + elif hit.bbox == [4.0, 2.0, 6.0, 4.0]: + assert hit.object == "a_2" + elif hit.bbox == [2.0, 1.0, 7.0, 6.0]: + assert hit.object == "b_3" + elif hit.bbox == [8.0, 7.0, 9.0, 8.0]: + assert hit.object == "b_4" + else: + assert False + + def test_union_uninterleaved( + self, index_a_uninterleaved: index.Index, index_b_uninterleaved: index.Index + ) -> None: + index_c_uninterleaved = index_a_uninterleaved | index_b_uninterleaved + assert not index_c_uninterleaved.interleaved + assert len(index_c_uninterleaved) == 4 + for hit in index_c_uninterleaved.intersection( + index_c_uninterleaved.bounds, objects=True + ): + if hit.bounds == [3.0, 5.0, 3.0, 5.0]: + assert hit.object == "a_1" + elif hit.bounds == [4.0, 6.0, 2.0, 4.0]: + assert hit.object == "a_2" + elif hit.bounds == [2.0, 7.0, 1.0, 6.0]: + assert hit.object == "b_3" + elif hit.bounds == [8.0, 9.0, 7.0, 8.0]: + assert hit.object == "b_4" + else: + assert False + + def test_union_mismatch( + self, index_a_interleaved: index.Index, index_b_uninterleaved: index.Index + ) -> None: + with pytest.raises(AssertionError): + index_a_interleaved | index_b_uninterleaved + + class IndexSerialization(unittest.TestCase): def setUp(self) -> None: self.boxes15 = np.genfromtxt("boxes_15x15.data") From bcbb3b2804f08f2a48929141569d21f64166b4f6 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Fri, 4 Mar 2022 15:58:41 -0600 Subject: [PATCH 4/4] Simplify code --- rtree/index.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/rtree/index.py b/rtree/index.py index f96fb8e5..e2fc9de8 100644 --- a/rtree/index.py +++ b/rtree/index.py @@ -676,9 +676,9 @@ def __and__(self, other: Index) -> Index: i = 0 new_idx = Index(interleaved=self.interleaved, properties=self.properties) - if self.interleaved: - # For each Item in self... - for item1 in self.intersection(self.bounds, objects=True): + # For each Item in self... + for item1 in self.intersection(self.bounds, objects=True): + if self.interleaved: # For each Item in other that intersects... for item2 in other.intersection(item1.bbox, objects=True): # Compute the intersection bounding box @@ -692,9 +692,7 @@ def __and__(self, other: Index) -> Index: new_idx.insert(i, bbox, (item1.object, item2.object)) i += 1 - else: - # For each Item in self... - for item1 in self.intersection(self.bounds, objects=True): + else: # For each Item in other that intersects... for item2 in other.intersection(item1.bounds, objects=True): # Compute the intersection bounding box