diff --git a/covalent/_shared_files/util_classes.py b/covalent/_shared_files/util_classes.py index 92917fdfb..e174dbd36 100644 --- a/covalent/_shared_files/util_classes.py +++ b/covalent/_shared_files/util_classes.py @@ -61,6 +61,7 @@ class RESULT_STATUS: RUNNING = Status("RUNNING") CANCELLED = Status("CANCELLED") DISPATCHING = Status("DISPATCHING") + DISPATCHING_SUBLATTICE = Status("DISPATCHING") @staticmethod def is_terminal(status): diff --git a/covalent_dispatcher/_core/data_manager.py b/covalent_dispatcher/_core/data_manager.py index 0f015de00..f859c56a6 100644 --- a/covalent_dispatcher/_core/data_manager.py +++ b/covalent_dispatcher/_core/data_manager.py @@ -35,7 +35,7 @@ from covalent._workflow.lattice import Lattice from covalent._workflow.transport_graph_ops import TransportGraphOps -from .._db import load, update, upsert +from .._db import load, update from .._db.write_result_to_db import resolve_electron_id app_log = logger.app_log @@ -359,7 +359,7 @@ def get_status_queue(dispatch_id: str): async def persist_result(dispatch_id: str): result_object = get_result_object(dispatch_id) - update.persist(result_object) + upsert_lattice_data(result_object.dispatch_id) await _update_parent_electron(result_object) @@ -387,4 +387,6 @@ async def _update_parent_electron(result_object: Result): def upsert_lattice_data(dispatch_id: str): result_object = get_result_object(dispatch_id) - upsert.lattice_data(result_object) + # Redirect to new DAL -- this is a temporary fix as + # upsert_lattice_data will be obsoleted next by the next patch. + update.lattice_data(result_object) diff --git a/covalent_dispatcher/_core/dispatcher.py b/covalent_dispatcher/_core/dispatcher.py index 0d853c458..f6b8a30a6 100644 --- a/covalent_dispatcher/_core/dispatcher.py +++ b/covalent_dispatcher/_core/dispatcher.py @@ -29,7 +29,7 @@ from covalent._results_manager import Result from covalent._shared_files import logger -from covalent._shared_files.defaults import parameter_prefix +from covalent._shared_files.defaults import WAIT_EDGE_NAME, parameter_prefix from covalent._shared_files.util_classes import RESULT_STATUS from covalent_ui import result_webhook @@ -72,7 +72,7 @@ def _get_abstract_task_inputs(node_id: int, node_name: str, result_object: Resul edge_data = result_object.lattice.transport_graph.get_edge_data(parent, node_id) for _, d in edge_data.items(): - if not d.get("wait_for"): + if d["edge_name"] != WAIT_EDGE_NAME: if d["param_type"] == "arg": abstract_task_input["args"].append((parent, d["arg_index"])) elif d["param_type"] == "kwarg": @@ -248,7 +248,9 @@ async def _run_planned_workflow(result_object: Result, status_queue: asyncio.Que while unresolved_tasks > 0: app_log.debug(f"{tasks_left} tasks left to complete.") - app_log.debug(f"Waiting to hear from {unresolved_tasks} tasks.") + app_log.debug( + f"{result_object.dispatch_id}: Waiting to hear from {unresolved_tasks} tasks." + ) node_id, node_status, detail = await status_queue.get() diff --git a/covalent_dispatcher/_db/update.py b/covalent_dispatcher/_db/update.py index 7a1b29c81..55db44663 100644 --- a/covalent_dispatcher/_db/update.py +++ b/covalent_dispatcher/_db/update.py @@ -19,15 +19,19 @@ # Relief from the License may be granted by purchasing a commercial license. import os +from datetime import datetime from pathlib import Path -from typing import Union +from typing import Any, Union from covalent._results_manager import Result from covalent._shared_files import logger from covalent._shared_files.config import get_config +from covalent._shared_files.defaults import postprocess_prefix +from covalent._shared_files.util_classes import Status from covalent._workflow.lattice import Lattice from covalent._workflow.transport import _TransportGraph +from .._dal.result import get_result_object from . import upsert app_log = logger.app_log @@ -57,3 +61,90 @@ def _initialize_results_dir(result): f"{result.dispatch_id}", ) Path(result_folder_path).mkdir(parents=True, exist_ok=True) + + +# Temporary implementation using new DAL. Will be removed in the next +# patch which transitions core covalent to the new DAL. +def _node( + result, + node_id: int, + node_name: str = None, + start_time: datetime = None, + end_time: datetime = None, + status: "Status" = None, + output: Any = None, + error: Exception = None, + stdout: str = None, + stderr: str = None, + sub_dispatch_id=None, + sublattice_result=None, +) -> bool: + """ + Update the node result in the transport graph. + Called after any change in node's execution state. + + Args: + node_id: The node id. + node_name: The name of the node. + start_time: The start time of the node execution. + end_time: The end time of the node execution. + status: The status of the node execution. + output: The output of the node unless error occured in which case None. + error: The error of the node if occured else None. + stdout: The stdout of the node execution. + stderr: The stderr of the node execution. + + Returns: + True/False indicating whether the update succeeded + """ + + # Update the in-memory result object + result._update_node( + node_id=node_id, + node_name=node_name, + start_time=start_time, + end_time=end_time, + status=status, + output=output, + error=error, + stdout=stdout, + stderr=stderr, + sub_dispatch_id=sub_dispatch_id, + sublattice_result=sublattice_result, + ) + + # Write out update to persistent storage + srvres = get_result_object(result.dispatch_id, bare=True) + srvres._update_node( + node_id=node_id, + node_name=node_name, + start_time=start_time, + end_time=end_time, + status=status, + output=output, + error=error, + stdout=stdout, + stderr=error, + ) + + if node_name.startswith(postprocess_prefix) and end_time is not None: + app_log.warning( + f"Persisting postprocess result {output.get_deserialized()}, node_name: {node_name}" + ) + result._result = output + result._status = status + result._end_time = end_time + lattice_data(result) + + +# Temporary implementation of upsert.lattice_data using the new DAL. +# Will be removed in the next patch which transitions core covalent to +# the new DAL. +def lattice_data(result_object: Result) -> None: + srv_res = get_result_object(result_object.dispatch_id, bare=True) + srv_res._update_dispatch( + result_object.start_time, + result_object.end_time, + result_object.status, + result_object.error, + ) diff --git a/tests/covalent_dispatcher_tests/_core/tmp_data_manager_test.py b/tests/covalent_dispatcher_tests/_core/tmp_data_manager_test.py index 0f0e02c1a..9df2e16f6 100644 --- a/tests/covalent_dispatcher_tests/_core/tmp_data_manager_test.py +++ b/tests/covalent_dispatcher_tests/_core/tmp_data_manager_test.py @@ -473,11 +473,13 @@ async def test_persist_result(mocker): mock_update_parent = mocker.patch( "covalent_dispatcher._core.data_manager._update_parent_electron" ) - mock_persist = mocker.patch("covalent_dispatcher._core.data_manager.update.persist") + mock_update_lattice = mocker.patch( + "covalent_dispatcher._core.data_manager.update.lattice_data" + ) await persist_result(result_object.dispatch_id) mock_update_parent.assert_awaited_with(result_object) - mock_persist.assert_called_with(result_object) + mock_update_lattice.assert_called_with(result_object) @pytest.mark.parametrize( diff --git a/tests/covalent_dispatcher_tests/_core/tmp_dispatcher_test.py b/tests/covalent_dispatcher_tests/_core/tmp_dispatcher_test.py index 79b3c77c5..ec7016c4b 100644 --- a/tests/covalent_dispatcher_tests/_core/tmp_dispatcher_test.py +++ b/tests/covalent_dispatcher_tests/_core/tmp_dispatcher_test.py @@ -305,13 +305,16 @@ async def test_run_workflow_normal(mocker): mocker.patch( "covalent_dispatcher._core.dispatcher._run_planned_workflow", return_value=result_object ) - mock_persist = mocker.patch("covalent_dispatcher._core.dispatcher.datasvc.upsert_lattice_data") + mock_get_result_object = mocker.patch( + "covalent_dispatcher._core.data_manager.get_result_object", return_value=result_object + ) + mock_upsert = mocker.patch("covalent_dispatcher._core.dispatcher.datasvc.upsert_lattice_data") mock_unregister = mocker.patch( "covalent_dispatcher._core.dispatcher.datasvc.finalize_dispatch" ) await run_workflow(result_object) - mock_persist.assert_called_with(result_object) + mock_upsert.assert_called_with(result_object.dispatch_id) mock_unregister.assert_called_with(result_object.dispatch_id) @@ -366,12 +369,15 @@ async def test_run_workflow_exception(mocker): return_value=result_object, side_effect=RuntimeError("Error"), ) - mock_persist = mocker.patch("covalent_dispatcher._core.dispatcher.datasvc.upsert_lattice_data") + mock_get_result_object = mocker.patch( + "covalent_dispatcher._core.data_manager.get_result_object", return_value=result_object + ) + mock_upsert = mocker.patch("covalent_dispatcher._core.dispatcher.datasvc.upsert_lattice_data") result = await run_workflow(result_object) assert result.status == Result.FAILED - mock_persist.assert_called_with(result_object) + mock_upsert.assert_called_with(result_object.dispatch_id) mock_unregister.assert_called_with(result_object.dispatch_id) diff --git a/tests/covalent_dispatcher_tests/_core/tmp_execution_test.py b/tests/covalent_dispatcher_tests/_core/tmp_execution_test.py index 7f7c6f458..d8d376c0f 100644 --- a/tests/covalent_dispatcher_tests/_core/tmp_execution_test.py +++ b/tests/covalent_dispatcher_tests/_core/tmp_execution_test.py @@ -201,10 +201,10 @@ def workflow(x): result_object._root_dispatch_id = dispatch_id result_object._initialize_nodes() - # patch all methods that reference a DB - mocker.patch("covalent_dispatcher._db.upsert._lattice_data") - mocker.patch("covalent_dispatcher._db.upsert._electron_data") - mocker.patch("covalent_dispatcher._db.update.persist") + mocker.patch("covalent_dispatcher._db.datastore.workflow_db", test_db) + mocker.patch("covalent_dispatcher._db.upsert.workflow_db", test_db) + mocker.patch("covalent_dispatcher._dal.base.workflow_db", test_db) + mocker.patch( "covalent._results_manager.result.Result._get_node_name", return_value="failing_task" ) @@ -296,7 +296,7 @@ def workflow(x): @pytest.mark.asyncio -async def test_run_workflow_does_not_deserialize(mocker): +async def test_run_workflow_does_not_deserialize(test_db, mocker): """Check that dispatcher does not deserialize user data when using out-of-process `workflow_executor`""" @@ -319,9 +319,10 @@ def workflow(x): result_object = Result(lattice, dispatch_id=dispatch_id) result_object._initialize_nodes() - mocker.patch("covalent_dispatcher._db.upsert._lattice_data") - mocker.patch("covalent_dispatcher._db.upsert._electron_data") - mocker.patch("covalent_dispatcher._db.update.persist") + mocker.patch("covalent_dispatcher._db.datastore.workflow_db", test_db) + mocker.patch("covalent_dispatcher._db.upsert.workflow_db", test_db) + mocker.patch("covalent_dispatcher._dal.base.workflow_db", test_db) + mock_unregister = mocker.patch( "covalent_dispatcher._core.dispatcher.datasvc.finalize_dispatch" )