Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Index: add intersection/union support #210

Merged
merged 4 commits into from
Mar 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions rtree/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,6 +663,74 @@ def contains(
)
return self._get_ids(it, p_num_results.value)

def __and__(self, other: Index) -> Index:
"""Take the intersection of two Index objects.

:param other: another index
:return: a new index
:raises AssertionError: if self and other have different interleave or dimension
"""
assert self.interleaved == other.interleaved
assert self.properties.dimension == other.properties.dimension

i = 0
new_idx = Index(interleaved=self.interleaved, properties=self.properties)

# For each Item in self...
for item1 in self.intersection(self.bounds, objects=True):
if self.interleaved:
# For each Item in other that intersects...
for item2 in other.intersection(item1.bbox, objects=True):
# Compute the intersection bounding box
bbox = []
for j in range(len(item1.bbox)):
if j < len(item1.bbox) // 2:
bbox.append(max(item1.bbox[j], item2.bbox[j]))
else:
bbox.append(min(item1.bbox[j], item2.bbox[j]))

new_idx.insert(i, bbox, (item1.object, item2.object))
i += 1

else:
# For each Item in other that intersects...
for item2 in other.intersection(item1.bounds, objects=True):
# Compute the intersection bounding box
bounds = []
for j in range(len(item1.bounds)):
if j % 2 == 0:
bounds.append(max(item1.bounds[j], item2.bounds[j]))
else:
bounds.append(min(item1.bounds[j], item2.bounds[j]))

new_idx.insert(i, bounds, (item1.object, item2.object))
i += 1

return new_idx

def __or__(self, other: Index) -> Index:
"""Take the union of two Index objects.

:param other: another index
:return: a new index
:raises AssertionError: if self and other have different interleave or dimension
"""
assert self.interleaved == other.interleaved
assert self.properties.dimension == other.properties.dimension

new_idx = Index(interleaved=self.interleaved, properties=self.properties)

# For each index...
for old_idx in [self, other]:
# For each item...
for item in old_idx.intersection(old_idx.bounds, objects=True):
if self.interleaved:
new_idx.insert(item.id, item.bbox, item.object)
else:
new_idx.insert(item.id, item.bounds, item.object)

return new_idx

@overload
def intersection(self, coordinates: Any, objects: Literal[True]) -> Iterator[Item]:
...
Expand Down
114 changes: 114 additions & 0 deletions tests/test_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,120 @@ def test_double_insertion(self) -> None:
self.assertEqual([1, 1], list(idx.intersection((0, 0, 5, 5))))


class TestIndexIntersectionUnion:
@pytest.fixture(scope="class")
def index_a_interleaved(self) -> index.Index:
idx = index.Index(interleaved=True)
idx.insert(1, (3, 3, 5, 5), "a_1")
idx.insert(2, (4, 2, 6, 4), "a_2")
return idx

@pytest.fixture(scope="class")
def index_a_uninterleaved(self) -> index.Index:
idx = index.Index(interleaved=False)
idx.insert(1, (3, 5, 3, 5), "a_1")
idx.insert(2, (4, 6, 2, 4), "a_2")
return idx

@pytest.fixture(scope="class")
def index_b_interleaved(self) -> index.Index:
idx = index.Index(interleaved=True)
idx.insert(3, (2, 1, 7, 6), "b_3")
idx.insert(4, (8, 7, 9, 8), "b_4")
return idx

@pytest.fixture(scope="class")
def index_b_uninterleaved(self) -> index.Index:
idx = index.Index(interleaved=False)
idx.insert(3, (2, 7, 1, 6), "b_3")
idx.insert(4, (8, 9, 7, 8), "b_4")
return idx

def test_intersection_interleaved(
self, index_a_interleaved: index.Index, index_b_interleaved: index.Index
) -> None:
index_c_interleaved = index_a_interleaved & index_b_interleaved
assert index_c_interleaved.interleaved
assert len(index_c_interleaved) == 2
for hit in index_c_interleaved.intersection(
index_c_interleaved.bounds, objects=True
):
if hit.bbox == [3.0, 3.0, 5.0, 5.0]:
assert hit.object == ("a_1", "b_3")
elif hit.bbox == [4.0, 2.0, 6.0, 4.0]:
assert hit.object == ("a_2", "b_3")
else:
assert False

def test_intersection_uninterleaved(
self, index_a_uninterleaved: index.Index, index_b_uninterleaved: index.Index
) -> None:
index_c_uninterleaved = index_a_uninterleaved & index_b_uninterleaved
assert not index_c_uninterleaved.interleaved
assert len(index_c_uninterleaved) == 2
for hit in index_c_uninterleaved.intersection(
index_c_uninterleaved.bounds, objects=True
):
if hit.bounds == [3.0, 5.0, 3.0, 5.0]:
assert hit.object == ("a_1", "b_3")
elif hit.bounds == [4.0, 6.0, 2.0, 4.0]:
assert hit.object == ("a_2", "b_3")
else:
assert False

def test_intersection_mismatch(
self, index_a_interleaved: index.Index, index_b_uninterleaved: index.Index
) -> None:
with pytest.raises(AssertionError):
index_a_interleaved & index_b_uninterleaved

def test_union_interleaved(
self, index_a_interleaved: index.Index, index_b_interleaved: index.Index
) -> None:
index_c_interleaved = index_a_interleaved | index_b_interleaved
assert index_c_interleaved.interleaved
assert len(index_c_interleaved) == 4
for hit in index_c_interleaved.intersection(
index_c_interleaved.bounds, objects=True
):
if hit.bbox == [3.0, 3.0, 5.0, 5.0]:
assert hit.object == "a_1"
elif hit.bbox == [4.0, 2.0, 6.0, 4.0]:
assert hit.object == "a_2"
elif hit.bbox == [2.0, 1.0, 7.0, 6.0]:
assert hit.object == "b_3"
elif hit.bbox == [8.0, 7.0, 9.0, 8.0]:
assert hit.object == "b_4"
else:
assert False

def test_union_uninterleaved(
self, index_a_uninterleaved: index.Index, index_b_uninterleaved: index.Index
) -> None:
index_c_uninterleaved = index_a_uninterleaved | index_b_uninterleaved
assert not index_c_uninterleaved.interleaved
assert len(index_c_uninterleaved) == 4
for hit in index_c_uninterleaved.intersection(
index_c_uninterleaved.bounds, objects=True
):
if hit.bounds == [3.0, 5.0, 3.0, 5.0]:
assert hit.object == "a_1"
elif hit.bounds == [4.0, 6.0, 2.0, 4.0]:
assert hit.object == "a_2"
elif hit.bounds == [2.0, 7.0, 1.0, 6.0]:
assert hit.object == "b_3"
elif hit.bounds == [8.0, 9.0, 7.0, 8.0]:
assert hit.object == "b_4"
else:
assert False

def test_union_mismatch(
self, index_a_interleaved: index.Index, index_b_uninterleaved: index.Index
) -> None:
with pytest.raises(AssertionError):
index_a_interleaved | index_b_uninterleaved


class IndexSerialization(unittest.TestCase):
def setUp(self) -> None:
self.boxes15 = np.genfromtxt("boxes_15x15.data")
Expand Down