Skip to content

Commit

Permalink
[WIP]
Browse files Browse the repository at this point in the history
  • Loading branch information
agebhar1 committed Sep 7, 2023
1 parent 248f761 commit e89017b
Show file tree
Hide file tree
Showing 5 changed files with 751 additions and 0 deletions.
131 changes: 131 additions & 0 deletions graph/planner/direct_acyclic_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
from enum import Enum
from typing import Callable, Union


class EdgeType(Enum):
DEPENDENCY = 0
TOOL = 1


AdjacentEdge = tuple[str, EdgeType]


def dependency(*nodes: str) -> list[AdjacentEdge]:
return [(node, EdgeType.DEPENDENCY) for node in nodes]


def tool(*nodes: str) -> list[AdjacentEdge]:
return [(node, EdgeType.TOOL) for node in nodes]


class CycleException(Exception):
pass


class NodeNotFoundException(Exception):
pass


class DirectAcyclicGraph:
_g: dict[str, list[AdjacentEdge]]
_g_inverse: dict[str, list[AdjacentEdge]]
_topological_order: list[str]

def __init__(self, g: dict[str, list[AdjacentEdge]]):
self._g_inverse = {}
for u in g.keys():
if self._g_inverse.get(u) is None:
self._g_inverse[u] = []

for v in g[u]:
if v[0] not in g:
raise NodeNotFoundException(f"Node '{v[0]}' not found")

if self._g_inverse.get(v[0]) is None:
self._g_inverse[v[0]] = [(u, v[1])]
else:
self._g_inverse[v[0]].append((u, v[1]))

self._g = g
self._topsort()

def _topsort(self):
adjacent = {
u: iter(sorted([v for (v, _) in value])) for (u, value) in self._g.items()
}
result = []

while len(adjacent):
# print(f"edges: {edges}")
dfs_stack = [list(adjacent.keys())[0]]
while dfs_stack:
# print("-" * 25)
# print(f"dfs_stack: {dfs_stack}")
# print(f"adjacent : {adjacent}")
u = dfs_stack.pop()
# print(f"node: {u}")
if u not in result:
# print(f"node: {u} not visited")
try:
v = next(adjacent[u])
if v in dfs_stack: # cycle
raise CycleException(f"{dfs_stack} + {u} -> {v}")
dfs_stack.append(u) # recursive return
if v not in result:
dfs_stack.append(v) # recursive call
except StopIteration:
# print(f"finished {u}")
del adjacent[u]
result.append(u)

result.reverse()
self._topological_order = result

def nodes(self) -> set[str]:
return {item[0] for item in self._g.items()}

def topological_order(self) -> list[str]:
return self._topological_order

def inverse(self) -> dict[str, list[AdjacentEdge]]:
return self._g_inverse

def childs(
self,
node: str,
accept: Callable[[tuple[str, EdgeType]], bool] = lambda _: True,
) -> list[str]:
return [adjacent[0] for adjacent in self._g_inverse[node] if accept(adjacent)]

def childs_transitive(
self,
node: str,
accept: Callable[[tuple[str, EdgeType]], bool] = lambda _: True,
) -> set[str]:
result = set()
childs = self.childs(node, accept)
while childs:
child = childs.pop()
childs.extend(self.childs(child, accept))
result.add(child)

return result

def parents(
self, node: str, accept: Callable[[tuple[str, EdgeType]], bool] = lambda _: True
) -> list[str]:
return [adjacent[0] for adjacent in self._g.get(node) if accept(adjacent)]

def parents_transitive(
self,
node: str,
accept: Callable[[tuple[str, EdgeType]], bool] = lambda _: True,
):
result = set()
parents = self.parents(node, accept)
while parents:
parent = parents.pop()
parents.extend(self.parents(parent, accept))
result.add(parent)

return result
204 changes: 204 additions & 0 deletions graph/planner/direct_acyclic_graph_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
import unittest

from direct_acyclic_graph import (
dependency,
CycleException,
DirectAcyclicGraph,
tool,
NodeNotFoundException,
EdgeType,
)

G1 = {
"a": dependency("b"),
"b": dependency("c"),
"c": dependency("d"),
"d": dependency("e"),
"e": [],
}

G2 = {
"a": dependency("b"),
"b": dependency("c", "d", "e"),
"c": dependency("c1", "c2"),
"d": dependency("d1", "d2"),
"e": [],
"c1": dependency("c1d1"),
"c2": dependency("c2d2"),
"d1": dependency("c1d1"),
"d2": dependency("c2d2"),
"c1d1": [],
"c2d2": [],
}

G3 = {
"a": [],
"b": dependency("a"),
"c": tool("b"),
"d": dependency("b", "c"),
}


