-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
751 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
from enum import Enum | ||
from typing import Callable, Union | ||
|
||
|
||
class EdgeType(Enum): | ||
DEPENDENCY = 0 | ||
TOOL = 1 | ||
|
||
|
||
AdjacentEdge = tuple[str, EdgeType] | ||
|
||
|
||
def dependency(*nodes: str) -> list[AdjacentEdge]: | ||
return [(node, EdgeType.DEPENDENCY) for node in nodes] | ||
|
||
|
||
def tool(*nodes: str) -> list[AdjacentEdge]: | ||
return [(node, EdgeType.TOOL) for node in nodes] | ||
|
||
|
||
class CycleException(Exception): | ||
pass | ||
|
||
|
||
class NodeNotFoundException(Exception): | ||
pass | ||
|
||
|
||
class DirectAcyclicGraph: | ||
_g: dict[str, list[AdjacentEdge]] | ||
_g_inverse: dict[str, list[AdjacentEdge]] | ||
_topological_order: list[str] | ||
|
||
def __init__(self, g: dict[str, list[AdjacentEdge]]): | ||
self._g_inverse = {} | ||
for u in g.keys(): | ||
if self._g_inverse.get(u) is None: | ||
self._g_inverse[u] = [] | ||
|
||
for v in g[u]: | ||
if v[0] not in g: | ||
raise NodeNotFoundException(f"Node '{v[0]}' not found") | ||
|
||
if self._g_inverse.get(v[0]) is None: | ||
self._g_inverse[v[0]] = [(u, v[1])] | ||
else: | ||
self._g_inverse[v[0]].append((u, v[1])) | ||
|
||
self._g = g | ||
self._topsort() | ||
|
||
def _topsort(self): | ||
adjacent = { | ||
u: iter(sorted([v for (v, _) in value])) for (u, value) in self._g.items() | ||
} | ||
result = [] | ||
|
||
while len(adjacent): | ||
# print(f"edges: {edges}") | ||
dfs_stack = [list(adjacent.keys())[0]] | ||
while dfs_stack: | ||
# print("-" * 25) | ||
# print(f"dfs_stack: {dfs_stack}") | ||
# print(f"adjacent : {adjacent}") | ||
u = dfs_stack.pop() | ||
# print(f"node: {u}") | ||
if u not in result: | ||
# print(f"node: {u} not visited") | ||
try: | ||
v = next(adjacent[u]) | ||
if v in dfs_stack: # cycle | ||
raise CycleException(f"{dfs_stack} + {u} -> {v}") | ||
dfs_stack.append(u) # recursive return | ||
if v not in result: | ||
dfs_stack.append(v) # recursive call | ||
except StopIteration: | ||
# print(f"finished {u}") | ||
del adjacent[u] | ||
result.append(u) | ||
|
||
result.reverse() | ||
self._topological_order = result | ||
|
||
def nodes(self) -> set[str]: | ||
return {item[0] for item in self._g.items()} | ||
|
||
def topological_order(self) -> list[str]: | ||
return self._topological_order | ||
|
||
def inverse(self) -> dict[str, list[AdjacentEdge]]: | ||
return self._g_inverse | ||
|
||
def childs( | ||
self, | ||
node: str, | ||
accept: Callable[[tuple[str, EdgeType]], bool] = lambda _: True, | ||
) -> list[str]: | ||
return [adjacent[0] for adjacent in self._g_inverse[node] if accept(adjacent)] | ||
|
||
def childs_transitive( | ||
self, | ||
node: str, | ||
accept: Callable[[tuple[str, EdgeType]], bool] = lambda _: True, | ||
) -> set[str]: | ||
result = set() | ||
childs = self.childs(node, accept) | ||
while childs: | ||
child = childs.pop() | ||
childs.extend(self.childs(child, accept)) | ||
result.add(child) | ||
|
||
return result | ||
|
||
def parents( | ||
self, node: str, accept: Callable[[tuple[str, EdgeType]], bool] = lambda _: True | ||
) -> list[str]: | ||
return [adjacent[0] for adjacent in self._g.get(node) if accept(adjacent)] | ||
|
||
def parents_transitive( | ||
self, | ||
node: str, | ||
accept: Callable[[tuple[str, EdgeType]], bool] = lambda _: True, | ||
): | ||
result = set() | ||
parents = self.parents(node, accept) | ||
while parents: | ||
parent = parents.pop() | ||
parents.extend(self.parents(parent, accept)) | ||
result.add(parent) | ||
|
||
return result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,204 @@ | ||
import unittest | ||
|
||
from direct_acyclic_graph import ( | ||
dependency, | ||
CycleException, | ||
DirectAcyclicGraph, | ||
tool, | ||
NodeNotFoundException, | ||
EdgeType, | ||
) | ||
|
||
G1 = { | ||
"a": dependency("b"), | ||
"b": dependency("c"), | ||
"c": dependency("d"), | ||
"d": dependency("e"), | ||
"e": [], | ||
} | ||
|
||
G2 = { | ||
"a": dependency("b"), | ||
"b": dependency("c", "d", "e"), | ||
"c": dependency("c1", "c2"), | ||
"d": dependency("d1", "d2"), | ||
"e": [], | ||
"c1": dependency("c1d1"), | ||
"c2": dependency("c2d2"), | ||
"d1": dependency("c1d1"), | ||
"d2": dependency("c2d2"), | ||
"c1d1": [], | ||
"c2d2": [], | ||
} | ||
|
||
G3 = { | ||
"a": [], | ||
"b": dependency("a"), | ||
"c": tool("b"), | ||
"d": dependency("b", "c"), | ||
} | ||
|
||
|
||
class DirectAcyclicGraphTestCase(unittest.TestCase): | ||
def test_incomplete(self): | ||
g = {"a": dependency("b")} | ||
|
||
with self.assertRaises(NodeNotFoundException) as err: | ||
DirectAcyclicGraph(g) | ||
|
||
self.assertEqual("Node 'b' not found", str(err.exception)) | ||
|
||
def test_cycle(self): | ||
g = { | ||
"a": dependency("b"), | ||
"b": dependency("c"), | ||
"c": dependency("d"), | ||
"d": dependency("a"), | ||
} | ||
|
||
with self.assertRaises(CycleException) as err: | ||
DirectAcyclicGraph(g) | ||
|
||
self.assertEqual("['a', 'b', 'c'] + d -> a", str(err.exception)) | ||
|
||
def test_topsort_simple(self): | ||
dag = DirectAcyclicGraph(g=G1) | ||
|
||
self.assertEqual(["a", "b", "c", "d", "e"], dag.topological_order()) | ||
|
||
def test_topsort(self): | ||
dag = DirectAcyclicGraph(g=G2) | ||
|
||
self.assertEqual( | ||
["a", "b", "e", "d", "d2", "d1", "c", "c2", "c2d2", "c1", "c1d1"], | ||
dag.topological_order(), | ||
) | ||
|
||
def test_topsort_mixed(self): | ||
dag = DirectAcyclicGraph( | ||
g={ | ||
"a": [], | ||
"b": dependency("a"), | ||
"c": tool("b"), | ||
"d": dependency("b", "c"), | ||
} | ||
) | ||
|
||
self.assertEqual(["d", "c", "b", "a"], dag.topological_order()) | ||
|
||
def test_inverse_simple(self): | ||
dag = DirectAcyclicGraph(g=G1) | ||
|
||
self.assertEqual( | ||
{ | ||
"a": [], | ||
"b": dependency("a"), | ||
"c": dependency("b"), | ||
"d": dependency("c"), | ||
"e": dependency("d"), | ||
}, | ||
dag.inverse(), | ||
) | ||
|
||
def test_inverse(self): | ||
dag = DirectAcyclicGraph(g=G2) | ||
|
||
self.assertEqual( | ||
{ | ||
"a": [], | ||
"b": dependency("a"), | ||
"c": dependency("b"), | ||
"d": dependency("b"), | ||
"e": dependency("b"), | ||
"c1": dependency("c"), | ||
"c2": dependency("c"), | ||
"d1": dependency("d"), | ||
"d2": dependency("d"), | ||
"c1d1": dependency("c1", "d1"), | ||
"c2d2": dependency("c2", "d2"), | ||
}, | ||
dag.inverse(), | ||
) | ||
|
||
def test_inverse_mixed(self): | ||
dag = DirectAcyclicGraph(g=G3) | ||
|
||
self.assertEqual( | ||
{ | ||
"a": dependency("b"), | ||
"b": tool("c") + dependency("d"), | ||
"c": dependency("d"), | ||
"d": [], | ||
}, | ||
dag.inverse(), | ||
) | ||
|
||
def test_childs_accept_all(self): | ||
dag = DirectAcyclicGraph(G3) | ||
|
||
self.assertEqual(["c", "d"], dag.childs("b")) | ||
|
||
def test_childs_accept_specific(self): | ||
dag = DirectAcyclicGraph(G3) | ||
|
||
self.assertEqual(["d"], dag.childs("b", lambda node: node[1] != EdgeType.TOOL)) | ||
|
||
def test_childs_transitive_accept_all(self): | ||
dag = DirectAcyclicGraph(G3) | ||
|
||
self.assertEqual({"b", "c", "d"}, dag.childs_transitive("a")) | ||
|
||
def test_childs_transitive_accept_specific(self): | ||
dag = DirectAcyclicGraph(G3) | ||
|
||
self.assertEqual( | ||
{"b", "d"}, | ||
dag.childs_transitive("a", lambda node: node[1] != EdgeType.TOOL), | ||
) | ||
|
||
def test_childs_accept_all_G2(self): | ||
dag = DirectAcyclicGraph(G2) | ||
|
||
self.assertEqual(["c1", "d1"], dag.childs("c1d1")) | ||
|
||
def test_childs_transitive_accept_all_G2(self): | ||
dag = DirectAcyclicGraph(G2) | ||
|
||
self.assertEqual( | ||
{"c1", "d1", "c", "d", "b", "a"}, dag.childs_transitive("c1d1") | ||
) | ||
|
||
def test_none_parent_G2(self): | ||
dag = DirectAcyclicGraph(G2) | ||
|
||
self.assertEqual([], dag.parents("e")) | ||
|
||
def test_single_parent_accept_all_G2(self): | ||
dag = DirectAcyclicGraph(G2) | ||
|
||
self.assertEqual(["c", "d", "e"], dag.parents("b")) | ||
|
||
def test_single_parent_accept_custom_G2(self): | ||
dag = DirectAcyclicGraph(G2) | ||
|
||
self.assertEqual( | ||
["c", "e"], dag.parents("b", accept=lambda adjacent: adjacent[0] != "d") | ||
) | ||
|
||
def test_multiple_parent_accept_all_G2(self): | ||
dag = DirectAcyclicGraph(G2) | ||
|
||
self.assertEqual(["c", "d", "e"], dag.parents("b")) | ||
|
||
def test_parents_transitive_accept_all_G2(self): | ||
dag = DirectAcyclicGraph(G2) | ||
|
||
self.assertEqual({"c1", "c1d1", "c2", "c2d2"}, dag.parents_transitive("c")) | ||
|
||
def test_parents_transitive_accept_custom_G2(self): | ||
dag = DirectAcyclicGraph(G2) | ||
|
||
self.assertEqual( | ||
{"c1", "c1d1"}, | ||
dag.parents_transitive("c", accept=lambda adjacent: adjacent[0] != "c2"), | ||
) |
Oops, something went wrong.