From dc16084ef9a9f93616e54280a719d9cf46a23f70 Mon Sep 17 00:00:00 2001 From: Matthew Treinish Date: Thu, 20 Oct 2022 12:50:54 -0400 Subject: [PATCH 1/2] Add transitive_closure_dag function This commit adds a new function transitive_closure_dag() which is an optimized method for computing the transitive closure for DAGs. In support of this a new function descendants_at_distance() for finding the nodes a fixed distance from a given source to both rustworkx and rustworkx-core. Related to: #704 --- docs/source/api.rst | 4 ++ ...ansitive-closure-dag-3fb45113d552f007.yaml | 13 ++++ rustworkx-core/src/traversal/descendants.rs | 44 +++++++++++++ rustworkx-core/src/traversal/mod.rs | 2 + rustworkx/__init__.py | 34 ++++++++++ src/dag_algo/mod.rs | 49 ++++++++++++++ src/lib.rs | 3 + src/traversal/mod.rs | 66 ++++++++++++++++++- .../digraph/test_transitive_closure.py | 33 ++++++++++ 9 files changed, 247 insertions(+), 1 deletion(-) create mode 100644 releasenotes/notes/add-transitive-closure-dag-3fb45113d552f007.yaml create mode 100644 rustworkx-core/src/traversal/descendants.rs create mode 100644 tests/rustworkx_tests/digraph/test_transitive_closure.py diff --git a/docs/source/api.rst b/docs/source/api.rst index 8368188c2..3c6853736 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -79,6 +79,7 @@ Traversal rustworkx.visit.BFSVisitor rustworkx.visit.DijkstraVisitor rustworkx.TopologicalSorter + rustworkx.descendants_at_distance .. _dag-algorithms: @@ -94,6 +95,7 @@ DAG Algorithms rustworkx.dag_weighted_longest_path_length rustworkx.is_directed_acyclic_graph rustworkx.layers + rustworkx.transitive_closure_dag .. _tree: @@ -325,6 +327,7 @@ the functions from the explicitly typed based on the data type. rustworkx.digraph_bfs_search rustworkx.digraph_dijkstra_search rustworkx.digraph_node_link_json + rustworkx.digraph_descendants_at_distance .. _api-functions-pygraph: @@ -379,6 +382,7 @@ typed API based on the data type. rustworkx.graph_bfs_search rustworkx.graph_dijkstra_search rustworkx.graph_node_link_json + rustworkx.graph_descendants_at_distance Exceptions ========== diff --git a/releasenotes/notes/add-transitive-closure-dag-3fb45113d552f007.yaml b/releasenotes/notes/add-transitive-closure-dag-3fb45113d552f007.yaml new file mode 100644 index 000000000..cc5099212 --- /dev/null +++ b/releasenotes/notes/add-transitive-closure-dag-3fb45113d552f007.yaml @@ -0,0 +1,13 @@ +--- +features: + - | + Added a new function ``descendants_at_distance`` to the rustworkx-core + crate under the ``traversal`` module + - | + Added a new function, :func:`~.transitive_closure_dag`, which provides + an optimize method for computing the transitive closure of an input + DAG. + - | + Added a new function :func:`~.descendants_at_distance` which provides + a method to find the nodes at a fixed distance from a source in + a graph object. diff --git a/rustworkx-core/src/traversal/descendants.rs b/rustworkx-core/src/traversal/descendants.rs new file mode 100644 index 000000000..67bf6bf60 --- /dev/null +++ b/rustworkx-core/src/traversal/descendants.rs @@ -0,0 +1,44 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +use hashbrown::HashSet; +use petgraph::visit::{IntoNeighborsDirected, NodeCount, Visitable}; + +/// Returns all nodes at a fixed `distance` from `source` in `G`. +/// Args: +/// `graph`: +/// `source`: +/// `distance`: +pub fn descendants_at_distance(graph: G, source: G::NodeId, distance: usize) -> Vec +where + G: Visitable + IntoNeighborsDirected + NodeCount, + G::NodeId: std::cmp::Eq + std::hash::Hash, +{ + let mut current_layer: Vec = vec![source]; + let mut layers: usize = 0; + let mut visited: HashSet = HashSet::with_capacity(graph.node_count()); + visited.insert(source); + while !current_layer.is_empty() && layers < distance { + let mut next_layer: Vec = Vec::new(); + for node in current_layer { + for child in graph.neighbors_directed(node, petgraph::Outgoing) { + if !visited.contains(&child) { + visited.insert(child); + next_layer.push(child); + } + } + } + current_layer = next_layer; + layers += 1; + } + current_layer +} diff --git a/rustworkx-core/src/traversal/mod.rs b/rustworkx-core/src/traversal/mod.rs index 2a6254481..b531075da 100644 --- a/rustworkx-core/src/traversal/mod.rs +++ b/rustworkx-core/src/traversal/mod.rs @@ -13,11 +13,13 @@ //! Module for graph traversal algorithms. mod bfs_visit; +mod descendants; mod dfs_edges; mod dfs_visit; mod dijkstra_visit; pub use bfs_visit::{breadth_first_search, BfsEvent}; +pub use descendants::descendants_at_distance; pub use dfs_edges::dfs_edges; pub use dfs_visit::{depth_first_search, DfsEvent}; pub use dijkstra_visit::{dijkstra_search, DijkstraEvent}; diff --git a/rustworkx/__init__.py b/rustworkx/__init__.py index 6d82d41f2..313d2831c 100644 --- a/rustworkx/__init__.py +++ b/rustworkx/__init__.py @@ -2382,3 +2382,37 @@ def _graph_node_link_json(graph, path=None, graph_attrs=None, node_attrs=None, e return graph_node_link_json( graph, path=path, graph_attrs=graph_attrs, node_attrs=node_attrs, edge_attrs=edge_attrs ) + + +@functools.singledispatch +def descendants_at_distance(graph, source, distance): + """Returns all nodes at a fixed distance from ``source`` in ``graph`` + + :param graph: The graph to find the descendants in + :param int source: The node index to find the descendants from + :param int distance: The distance from ``source`` + + :returns: The node indices of the nodes ``distance`` from ``source`` in ``graph``. + :rtype: NodeIndices + + For example:: + + import rustworkx as rx + + graph = rx.generators.path_graph(5) + res = rx.descendants_at_distance(graph, 2, 2) + print(res) + + will return: ``[0, 4]`` + """ + raise TypeError("Invalid Input Type %s for graph" % type(graph)) + + +@descendants_at_distance.register(PyDiGraph) +def _digraph_descendants_at_distance(graph, source, distance): + return digraph_descendants_at_distance(graph, source, distance) + + +@descendants_at_distance.register(PyGraph) +def _graph_descendants_at_distance(graph, source, distance): + return graph_descendants_at_distance(graph, source, distance) diff --git a/src/dag_algo/mod.rs b/src/dag_algo/mod.rs index dbfcf4702..bf6e0f2a3 100644 --- a/src/dag_algo/mod.rs +++ b/src/dag_algo/mod.rs @@ -29,6 +29,8 @@ use petgraph::graph::NodeIndex; use petgraph::prelude::*; use petgraph::visit::NodeCount; +use rustworkx_core::traversal::descendants_at_distance; + /// Find the longest path in a DAG /// /// :param PyDiGraph graph: The graph to find the longest path on. The input @@ -634,3 +636,50 @@ pub fn collect_bicolor_runs( Ok(block_list) } + +/// Return the transitive closure of a directed acyclic graph +/// +/// The transitive closure of :math:`G = (V, E)` is a graph :math:`G+ = (V, E+)` +/// such that for all pairs of :math:`v, w` in :math:`V` there is an edge +/// :math:`(v, w) in :math:`E+` if and only if there is a non-null path from +/// :math:`v` to :math:`w` in :math:`G`. +/// +/// :param PyDiGraph graph: The input DAG to compute the transitive closure of +/// :param list topological_order: An optional topological order for ``graph`` +/// which represents the order the graph will be traversed in computing +/// the transitive closure. If one is not provided (or it is explicitly +/// set to ``None``) a topological order will be computed by this function. +/// +/// :returns: The transitive closure of ``graph`` +/// :rtype: PyDiGraph +/// +/// :raises DAGHasCycle: If the input ``graph`` is not acyclic +#[pyfunction] +#[pyo3(text_signature = "(graph, / topological_order=None)")] +pub fn transitive_closure_dag( + py: Python, + graph: &digraph::PyDiGraph, + topological_order: Option>, +) -> PyResult { + let node_order: Vec = match topological_order { + Some(topo_order) => topo_order.into_iter().map(NodeIndex::new).collect(), + None => match algo::toposort(&graph.graph, None) { + Ok(nodes) => nodes, + Err(_err) => return Err(DAGHasCycle::new_err("Topological Sort encountered a cycle")), + }, + }; + let mut out_graph = graph.graph.clone(); + for node in node_order.into_iter().rev() { + for descendant in descendants_at_distance(&out_graph, node, 2) { + out_graph.add_edge(node, descendant, py.None()); + } + } + Ok(digraph::PyDiGraph { + graph: out_graph, + cycle_state: algo::DfsSpace::default(), + check_cycle: false, + node_removed: false, + multigraph: true, + attrs: py.None(), + }) +} diff --git a/src/lib.rs b/src/lib.rs index a9ad4f4dd..ad449de46 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -477,6 +477,9 @@ fn rustworkx(py: Python<'_>, m: &PyModule) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(read_graphml))?; m.add_wrapped(wrap_pyfunction!(digraph_node_link_json))?; m.add_wrapped(wrap_pyfunction!(graph_node_link_json))?; + m.add_wrapped(wrap_pyfunction!(transitive_closure_dag))?; + m.add_wrapped(wrap_pyfunction!(graph_descendants_at_distance))?; + m.add_wrapped(wrap_pyfunction!(digraph_descendants_at_distance))?; m.add_class::()?; m.add_class::()?; m.add_class::()?; diff --git a/src/traversal/mod.rs b/src/traversal/mod.rs index 8f27cad86..0709322bb 100644 --- a/src/traversal/mod.rs +++ b/src/traversal/mod.rs @@ -19,7 +19,7 @@ use dfs_visit::{dfs_handler, PyDfsVisitor}; use dijkstra_visit::{dijkstra_handler, PyDijkstraVisitor}; use rustworkx_core::traversal::{ - breadth_first_search, depth_first_search, dfs_edges, dijkstra_search, + breadth_first_search, depth_first_search, descendants_at_distance, dfs_edges, dijkstra_search, }; use super::{digraph, graph, iterators, CostFn}; @@ -707,3 +707,67 @@ pub fn graph_dijkstra_search( Ok(()) } + +/// Returns all nodes at a fixed distance from ``source`` in ``graph`` +/// +/// :param PyGraph graph: The graph to find the descendants in +/// :param int source: The node index to find the descendants from +/// :param int distance: The distance from ``source`` +/// +/// :returns: The node indices of the nodes ``distance`` from ``source`` in ``graph``. +/// :rtype: NodeIndices +/// For example:: +/// +/// import rustworkx as rx +/// +/// graph = rx.generators.path_graph(5) +/// res = rx.descendants_at_distance(graph, 2, 2) +/// print(res) +/// +/// will return: ``[0, 4]`` +#[pyfunction] +pub fn graph_descendants_at_distance( + graph: graph::PyGraph, + source: usize, + distance: usize, +) -> iterators::NodeIndices { + let source = NodeIndex::new(source); + iterators::NodeIndices { + nodes: descendants_at_distance(&graph.graph, source, distance) + .into_iter() + .map(|x| x.index()) + .collect(), + } +} + +/// Returns all nodes at a fixed distance from ``source`` in ``graph`` +/// +/// :param PyDiGraph graph: The graph to find the descendants in +/// :param int source: The node index to find the descendants from +/// :param int distance: The distance from ``source`` +/// +/// :returns: The node indices of the nodes ``distance`` from ``source`` in ``graph``. +/// :rtype: NodeIndices +/// For example:: +/// +/// import rustworkx as rx +/// +/// graph = rx.generators.directed_path_graph(5) +/// res = rx.descendants_at_distance(graph, 2, 2) +/// print(res) +/// +/// will return: ``[4]`` +#[pyfunction] +pub fn digraph_descendants_at_distance( + graph: digraph::PyDiGraph, + source: usize, + distance: usize, +) -> iterators::NodeIndices { + let source = NodeIndex::new(source); + iterators::NodeIndices { + nodes: descendants_at_distance(&graph.graph, source, distance) + .into_iter() + .map(|x| x.index()) + .collect(), + } +} diff --git a/tests/rustworkx_tests/digraph/test_transitive_closure.py b/tests/rustworkx_tests/digraph/test_transitive_closure.py new file mode 100644 index 000000000..cf707bf55 --- /dev/null +++ b/tests/rustworkx_tests/digraph/test_transitive_closure.py @@ -0,0 +1,33 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import unittest + +import rustworkx as rx + + +class TestTransitivity(unittest.TestCase): + def test_path_graph(self): + graph = rx.generators.directed_path_graph(4) + transitive_closure = rx.transitive_closure_dag(graph) + expected_edge_list = [(0, 1), (1, 2), (2, 3), (1, 3), (0, 3), (0, 2)] + self.assertEqual(transitive_closure.edge_list(), expected_edge_list) + + def test_invalid_type(self): + with self.assertRaises(TypeError): + rx.transitive_closure_dag(rx.PyGraph()) + + def test_cycle_error(self): + graph = rx.PyDiGraph() + graph.extend_from_edge_list([(0, 1), (1, 0)]) + with self.assertRaises(rx.DAGHasCycle): + rx.transitive_closure_dag(graph) From c4c0ab73444c01681402115f2a561b5f320d3f60 Mon Sep 17 00:00:00 2001 From: Matthew Treinish Date: Wed, 26 Oct 2022 09:24:37 -0400 Subject: [PATCH 2/2] Add build_transitive_closure_dag() to rustworkx-core This commit adds a new function to rustworkx-core for building a transitive closure inplace out of an input DAG. The function takes ownership of the input graph and will mutate it to add edtges to make it a transitive closure before returning it. This is then used internally by the retworkx python function transitive_clsoure_dag(). --- ...ansitive-closure-dag-3fb45113d552f007.yaml | 3 + rustworkx-core/src/traversal/mod.rs | 2 + .../src/traversal/transitive_closure.rs | 83 +++++++++++++++++++ src/dag_algo/mod.rs | 37 ++++----- 4 files changed, 104 insertions(+), 21 deletions(-) create mode 100644 rustworkx-core/src/traversal/transitive_closure.rs diff --git a/releasenotes/notes/add-transitive-closure-dag-3fb45113d552f007.yaml b/releasenotes/notes/add-transitive-closure-dag-3fb45113d552f007.yaml index cc5099212..230527e24 100644 --- a/releasenotes/notes/add-transitive-closure-dag-3fb45113d552f007.yaml +++ b/releasenotes/notes/add-transitive-closure-dag-3fb45113d552f007.yaml @@ -3,6 +3,9 @@ features: - | Added a new function ``descendants_at_distance`` to the rustworkx-core crate under the ``traversal`` module + - | + Added a new function ``build_transitive_closure_dag`` to the rustworkx-core + crate under the ``traversal`` module. - | Added a new function, :func:`~.transitive_closure_dag`, which provides an optimize method for computing the transitive closure of an input diff --git a/rustworkx-core/src/traversal/mod.rs b/rustworkx-core/src/traversal/mod.rs index b531075da..8e4437e80 100644 --- a/rustworkx-core/src/traversal/mod.rs +++ b/rustworkx-core/src/traversal/mod.rs @@ -17,12 +17,14 @@ mod descendants; mod dfs_edges; mod dfs_visit; mod dijkstra_visit; +mod transitive_closure; pub use bfs_visit::{breadth_first_search, BfsEvent}; pub use descendants::descendants_at_distance; pub use dfs_edges::dfs_edges; pub use dfs_visit::{depth_first_search, DfsEvent}; pub use dijkstra_visit::{dijkstra_search, DijkstraEvent}; +pub use transitive_closure::build_transitive_closure_dag; /// Return if the expression is a break value, execute the provided statement /// if it is a prune value. diff --git a/rustworkx-core/src/traversal/transitive_closure.rs b/rustworkx-core/src/traversal/transitive_closure.rs new file mode 100644 index 000000000..6d6fd196f --- /dev/null +++ b/rustworkx-core/src/traversal/transitive_closure.rs @@ -0,0 +1,83 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +use petgraph::algo::{toposort, Cycle}; +use petgraph::data::Build; +use petgraph::visit::{ + GraphBase, IntoNeighborsDirected, IntoNodeIdentifiers, NodeCount, Visitable, +}; + +use crate::traversal::descendants_at_distance; + +/// Build a transitive closure out of a given DAG +/// +/// This function will mutate a given DAG object (which is typically moved to +/// this function) into a transitive closure of the graph and then returned. +/// If you'd like to preserve the input graph pass a clone of the original graph. +/// The transitive closure of :math:`G = (V, E)` is a graph :math:`G+ = (V, E+)` +/// such that for all pairs of :math:`v, w` in :math:`V` there is an edge +/// :math:`(v, w) in :math:`E+` if and only if there is a non-null path from +/// :math:`v` to :math:`w` in :math:`G`. This funciton provides an optimized +/// path for computing the the transitive closure of a DAG, if the input graph +/// contains cycles it will error. +/// +/// Arguments: +/// +/// - `graph`: A mutable graph object representing the DAG +/// - `topological_order`: An optional `Vec` of node identifiers representing +/// the topological order to traverse the DAG with. If not specified the +/// `petgraph::algo::toposort` function will be called to generate this +/// - `default_edge_weight`: A callable function that takes no arguments and +/// returns the `EdgeWeight` type object to use for each edge added to +/// `graph +/// +/// # Example +/// +/// ```rust +/// use rustworkx_core::traversal::build_transitive_closure_dag; +/// +/// let g = petgraph::graph::DiGraph::::from_edges(&[(0, 1, 0), (1, 2, 0), (2, 3, 0)]); +/// +/// let res = build_transitive_closure_dag(g, None, || -> i32 {0}); +/// let out_graph = res.unwrap(); +/// let out_edges: Vec<(usize, usize)> = out_graph +/// .edge_indices() +/// .map(|e| { +/// let endpoints = out_graph.edge_endpoints(e).unwrap(); +/// (endpoints.0.index(), endpoints.1.index()) +/// }) +/// .collect(); +/// assert_eq!(vec![(0, 1), (1, 2), (2, 3), (1, 3), (0, 3), (0, 2)], out_edges) +/// ``` +pub fn build_transitive_closure_dag<'a, G, F>( + mut graph: G, + topological_order: Option>, + default_edge_weight: F, +) -> Result> +where + G: NodeCount + Build + Clone, + for<'b> &'b G: + GraphBase + Visitable + IntoNeighborsDirected + IntoNodeIdentifiers, + G::NodeId: std::cmp::Eq + std::hash::Hash, + F: Fn() -> G::EdgeWeight, +{ + let node_order: Vec = match topological_order { + Some(topo_order) => topo_order, + None => toposort(&graph, None)?, + }; + for node in node_order.into_iter().rev() { + for descendant in descendants_at_distance(&graph, node, 2) { + graph.add_edge(node, descendant, default_edge_weight()); + } + } + Ok(graph) +} diff --git a/src/dag_algo/mod.rs b/src/dag_algo/mod.rs index bf6e0f2a3..23239d85b 100644 --- a/src/dag_algo/mod.rs +++ b/src/dag_algo/mod.rs @@ -29,7 +29,7 @@ use petgraph::graph::NodeIndex; use petgraph::prelude::*; use petgraph::visit::NodeCount; -use rustworkx_core::traversal::descendants_at_distance; +use rustworkx_core::traversal; /// Find the longest path in a DAG /// @@ -661,25 +661,20 @@ pub fn transitive_closure_dag( graph: &digraph::PyDiGraph, topological_order: Option>, ) -> PyResult { - let node_order: Vec = match topological_order { - Some(topo_order) => topo_order.into_iter().map(NodeIndex::new).collect(), - None => match algo::toposort(&graph.graph, None) { - Ok(nodes) => nodes, - Err(_err) => return Err(DAGHasCycle::new_err("Topological Sort encountered a cycle")), - }, - }; - let mut out_graph = graph.graph.clone(); - for node in node_order.into_iter().rev() { - for descendant in descendants_at_distance(&out_graph, node, 2) { - out_graph.add_edge(node, descendant, py.None()); - } + let default_weight = || -> PyObject { py.None() }; + match traversal::build_transitive_closure_dag( + graph.graph.clone(), + topological_order.map(|order| order.into_iter().map(NodeIndex::new).collect()), + default_weight, + ) { + Ok(out_graph) => Ok(digraph::PyDiGraph { + graph: out_graph, + cycle_state: algo::DfsSpace::default(), + check_cycle: false, + node_removed: false, + multigraph: true, + attrs: py.None(), + }), + Err(_err) => Err(DAGHasCycle::new_err("Topological Sort encountered a cycle")), } - Ok(digraph::PyDiGraph { - graph: out_graph, - cycle_state: algo::DfsSpace::default(), - check_cycle: false, - node_removed: false, - multigraph: true, - attrs: py.None(), - }) }