class DirectAcyclicGraphTestCase(unittest.TestCase):
def test_incomplete(self):
g = {"a": dependency("b")}

with self.assertRaises(NodeNotFoundException) as err:
DirectAcyclicGraph(g)

self.assertEqual("Node 'b' not found", str(err.exception))

def test_cycle(self):
g = {
"a": dependency("b"),
"b": dependency("c"),
"c": dependency("d"),
"d": dependency("a"),
}

with self.assertRaises(CycleException) as err:
DirectAcyclicGraph(g)

self.assertEqual("['a', 'b', 'c'] + d -> a", str(err.exception))

def test_topsort_simple(self):
dag = DirectAcyclicGraph(g=G1)

self.assertEqual(["a", "b", "c", "d", "e"], dag.topological_order())

def test_topsort(self):
dag = DirectAcyclicGraph(g=G2)

self.assertEqual(
["a", "b", "e", "d", "d2", "d1", "c", "c2", "c2d2", "c1", "c1d1"],
dag.topological_order(),
)

def test_topsort_mixed(self):
dag = DirectAcyclicGraph(
g={
"a": [],
"b": dependency("a"),
"c": tool("b"),
"d": dependency("b", "c"),
}
)

self.assertEqual(["d", "c", "b", "a"], dag.topological_order())

def test_inverse_simple(self):
dag = DirectAcyclicGraph(g=G1)

self.assertEqual(
{
"a": [],
"b": dependency("a"),
"c": dependency("b"),
"d": dependency("c"),
"e": dependency("d"),
},
dag.inverse(),
)

def test_inverse(self):
dag = DirectAcyclicGraph(g=G2)

self.assertEqual(
{
"a": [],
"b": dependency("a"),
"c": dependency("b"),
"d": dependency("b"),
"e": dependency("b"),
"c1": dependency("c"),
"c2": dependency("c"),
"d1": dependency("d"),
"d2": dependency("d"),
"c1d1": dependency("c1", "d1"),
"c2d2": dependency("c2", "d2"),
},
dag.inverse(),
)

def test_inverse_mixed(self):
dag = DirectAcyclicGraph(g=G3)

self.assertEqual(
{
"a": dependency("b"),
"b": tool("c") + dependency("d"),
"c": dependency("d"),
"d": [],
},
dag.inverse(),
)

def test_childs_accept_all(self):
dag = DirectAcyclicGraph(G3)

self.assertEqual(["c", "d"], dag.childs("b"))

def test_childs_accept_specific(self):
dag = DirectAcyclicGraph(G3)

self.assertEqual(["d"], dag.childs("b", lambda node: node[1] != EdgeType.TOOL))

def test_childs_transitive_accept_all(self):
dag = DirectAcyclicGraph(G3)

self.assertEqual({"b", "c", "d"}, dag.childs_transitive("a"))

def test_childs_transitive_accept_specific(self):
dag = DirectAcyclicGraph(G3)

self.assertEqual(
{"b", "d"},
dag.childs_transitive("a", lambda node: node[1] != EdgeType.TOOL),
)

def test_childs_accept_all_G2(self):
dag = DirectAcyclicGraph(G2)

self.assertEqual(["c1", "d1"], dag.childs("c1d1"))

def test_childs_transitive_accept_all_G2(self):
dag = DirectAcyclicGraph(G2)

self.assertEqual(
{"c1", "d1", "c", "d", "b", "a"}, dag.childs_transitive("c1d1")
)

def test_none_parent_G2(self):
dag = DirectAcyclicGraph(G2)

self.assertEqual([], dag.parents("e"))

def test_single_parent_accept_all_G2(self):
dag = DirectAcyclicGraph(G2)

self.assertEqual(["c", "d", "e"], dag.parents("b"))

def test_single_parent_accept_custom_G2(self):
dag = DirectAcyclicGraph(G2)

self.assertEqual(
["c", "e"], dag.parents("b", accept=lambda adjacent: adjacent[0] != "d")
)

def test_multiple_parent_accept_all_G2(self):
dag = DirectAcyclicGraph(G2)

self.assertEqual(["c", "d", "e"], dag.parents("b"))

def test_parents_transitive_accept_all_G2(self):
dag = DirectAcyclicGraph(G2)

self.assertEqual({"c1", "c1d1", "c2", "c2d2"}, dag.parents_transitive("c"))

def test_parents_transitive_accept_custom_G2(self):
dag = DirectAcyclicGraph(G2)

self.assertEqual(
{"c1", "c1d1"},
dag.parents_transitive("c", accept=lambda adjacent: adjacent[0] != "c2"),
)
Loading

0 comments on commit e89017b

Please sign in to comment.