diff --git a/covalent_dispatcher/_core/data_manager.py b/covalent_dispatcher/_core/data_manager.py index e9481f3c7..7e1967f0d 100644 --- a/covalent_dispatcher/_core/data_manager.py +++ b/covalent_dispatcher/_core/data_manager.py @@ -37,7 +37,7 @@ from .._dal.result import get_result_object as get_result_object_from_db from .._db.write_result_to_db import resolve_electron_id from . import dispatcher -from .data_modules import dispatch, electron, graph # nopycln: import +from .data_modules import dispatch, electron # nopycln: import from .data_modules import importer as manifest_importer from .data_modules.utils import run_in_executor diff --git a/covalent_dispatcher/_core/dispatcher.py b/covalent_dispatcher/_core/dispatcher.py index f01f172ea..7041dde88 100644 --- a/covalent_dispatcher/_core/dispatcher.py +++ b/covalent_dispatcher/_core/dispatcher.py @@ -32,6 +32,7 @@ from . import data_manager as datasvc from . import runner +from .data_modules import graph as tg_utils from .data_modules import job_manager as jbmgr from .dispatcher_modules.caches import _pending_parents, _sorted_task_groups, _unresolved_tasks from .runner_modules.cancel import cancel_tasks @@ -63,7 +64,7 @@ async def _get_abstract_task_inputs(dispatch_id: str, node_id: int, node_name: s abstract_task_input = {"args": [], "kwargs": {}} - for edge in await datasvc.graph.get_incoming_edges(dispatch_id, node_id): + for edge in await tg_utils.get_incoming_edges(dispatch_id, node_id): parent = edge["source"] d = edge["attrs"] @@ -89,7 +90,7 @@ async def _handle_completed_node(dispatch_id: str, node_id: int): parent_gid = (await datasvc.electron.get(dispatch_id, node_id, ["task_group_id"]))[ "task_group_id" ] - for child in await datasvc.graph.get_node_successors(dispatch_id, node_id): + for child in await tg_utils.get_node_successors(dispatch_id, node_id): node_id = child["node_id"] gid = child["task_group_id"] app_log.debug(f"dispatch {dispatch_id}: parent gid {parent_gid}, child gid {gid}") @@ -129,7 +130,7 @@ async def _get_initial_tasks_and_deps(dispatch_id: str) -> Tuple[int, int, Dict] # Number of pending predecessor nodes for each task group pending_parents = {} - g_node_link = await datasvc.graph.get_nodes_links(dispatch_id) + g_node_link = await tg_utils.get_nodes_links(dispatch_id) g = nx.readwrite.node_link_graph(g_node_link) # Topologically sort each task group @@ -344,7 +345,7 @@ async def cancel_dispatch(dispatch_id: str, task_ids: List[int] = None) -> None: if task_ids: app_log.debug(f"Cancelling tasks {task_ids} in dispatch {dispatch_id}") else: - task_ids = await datasvc.graph.get_nodes(dispatch_id) + task_ids = await tg_utils.get_nodes(dispatch_id) app_log.debug(f"Cancelling dispatch {dispatch_id}") @@ -547,7 +548,7 @@ async def _clear_caches(dispatch_id: str): """Clean up all keys in caches.""" await _unresolved_tasks.remove(dispatch_id) - g_node_link = await datasvc.graph.get_nodes_links(dispatch_id) + g_node_link = await tg_utils.get_nodes_links(dispatch_id) g = nx.readwrite.node_link_graph(g_node_link) task_groups = {g.nodes[i]["task_group_id"] for i in g.nodes} diff --git a/tests/covalent_dispatcher_tests/_core/tmp_dispatcher_test.py b/tests/covalent_dispatcher_tests/_core/tmp_dispatcher_test.py index ccb9ecfaa..5b5b79414 100644 --- a/tests/covalent_dispatcher_tests/_core/tmp_dispatcher_test.py +++ b/tests/covalent_dispatcher_tests/_core/tmp_dispatcher_test.py @@ -716,7 +716,7 @@ async def test_clear_caches(mocker): g.add_node(2, task_group_id=0) g.add_node(3, task_group_id=3) - mocker.patch("covalent_dispatcher._core.dispatcher.datasvc.graph.get_nodes_links") + mocker.patch("covalent_dispatcher._core.dispatcher.tg_utils.get_nodes_links") mocker.patch("networkx.readwrite.node_link_graph", return_value=g) mock_unresolved_remove = mocker.patch( "covalent_dispatcher._core.dispatcher._unresolved_tasks.remove" @@ -764,7 +764,7 @@ async def mock_get_nodes(dispatch_id): else: return list(sub_tg._graph.nodes) - mocker.patch("covalent_dispatcher._core.dispatcher.datasvc.graph.get_nodes", mock_get_nodes) + mocker.patch("covalent_dispatcher._core.dispatcher.tg_utils.get_nodes", mock_get_nodes) node_attrs = [ {"sub_dispatch_id": tg.get_node_value(i, "sub_dispatch_id")} for i in tg._graph.nodes @@ -819,7 +819,7 @@ async def mock_get_nodes(dispatch_id): else: return list(sub_tg._graph.nodes) - mocker.patch("covalent_dispatcher._core.dispatcher.datasvc.graph.get_nodes", mock_get_nodes) + mocker.patch("covalent_dispatcher._core.dispatcher.tg_utils.get_nodes", mock_get_nodes) node_attrs = [ {"sub_dispatch_id": tg.get_node_value(i, "sub_dispatch_id")} for i in tg._graph.nodes