Skip to content

Commit

Permalink
Make IDViews respect order (#482)
Browse files Browse the repository at this point in the history
* make views respect order

* Fix diviews

* fix issue

* improved tests
  • Loading branch information
nwlandry authored Oct 26, 2023
1 parent 84187d6 commit 5e194bc
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 4 deletions.
4 changes: 3 additions & 1 deletion tests/core/test_diviews.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,13 +191,15 @@ def test_view_len(diedgelist2):

def test_bunch_view(diedgelist2):
H = xgi.DiHypergraph(diedgelist2)
bunch_view = H.edges.from_view(H.edges, bunch=[1, 2])
bunch_view = H.edges.from_view(H.edges, bunch=[2, 1])
assert len(bunch_view) == 2
assert (1 in bunch_view) and (2 in bunch_view)
assert 0 not in bunch_view
assert bunch_view.members(dtype=dict) == {1: {1, 2, 4}, 2: {2, 3, 4, 5}}
with pytest.raises(IDNotFound):
bunch_view.members(0)
# test ID order
assert list(bunch_view) == [1, 2]


def test_call_wrong_bunch():
Expand Down
4 changes: 3 additions & 1 deletion tests/core/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,14 +210,16 @@ def test_view_len(edgelist2):

def test_bunch_view(edgelist1):
H = xgi.Hypergraph(edgelist1)
bunch_view = H.edges.from_view(H.edges, bunch=[1, 2])
bunch_view = H.edges.from_view(H.edges, bunch=[2, 1])
assert len(bunch_view) == 2
assert (1 in bunch_view) and (2 in bunch_view)
assert 0 not in bunch_view
assert bunch_view.members(dtype=dict) == {1: {4}, 2: {5, 6}}
with pytest.raises(IDNotFound):
bunch_view.members(0)

assert list(bunch_view) == [1, 2]


def test_call_wrong_bunch():
H = xgi.Hypergraph()
Expand Down
2 changes: 1 addition & 1 deletion xgi/core/diviews.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ def from_view(cls, view, bunch=None):
wrong = bunch - all_ids
if wrong:
raise IDNotFound(f"IDs {wrong} not in the hypergraph")
newview._ids = bunch
newview._ids = [i for i in view._in_id_dict if i in bunch]
return newview

def _from_iterable(self, it):
Expand Down
2 changes: 1 addition & 1 deletion xgi/core/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ def from_view(cls, view, bunch=None):
wrong = bunch - all_ids
if wrong:
raise IDNotFound(f"IDs {wrong} not in the hypergraph")
newview._ids = bunch
newview._ids = [i for i in view._id_dict if i in bunch]
return newview

def _from_iterable(self, it):
Expand Down

0 comments on commit 5e194bc

Please sign in to comment.