Skip to content

Commit

Permalink
Add method to return edge indices from endpoints (#1055)
Browse files Browse the repository at this point in the history
* add edge_indices_from_endpoints for graph and digraph

* fix tests

* Add release notes

* linter

* black

* Fix text_signature for new method

---------

Co-authored-by: Matthew Treinish <[email protected]>
  • Loading branch information
eendebakpt and mtreinish authored Jan 19, 2024
1 parent 27c88f6 commit e8e380b
Show file tree
Hide file tree
Showing 7 changed files with 83 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
features:
- |
Added method :meth:`~rustworkx.PyGraph.edge_indices_from_endpoints` which returns the indices of all edges
between the specified endpoints. For :class:`~rustworkx.PyDiGraph` there is a corresponding method that returns the
directed edges.
1 change: 1 addition & 0 deletions rustworkx/digraph.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ class PyDiGraph(Generic[S, T]):
def copy(self) -> PyDiGraph[S, T]: ...
def edge_index_map(self) -> EdgeIndexMap[T]: ...
def edge_indices(self) -> EdgeIndices: ...
def edge_indices_from_endpoints(self, node_a: int, node_b: int) -> EdgeIndices: ...
def edge_list(self) -> EdgeList: ...
def edges(self) -> list[T]: ...
def edge_subgraph(self, edge_list: Sequence[tuple[int, int]], /) -> PyDiGraph[S, T]: ...
Expand Down
1 change: 1 addition & 0 deletions rustworkx/graph.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ class PyGraph(Generic[S, T]):
def degree(self, node: int, /) -> int: ...
def edge_index_map(self) -> EdgeIndexMap[T]: ...
def edge_indices(self) -> EdgeIndices: ...
def edge_indices_from_endpoints(self, node_a: int, node_b: int) -> EdgeIndices: ...
def edge_list(self) -> EdgeList: ...
def edges(self) -> list[T]: ...
def edge_subgraph(self, edge_list: Sequence[tuple[int, int]], /) -> PyGraph[S, T]: ...
Expand Down
18 changes: 18 additions & 0 deletions src/digraph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,24 @@ impl PyDiGraph {
}
}

/// Return a list of indices of all directed edges between specified nodes
///
/// :returns: A list of all the edge indices connecting the specified start and end node
/// :rtype: EdgeIndices
pub fn edge_indices_from_endpoints(&self, node_a: usize, node_b: usize) -> EdgeIndices {
let node_a_index = NodeIndex::new(node_a);
let node_b_index = NodeIndex::new(node_b);

EdgeIndices {
edges: self
.graph
.edges_directed(node_a_index, petgraph::Direction::Outgoing)
.filter(|edge| edge.target() == node_b_index)
.map(|edge| edge.id().index())
.collect(),
}
}

/// Return a list of all node data.
///
/// :returns: A list of all the node data objects in the graph
Expand Down
18 changes: 18 additions & 0 deletions src/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,24 @@ impl PyGraph {
}
}

/// Return a list of indices of all edges between specified nodes
///
/// :returns: A list of all the edge indices connecting the specified start and end node
/// :rtype: EdgeIndices
pub fn edge_indices_from_endpoints(&self, node_a: usize, node_b: usize) -> EdgeIndices {
let node_a_index = NodeIndex::new(node_a);
let node_b_index = NodeIndex::new(node_b);

EdgeIndices {
edges: self
.graph
.edges_directed(node_a_index, petgraph::Direction::Outgoing)
.filter(|edge| edge.target() == node_b_index)
.map(|edge| edge.id().index())
.collect(),
}
}

/// Return a list of all node data.
///
/// :returns: A list of all the node data objects in the graph
Expand Down
19 changes: 19 additions & 0 deletions tests/rustworkx_tests/digraph/test_edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,25 @@ def test_weighted_edge_list_empty(self):
dag = rustworkx.PyDiGraph()
self.assertEqual([], dag.weighted_edge_list())

def test_edge_indices_from_endpoints(self):
dag = rustworkx.PyDiGraph()
dag.add_nodes_from(list(range(4)))
edge_list = [
(0, 1, None),
(1, 2, None),
(0, 2, None),
(2, 3, None),
(0, 3, None),
(0, 2, None),
]
dag.add_edges_from(edge_list)
indices = dag.edge_indices_from_endpoints(0, 0)
self.assertEqual(indices, [])
indices = dag.edge_indices_from_endpoints(0, 1)
self.assertEqual(indices, [0])
indices = dag.edge_indices_from_endpoints(0, 2)
self.assertEqual(set(indices), {2, 5})

def test_extend_from_edge_list(self):
dag = rustworkx.PyDAG()
edge_list = [(0, 1), (1, 2), (0, 2), (2, 3), (0, 3)]
Expand Down
20 changes: 20 additions & 0 deletions tests/rustworkx_tests/graph/test_edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,26 @@ def test_weighted_edge_list_empty(self):
graph = rustworkx.PyGraph()
self.assertEqual([], graph.weighted_edge_list())

def test_edge_indices_from_endpoints(self):
dag = rustworkx.PyGraph()
dag.add_nodes_from(list(range(4)))
edge_list = [
(0, 1, None),
(1, 2, None),
(0, 2, None),
(2, 3, None),
(0, 3, None),
(0, 2, None),
(2, 0, None),
]
dag.add_edges_from(edge_list)
indices = dag.edge_indices_from_endpoints(0, 0)
self.assertEqual(indices, [])
indices = dag.edge_indices_from_endpoints(0, 1)
self.assertEqual(set(indices), {0})
indices = dag.edge_indices_from_endpoints(0, 2)
self.assertEqual(set(indices), {2, 5, 6})

def test_extend_from_edge_list(self):
graph = rustworkx.PyGraph()
edge_list = [(0, 1), (1, 2), (0, 2), (2, 3), (0, 3)]
Expand Down

0 comments on commit e8e380b

Please sign in to comment.