Skip to content

Commit

Permalink
tg utils import location changed
Browse files Browse the repository at this point in the history
  • Loading branch information
kessler-frost committed Oct 17, 2023
1 parent 7f6babc commit 49e5f3e
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 9 deletions.
2 changes: 1 addition & 1 deletion covalent_dispatcher/_core/data_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from .._dal.result import get_result_object as get_result_object_from_db
from .._db.write_result_to_db import resolve_electron_id
from . import dispatcher
from .data_modules import dispatch, electron, graph # nopycln: import
from .data_modules import dispatch, electron # nopycln: import
from .data_modules import importer as manifest_importer
from .data_modules.utils import run_in_executor

Expand Down
11 changes: 6 additions & 5 deletions covalent_dispatcher/_core/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

from . import data_manager as datasvc
from . import runner
from .data_modules import graph as tg_utils
from .data_modules import job_manager as jbmgr
from .dispatcher_modules.caches import _pending_parents, _sorted_task_groups, _unresolved_tasks
from .runner_modules.cancel import cancel_tasks
Expand Down Expand Up @@ -63,7 +64,7 @@ async def _get_abstract_task_inputs(dispatch_id: str, node_id: int, node_name: s

abstract_task_input = {"args": [], "kwargs": {}}

for edge in await datasvc.graph.get_incoming_edges(dispatch_id, node_id):
for edge in await tg_utils.get_incoming_edges(dispatch_id, node_id):
parent = edge["source"]

d = edge["attrs"]
Expand All @@ -89,7 +90,7 @@ async def _handle_completed_node(dispatch_id: str, node_id: int):
parent_gid = (await datasvc.electron.get(dispatch_id, node_id, ["task_group_id"]))[
"task_group_id"
]
for child in await datasvc.graph.get_node_successors(dispatch_id, node_id):
for child in await tg_utils.get_node_successors(dispatch_id, node_id):
node_id = child["node_id"]
gid = child["task_group_id"]
app_log.debug(f"dispatch {dispatch_id}: parent gid {parent_gid}, child gid {gid}")
Expand Down Expand Up @@ -129,7 +130,7 @@ async def _get_initial_tasks_and_deps(dispatch_id: str) -> Tuple[int, int, Dict]
# Number of pending predecessor nodes for each task group
pending_parents = {}

g_node_link = await datasvc.graph.get_nodes_links(dispatch_id)
g_node_link = await tg_utils.get_nodes_links(dispatch_id)
g = nx.readwrite.node_link_graph(g_node_link)

# Topologically sort each task group
Expand Down Expand Up @@ -344,7 +345,7 @@ async def cancel_dispatch(dispatch_id: str, task_ids: List[int] = None) -> None:
if task_ids:
app_log.debug(f"Cancelling tasks {task_ids} in dispatch {dispatch_id}")
else:
task_ids = await datasvc.graph.get_nodes(dispatch_id)
task_ids = await tg_utils.get_nodes(dispatch_id)

app_log.debug(f"Cancelling dispatch {dispatch_id}")

Expand Down Expand Up @@ -547,7 +548,7 @@ async def _clear_caches(dispatch_id: str):
"""Clean up all keys in caches."""
await _unresolved_tasks.remove(dispatch_id)

g_node_link = await datasvc.graph.get_nodes_links(dispatch_id)
g_node_link = await tg_utils.get_nodes_links(dispatch_id)
g = nx.readwrite.node_link_graph(g_node_link)

task_groups = {g.nodes[i]["task_group_id"] for i in g.nodes}
Expand Down
6 changes: 3 additions & 3 deletions tests/covalent_dispatcher_tests/_core/tmp_dispatcher_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,7 +716,7 @@ async def test_clear_caches(mocker):
g.add_node(2, task_group_id=0)
g.add_node(3, task_group_id=3)

mocker.patch("covalent_dispatcher._core.dispatcher.datasvc.graph.get_nodes_links")
mocker.patch("covalent_dispatcher._core.dispatcher.tg_utils.get_nodes_links")
mocker.patch("networkx.readwrite.node_link_graph", return_value=g)
mock_unresolved_remove = mocker.patch(
"covalent_dispatcher._core.dispatcher._unresolved_tasks.remove"
Expand Down Expand Up @@ -764,7 +764,7 @@ async def mock_get_nodes(dispatch_id):
else:
return list(sub_tg._graph.nodes)

mocker.patch("covalent_dispatcher._core.dispatcher.datasvc.graph.get_nodes", mock_get_nodes)
mocker.patch("covalent_dispatcher._core.dispatcher.tg_utils.get_nodes", mock_get_nodes)

node_attrs = [
{"sub_dispatch_id": tg.get_node_value(i, "sub_dispatch_id")} for i in tg._graph.nodes
Expand Down Expand Up @@ -819,7 +819,7 @@ async def mock_get_nodes(dispatch_id):
else:
return list(sub_tg._graph.nodes)

mocker.patch("covalent_dispatcher._core.dispatcher.datasvc.graph.get_nodes", mock_get_nodes)
mocker.patch("covalent_dispatcher._core.dispatcher.tg_utils.get_nodes", mock_get_nodes)

node_attrs = [
{"sub_dispatch_id": tg.get_node_value(i, "sub_dispatch_id")} for i in tg._graph.nodes
Expand Down

0 comments on commit 49e5f3e

Please sign in to comment.