diff --git a/graph/planner/planner.py b/graph/planner/planner.py index 408c29f..899d084 100644 --- a/graph/planner/planner.py +++ b/graph/planner/planner.py @@ -1,8 +1,18 @@ -from typing import Union, Tuple, List +from dataclasses import dataclass, field +from typing import Union from direct_acyclic_graph import DirectAcyclicGraph, EdgeType +@dataclass +class NodeState: + user_selected: bool = True + primary_action: None | str = None + transitive_action: None | str = None + state: set[str] = field(default_factory=set) + transitive_state: set[str] = field(default_factory=set) + + class Planner: _state: DirectAcyclicGraph _target: DirectAcyclicGraph @@ -24,6 +34,88 @@ def __init__( self._unhealthy = unhealthy or [] self._selected = selected + def propagate(self) -> tuple[dict[str, NodeState], list[str]]: + nodes_state = self._state.nodes() + nodes_target = self._target.nodes() + + nodes_add = (nodes_target - nodes_state).intersection( + (nodes_target - nodes_state) + if self._selected == "*" + else {node for node in self._selected if not node.startswith("-")} + ) + nodes_delete = ( + (nodes_state - nodes_target) + if self._selected == "*" + else {node[1:] for node in self._selected if node.startswith("-")} + ) + nodes_selected = ( + nodes_state.union(nodes_target) + if self._selected == "*" + else { + node if not node.startswith("-") else node[1:] + for node in self._selected + } + ) + + nodes_state = { + node: NodeState( + user_selected=node in nodes_selected, + primary_action=( + "add" + if node in nodes_add + else "delete" + if node in nodes_delete + else None + ), + state=({"modified"} if node in self._modified else set()).union( + {"unhealthy"} if node in self._unhealthy else set() + ), + ) + for node in (nodes_state.union(nodes_target)) + } + + for node in nodes_delete: + childs_transitive = self._state.childs_transitive( + node, + accept=lambda adjacent: adjacent[1] != EdgeType.TOOL, + ) + for node_transitive_delete in childs_transitive: + # if nodes[node_transitive_delete].primary_action != "delete": + nodes_state[node_transitive_delete].transitive_action = "delete" + + unhealthy_adjacent_tool_edge_nodes = set() + for key, value in nodes_state.items(): + node = key + state = value.state + if "unhealthy" in state: + # propagate 'unhealthy' state to children != EdgeType.TOOL + childs_transitive = self._state.childs_transitive( + node, + accept=lambda adjacent: adjacent[1] != EdgeType.TOOL, + ) + for node_transitive_state in childs_transitive: + nodes_state[node_transitive_state].transitive_state = nodes_state[ + node_transitive_state + ].transitive_state.union({"unhealthy"}) + + # collect nodes ('unhealthy')' w/ EdgeType.TOOL in reverse graph ~> incoming edge(s) + state_inverse = self._state.inverse() + for edge in state_inverse[node]: + if edge[1] == EdgeType.TOOL: + unhealthy_adjacent_tool_edge_nodes.add(node) + break + + operations = [] + if len(unhealthy_adjacent_tool_edge_nodes) > 0: + for node in reversed(self._state.topological_order()): + if node in unhealthy_adjacent_tool_edge_nodes: + operations.append(f"-{node}") + for node in self._state.topological_order(): + if node in unhealthy_adjacent_tool_edge_nodes: + operations.append(f"+{node}") + + return nodes_state, operations + def apply(self) -> list[str]: nodes_state = self._state.nodes() nodes_target = self._target.nodes() diff --git a/graph/planner/planner_test.py b/graph/planner/planner_test.py index 19ae392..29f69f4 100644 --- a/graph/planner/planner_test.py +++ b/graph/planner/planner_test.py @@ -1,8 +1,7 @@ import unittest from direct_acyclic_graph import DirectAcyclicGraph, dependency, tool -from planner import Planner - +from planner import Planner, NodeState DAG1 = DirectAcyclicGraph( g={ @@ -45,6 +44,19 @@ def test_no_action(self): planner = Planner(state=DAG1, target=DAG1) self.assertEqual([], planner.apply()) + self.assertEqual( + ( + { + "a": NodeState(), + "b": NodeState(), + "c": NodeState(), + "d": NodeState(), + "e": NodeState(), + }, + [], + ), + planner.propagate(), + ) def test_apply_DAG1_modified_multiple(self): planner = Planner(state=DAG1, target=DAG1, modified=["b", "d"]) @@ -53,6 +65,19 @@ def test_apply_DAG1_modified_multiple(self): ["-a", "-b", "-c", "-d", "+d", "+c", "+b", "+a"], planner.apply(), ) + self.assertEqual( + ( + { + "a": NodeState(state=set()), + "b": NodeState(state={"modified"}), + "c": NodeState(state=set()), + "d": NodeState(state={"modified"}), + "e": NodeState(state=set()), + }, + [], + ), + planner.propagate(), + ) def test_apply_DAG1_unhealthy_multiple(self): planner = Planner(state=DAG1, target=DAG1, unhealthy=["b", "d"]) @@ -61,6 +86,19 @@ def test_apply_DAG1_unhealthy_multiple(self): ["-a", "-b", "-c", "-d", "+d", "+c", "+b", "+a"], planner.apply(), ) + self.assertEqual( + ( + { + "a": NodeState(transitive_state={"unhealthy"}), + "b": NodeState(state={"unhealthy"}, transitive_state={"unhealthy"}), + "c": NodeState(transitive_state={"unhealthy"}), + "d": NodeState(state={"unhealthy"}), + "e": NodeState(), + }, + [], + ), + planner.propagate(), + ) def test_apply_DAG1_unhealthy_and_modified(self): planner = Planner( @@ -71,6 +109,21 @@ def test_apply_DAG1_unhealthy_and_modified(self): ["-a", "-b", "-c", "-d", "+d", "+c", "+b", "+a"], planner.apply(), ) + self.assertEqual( + ( + { + "a": NodeState(transitive_state={"unhealthy"}), + "b": NodeState( + state={"modified", "unhealthy"}, transitive_state={"unhealthy"} + ), + "c": NodeState(transitive_state={"unhealthy"}), + "d": NodeState(state={"modified", "unhealthy"}), + "e": NodeState(), + }, + [], + ), + planner.propagate(), + ) def test_apply_DAG1_unhealthy_or_modified(self): planner = Planner(state=DAG1, target=DAG1, modified=["b"], unhealthy=["d"]) @@ -79,6 +132,19 @@ def test_apply_DAG1_unhealthy_or_modified(self): ["-a", "-b", "-c", "-d", "+d", "+c", "+b", "+a"], planner.apply(), ) + self.assertEqual( + ( + { + "a": NodeState(transitive_state={"unhealthy"}), + "b": NodeState(state={"modified"}, transitive_state={"unhealthy"}), + "c": NodeState(transitive_state={"unhealthy"}), + "d": NodeState(state={"unhealthy"}), + "e": NodeState(), + }, + [], + ), + planner.propagate(), + ) def test_apply_DAG1_modified_root(self): planner = Planner(state=DAG1, target=DAG1, modified=["e"]) @@ -87,6 +153,19 @@ def test_apply_DAG1_modified_root(self): ["-a", "-b", "-c", "-d", "-e", "+e", "+d", "+c", "+b", "+a"], planner.apply(), ) + self.assertEqual( + ( + { + "a": NodeState(), + "b": NodeState(), + "c": NodeState(), + "d": NodeState(), + "e": NodeState(state={"modified"}), + }, + [], + ), + planner.propagate(), + ) def test_apply_DAG1_unhealthy_root(self): planner = Planner(state=DAG1, target=DAG1, unhealthy=["e"]) @@ -95,16 +174,53 @@ def test_apply_DAG1_unhealthy_root(self): ["-a", "-b", "-c", "-d", "-e", "+e", "+d", "+c", "+b", "+a"], planner.apply(), ) + self.assertEqual( + ( + { + "a": NodeState(transitive_state={"unhealthy"}), + "b": NodeState(transitive_state={"unhealthy"}), + "c": NodeState(transitive_state={"unhealthy"}), + "d": NodeState(transitive_state={"unhealthy"}), + "e": NodeState(state={"unhealthy"}), + }, + [], + ), + planner.propagate(), + ) def test_apply_DAG3_modified_single(self): planner = Planner(state=DAG3, target=DAG3, modified=["d"]) self.assertEqual(["-d", "+d"], planner.apply()) + self.assertEqual( + ( + { + "a": NodeState(), + "b": NodeState(), + "c": NodeState(), + "d": NodeState(state={"modified"}), + }, + [], + ), + planner.propagate(), + ) def test_apply_DAG3_modified_root(self): planner = Planner(state=DAG3, target=DAG3, modified=["a"]) self.assertEqual(["-d", "-b", "-a", "+a", "+b", "+d"], planner.apply()) + self.assertEqual( + ( + { + "a": NodeState(state={"modified"}), + "b": NodeState(), + "c": NodeState(), + "d": NodeState(), + }, + [], + ), + planner.propagate(), + ) def test_apply_add_to_empty(self): planner = Planner( @@ -116,6 +232,16 @@ def test_apply_add_to_empty(self): ["+b", "+a"], planner.apply(), ) + self.assertEqual( + ( + { + "a": NodeState(primary_action="add"), + "b": NodeState(primary_action="add"), + }, + [], + ), + planner.propagate(), + ) def test_apply_add_to_existing(self): planner = Planner( @@ -126,6 +252,17 @@ def test_apply_add_to_existing(self): ) self.assertEqual(["+c"], planner.apply()) + self.assertEqual( + ( + { + "a": NodeState(), + "b": NodeState(), + "c": NodeState(primary_action="add"), + }, + [], + ), + planner.propagate(), + ) def test_apply_remove_all(self): planner = Planner( @@ -137,6 +274,16 @@ def test_apply_remove_all(self): ["-a", "-b"], planner.apply(), ) + self.assertEqual( + ( + { + "a": NodeState(primary_action="delete", transitive_action="delete"), + "b": NodeState(primary_action="delete"), + }, + [], + ), + planner.propagate(), + ) def test_apply_DAG3_remove_all(self): planner = Planner(state=DAG3, target=DirectAcyclicGraph(g={})) @@ -145,6 +292,18 @@ def test_apply_DAG3_remove_all(self): ["-d", "-c", "-b", "-a"], planner.apply(), ) + self.assertEqual( + ( + { + "a": NodeState(primary_action="delete"), + "b": NodeState(primary_action="delete", transitive_action="delete"), + "c": NodeState(primary_action="delete"), + "d": NodeState(primary_action="delete", transitive_action="delete"), + }, + [], + ), + planner.propagate(), + ) def test_add_and_delete(self): planner = Planner( @@ -156,6 +315,17 @@ def test_add_and_delete(self): ["-b", "+c"], planner.apply(), ) + self.assertEqual( + ( + { + "a": NodeState(), + "b": NodeState(primary_action="delete"), + "c": NodeState(primary_action="add"), + }, + [], + ), + planner.propagate(), + ) def test_modified_but_delete(self): planner = Planner( @@ -167,6 +337,17 @@ def test_modified_but_delete(self): ) self.assertEqual(["-c"], planner.apply()) + self.assertEqual( + ( + { + "a": NodeState(), + "b": NodeState(), + "c": NodeState(primary_action="delete", state={"modified"}), + }, + [], + ), + planner.propagate(), + ) def test_apply_delete_DAG2_temporarily(self): planner = Planner(state=DAG2, target=DAG2, selected=["-c1d1"]) @@ -175,6 +356,25 @@ def test_apply_delete_DAG2_temporarily(self): ["-a", "-b", "-d", "-d1", "-c", "-c1", "-c1d1"], planner.apply(), ) + self.assertEqual( + ( + { + "a": NodeState(user_selected=False, transitive_action="delete"), + "b": NodeState(user_selected=False, transitive_action="delete"), + "c": NodeState(user_selected=False, transitive_action="delete"), + "c1": NodeState(user_selected=False, transitive_action="delete"), + "c1d1": NodeState(primary_action="delete"), + "c2": NodeState(user_selected=False), + "c2d2": NodeState(user_selected=False), + "d": NodeState(user_selected=False, transitive_action="delete"), + "d1": NodeState(user_selected=False, transitive_action="delete"), + "d2": NodeState(user_selected=False), + "e": NodeState(user_selected=False), + }, + [], + ), + planner.propagate(), + ) def test_apply_delete_DAG2_temporarily_with_target_changes(self): planner = Planner( @@ -198,6 +398,25 @@ def test_apply_delete_DAG2_temporarily_with_target_changes(self): ["-a", "-b", "-d", "-d1", "-c", "-c1", "-c1d1"], planner.apply(), ) + self.assertEqual( + ( + { + "a": NodeState(user_selected=False, transitive_action="delete"), + "b": NodeState(user_selected=False, transitive_action="delete"), + "c": NodeState(user_selected=False, transitive_action="delete"), + "c1": NodeState(user_selected=False, transitive_action="delete"), + "c1d1": NodeState(primary_action="delete"), + "c2": NodeState(user_selected=False), + "c2d2": NodeState(user_selected=False), + "d": NodeState(user_selected=False, transitive_action="delete"), + "d1": NodeState(user_selected=False, transitive_action="delete"), + "d2": NodeState(user_selected=False), + "e": NodeState(user_selected=False), + }, + [], + ), + planner.propagate(), + ) def test_apply_delete_roots_DAG2_temporarily(self): planner = Planner(state=DAG2, target=DAG2, selected=["-c1d1", "-c2d2"]) @@ -208,6 +427,25 @@ def test_apply_delete_roots_DAG2_temporarily(self): actual, ) self.assertNotIn("e", actual) + self.assertEqual( + ( + { + "a": NodeState(user_selected=False, transitive_action="delete"), + "b": NodeState(user_selected=False, transitive_action="delete"), + "c": NodeState(user_selected=False, transitive_action="delete"), + "c1": NodeState(user_selected=False, transitive_action="delete"), + "c1d1": NodeState(primary_action="delete"), + "c2": NodeState(user_selected=False, transitive_action="delete"), + "c2d2": NodeState(primary_action="delete"), + "d": NodeState(user_selected=False, transitive_action="delete"), + "d1": NodeState(user_selected=False, transitive_action="delete"), + "d2": NodeState(user_selected=False, transitive_action="delete"), + "e": NodeState(user_selected=False), + }, + [], + ), + planner.propagate(), + ) def test_apply_delete_with_modified_one_TOOL_edge(self): planner = Planner( @@ -219,6 +457,17 @@ def test_apply_delete_with_modified_one_TOOL_edge(self): selected=["-c"], ) self.assertEqual(["-c"], planner.apply()) + self.assertEqual( + ( + { + "a": NodeState(user_selected=False), + "b": NodeState(user_selected=False, state={"modified"}), + "c": NodeState(primary_action="delete"), + }, + [], + ), + planner.propagate(), + ) def test_apply_delete_with_unhealthy_one_TOOL_edge(self): planner = Planner( @@ -230,6 +479,17 @@ def test_apply_delete_with_unhealthy_one_TOOL_edge(self): selected=["-c"], ) self.assertEqual(["-b", "+b", "-c"], planner.apply()) + self.assertEqual( + ( + { + "a": NodeState(user_selected=False), + "b": NodeState(user_selected=False, state={"unhealthy"}), + "c": NodeState(primary_action="delete"), + }, + ["-b", "+b"], + ), + planner.propagate(), + ) def test_apply_delete_with_one_modified_two_TOOL_edge_in_chain(self): planner = Planner( @@ -243,6 +503,18 @@ def test_apply_delete_with_one_modified_two_TOOL_edge_in_chain(self): selected=["-c"], ) self.assertEqual(["-c"], planner.apply()) + self.assertEqual( + ( + { + "a": NodeState(user_selected=False), + "b1": NodeState(user_selected=False), + "b2": NodeState(user_selected=False, state={"modified"}), + "c": NodeState(primary_action="delete"), + }, + [], + ), + planner.propagate(), + ) def test_apply_delete_with_one_unhealthy_two_TOOL_edge_in_chain(self): planner = Planner( @@ -256,6 +528,18 @@ def test_apply_delete_with_one_unhealthy_two_TOOL_edge_in_chain(self): selected=["-c"], ) self.assertEqual(["-b2", "+b2", "-c"], planner.apply()) + self.assertEqual( + ( + { + "a": NodeState(user_selected=False), + "b1": NodeState(user_selected=False), + "b2": NodeState(user_selected=False, state={"unhealthy"}), + "c": NodeState(primary_action="delete"), + }, + ["-b2", "+b2"], + ), + planner.propagate(), + ) def test_apply_delete_with_two_modified_two_TOOL_edge_in_chain(self): planner = Planner( @@ -269,6 +553,18 @@ def test_apply_delete_with_two_modified_two_TOOL_edge_in_chain(self): selected=["-c"], ) self.assertEqual(["-c"], planner.apply()) + self.assertEqual( + ( + { + "a": NodeState(user_selected=False), + "b1": NodeState(user_selected=False, state={"modified"}), + "b2": NodeState(user_selected=False, state={"modified"}), + "c": NodeState(primary_action="delete"), + }, + [], + ), + planner.propagate(), + ) def test_apply_delete_with_two_unhealthy_two_TOOL_edge_in_chain(self): planner = Planner( @@ -282,6 +578,18 @@ def test_apply_delete_with_two_unhealthy_two_TOOL_edge_in_chain(self): selected=["-c"], ) self.assertEqual(["-b2", "-b1", "+b1", "+b2", "-c"], planner.apply()) + self.assertEqual( + ( + { + "a": NodeState(user_selected=False), + "b1": NodeState(user_selected=False, state={"unhealthy"}), + "b2": NodeState(user_selected=False, state={"unhealthy"}), + "c": NodeState(primary_action="delete"), + }, + ["-b1", "-b2", "+b2", "+b1"], + ), + planner.propagate(), + ) def test_apply_delete_with_one_modified_two_TOOL_edge_in_chain_skip(self): planner = Planner( @@ -295,16 +603,54 @@ def test_apply_delete_with_one_modified_two_TOOL_edge_in_chain_skip(self): selected=["-c"], ) self.assertEqual(["-c"], planner.apply()) + self.assertEqual( + ( + { + "a": NodeState(user_selected=False), + "b1": NodeState(user_selected=False, state={"modified"}), + "b2": NodeState(user_selected=False), + "c": NodeState(primary_action="delete"), + }, + [], + ), + planner.propagate(), + ) def test_no_action_apply_selected(self): planner = Planner(state=DAG1, target=DAG1, modified=[], selected=["c"]) self.assertEqual([], planner.apply()) + self.assertEqual( + ( + { + "a": NodeState(user_selected=False), + "b": NodeState(user_selected=False), + "c": NodeState(), + "d": NodeState(user_selected=False), + "e": NodeState(user_selected=False), + }, + [], + ), + planner.propagate(), + ) def test_apply_DAG1_modified_multiple_single_selected(self): planner = Planner(state=DAG1, target=DAG1, modified=["b", "d"], selected=["b"]) self.assertEqual(["-a", "-b", "+b", "+a"], planner.apply()) + self.assertEqual( + ( + { + "a": NodeState(user_selected=False), + "b": NodeState(state={"modified"}), + "c": NodeState(user_selected=False), + "d": NodeState(user_selected=False, state={"modified"}), + "e": NodeState(user_selected=False), + }, + [], + ), + planner.propagate(), + ) def test_apply_DAG3_selected_target(self): planner = Planner( @@ -312,6 +658,18 @@ def test_apply_DAG3_selected_target(self): ) self.assertEqual(["+b", "+c"], planner.apply()) + self.assertEqual( + ( + { + "a": NodeState(user_selected=False), + "b": NodeState(user_selected=False), + "c": NodeState(primary_action="add"), + "d": NodeState(user_selected=False), + }, + [], + ), + planner.propagate(), + ) def test_apply_DAG3_selected_target_modified(self): planner = Planner( @@ -322,6 +680,18 @@ def test_apply_DAG3_selected_target_modified(self): ) self.assertEqual(["+b", "+c"], planner.apply()) + self.assertEqual( + ( + { + "a": NodeState(user_selected=False, state={"modified"}), + "b": NodeState(user_selected=False), + "c": NodeState(primary_action="add"), + "d": NodeState(user_selected=False), + }, + [], + ), + planner.propagate(), + ) def test_apply_DAG3_selected_target_unhealthy(self): planner = Planner( @@ -332,6 +702,18 @@ def test_apply_DAG3_selected_target_unhealthy(self): ) self.assertEqual(["-a", "+a", "+b", "+c"], planner.apply()) + self.assertEqual( + ( + { + "a": NodeState(user_selected=False, state={"unhealthy"}), + "b": NodeState(user_selected=False), + "c": NodeState(primary_action="add"), + "d": NodeState(user_selected=False), + }, + [], + ), + planner.propagate(), + ) def test_apply_DAG3_selected_target_modified_one_TOOL_edge(self): planner = Planner( @@ -344,6 +726,17 @@ def test_apply_DAG3_selected_target_modified_one_TOOL_edge(self): ) self.assertEqual(["+c"], planner.apply()) + self.assertEqual( + ( + { + "a": NodeState(user_selected=False), + "b": NodeState(user_selected=False, state={"modified"}), + "c": NodeState(primary_action="add"), + }, + [], + ), + planner.propagate(), + ) def test_apply_DAG3_selected_target_unhealthy_one_TOOL_edge(self): planner = Planner( @@ -356,6 +749,17 @@ def test_apply_DAG3_selected_target_unhealthy_one_TOOL_edge(self): ) self.assertEqual(["-b", "+b", "+c"], planner.apply()) + self.assertEqual( + ( + { + "a": NodeState(user_selected=False), + "b": NodeState(user_selected=False, state={"unhealthy"}), + "c": NodeState(primary_action="add"), + }, + [], + ), + planner.propagate(), + ) def test_apply_DAG3_selected_target_modified_one_TOOL_edge_chain_skip(self): planner = Planner( @@ -374,6 +778,19 @@ def test_apply_DAG3_selected_target_modified_one_TOOL_edge_chain_skip(self): ) self.assertEqual(["+d", "+e"], planner.apply()) + self.assertEqual( + ( + { + "a": NodeState(user_selected=False), + "b": NodeState(user_selected=False, state={"modified"}), + "c": NodeState(user_selected=False), + "d": NodeState(user_selected=False), + "e": NodeState(primary_action="add"), + }, + [], + ), + planner.propagate(), + ) def test_apply_selected_TOOL_edge_modified(self): planner = Planner( @@ -384,6 +801,16 @@ def test_apply_selected_TOOL_edge_modified(self): ) self.assertEqual(["-a"], planner.apply()) + self.assertEqual( + ( + { + "a": NodeState(primary_action="delete"), + "b": NodeState(user_selected=False, state={"modified"}), + }, + [], + ), + planner.propagate(), + ) def test_apply_selected_TOOL_edge_unhealthy(self): planner = Planner( @@ -394,3 +821,45 @@ def test_apply_selected_TOOL_edge_unhealthy(self): ) self.assertEqual(["-b", "+b", "-a"], planner.apply()) + self.assertEqual( + ( + { + "a": NodeState(primary_action="delete"), + "b": NodeState(user_selected=False, state={"unhealthy"}), + }, + ["-b", "+b"], + ), + planner.propagate(), + ) + + def test_apply_unhealthy_tool(self): + planner = Planner( + state=DirectAcyclicGraph( + g={"a": dependency("c") + tool("b"), "b": dependency("c"), "c": []} + ), + target=DirectAcyclicGraph( + g={"a": dependency("c") + tool("b"), "b": dependency("c"), "c": []} + ), + unhealthy=["b"], + selected=["-a", "-b", "-c"], + ) + + # self.assertEqual(["-b", "+b", "-a", "-b", "-c"], planner.apply()) + self.assertEqual( + ( + { + "a": NodeState( + primary_action="delete", + transitive_action="delete", + ), + "b": NodeState( + primary_action="delete", + transitive_action="delete", + state={"unhealthy"}, + ), + "c": NodeState(primary_action="delete"), + }, + ["-b", "+b"], + ), + planner.propagate(), + )