diff --git a/covalent_dispatcher/_core/dispatcher.py b/covalent_dispatcher/_core/dispatcher.py index e159a8a237..ec20553419 100644 --- a/covalent_dispatcher/_core/dispatcher.py +++ b/covalent_dispatcher/_core/dispatcher.py @@ -35,9 +35,7 @@ from covalent._shared_files.util_classes import RESULT_STATUS from . import data_manager as datasvc -from . import runner - -# from . import runner_ng +from . import runner_ng from .data_modules import job_manager as jbmgr from .dispatcher_modules.caches import _pending_parents, _sorted_task_groups, _unresolved_tasks @@ -226,28 +224,13 @@ async def _submit_task_group(dispatch_id: str, sorted_nodes: List[int], task_gro app_log.debug(f"Using new runner for task group {task_group_id}") known_nodes = list(set(known_nodes)) - - task_spec = task_specs[0] - abstract_inputs = {"args": task_spec["args_ids"], "kwargs": task_spec["kwargs_ids"]} - - # Temporarily redirect to in-memory runner (this is incompatible with task packing) - if len(task_specs) > 1: - raise RuntimeError("Task packing is not supported yet.") - - coro = runner.run_abstract_task( + coro = runner_ng.run_abstract_task_group( dispatch_id=dispatch_id, - node_id=task_group_id, - node_name=node_name, - abstract_inputs=abstract_inputs, + task_group_id=task_group_id, + task_seq=task_specs, + known_nodes=known_nodes, selected_executor=[selected_executor, selected_executor_data], ) - # coro = runner_ng.run_abstract_task_group( - # dispatch_id=dispatch_id, - # task_group_id=task_group_id, - # task_seq=task_specs, - # known_nodes=known_nodes, - # selected_executor=[selected_executor, selected_executor_data], - # ) asyncio.create_task(coro) else: @@ -354,8 +337,7 @@ async def cancel_dispatch(dispatch_id: str, task_ids: List[int] = []) -> None: app_log.debug(f"Cancelling dispatch {dispatch_id}") await jbmgr.set_cancel_requested(dispatch_id, task_ids) - await runner.cancel_tasks(dispatch_id, task_ids) - # await runner_ng.cancel_tasks(dispatch_id, task_ids) + await runner_ng.cancel_tasks(dispatch_id, task_ids) # Recursively cancel running sublattice dispatches attrs = await datasvc.electron.get_bulk(dispatch_id, task_ids, ["sub_dispatch_id"]) diff --git a/covalent_dispatcher/_core/runner.py b/covalent_dispatcher/_core/runner.py index 8cf65c69f3..615df4de47 100644 --- a/covalent_dispatcher/_core/runner.py +++ b/covalent_dispatcher/_core/runner.py @@ -32,7 +32,7 @@ from covalent._shared_files.config import get_config from covalent._shared_files.util_classes import RESULT_STATUS from covalent._workflow import DepsBash, DepsCall, DepsPip -from covalent.executor.base import wrapper_fn +from covalent.executor.utils.wrappers import wrapper_fn from . import data_manager as datasvc from .runner_modules import executor_proxy diff --git a/covalent_dispatcher/_service/app.py b/covalent_dispatcher/_service/app.py index a02f3e2f2c..434eafc601 100644 --- a/covalent_dispatcher/_service/app.py +++ b/covalent_dispatcher/_service/app.py @@ -35,6 +35,7 @@ from covalent._shared_files.schemas.result import ResultSchema from covalent._shared_files.util_classes import RESULT_STATUS from covalent_dispatcher._core import dispatcher as core_dispatcher +from covalent_dispatcher._core import runner_ng as core_runner from .._dal.exporters.result import export_result_manifest from .._dal.result import Result, get_result_object @@ -43,9 +44,6 @@ from .heartbeat import Heartbeat from .models import ExportResponseSchema -# from covalent_dispatcher._core import runner_ng as core_runner - - app_log = logger.app_log log_stack_info = logger.log_stack_info @@ -63,9 +61,9 @@ async def lifespan(app: FastAPI): _background_tasks.add(fut) fut.add_done_callback(_background_tasks.discard) - # # Runner event queue and listener - # core_runner._job_events = asyncio.Queue() - # core_runner._job_event_listener = asyncio.create_task(core_runner._listen_for_job_events()) + # Runner event queue and listener + core_runner._job_events = asyncio.Queue() + core_runner._job_event_listener = asyncio.create_task(core_runner._listen_for_job_events()) # Dispatcher event queue and listener core_dispatcher._global_status_queue = asyncio.Queue() @@ -83,7 +81,7 @@ async def lifespan(app: FastAPI): await cancel_all_with_status(status) core_dispatcher._global_event_listener.cancel() - # core_runner._job_event_listener.cancel() + core_runner._job_event_listener.cancel() Heartbeat.stop() diff --git a/covalent_ui/api/v1/routes/routes.py b/covalent_ui/api/v1/routes/routes.py index 9ccbde8f90..d9b4de8f96 100644 --- a/covalent_ui/api/v1/routes/routes.py +++ b/covalent_ui/api/v1/routes/routes.py @@ -22,7 +22,7 @@ from fastapi import APIRouter -from covalent_dispatcher._service import app, assets +from covalent_dispatcher._service import app, assets, runnersvc from covalent_dispatcher._triggers_app.app import router as tr_router from covalent_ui.api.v1.routes.end_points import ( electron_routes, @@ -46,5 +46,4 @@ routes.include_router(tr_router, prefix="/api", tags=["Triggers"]) routes.include_router(app.router, prefix="/api/v1", tags=["Dispatcher"]) routes.include_router(assets.router, prefix="/api/v1", tags=["Assets"]) -# This will be enabled in the next patch -# routes.include_router(runnersvc.router, prefix="/api/v1", tags=["Runner"]) +routes.include_router(runnersvc.router, prefix="/api/v1", tags=["Runner"]) diff --git a/tests/covalent_dispatcher_tests/_core/dispatcher_test.py b/tests/covalent_dispatcher_tests/_core/dispatcher_test.py index e1d190a860..634fd59321 100644 --- a/tests/covalent_dispatcher_tests/_core/dispatcher_test.py +++ b/tests/covalent_dispatcher_tests/_core/dispatcher_test.py @@ -514,11 +514,10 @@ async def test_submit_initial_tasks(mocker): @pytest.mark.asyncio -async def test_submit_task_group_single(mocker): - """Test submitting a singleton task groups""" +async def test_submit_task_group(mocker): dispatch_id = "dispatch_1" gid = 2 - nodes = [2] + nodes = [4, 3, 2] mock_get_abs_input = mocker.patch( "covalent_dispatcher._core.dispatcher._get_abstract_task_inputs", @@ -555,75 +554,15 @@ async def get_electron_attrs(dispatch_id, node_id, keys): "covalent_dispatcher._core.dispatcher.datasvc.update_node_result", ) - # This will be removed in the next patch mock_run_abs_task = mocker.patch( - "covalent_dispatcher._core.dispatcher.runner.run_abstract_task", + "covalent_dispatcher._core.dispatcher.runner_ng.run_abstract_task_group", ) - # mock_run_abs_task = mocker.patch( - # "covalent_dispatcher._core.dispatcher.runner_ng.run_abstract_task_group", - # ) await _submit_task_group(dispatch_id, nodes, gid) mock_run_abs_task.assert_called() assert mock_get_abs_input.await_count == len(nodes) -# Temporary only because the current runner does not support -# nontrivial task groups. -@pytest.mark.asyncio -async def test_submit_task_group_multiple(mocker): - """Check that submitting multiple tasks errors out""" - dispatch_id = "dispatch_1" - gid = 2 - nodes = [4, 3, 2] - - mock_get_abs_input = mocker.patch( - "covalent_dispatcher._core.dispatcher._get_abstract_task_inputs", - return_value={"args": [], "kwargs": {}}, - ) - - mock_attrs = { - "name": "task", - "value": 5, - "executor": "local", - "executor_data": {}, - } - - mock_statuses = [ - {"status": Result.NEW_OBJ}, - {"status": Result.NEW_OBJ}, - {"status": Result.NEW_OBJ}, - ] - - async def get_electron_attrs(dispatch_id, node_id, keys): - return {key: mock_attrs[key] for key in keys} - - mocker.patch( - "covalent_dispatcher._core.dispatcher.datasvc.electron.get", - get_electron_attrs, - ) - - mocker.patch( - "covalent_dispatcher._core.dispatcher.datasvc.electron.get_bulk", - return_value=mock_statuses, - ) - - mocker.patch( - "covalent_dispatcher._core.dispatcher.datasvc.update_node_result", - ) - - # This will be removed in the next patch - mock_run_abs_task = mocker.patch( - "covalent_dispatcher._core.dispatcher.runner.run_abstract_task", - ) - # mock_run_abs_task = mocker.patch( - # "covalent_dispatcher._core.dispatcher.runner_ng.run_abstract_task_group", - # ) - - with pytest.raises(RuntimeError): - await _submit_task_group(dispatch_id, nodes, gid) - - @pytest.mark.asyncio async def test_submit_task_group_skips_reusable(mocker): """Check that submit_task_group skips reusable groups""" @@ -666,13 +605,9 @@ async def get_electron_attrs(dispatch_id, node_id, keys): "covalent_dispatcher._core.dispatcher.datasvc.update_node_result", ) - # Will be removed next patch mock_run_abs_task = mocker.patch( - "covalent_dispatcher._core.dispatcher.runner.run_abstract_task", + "covalent_dispatcher._core.dispatcher.runner_ng.run_abstract_task_group", ) - # mock_run_abs_task = mocker.patch( - # "covalent_dispatcher._core.dispatcher.runner_ng.run_abstract_task_group", - # ) await _submit_task_group(dispatch_id, nodes, gid) mock_run_abs_task.assert_not_called() @@ -706,14 +641,9 @@ async def get_electron_attrs(dispatch_id, node_id, keys): "covalent_dispatcher._core.dispatcher.datasvc.update_node_result", ) - # Will be removed next patch mock_run_abs_task = mocker.patch( - "covalent_dispatcher._core.dispatcher.runner.run_abstract_task", + "covalent_dispatcher._core.dispatcher.runner_ng.run_abstract_task_group", ) - # mock_run_abs_task = mocker.patch( - # "covalent_dispatcher._core.dispatcher.runner_ng.run_abstract_task_group", - # ) - await _submit_task_group(dispatch_id, [node_id], node_id) mock_run_abs_task.assert_not_called() @@ -763,8 +693,7 @@ async def test_cancel_dispatch(mocker): "covalent_dispatcher._core.dispatcher.jbmgr.set_cancel_requested" ) - mock_runner = mocker.patch("covalent_dispatcher._core.dispatcher.runner") - # mock_runner = mocker.patch("covalent_dispatcher._core.dispatcher.runner_ng") + mock_runner = mocker.patch("covalent_dispatcher._core.dispatcher.runner_ng") mock_runner.cancel_tasks = AsyncMock() res._initialize_nodes() @@ -830,8 +759,7 @@ async def test_cancel_dispatch_with_task_ids(mocker): "covalent_dispatcher._core.dispatcher.jbmgr.set_cancel_requested" ) - mock_runner = mocker.patch("covalent_dispatcher._core.dispatcher.runner") - # mock_runner = mocker.patch("covalent_dispatcher._core.dispatcher.runner_ng") + mock_runner = mocker.patch("covalent_dispatcher._core.dispatcher.runner_ng") mock_runner.cancel_tasks = AsyncMock() async def mock_get_nodes(dispatch_id):